finetune#

Use these classes and methods to finetune a pretrained Zoobot model.

See the README for a minimal example. See zoobot/pytorch/examples for more worked examples.

class zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract(name=None, encoder=None, zoobot_checkpoint_loc=None, training_mode='full', layer_decay=0.75, weight_decay=0.05, learning_rate=0.0001, head_dropout_prob=0.5, scheduler_kwargs=None, timm_kwargs={}, greyscale=False, prog_bar=True, visualize_images=False, seed=42)#

Parent class of FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree. You cannot use this class directly - you must use the child classes above instead.

This class defines the shared finetuning args and methods used by those child classes. For example: * When provided name, it will load the HuggingFace encoder with that name (see below for more). * When provided learning_rate it will set the optimizer to use that learning rate.

Any FinetuneableZoobot model can be loaded in one of three ways: * HuggingFace name e.g. FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended. * Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, ...) * ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)

You could subclass this class to solve new finetuning tasks - see Advanced Finetuning.

Parameters
  • name (str, optional) – Name of a model on HuggingFace Hub e.g. hf_hub:mwalmsley/zoobot-encoder-convnext_nano. Defaults to None.

  • encoder (torch.nn.Module, optional) – Instead of name, use a PyTorch model already loaded in memory. Defaults to None.

  • zoobot_checkpoint_loc (str, optional) – Instead of name, use a path to ZoobotTree lightning checkpoint to load. Loads with zoobot.pytorch.training.finetune.load_pretrained_zoobot(). Defaults to None.

  • training_mode (str, optional) – 'full' to train all parameters, 'head_only' to freeze encoder and only train head. Defaults to 'full'.

  • layer_decay (float, optional) – For each layer below the head, reduce the learning rate by layer_decay ** i. Defaults to 0.75.

  • weight_decay (float, optional) – AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.

  • learning_rate (float, optional) – AdamW learning rate arg. Defaults to 1e-4.

  • head_dropout_prob (float, optional) – Probability of dropout before final output layer. Defaults to 0.5.

  • scheduler_kwargs (dict, optional) – Arguments for the optional learning rate scheduler. Defaults to None (no scheduler).

  • timm_kwargs (dict, optional) – Additional arguments for timm.create_model.

  • greyscale (bool, optional) – If True, convert model to single channel version (adds {'in_chans': 1} to timm kwargs). Defaults to False.

  • prog_bar (bool, optional) – Print progress bar during finetuning. Defaults to True.

  • visualize_images (bool, optional) – Upload example images to WandB. Good for debugging but potentially slow. Defaults to False.

  • seed (int, optional) – Random seed to use. Defaults to 42.

configure_optimizers()#

Sets up the optimizer and, optionally, a learning rate scheduler.

When self.training_mode == 'head_only', only self.head is optimized (i.e. frozen encoder, linear finetuning). When self.training_mode == 'full', all parameters are optimized.

Learning rate decay is applied to the encoder only. Counterintuitively, a higher learning rate decay value (e.g. 0.9) causes less reduction in the learning rate: the learning rate is (from the top encoder layer down) lr, lr * layer_decay, lr * layer_decay**2, … I use timm’s definition of layers, which groups some torch layers together.

Weight decay (aka L2 regularization, penalizing large weights) is applied to both the head and (if relevant) the encoder.

For schedulers, I use the timm scheduler factory. See https://github.com/rwightman/timm/blob/main/timm/scheduler/scheduler_factory.py#L63. self.scheduler_kwargs (passed to the factory) should be a dict with the scheduler name and any additional args, e.g. {'name': 'cosine', 'warmup_epochs': 5, 'max_epochs': 100}.


class zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier(num_classes: int, label_col: str = 'label', label_smoothing=0.0, class_weights=None, run_linear_sanity_check: bool = False, **super_kwargs)#

Pretrained Zoobot model intended for finetuning on a classification problem.

Any args not listed below are passed to FinetuneableZoobotAbstract (for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of FinetuneableZoobotAbstract for more.

Models can be loaded with FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). See FinetuneableZoobotAbstract for other loading options (e.g. in-memory models or local checkpoints).

Parameters
  • label_col (str, optional) – Name of the column in the batch dict (e.g. a column in your dataframe) containing the labels. Defaults to 'label'.

  • num_classes (int) – Number of target classes (e.g. 2 for binary classification).

  • label_smoothing (float, optional) – See torch cross_entropy_loss docs. Defaults to 0.

  • class_weights (arraylike, optional) – See torch cross_entropy_loss docs. Defaults to None.

  • run_linear_sanity_check (bool, optional) – Before fitting, use sklearn to fit a linear model. Defaults to False.


class zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor(label_col: str = 'label', loss: str = 'mse', unit_interval: bool = False, **super_kwargs)#

Pretrained Zoobot model intended for finetuning on a regression problem.

Any args not listed below are passed to FinetuneableZoobotAbstract (for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of FinetuneableZoobotAbstract for more.

Models can be loaded with FinetuneableZoobotRegressor(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). See FinetuneableZoobotAbstract for other loading options (e.g. in-memory models or local checkpoints).

Parameters
  • label_col (str, optional) – Name of the column in the batch dict (e.g. a column in your dataframe) containing the labels. Defaults to 'label'.

  • loss (str, optional) – Loss function to use. Must be one of 'mse', 'mae'. Defaults to 'mse'.

  • unit_interval (bool, optional) – If True, use sigmoid activation for the final layer, ensuring predictions between 0 and 1. Defaults to False.


class zoobot.pytorch.training.finetune.FinetuneableZoobotTree(schema: Schema, **super_kwargs)#

Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. Uses Dirichlet-Multinomial loss introduced in GZ DECaLS. Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer, and uses the Dirichlet-Multinomial loss to compare the predicted distribution of votes (given k volunteers were asked) to the true distribution.

Does not produce accuracy or MSE metrics, as these are not relevant for this task. Loss logging only.

If you’re using this, you’re probably working on a Galaxy Zoo catalog, and you should Slack Mike!

Any args not listed below are passed to FinetuneableZoobotAbstract (for example, learning_rate). These are shared between classifier, regressor, and tree models. See the docstring of FinetuneableZoobotAbstract for more.

Models can be loaded with FinetuneableZoobotTree(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). See FinetuneableZoobotAbstract for other loading options (e.g. in-memory models or local checkpoints).

Parameters

schema (schemas.Schema) – Description of the layout of the decision tree. See zoobot.shared.schemas.Schema.


class zoobot.pytorch.training.finetune.LinearHead(input_dim: int, output_dim: int, head_dropout_prob=0.5, activation=None)#
forward(x)#

Returns logits, as recommended for CrossEntropy loss.

Parameters

x (torch.Tensor) – Encoded representation.

Returns

Result (see docstring of LinearHead).

Return type

torch.Tensor


zoobot.pytorch.training.finetune.load_pretrained_zoobot(checkpoint_loc: str) Module#

Load a pretrained Zoobot encoder from a LightningModule checkpoint.

Parameters

checkpoint_loc (str) – Path to saved LightningModule checkpoint, likely of ZoobotTree, FinetuneableZoobotClassifier, or FinetuneableZoobotTree. Must have .encoder attribute.

Returns

Pretrained PyTorch encoder within that LightningModule.

Return type

torch.nn.Module


zoobot.pytorch.training.finetune.get_trainer(save_dir: str, file_template='{epoch}', save_top_k=1, max_epochs=100, patience=10, devices='auto', accelerator='auto', logger=None, **trainer_kwargs) Trainer#

Convenience wrapper to create a PyTorch Lightning Trainer that carries out the finetuning process. Use like so: trainer.fit(model, datamodule)

get_trainer args are for common Trainer settings e.g. early stopping, checkpointing, etc. By default: - Saves the top-k models based on validation loss - Uses early stopping with patience (i.e. end training if validation loss does not improve after patience epochs) - Monitors the learning rate (useful when using a learning rate scheduler)

Any extra args not listed below are passed directly to the PyTorch Lightning Trainer. Use this to add any custom configuration not covered by the get_trainer args. See https://lightning.ai/docs/pytorch/stable/common/trainer.html

Parameters
  • save_dir (str) – Folder in which to save checkpoints and logs.

  • file_template (str, optional) – Custom naming for checkpoint files. See Lightning docs. Defaults to "{epoch}".

  • save_top_k (int, optional) – Save the top k checkpoints only. Defaults to 1.

  • max_epochs (int, optional) – Train for up to this many epochs. Defaults to 100.

  • patience (int, optional) – Wait up to this many epochs for decreasing loss before ending training. Defaults to 10.

  • devices (str, optional) – Number of devices for training (typically, num. GPUs). Defaults to 'auto'.

  • accelerator (str, optional) – Which device to use (typically 'gpu' or 'cpu'). Defaults to 'auto'.

  • logger (L.pytorch.loggers.wandb_logger, optional) – If L.pytorch.loggers.wandb_logger, track experiment on Weights and Biases. Defaults to None.

Returns

PyTorch Lightning trainer object for finetuning a model on a GalaxyDataModule.

Return type

L.Trainer


zoobot.pytorch.training.finetune.download_from_name(class_name: str, hub_name: str)#

Download a finetuned model from the HuggingFace Hub by name. Used to load pretrained Zoobot models by name, e.g. FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...).

Downloaded models are saved to the HuggingFace cache directory for later use (typically ~/.cache/huggingface).

You shouldn’t need to call this; it’s used internally by the FinetuneableZoobot classes.

Parameters
  • class_name (str) – One of FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree.

  • hub_name (str) – e.g. mwalmsley/zoobot-encoder-convnext_nano.

Returns

Path to downloaded model (in HuggingFace cache directory). Likely then loaded by Lightning.

Return type

str