finetune#

Use these classes and methods to finetune a pretrained Zoobot model.

See the README for a minimal example. See zoobot/pytorch/examples for more worked examples.

class zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract(name=None, encoder=None, zoobot_checkpoint_loc=None, n_blocks=0, lr_decay=0.75, weight_decay=0.05, learning_rate=0.0001, dropout_prob=0.5, always_train_batchnorm=False, cosine_schedule=False, warmup_epochs=0, max_cosine_epochs=100, max_learning_rate_reduction_factor=0.01, from_scratch=False, prog_bar=True, visualize_images=False, seed=42, n_layers=None)#

Parent class of FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree. You cannot use this class directly - you must use the child classes above instead.

This class defines the shared finetuning args and methods used by those child classes. For example: - When provided name, it will load the HuggingFace encoder with that name (see below for more). - When provided learning_rate it will set the optimizer to use that learning rate.

Any FinetuneableZoobot model can be loaded in one of three ways:
  • HuggingFace name e.g. FinetuneableZoobotX(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). Recommended.

  • Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, …)

  • ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc=’path/to/zoobot_tree.ckpt’, …)

You could subclass this class to solve new finetuning tasks - see Advanced Finetuning.

Parameters
  • name (str, optional) – Name of a model on HuggingFace Hub e.g.’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’. Defaults to None.

  • encoder (torch.nn.Module, optional) – A PyTorch model already loaded in memory

  • zoobot_checkpoint_loc (str, optional) – Path to ZoobotTree lightning checkpoint to load. Loads with Load with zoobot.pytorch.training.finetune.load_pretrained_zoobot(). Defaults to None.

  • n_blocks (int, optional) –

  • lr_decay (float, optional) – For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.

  • weight_decay (float, optional) – AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.

  • learning_rate (float, optional) – AdamW learning rate arg. Defaults to 1e-4.

  • dropout_prob (float, optional) – P of dropout before final output layer. Defaults to 0.5.

  • always_train_batchnorm (bool, optional) – Temporarily deprecated. Previously, if True, do not update batchnorm stats during finetuning. Defaults to True.

  • cosine_schedule (bool, optional) – Reduce the learning rate each epoch according to a cosine schedule, after warmup_epochs. Defaults to False.

  • warmup_epochs (int, optional) – Linearly increase the learning rate from 0 to learning_rate over the first warmup_epochs epochs, before applying cosine schedule. No effect if cosine_schedule=False.

  • max_cosine_epochs (int, optional) – Epochs for the scheduled learning rate to decay to final learning rate (below). Warmup epochs don’t count. No effect if cosine_schedule=False.

  • max_learning_rate_reduction_factor (float, optional) – Set final learning rate as learning_rate * max_learning_rate_reduction_factor. No effect if cosine_schedule=False.

  • from_scratch (bool, optional) – Ignore all settings above and train from scratch at learning_rate for all layers. Useful for a quick baseline. Defaults to False.

  • prog_bar (bool, optional) – Print progress bar during finetuning. Defaults to True.

  • visualize_images (bool, optional) – Upload example images to WandB. Good for debugging but slow. Defaults to False.

  • seed (int, optional) – random seed to use. Defaults to 42.

  • n_layers – No effect, deprecated. Use n_blocks instead.

configure_optimizers()#

This controls which parameters get optimized

self.head is always optimized, with no learning rate decay when self.n_blocks == 0, only self.head is optimized (i.e. frozen* encoder)

for self.encoder, we enumerate the blocks (groups of layers) to potentially finetune and then pick the top self.n_blocks to finetune

weight_decay is applied to both the head and (if relevant) the encoder learning rate decay is applied to the encoder only: lr x (lr_decay^block_n), ignoring the head (block 0)

What counts as a “block” is a bit fuzzy, but I generally use the self.encoder.stages from timm. I also count the stem as a block.

batch norm layers may optionally still have updated statistics using always_train_batchnorm


class zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier(num_classes: int, label_smoothing=0.0, class_weights=None, **super_kwargs)#

Pretrained Zoobot model intended for finetuning on a classification problem.

Any args not listed below are passed to :class:FinetuneableZoobotAbstract (for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of :class:FinetuneableZoobotAbstract for more.

Models can be loaded with FinetuneableZoobotClassifier(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). See :class:FinetuneableZoobotAbstract for other loading options (e.g. in-memory models or local checkpoints).

Parameters
  • num_classes (int) – num. of target classes (e.g. 2 for binary classification).

  • label_smoothing (float, optional) – See torch cross_entropy_loss docs. Defaults to 0.

  • class_weights (arraylike, optional) – See torch cross_entropy_loss docs. Defaults to None.


class zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor(loss: str = 'mse', unit_interval: bool = False, **super_kwargs)#

Pretrained Zoobot model intended for finetuning on a regression problem.

Any args not listed below are passed to :class:FinetuneableZoobotAbstract (for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of :class:FinetuneableZoobotAbstract for more.

Models can be loaded with FinetuneableZoobotRegressor(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). See :class:FinetuneableZoobotAbstract for other loading options (e.g. in-memory models or local checkpoints).

Parameters
  • loss (str, optional) – Loss function to use. Must be one of ‘mse’, ‘mae’. Defaults to ‘mse’.

  • unit_interval (bool, optional) – If True, use sigmoid activation for the final layer, ensuring predictions between 0 and 1. Defaults to False.


class zoobot.pytorch.training.finetune.FinetuneableZoobotTree(schema: Schema, **super_kwargs)#

Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. Uses Dirichlet-Multinomial loss introduced in GZ DECaLS. Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer, and uses the Dirichlet-Multinomial loss to compare the predicted distribution of votes (given k volunteers were asked) to the true distribution.

Does not produce accuracy or MSE metrics, as these are not relevant for this task. Loss logging only.

If you’re using this, you’re probably working on a Galaxy Zoo catalog, and you should Slack Mike!

Any args not listed below are passed to :class:FinetuneableZoobotAbstract (for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of :class:FinetuneableZoobotAbstract for more.

Models can be loaded with FinetuneableZoobotTree(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …). See :class:FinetuneableZoobotAbstract for other loading options (e.g. in-memory models or local checkpoints).

Parameters

schema (schemas.Schema) – description of the layout of the decision tree. See zoobot.shared.schemas.Schema.


class zoobot.pytorch.training.finetune.LinearHead(input_dim: int, output_dim: int, dropout_prob=0.5, activation=None)#
forward(x)#

returns logits, as recommended for CrossEntropy loss

Parameters

x (torch.Tensor) – encoded representation

Returns

result (see docstring of LinearHead)

Return type

torch.Tensor


zoobot.pytorch.training.finetune.load_pretrained_zoobot(checkpoint_loc: str) Module#
Parameters

checkpoint_loc (str) – path to saved LightningModule checkpoint, likely of ZoobotTree, FinetuneableZoobotClassifier, or FinetunabelZoobotTree. Must have .zoobot attribute.

Returns

pretrained PyTorch encoder within that LightningModule.

Return type

torch.nn.Module


zoobot.pytorch.training.finetune.get_trainer(save_dir: str, file_template='{epoch}', save_top_k=1, max_epochs=100, patience=10, devices='auto', accelerator='auto', logger=None, **trainer_kwargs) Trainer#

Convenience wrapper to create a PyTorch Lightning Trainer that carries out the finetuning process. Use like so: trainer.fit(model, datamodule)

get_trainer args are for common Trainer settings e.g. early stopping checkpointing, etc. By default: - Saves the top-k models based on validation loss - Uses early stopping with patience i.e. end training if validation loss does not improve after patience epochs. - Monitors the learning rate (useful when using a learning rate scheduler)

Any extra args not listed below are passed directly to the PyTorch Lightning Trainer. Use this to add any custom configuration not covered by the get_trainer args. See https://lightning.ai/docs/pytorch/stable/common/trainer.html

Parameters
  • save_dir (str) – folder in which to save checkpoints and logs.

  • file_template (str, optional) – custom naming for checkpoint files. See Lightning docs. Defaults to “{epoch}”.

  • save_top_k (int, optional) – save the top k checkpoints only. Defaults to 1.

  • max_epochs (int, optional) – train for up to this many epochs. Defaults to 100.

  • patience (int, optional) – wait up to this many epochs for decreasing loss before ending training. Defaults to 10.

  • devices (str, optional) – number of devices for training (typically, num. GPUs). Defaults to ‘auto’.

  • accelerator (str, optional) – which device to use (typically ‘gpu’ or ‘cpu’). Defaults to ‘auto’.

  • logger (pl.loggers.wandb_logger, optional) – If pl.loggers.wandb_logger, track experiment on Weights and Biases. Defaults to None.

Returns

PyTorch Lightning trainer object for finetuning a model on a GalaxyDataModule.

Return type

pl.Trainer


zoobot.pytorch.training.finetune.download_from_name(class_name: str, hub_name: str)#

Download a finetuned model from the HuggingFace Hub by name. Used to load pretrained Zoobot models by name, e.g. FinetuneableZoobotClassifier(name=’hf_hub:mwalmsley/zoobot-encoder-convnext_nano’, …).

Downloaded models are saved to the HuggingFace cache directory for later use (typically ~/.cache/huggingface).

You shouldn’t need to call this; it’s used internally by the FinetuneableZoobot classes.

Parameters
  • class_name (str) – one of FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree

  • hub_name (str) – e.g. mwalmsley/zoobot-encoder-convnext_nano

Returns

path to downloaded model (in HuggingFace cache directory). Likely then loaded by Lightning.

Return type

str