losses#

class zoobot.pytorch.training.losses.CustomMultiQuestionLoss(question_answer_pairs: dict, question_functional_loss, careful=False, sum_over_questions=False)#

Bases: Module

forward(inputs: Tensor, targets: dict) Tensor#

Compute the loss for multi-question predictions.

Parameters
  • inputs (torch.Tensor) – Prediction vector of shape (batch_size, num_answer_keys). Contains predicted fractions for all answer keys across all questions.

  • targets (dict) – Dictionary with answer keys as keys and target values as values. Each value has shape (batch_size,) containing the target counts/votes.

Returns

Loss tensor. If sum_over_questions is True, returns shape (batch_size,)

with one loss value per galaxy. If False, returns shape (batch_size, num_questions) with one loss value per question per galaxy.

Return type

torch.Tensor


zoobot.pytorch.training.losses.get_dirichlet_neg_log_prob(concentrations_for_q, labels_for_q) Tensor#

Negative log likelihood of labels_for_q being drawn from Dirichlet-Multinomial distribution with concentrations_for_q concentrations. This loss is for one question. Sum over multiple questions if needed (assumes independent). Applied by CustomMultiQuestionLoss, above, if passed as the question_functional_loss argument.

Parameters
  • labels_for_q (torch.Tensor) – observed labels (count of volunteer responses) of shape (batch, answer)

  • concentrations_for_q (torch.Tensor) – concentrations for multinomial-dirichlet predicting the observed labels, of shape (batch, answer)

Returns

negative log. prob per galaxy, of shape (batch_dim).

Return type

torch.Tensor


zoobot.pytorch.training.losses.log_prob(alpha, value) Tensor#

Compute log probability for Dirichlet-Multinomial distribution.

Manual implementation equivalent to pyro.distributions.DirichletMultinomial.log_prob() to remove pyro dependency.

Parameters
  • alpha (torch.Tensor) – Concentration parameters of shape (batch, categories).

  • value (torch.Tensor) – Observed counts of shape (batch, categories).

Returns

Log probability of shape (batch,).

Return type

torch.Tensor