predict_on_catalog#
- zoobot.pytorch.predictions.predict_on_catalog.predict(catalog: DataFrame, model: LightningModule, label_cols: List[str], inference_transform: Compose, save_loc=None, datamodule_kwargs={}, trainer_kwargs={}) DataFrame#
Use trained model to make predictions on a catalog of galaxies.
- Parameters
catalog (pd.DataFrame) – catalog of galaxies to make predictions on. Must include file_loc and id_str columns.
model (L.LightningModule) – with which to make predictions. Probably ZoobotTree, FinetuneableZoobotClassifier, FinetuneableZoobotTree, or ZoobotEncoder.
save_loc (str) – desired name of file recording the predictions
label_cols (List[str]) – columns in the catalog to use as labels. Used to name the output columns.
datamodule_kwargs (dict, optional) – Passed to CatalogDataModule. Use to e.g. add custom augmentations. Defaults to {}.
trainer_kwargs (dict, optional) – Passed to L.Trainer. Defaults to {}.