predict_on_catalog#
- zoobot.pytorch.predictions.predict_on_catalog.predict(catalog: DataFrame, model: LightningModule, n_samples: int, label_cols: List, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}) None #
Use trained model to make predictions on a catalog of galaxies.
- Parameters
catalog (pd.DataFrame) – catalog of galaxies to make predictions on. Must include file_loc and id_str columns.
model (pl.LightningModule) – with which to make predictions. Probably ZoobotTree, FinetuneableZoobotClassifier, FinetuneableZoobotTree, or ZoobotEncoder.
n_samples (int) – num. of forward passes to make per galaxy. Useful to marginalise over augmentations/test-time dropout.
label_cols (List) – Names for prediction columns. Only for your convenience - has no effect on predictions.
save_loc (str) – desired name of file recording the predictions
datamodule_kwargs (dict, optional) – Passed to GalaxyDataModule. Use to e.g. add custom augmentations. Defaults to {}.
trainer_kwargs (dict, optional) – Passed to pl.Trainer. Defaults to {}.