How the Code Fits Together#
The Zoobot package has many classes and methods. This guide aims to be a map summarising how they fit together.
The Map#
The Zoobot package has two roles:
Finetuning:
pytorch/training/finetune.py
is the heart of the package. You will use these classes to load pretrained models and finetune them on new data.Training from Scratch
pytorch/estimators/define_model.py
andpytorch/training/train_with_pytorch_lightning.py
create and train the Zoobot models from scratch. These are not required for finetuning and will eventually be migrated out.
Let’s zoom in on the finetuning part.
Finetuning with Zoobot Classes#
There are three Zoobot classes for finetuning:
FinetuneableZoobotClassifier
for classification tasks (including multi-class).FinetuneableZoobotRegressor
for regression tasks (including on a unit interval e.g. a fraction).FinetuneableZoobotTree
for training on a tree of labels (e.g. Galaxy Zoo vote counts).
Each user-facing class is actually a subclass of a non-user-facing abstract class, FinetuneableZoobotAbstract
.
FinetuneableZoobotAbstract
has specifying how to finetune a general PyTorch model,
which the user-facing classes inherit.
FinetuneableZoobotAbstract <zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract> controls the core finetuning process: loading a model, accepting arguments controlling the finetuning process, and running the finetuning.
The user-facing class adds features specific to that type of task. For example, FinetuneableZoobotClassifier
adds additional arguments like num_classes.
It also specifies an appropriate head and a loss function.
Finetuning with PyTorch Lightning#
are all “LightningModule” classes.
These classes have (custom) methods like training_step
, validation_step
, etc., which specify what should happen at each training stage.
Zoobot is written in PyTorch, a popular deep learning library for Python. PyTorch requires a lot of boilerplate code to train models, especially at scale (e.g. multi-node, multi-GPU). We use PyTorch Lightning, a third party wrapper API, to make this boilerplate code someone else’s job.
ZoobotTree
is similar to FinetuneableZoobotAbstract
but has methods for training from scratch.
Some generic methods (like logging) are defined in define_model.py
and called by both ZoobotTree
and FinetuneableZoobotAbstract
LightningModules can be passed to a Lightning Trainer
object. This handles running the training in practice (e.g. how to distribute training onto a GPU, how many epochs to run, etc.).
So when we do:
model = FinetuneableZoobotTree(...)
trainer = get_trainer(...)
trainer.fit(model, datamodule)
We are:
Defining a PyTorch encoder and head (inside
FinetuneableZoobotTree
)Wrapping them in a LightningModule specifying how to train them (
FinetuneableZoobotTree
)Fitting the LightningModule using Lighting’s
Trainer
class
Slightly confusingly, Lightning’s Trainer
can also be used to make predictions:
trainer.predict(model, datamodule)
and that’s how we make predictions with zoobot.pytorch.predictions.predict_on_catalog.predict()
.
As you can see, there’s quite a few layers (pun intended) to training Zoobot models. But we hope this setup is both simple to use and easy to extend, whichever (PyTorch) frameworks you’re using.