losses#

zoobot.pytorch.training.losses.calculate_multiquestion_loss(labels: Tensor, predictions: Tensor, question_index_groups: Tuple, careful=True) Tensor#

The full decision tree loss used for training GZ DECaLS models

Negative log likelihood of observing labels (volunteer answers to all questions) from Dirichlet-Multinomial distributions for each question, using concentrations predictions.

Parameters
  • labels (torch.Tensor) – (galaxy, k successes) where k successes dimension is indexed by question_index_groups.

  • predictions (torch.Tensor) – Dirichlet concentrations, matching shape of labels

  • question_index_groups (list) – Answer indices for each question i.e. [(question.start_index, question.end_index), …] for all questions. Useful for slicing model predictions by question. See schemas.

Returns

neg. log likelihood of shape (batch, question).

Return type

torch.Tensor


zoobot.pytorch.training.losses.dirichlet_loss(labels_for_q, concentrations_for_q)#

Negative log likelihood of labels_for_q being drawn from Dirichlet-Multinomial distribution with concentrations_for_q concentrations. This loss is for one question. Sum over multiple questions if needed (assumes independent). Applied by calculate_multiquestion_loss(), above.

Parameters
  • labels_for_q (tf.constant) – observed labels (count of volunteer responses) of shape (batch, answer)

  • concentrations_for_q (tf.constant) – concentrations for multinomial-dirichlet predicting the observed labels, of shape (batch, answer)

Returns

negative log. prob per galaxy, of shape (batch_dim).

Return type

tf.constant