This commit is contained in:
thanhvc3 2024-06-19 00:24:20 +07:00
parent 2637f53848
commit 9502c8d009
2 changed files with 0 additions and 6 deletions

View File

@ -96,7 +96,6 @@ class Main(object):
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)} # self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)} # self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
print("Num rel1: " + str(len(self.rel2id)))
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
for idx, rel in enumerate(rel_set)}) for idx, rel in enumerate(rel_set)})
@ -105,7 +104,6 @@ class Main(object):
self.p.num_ent = len(self.ent2id) self.p.num_ent = len(self.ent2id)
self.p.num_rel = len(self.rel2id) // 2 self.p.num_rel = len(self.rel2id) // 2
print("Num rel: " + str(self.p.num_rel))
self.p.embed_dim = self.p.k_w * \ self.p.embed_dim = self.p.k_w * \
self.p.k_h if self.p.embed_dim is None else self.p.embed_dim self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
@ -280,11 +278,9 @@ class Main(object):
if self.p.train_strategy == 'one_to_x': if self.p.train_strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [ triple, label, neg_ent, sub_samp = [
_.to(self.device) for _ in batch] _.to(self.device) for _ in batch]
print(triple.shape)
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
else: else:
triple, label = [_.to(self.device) for _ in batch] triple, label = [_.to(self.device) for _ in batch]
print(triple.shape)
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
else: else:
triple, label = [_.to(self.device) for _ in batch] triple, label = [_.to(self.device) for _ in batch]
@ -483,7 +479,6 @@ class Main(object):
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch( sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
batch, 'train') batch, 'train')
print(nt_rel)
pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp) loss = self.model.loss(pred, label, sub_samp)

View File

@ -570,7 +570,6 @@ class FouriER(torch.nn.Module):
z = z.mean([-2, -1]) z = z.mean([-2, -1])
nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel)) nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
print(nt_rel)
comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w) y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
y_1 = self.bn0(y_1) y_1 = self.bn0(y_1)