predict_on_catalog#

zoobot.pytorch.predictions.predict_on_catalog.predict(catalog: DataFrame, model: LightningModule, n_samples: int, label_cols: List, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}) None#

Use trained model to make predictions on a catalog of galaxies.

Parameters
  • catalog (pd.DataFrame) – catalog of galaxies to make predictions on. Must include file_loc and id_str columns.

  • model (pl.LightningModule) – with which to make predictions. Probably ZoobotTree, FinetuneableZoobotClassifier, FinetuneableZoobotTree, or ZoobotEncoder.

  • n_samples (int) – num. of forward passes to make per galaxy. Useful to marginalise over augmentations/test-time dropout.

  • label_cols (List) – Names for prediction columns. Only for your convenience - has no effect on predictions.

  • save_loc (str) – desired name of file recording the predictions

  • datamodule_kwargs (dict, optional) – Passed to GalaxyDataModule. Use to e.g. add custom augmentations. Defaults to {}.

  • trainer_kwargs (dict, optional) – Passed to pl.Trainer. Defaults to {}.