finetune#
Use these classes and methods to finetune a pretrained Zoobot model.
See the README for a minimal example. See zoobot/pytorch/examples for more worked examples.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract(name=None, encoder=None, zoobot_checkpoint_loc=None, n_blocks=0, lr_decay=0.75, weight_decay=0.05, learning_rate=0.0001, dropout_prob=0.5, always_train_batchnorm=False, cosine_schedule=False, warmup_epochs=0, max_cosine_epochs=100, max_learning_rate_reduction_factor=0.01, from_scratch=False, prog_bar=True, visualize_images=False, seed=42, n_layers=None)#
Parent class of
FinetuneableZoobotClassifier
,FinetuneableZoobotRegressor
,FinetuneableZoobotTree
. You cannot use this class directly - you must use the child classes above instead.This class defines the shared finetuning args and methods used by those child classes. For example: - When provided name, it will load the HuggingFace encoder with that name (see below for more). - When provided learning_rate it will set the optimizer to use that learning rate.
- Any FinetuneableZoobot model can be loaded in one of three ways:
HuggingFace name e.g. FinetuneableZoobotX(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). Recommended.
Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, …)
ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc=’path/to/zoobot_tree.ckpt’, …)
You could subclass this class to solve new finetuning tasks - see Advanced Finetuning.
- Parameters
name (str, optional) – Name of a model on HuggingFace Hub e.g.’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’. Defaults to None.
encoder (torch.nn.Module, optional) – A PyTorch model already loaded in memory
zoobot_checkpoint_loc (str, optional) – Path to ZoobotTree lightning checkpoint to load. Loads with Load with
zoobot.pytorch.training.finetune.load_pretrained_zoobot()
. Defaults to None.n_blocks (int, optional) –
lr_decay (float, optional) – For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.
weight_decay (float, optional) – AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.
learning_rate (float, optional) – AdamW learning rate arg. Defaults to 1e-4.
dropout_prob (float, optional) – P of dropout before final output layer. Defaults to 0.5.
always_train_batchnorm (bool, optional) – Temporarily deprecated. Previously, if True, do not update batchnorm stats during finetuning. Defaults to True.
cosine_schedule (bool, optional) – Reduce the learning rate each epoch according to a cosine schedule, after warmup_epochs. Defaults to False.
warmup_epochs (int, optional) – Linearly increase the learning rate from 0 to
learning_rate
over the firstwarmup_epochs
epochs, before applying cosine schedule. No effect if cosine_schedule=False.max_cosine_epochs (int, optional) – Epochs for the scheduled learning rate to decay to final learning rate (below). Warmup epochs don’t count. No effect if
cosine_schedule=False
.max_learning_rate_reduction_factor (float, optional) – Set final learning rate as
learning_rate
*max_learning_rate_reduction_factor
. No effect ifcosine_schedule=False
.from_scratch (bool, optional) – Ignore all settings above and train from scratch at
learning_rate
for all layers. Useful for a quick baseline. Defaults to False.prog_bar (bool, optional) – Print progress bar during finetuning. Defaults to True.
visualize_images (bool, optional) – Upload example images to WandB. Good for debugging but slow. Defaults to False.
seed (int, optional) – random seed to use. Defaults to 42.
n_layers – No effect, deprecated. Use n_blocks instead.
- configure_optimizers()#
This controls which parameters get optimized
self.head is always optimized, with no learning rate decay when self.n_blocks == 0, only self.head is optimized (i.e. frozen* encoder)
for self.encoder, we enumerate the blocks (groups of layers) to potentially finetune and then pick the top self.n_blocks to finetune
weight_decay is applied to both the head and (if relevant) the encoder learning rate decay is applied to the encoder only: lr x (lr_decay^block_n), ignoring the head (block 0)
What counts as a “block” is a bit fuzzy, but I generally use the self.encoder.stages from timm. I also count the stem as a block.
batch norm layers may optionally still have updated statistics using always_train_batchnorm
- class zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier(num_classes: int, label_smoothing=0.0, class_weights=None, **super_kwargs)#
Pretrained Zoobot model intended for finetuning on a classification problem.
Any args not listed below are passed to :class:
FinetuneableZoobotAbstract
(for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of :class:FinetuneableZoobotAbstract
for more.Models can be loaded with FinetuneableZoobotClassifier(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). See :class:
FinetuneableZoobotAbstract
for other loading options (e.g. in-memory models or local checkpoints).- Parameters
num_classes (int) – num. of target classes (e.g. 2 for binary classification).
label_smoothing (float, optional) – See torch cross_entropy_loss docs. Defaults to 0.
class_weights (arraylike, optional) – See torch cross_entropy_loss docs. Defaults to None.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor(loss: str = 'mse', unit_interval: bool = False, **super_kwargs)#
Pretrained Zoobot model intended for finetuning on a regression problem.
Any args not listed below are passed to :class:
FinetuneableZoobotAbstract
(for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of :class:FinetuneableZoobotAbstract
for more.Models can be loaded with FinetuneableZoobotRegressor(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). See :class:
FinetuneableZoobotAbstract
for other loading options (e.g. in-memory models or local checkpoints).- Parameters
loss (str, optional) – Loss function to use. Must be one of ‘mse’, ‘mae’. Defaults to ‘mse’.
unit_interval (bool, optional) – If True, use sigmoid activation for the final layer, ensuring predictions between 0 and 1. Defaults to False.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotTree(schema: Schema, **super_kwargs)#
Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. Uses Dirichlet-Multinomial loss introduced in GZ DECaLS. Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer, and uses the Dirichlet-Multinomial loss to compare the predicted distribution of votes (given k volunteers were asked) to the true distribution.
Does not produce accuracy or MSE metrics, as these are not relevant for this task. Loss logging only.
If you’re using this, you’re probably working on a Galaxy Zoo catalog, and you should Slack Mike!
Any args not listed below are passed to :class:
FinetuneableZoobotAbstract
(for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of :class:FinetuneableZoobotAbstract
for more.Models can be loaded with FinetuneableZoobotTree(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). See :class:
FinetuneableZoobotAbstract
for other loading options (e.g. in-memory models or local checkpoints).- Parameters
schema (schemas.Schema) – description of the layout of the decision tree. See
zoobot.shared.schemas.Schema
.
- class zoobot.pytorch.training.finetune.LinearHead(input_dim: int, output_dim: int, dropout_prob=0.5, activation=None)#
- forward(x)#
returns logits, as recommended for CrossEntropy loss
- Parameters
x (torch.Tensor) – encoded representation
- Returns
result (see docstring of LinearHead)
- Return type
torch.Tensor
- zoobot.pytorch.training.finetune.load_pretrained_zoobot(checkpoint_loc: str) Module #
- Parameters
checkpoint_loc (str) – path to saved LightningModule checkpoint, likely of
ZoobotTree
,FinetuneableZoobotClassifier
, orFinetunabelZoobotTree
. Must have .zoobot attribute.- Returns
pretrained PyTorch encoder within that LightningModule.
- Return type
torch.nn.Module
- zoobot.pytorch.training.finetune.get_trainer(save_dir: str, file_template='{epoch}', save_top_k=1, max_epochs=100, patience=10, devices='auto', accelerator='auto', logger=None, **trainer_kwargs) Trainer #
Convenience wrapper to create a PyTorch Lightning Trainer that carries out the finetuning process. Use like so: trainer.fit(model, datamodule)
get_trainer args are for common Trainer settings e.g. early stopping checkpointing, etc. By default: - Saves the top-k models based on validation loss - Uses early stopping with patience i.e. end training if validation loss does not improve after patience epochs. - Monitors the learning rate (useful when using a learning rate scheduler)
Any extra args not listed below are passed directly to the PyTorch Lightning Trainer. Use this to add any custom configuration not covered by the get_trainer args. See https://lightning.ai/docs/pytorch/stable/common/trainer.html
- Parameters
save_dir (str) – folder in which to save checkpoints and logs.
file_template (str, optional) – custom naming for checkpoint files. See Lightning docs. Defaults to “{epoch}”.
save_top_k (int, optional) – save the top k checkpoints only. Defaults to 1.
max_epochs (int, optional) – train for up to this many epochs. Defaults to 100.
patience (int, optional) – wait up to this many epochs for decreasing loss before ending training. Defaults to 10.
devices (str, optional) – number of devices for training (typically, num. GPUs). Defaults to ‘auto’.
accelerator (str, optional) – which device to use (typically ‘gpu’ or ‘cpu’). Defaults to ‘auto’.
logger (pl.loggers.wandb_logger, optional) – If pl.loggers.wandb_logger, track experiment on Weights and Biases. Defaults to None.
- Returns
PyTorch Lightning trainer object for finetuning a model on a GalaxyDataModule.
- Return type
pl.Trainer
- zoobot.pytorch.training.finetune.download_from_name(class_name: str, hub_name: str)#
Download a finetuned model from the HuggingFace Hub by name. Used to load pretrained Zoobot models by name, e.g. FinetuneableZoobotClassifier(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …).
Downloaded models are saved to the HuggingFace cache directory for later use (typically ~/.cache/huggingface).
You shouldn’t need to call this; it’s used internally by the FinetuneableZoobot classes.
- Parameters
class_name (str) – one of FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree
hub_name (str) – e.g. mwalmsley/zoobot-encoder-convnext_nano
- Returns
path to downloaded model (in HuggingFace cache directory). Likely then loaded by Lightning.
- Return type
str