This commit is contained in:
thanhvc3 2024-06-18 21:11:18 +07:00
parent c2b17ec1ba
commit bb9856ecd1

View File

@ -570,6 +570,7 @@ class FouriER(torch.nn.Module):
z = z.mean([-2, -1]) z = z.mean([-2, -1])
nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel)) nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
print(nt_rel)
comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w) y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
y_1 = self.bn0(y_1) y_1 = self.bn0(y_1)