define_model#
This module defines Zoobot’s components.
get_pytorch_encoder() and get_pytorch_dirichlet_head() define the encoder and head, respectively, in PyTorch.
zoobot.pytorch.estimators.define_model.ZoobotTree wraps these components in a PyTorch LightningModule describing how to train them.
- class zoobot.pytorch.estimators.define_model.ZoobotTree(output_dim: int, question_answer_pairs: dict = None, dependencies: dict = None, architecture_name='convnext_nano', channels=3, test_time_dropout=False, compile_encoder=False, timm_kwargs={}, dropout_rate=0.2, learning_rate=0.001, betas=(0.9, 0.999), weight_decay=0.01, scheduler_params={})#
The Zoobot model. Train from scratch using
zoobot.pytorch.training.train_with_pytorch_lightning.train_default_zoobot_from_scratch().PyTorch LightningModule describing how to train the encoder and head (described below). Trains using Dirichlet loss. Labels should be num. volunteers giving each answer to each question.
Most Zoobot users won’t use this training code directly - you probably just want the pretrained model for finetuning. I will likely continue splitting the pretraining into other repos, and leave Zoobot for finetuning. But it’s here as a reference.
See the code for exact training step, logging, etc - there’s a lot of detail.
- Parameters
output_dim (int) – Output dimension of model’s head e.g. 34 for predicting a 34-answer decision tree.
question_answer_pairs (dict, optional) – Dictionary mapping questions to answers, e.g. {‘smooth-or-featured’: [‘smooth’, ‘featured’], …}. See schemas.py. Defaults to None.
dependencies (dict, optional) – Dictionary mapping questions to their dependencies, e.g. {‘disk-edge-on’: ‘smooth-or-featured_featured’}. See schemas.py Defaults to None.
architecture_name (str, optional) – Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to “convnext_nano”.
channels (int, optional) – Num. input channels. Probably 3 or 1. Defaults to 3.
test_time_dropout (bool, optional) – Apply dropout at test time, to pretend to be Bayesian. Defaults to False.
compile_encoder (bool, optional) – Compile the encoder with torch.compile. Defaults to False.
timm_kwargs (dict, optional) – passed to timm.create_model e.g. drop_path_rate=0.2 for effnet. Defaults to {}.
dropout_rate (float, optional) – Dropout rate for the head. Defaults to 0.2.
learning_rate (float, optional) – Learning rate for AdamW. Defaults to 1e-3.
betas (tuple, optional) – AdamW betas. Defaults to (0.9, 0.999).
weight_decay (float, optional) – AdamW weight decay. Defaults to 0.01.
scheduler_params (dict, optional) – kwargs to pass to the scheduler. If empty, no scheduler is used. Defaults to {}.
- zoobot.pytorch.estimators.define_model.get_pytorch_encoder(architecture_name='convnext_nano', channels=3, **timm_kwargs) Module#
Create a trainable timm model (convnext_nano by default). Wrapper for timm.create_model.
- Parameters
architecture_name (str) – Name of the timm architecture to use.
channels (int) – Number of input channels (e.g. 3 for RGB images).
**timm_kwargs – Additional keyword arguments to pass to timm.create_model.
- Returns
A timm PyTorch model with the specified architecture and input channels.
- Return type
nn.Module
- zoobot.pytorch.estimators.define_model.get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_dropout: bool, dropout_rate: float) Sequential#
Head to combine with encoder (above) when predicting Galaxy Zoo decision tree answers. Pytorch Sequential model. Predicts Dirichlet concentration parameters.
Also used when finetuning on a new decision tree - see
zoobot.pytorch.training.finetune.FinetuneableZoobotTree.- Parameters
encoder_dim (int) – dimensions of preceding encoder i.e. the input size expected by this submodel.
output_dim (int) – output dimensions of this head e.g. 34 to predict 34 answers.
test_time_dropout (bool) – Use dropout at test time.
dropout_rate (float) – P of dropout. See torch.nn.Dropout docs.
- Returns
pytorch model expecting encoder_dim vector and predicting output_dim decision tree answers.
- Return type
torch.nn.Sequential