How the Code Fits Together#

The Zoobot package has many classes and methods. This guide aims to be a map summarising how they fit together.

The Map#

The Zoobot package has two roles:

  1. Finetuning: pytorch/training/finetune.py is the heart of the package. You will use these classes to load pretrained models and finetune them on new data.

  2. Training from Scratch pytorch/estimators/define_model.py and pytorch/training/train_with_pytorch_lightning.py create and train the Zoobot models from scratch. These are not required for finetuning and will eventually be migrated out.

Let’s zoom in on the finetuning part.

Finetuning with Zoobot Classes#

There are three Zoobot classes for finetuning:

  1. FinetuneableZoobotClassifier for classification tasks (including multi-class).

  2. FinetuneableZoobotRegressor for regression tasks (including on a unit interval e.g. a fraction).

  3. FinetuneableZoobotTree for training on a tree of labels (e.g. Galaxy Zoo vote counts).

Each user-facing class is actually a subclass of a non-user-facing abstract class, FinetuneableZoobotAbstract. FinetuneableZoobotAbstract has specifying how to finetune a general PyTorch model, which the user-facing classes inherit.

FinetuneableZoobotAbstract <zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract> controls the core finetuning process: loading a model, accepting arguments controlling the finetuning process, and running the finetuning. The user-facing class adds features specific to that type of task. For example, FinetuneableZoobotClassifier adds additional arguments like num_classes. It also specifies an appropriate head and a loss function.

Finetuning with PyTorch Lightning#

are all “LightningModule” classes. These classes have (custom) methods like training_step, validation_step, etc., which specify what should happen at each training stage.

Zoobot is written in PyTorch, a popular deep learning library for Python. PyTorch requires a lot of boilerplate code to train models, especially at scale (e.g. multi-node, multi-GPU). We use PyTorch Lightning, a third party wrapper API, to make this boilerplate code someone else’s job.

ZoobotTree is similar to FinetuneableZoobotAbstract but has methods for training from scratch.

Some generic methods (like logging) are defined in define_model.py and called by both ZoobotTree and FinetuneableZoobotAbstract

LightningModules can be passed to a Lightning Trainer object. This handles running the training in practice (e.g. how to distribute training onto a GPU, how many epochs to run, etc.).

So when we do:

model = FinetuneableZoobotTree(...)
trainer = get_trainer(...)
trainer.fit(model, datamodule)

We are:

  • Defining a PyTorch encoder and head (inside FinetuneableZoobotTree)

  • Wrapping them in a LightningModule specifying how to train them (FinetuneableZoobotTree)

  • Fitting the LightningModule using Lighting’s Trainer class

Slightly confusingly, Lightning’s Trainer can also be used to make predictions:

trainer.predict(model, datamodule)

and that’s how we make predictions with zoobot.pytorch.predictions.predict_on_catalog.predict().

As you can see, there’s quite a few layers (pun intended) to training Zoobot models. But we hope this setup is both simple to use and easy to extend, whichever (PyTorch) frameworks you’re using.