train_with_pytorch_lightning#

zoobot.pytorch.training.train_with_pytorch_lightning.train_default_zoobot_from_scratch(save_dir: str, schema, catalog=None, train_catalog=None, val_catalog=None, test_catalog=None, train_urls=None, val_urls=None, test_urls=None, cache_dir=None, epochs=1000, patience=8, architecture_name='efficientnet_b0', timm_kwargs={}, batch_size=128, dropout_rate=0.2, learning_rate=0.001, betas=(0.9, 0.999), weight_decay=0.01, scheduler_params={}, color=False, resize_after_crop=224, crop_scale_bounds=(0.7, 0.8), crop_ratio_bounds=(0.9, 1.1), nodes=1, gpus=2, sync_batchnorm=False, num_workers=4, prefetch_factor=4, mixed_precision=False, compile_encoder=False, wandb_logger=None, checkpoint_file_template=None, auto_insert_metric_name=True, save_top_k=3, extra_callbacks=None, random_state=42) Tuple[ZoobotTree, Trainer]#

Train Zoobot from scratch on a big galaxy catalog.

You don’t need to use this. Training from scratch is becoming increasingly complicated (as you can see from the arguments) due to ongoing research on the best methods. This will be refactored to a dedicated “foundation” repo.

Parameters
  • save_dir (str) – folder to save training logs and trained model checkpoints

  • schema (shared.schemas.Schema) – Schema object with label_cols, question_answer_pairs, and dependencies

  • catalog (pd.DataFrame, optional) – Galaxy catalog with columns id_str and file_loc. Will be automatically split to train and val (no test). Defaults to None.

  • train_catalog (pd.DataFrame, optional) – As above, but already split by you for training. Defaults to None.

  • val_catalog (pd.DataFrame, optional) – As above, for validation. Defaults to None.

  • test_catalog (pd.DataFrame, optional) – As above, for testing. Defaults to None.

  • train_urls (list, optional) – List of URLs to webdatasets for training. Defaults to None.

  • val_urls (list, optional) – List of URLs to webdatasets for validation. Defaults to None.

  • test_urls (list, optional) – List of URLs to webdatasets for testing. Defaults to None.

  • cache_dir (str, optional) – Directory to cache webdatasets. Defaults to None.

  • epochs (int, optional) – Max. number of epochs to train for. Defaults to 1000.

  • patience (int, optional) – Max. number of epochs to wait for any loss improvement before ending training. Defaults to 8.

  • architecture_name (str, optional) – Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to ‘efficientnet_b0’.

  • timm_kwargs (dict, optional) – Additional kwargs to pass to timm model init method, for example {‘drop_connect_rate’: 0.2}. Defaults to {}.

  • batch_size (int, optional) – Batch size. Defaults to 128.

  • dropout_rate (float, optional) – Randomly drop activations prior to the output layer, with this probability. Defaults to 0.2.

  • learning_rate (float, optional) – Base learning rate for AdamW. Defaults to 1e-3.

  • betas (tuple, optional) – Beta args (i.e. momentum) for adamW. Defaults to (0.9, 0.999).

  • weight_decay (float, optional) – Weight decay arg (i.e. L2 penalty) for AdamW. Defaults to 0.01.

  • scheduler_params (dict, optional) – Specify a learning rate scheduler. See code below. Defaults to {}.

  • color (bool, optional) – Train on RGB images rather than channel-averaged. Defaults to False.

  • resize_after_crop (int, optional) – Input image size. After all transforms, images will be resized to this size. Defaults to 224.

  • crop_scale_bounds (tuple, optional) – Off-center crop fraction (<1 means zoom in). Defaults to (0.7, 0.8).

  • crop_ratio_bounds (tuple, optional) – Aspect ratio of crop above. Defaults to (0.9, 1.1).

  • nodes (int, optional) – Multi-node training Unlikely to work on your cluster without tinkering. Defaults to 1 (i.e. one node).

  • gpus (int, optional) – Multi-GPU training. Uses distributed data parallel - essentially, full dataset is split by GPU. See torch docs. Defaults to 2.

  • sync_batchnorm (bool, optional) – Use synchronized batchnorm. Defaults to False.

  • num_workers (int, optional) – Processes for loading data. See torch dataloader docs. Should be < num cpus. Defaults to 4.

  • prefetch_factor (int, optional) – Num. batches to queue in memory per dataloader. See torch dataloader docs. Defaults to 4.

  • mixed_precision (bool, optional) – Use (mostly) half-precision to halve memory requirements. May cause instability. See Lightning Trainer docs. Defaults to False.

  • compile_encoder (bool, optional) – Compile the encoder with torch.compile (new in torch v2). Defaults to False.

  • wandb_logger (pl.loggers.wandb.WandbLogger, optional) – Logger to track experiments on Weights and Biases. Defaults to None.

  • checkpoint_file_template (str, optional) – formatting for checkpoint filename. See Lightning docs. Defaults to None.

  • auto_insert_metric_name (bool, optional) – escape “/” in metric names when naming checkpoints. See Lightning docs. Defaults to True.

  • save_top_k (int, optional) – Keep the k best checkpoints. See Lightning docs. Defaults to 3.

  • extra_callbacks (list, optional) – Additional callbacks to pass to the Trainer. Defaults to None.

  • random_state (int, optional) – Seed. Defaults to 42.

Returns

Trained ZoobotTree model, and Trainer with which it was trained.

Return type

Tuple[define_model.ZoobotTree, pl.Trainer]