finetune#
Use these classes and methods to finetune a pretrained Zoobot model.
See the README for a minimal example. See zoobot/pytorch/examples for more worked examples.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract(name=None, encoder=None, zoobot_checkpoint_loc=None, training_mode='full', layer_decay=0.75, weight_decay=0.05, learning_rate=0.0001, head_dropout_prob=0.5, scheduler_kwargs=None, timm_kwargs={}, greyscale=False, prog_bar=True, visualize_images=False, seed=42)#
Parent class of
FinetuneableZoobotClassifier,FinetuneableZoobotRegressor,FinetuneableZoobotTree. You cannot use this class directly - you must use the child classes above instead.This class defines the shared finetuning args and methods used by those child classes. For example: * When provided
name, it will load the HuggingFace encoder with that name (see below for more). * When providedlearning_rateit will set the optimizer to use that learning rate.Any FinetuneableZoobot model can be loaded in one of three ways: * HuggingFace name e.g.
FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended. * Any PyTorch model in memory e.g.FinetuneableZoobotX(encoder=some_model, ...)* ZoobotTree checkpoint e.g.FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)You could subclass this class to solve new finetuning tasks - see Advanced Finetuning.
- Parameters
name (str, optional) – Name of a model on HuggingFace Hub e.g.
hf_hub:mwalmsley/zoobot-encoder-convnext_nano. Defaults toNone.encoder (torch.nn.Module, optional) – Instead of
name, use a PyTorch model already loaded in memory. Defaults toNone.zoobot_checkpoint_loc (str, optional) – Instead of
name, use a path to ZoobotTree lightning checkpoint to load. Loads withzoobot.pytorch.training.finetune.load_pretrained_zoobot(). Defaults toNone.training_mode (str, optional) –
'full'to train all parameters,'head_only'to freeze encoder and only train head. Defaults to'full'.layer_decay (float, optional) – For each layer below the head, reduce the learning rate by
layer_decay ** i. Defaults to0.75.weight_decay (float, optional) – AdamW weight decay arg (i.e. L2 penalty). Defaults to
0.05.learning_rate (float, optional) – AdamW learning rate arg. Defaults to
1e-4.head_dropout_prob (float, optional) – Probability of dropout before final output layer. Defaults to
0.5.scheduler_kwargs (dict, optional) – Arguments for the optional learning rate scheduler. Defaults to
None(no scheduler).timm_kwargs (dict, optional) – Additional arguments for
timm.create_model.greyscale (bool, optional) – If
True, convert model to single channel version (adds{'in_chans': 1}to timm kwargs). Defaults toFalse.prog_bar (bool, optional) – Print progress bar during finetuning. Defaults to
True.visualize_images (bool, optional) – Upload example images to WandB. Good for debugging but potentially slow. Defaults to
False.seed (int, optional) – Random seed to use. Defaults to
42.
- configure_optimizers()#
Sets up the optimizer and, optionally, a learning rate scheduler.
When
self.training_mode == 'head_only', onlyself.headis optimized (i.e. frozen encoder, linear finetuning). Whenself.training_mode == 'full', all parameters are optimized.Learning rate decay is applied to the encoder only. Counterintuitively, a higher learning rate decay value (e.g.
0.9) causes less reduction in the learning rate: the learning rate is (from the top encoder layer down)lr,lr * layer_decay,lr * layer_decay**2, … I use timm’s definition of layers, which groups some torch layers together.Weight decay (aka L2 regularization, penalizing large weights) is applied to both the head and (if relevant) the encoder.
For schedulers, I use the timm scheduler factory. See https://github.com/rwightman/timm/blob/main/timm/scheduler/scheduler_factory.py#L63. self.scheduler_kwargs (passed to the factory) should be a dict with the scheduler name and any additional args, e.g.
{'name': 'cosine', 'warmup_epochs': 5, 'max_epochs': 100}.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier(num_classes: int, label_col: str = 'label', label_smoothing=0.0, class_weights=None, run_linear_sanity_check: bool = False, **super_kwargs)#
Pretrained Zoobot model intended for finetuning on a classification problem.
Any args not listed below are passed to
FinetuneableZoobotAbstract(for example,learning_rate). These are shared between classifier, regressor, and tree models. See the docstring ofFinetuneableZoobotAbstractfor more.Models can be loaded with
FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). SeeFinetuneableZoobotAbstractfor other loading options (e.g. in-memory models or local checkpoints).- Parameters
label_col (str, optional) – Name of the column in the batch dict (e.g. a column in your dataframe) containing the labels. Defaults to
'label'.num_classes (int) – Number of target classes (e.g.
2for binary classification).label_smoothing (float, optional) – See torch
cross_entropy_lossdocs. Defaults to0.class_weights (arraylike, optional) – See torch
cross_entropy_lossdocs. Defaults toNone.run_linear_sanity_check (bool, optional) – Before fitting, use sklearn to fit a linear model. Defaults to
False.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor(label_col: str = 'label', loss: str = 'mse', unit_interval: bool = False, **super_kwargs)#
Pretrained Zoobot model intended for finetuning on a regression problem.
Any args not listed below are passed to
FinetuneableZoobotAbstract(for example,learning_rate). These are shared between classifier, regressor, and tree models. See the docstring ofFinetuneableZoobotAbstractfor more.Models can be loaded with
FinetuneableZoobotRegressor(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). SeeFinetuneableZoobotAbstractfor other loading options (e.g. in-memory models or local checkpoints).- Parameters
label_col (str, optional) – Name of the column in the batch dict (e.g. a column in your dataframe) containing the labels. Defaults to
'label'.loss (str, optional) – Loss function to use. Must be one of
'mse','mae'. Defaults to'mse'.unit_interval (bool, optional) – If
True, use sigmoid activation for the final layer, ensuring predictions between 0 and 1. Defaults toFalse.
- class zoobot.pytorch.training.finetune.FinetuneableZoobotTree(schema: Schema, **super_kwargs)#
Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. Uses Dirichlet-Multinomial loss introduced in GZ DECaLS. Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer, and uses the Dirichlet-Multinomial loss to compare the predicted distribution of votes (given k volunteers were asked) to the true distribution.
Does not produce accuracy or MSE metrics, as these are not relevant for this task. Loss logging only.
If you’re using this, you’re probably working on a Galaxy Zoo catalog, and you should Slack Mike!
Any args not listed below are passed to
FinetuneableZoobotAbstract(for example,learning_rate). These are shared between classifier, regressor, and tree models. See the docstring ofFinetuneableZoobotAbstractfor more.Models can be loaded with
FinetuneableZoobotTree(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). SeeFinetuneableZoobotAbstractfor other loading options (e.g. in-memory models or local checkpoints).- Parameters
schema (schemas.Schema) – Description of the layout of the decision tree. See
zoobot.shared.schemas.Schema.
- class zoobot.pytorch.training.finetune.LinearHead(input_dim: int, output_dim: int, head_dropout_prob=0.5, activation=None)#
- forward(x)#
Returns logits, as recommended for CrossEntropy loss.
- Parameters
x (torch.Tensor) – Encoded representation.
- Returns
Result (see docstring of LinearHead).
- Return type
torch.Tensor
- zoobot.pytorch.training.finetune.load_pretrained_zoobot(checkpoint_loc: str) Module#
Load a pretrained Zoobot encoder from a LightningModule checkpoint.
- Parameters
checkpoint_loc (str) – Path to saved LightningModule checkpoint, likely of
ZoobotTree,FinetuneableZoobotClassifier, orFinetuneableZoobotTree. Must have.encoderattribute.- Returns
Pretrained PyTorch encoder within that LightningModule.
- Return type
torch.nn.Module
- zoobot.pytorch.training.finetune.get_trainer(save_dir: str, file_template='{epoch}', save_top_k=1, max_epochs=100, patience=10, devices='auto', accelerator='auto', logger=None, **trainer_kwargs) Trainer#
Convenience wrapper to create a PyTorch Lightning Trainer that carries out the finetuning process. Use like so:
trainer.fit(model, datamodule)get_trainerargs are for common Trainer settings e.g. early stopping, checkpointing, etc. By default: - Saves the top-k models based on validation loss - Uses early stopping withpatience(i.e. end training if validation loss does not improve afterpatienceepochs) - Monitors the learning rate (useful when using a learning rate scheduler)Any extra args not listed below are passed directly to the PyTorch Lightning Trainer. Use this to add any custom configuration not covered by the
get_trainerargs. See https://lightning.ai/docs/pytorch/stable/common/trainer.html- Parameters
save_dir (str) – Folder in which to save checkpoints and logs.
file_template (str, optional) – Custom naming for checkpoint files. See Lightning docs. Defaults to
"{epoch}".save_top_k (int, optional) – Save the top k checkpoints only. Defaults to
1.max_epochs (int, optional) – Train for up to this many epochs. Defaults to
100.patience (int, optional) – Wait up to this many epochs for decreasing loss before ending training. Defaults to
10.devices (str, optional) – Number of devices for training (typically, num. GPUs). Defaults to
'auto'.accelerator (str, optional) – Which device to use (typically
'gpu'or'cpu'). Defaults to'auto'.logger (L.pytorch.loggers.wandb_logger, optional) – If
L.pytorch.loggers.wandb_logger, track experiment on Weights and Biases. Defaults toNone.
- Returns
PyTorch Lightning trainer object for finetuning a model on a GalaxyDataModule.
- Return type
L.Trainer
- zoobot.pytorch.training.finetune.download_from_name(class_name: str, hub_name: str)#
Download a finetuned model from the HuggingFace Hub by name. Used to load pretrained Zoobot models by name, e.g.
FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...).Downloaded models are saved to the HuggingFace cache directory for later use (typically
~/.cache/huggingface).You shouldn’t need to call this; it’s used internally by the FinetuneableZoobot classes.
- Parameters
class_name (str) – One of
FinetuneableZoobotClassifier,FinetuneableZoobotRegressor,FinetuneableZoobotTree.hub_name (str) – e.g.
mwalmsley/zoobot-encoder-convnext_nano.
- Returns
Path to downloaded model (in HuggingFace cache directory). Likely then loaded by Lightning.
- Return type
str