diff --git a/models.py b/models.py index 90bf38f..a8685c0 100644 --- a/models.py +++ b/models.py @@ -537,6 +537,12 @@ class FouriER(torch.nn.Module): loss = self.bceloss(pred, true_label) return loss + def score(self, pred, true_label=None, sub_samp=None): + label_pos = true_label[0] + label_neg = true_label[1:] + loss = self.bceloss(pred, true_label) + return loss + def forward_embeddings(self, x): x = self.patch_embed(x) return x