From ff855256e0c83c85857df2386bcfe9a2340b87fc Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Wed, 17 May 2023 13:45:46 +0700 Subject: [PATCH] add grid search --- models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/models.py b/models.py index 90bf38f..a8685c0 100644 --- a/models.py +++ b/models.py @@ -537,6 +537,12 @@ class FouriER(torch.nn.Module): loss = self.bceloss(pred, true_label) return loss + def score(self, pred, true_label=None, sub_samp=None): + label_pos = true_label[0] + label_neg = true_label[1:] + loss = self.bceloss(pred, true_label) + return loss + def forward_embeddings(self, x): x = self.patch_embed(x) return x