define_model#
This module defines Zoobot’s components.
get_pytorch_encoder()
and get_pytorch_dirichlet_head()
define the encoder and head, respectively, in PyTorch.
zoobot.pytorch.estimators.define_model.ZoobotTree
wraps these components in a PyTorch LightningModule describing how to train them.
- class zoobot.pytorch.estimators.define_model.ZoobotTree(output_dim: int, question_answer_pairs: Optional[dict] = None, dependencies: Optional[dict] = None, architecture_name='efficientnet_b0', channels=1, test_time_dropout=True, compile_encoder=False, timm_kwargs={}, dropout_rate=0.2, learning_rate=0.001, betas=(0.9, 0.999), weight_decay=0.01, scheduler_params={})#
The Zoobot model. Train from scratch using
zoobot.pytorch.training.train_with_pytorch_lightning.train_default_zoobot_from_scratch()
.PyTorch LightningModule describing how to train the encoder and head (described below). Trains using Dirichlet loss. Labels should be num. volunteers giving each answer to each question.
See the code for exact training step, logging, etc - there’s a lot of detail.
- Parameters
output_dim (int) – Output dimension of model’s head e.g. 34 for predicting a 34-answer decision tree.
architecture_name (str, optional) – Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to “efficientnet_b0”.
channels (int, optional) – Num. input channels. Probably 3 or 1. Defaults to 1.
test_time_dropout (bool, optional) – Apply dropout at test time, to pretend to be Bayesian. Defaults to True.
timm_kwargs (dict, optional) – passed to timm.create_model e.g. drop_path_rate=0.2 for effnet. Defaults to {}.
learning_rate (float, optional) – AdamW learning rate. Defaults to 1e-3.
- zoobot.pytorch.estimators.define_model.get_pytorch_encoder(architecture_name='efficientnet_b0', channels=1, **timm_kwargs) Module #
Create a trainable efficientnet model. First layers are galaxy-appropriate augmentation layers - see
zoobot.estimators.define_model.add_augmentation_layers()
. Expects single channel image e.g. (300, 300, 1), likely with leading batch dimension.Optionally (by default) include the head (output layers) used for GZ DECaLS. Specifically, global average pooling followed by a dense layer suitable for predicting dirichlet parameters. See
efficientnet_custom.custom_top_dirichlet
- Parameters
output_dim (int) – Dimension of head dense layer. No effect when include_top=False.
input_size (int) – Length of initial image e.g. 300 (asmeaned square)
crop_size (int) – Length to randomly crop image. See
zoobot.estimators.define_model.add_augmentation_layers()
.resize_size (int) – Length to resize image. See
zoobot.estimators.define_model.add_augmentation_layers()
.weights_loc (str, optional) – If str, load weights from efficientnet checkpoint at this location. Defaults to None.
include_top (bool, optional) – If True, include head used for GZ DECaLS: global pooling and dense layer. Defaults to True.
expect_partial (bool, optional) – If True, do not raise partial match error when loading weights (likely for optimizer state). Defaults to False.
channels (int, default 1) – Number of channels i.e. C in NHWC-dimension inputs.
- Returns
trainable efficientnet model including augmentations and optional head
- Return type
torch.nn.Sequential
- zoobot.pytorch.estimators.define_model.get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_dropout: bool, dropout_rate: float) Sequential #
Head to combine with encoder (above) when predicting Galaxy Zoo decision tree answers. Pytorch Sequential model. Predicts Dirichlet concentration parameters.
Also used when finetuning on a new decision tree - see
zoobot.pytorch.training.finetune.FinetuneableZoobotTree
.- Parameters
encoder_dim (int) – dimensions of preceding encoder i.e. the input size expected by this submodel.
output_dim (int) – output dimensions of this head e.g. 34 to predict 34 answers.
test_time_dropout (bool) – Use dropout at test time.
dropout_rate (float) – P of dropout. See torch.nn.Dropout docs.
- Returns
pytorch model expecting encoder_dim vector and predicting output_dim decision tree answers.
- Return type
torch.nn.Sequential