losses#
- zoobot.pytorch.training.losses.calculate_multiquestion_loss(labels: Tensor, predictions: Tensor, question_index_groups: Tuple, careful=True) Tensor #
The full decision tree loss used for training GZ DECaLS models
Negative log likelihood of observing
labels
(volunteer answers to all questions) from Dirichlet-Multinomial distributions for each question, using concentrationspredictions
.- Parameters
labels (torch.Tensor) – (galaxy, k successes) where k successes dimension is indexed by question_index_groups.
predictions (torch.Tensor) – Dirichlet concentrations, matching shape of labels
question_index_groups (list) – Answer indices for each question i.e. [(question.start_index, question.end_index), …] for all questions. Useful for slicing model predictions by question. See schemas.
- Returns
neg. log likelihood of shape (batch, question).
- Return type
torch.Tensor
- zoobot.pytorch.training.losses.dirichlet_loss(labels_for_q, concentrations_for_q)#
Negative log likelihood of
labels_for_q
being drawn from Dirichlet-Multinomial distribution withconcentrations_for_q
concentrations. This loss is for one question. Sum over multiple questions if needed (assumes independent). Applied bycalculate_multiquestion_loss()
, above.- Parameters
labels_for_q (tf.constant) – observed labels (count of volunteer responses) of shape (batch, answer)
concentrations_for_q (tf.constant) – concentrations for multinomial-dirichlet predicting the observed labels, of shape (batch, answer)
- Returns
negative log. prob per galaxy, of shape (batch_dim).
- Return type
tf.constant