test
This commit is contained in:
parent
2637f53848
commit
9502c8d009
5
main.py
5
main.py
@ -96,7 +96,6 @@ class Main(object):
|
|||||||
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
|
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
|
||||||
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
|
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
|
||||||
|
|
||||||
print("Num rel1: " + str(len(self.rel2id)))
|
|
||||||
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
|
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
|
||||||
for idx, rel in enumerate(rel_set)})
|
for idx, rel in enumerate(rel_set)})
|
||||||
|
|
||||||
@ -105,7 +104,6 @@ class Main(object):
|
|||||||
|
|
||||||
self.p.num_ent = len(self.ent2id)
|
self.p.num_ent = len(self.ent2id)
|
||||||
self.p.num_rel = len(self.rel2id) // 2
|
self.p.num_rel = len(self.rel2id) // 2
|
||||||
print("Num rel: " + str(self.p.num_rel))
|
|
||||||
self.p.embed_dim = self.p.k_w * \
|
self.p.embed_dim = self.p.k_w * \
|
||||||
self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
|
self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
|
||||||
|
|
||||||
@ -280,11 +278,9 @@ class Main(object):
|
|||||||
if self.p.train_strategy == 'one_to_x':
|
if self.p.train_strategy == 'one_to_x':
|
||||||
triple, label, neg_ent, sub_samp = [
|
triple, label, neg_ent, sub_samp = [
|
||||||
_.to(self.device) for _ in batch]
|
_.to(self.device) for _ in batch]
|
||||||
print(triple.shape)
|
|
||||||
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
|
||||||
else:
|
else:
|
||||||
triple, label = [_.to(self.device) for _ in batch]
|
triple, label = [_.to(self.device) for _ in batch]
|
||||||
print(triple.shape)
|
|
||||||
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
|
||||||
else:
|
else:
|
||||||
triple, label = [_.to(self.device) for _ in batch]
|
triple, label = [_.to(self.device) for _ in batch]
|
||||||
@ -483,7 +479,6 @@ class Main(object):
|
|||||||
|
|
||||||
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
|
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
|
||||||
batch, 'train')
|
batch, 'train')
|
||||||
print(nt_rel)
|
|
||||||
|
|
||||||
pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
|
pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
|
||||||
loss = self.model.loss(pred, label, sub_samp)
|
loss = self.model.loss(pred, label, sub_samp)
|
||||||
|
@ -570,7 +570,6 @@ class FouriER(torch.nn.Module):
|
|||||||
z = z.mean([-2, -1])
|
z = z.mean([-2, -1])
|
||||||
|
|
||||||
nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
|
nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
|
||||||
print(nt_rel)
|
|
||||||
comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
|
comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
|
||||||
y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
|
y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
|
||||||
y_1 = self.bn0(y_1)
|
y_1 = self.bn0(y_1)
|
||||||
|
Loading…
Reference in New Issue
Block a user