losses#
- class zoobot.pytorch.training.losses.CustomMultiQuestionLoss(question_answer_pairs: dict, question_functional_loss, careful=False, sum_over_questions=False)#
Bases:
Module- forward(inputs: Tensor, targets: dict) Tensor#
Compute the loss for multi-question predictions.
- Parameters
inputs (torch.Tensor) – Prediction vector of shape (batch_size, num_answer_keys). Contains predicted fractions for all answer keys across all questions.
targets (dict) – Dictionary with answer keys as keys and target values as values. Each value has shape (batch_size,) containing the target counts/votes.
- Returns
- Loss tensor. If sum_over_questions is True, returns shape (batch_size,)
with one loss value per galaxy. If False, returns shape (batch_size, num_questions) with one loss value per question per galaxy.
- Return type
torch.Tensor
- zoobot.pytorch.training.losses.get_dirichlet_neg_log_prob(concentrations_for_q, labels_for_q) Tensor#
Negative log likelihood of
labels_for_qbeing drawn from Dirichlet-Multinomial distribution withconcentrations_for_qconcentrations. This loss is for one question. Sum over multiple questions if needed (assumes independent). Applied byCustomMultiQuestionLoss, above, if passed as the question_functional_loss argument.- Parameters
labels_for_q (torch.Tensor) – observed labels (count of volunteer responses) of shape (batch, answer)
concentrations_for_q (torch.Tensor) – concentrations for multinomial-dirichlet predicting the observed labels, of shape (batch, answer)
- Returns
negative log. prob per galaxy, of shape (batch_dim).
- Return type
torch.Tensor
- zoobot.pytorch.training.losses.log_prob(alpha, value) Tensor#
Compute log probability for Dirichlet-Multinomial distribution.
Manual implementation equivalent to pyro.distributions.DirichletMultinomial.log_prob() to remove pyro dependency.
- Parameters
alpha (torch.Tensor) – Concentration parameters of shape (batch, categories).
value (torch.Tensor) – Observed counts of shape (batch, categories).
- Returns
Log probability of shape (batch,).
- Return type
torch.Tensor