diff --git a/models.py b/models.py index f8d4680..37c75ef 100644 --- a/models.py +++ b/models.py @@ -570,6 +570,7 @@ class FouriER(torch.nn.Module): z = z.mean([-2, -1]) nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel)) + print(nt_rel) comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w) y_1 = self.bn0(y_1)