test
This commit is contained in:
		
							
								
								
									
										5
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								main.py
									
									
									
									
									
								
							@@ -96,7 +96,6 @@ class Main(object):
 | 
			
		||||
        # self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
 | 
			
		||||
        # self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
 | 
			
		||||
        
 | 
			
		||||
        print("Num rel1: " + str(len(self.rel2id)))
 | 
			
		||||
        self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
 | 
			
		||||
                           for idx, rel in enumerate(rel_set)})
 | 
			
		||||
 | 
			
		||||
@@ -105,7 +104,6 @@ class Main(object):
 | 
			
		||||
 | 
			
		||||
        self.p.num_ent = len(self.ent2id)
 | 
			
		||||
        self.p.num_rel = len(self.rel2id) // 2
 | 
			
		||||
        print("Num rel: " + str(self.p.num_rel))
 | 
			
		||||
        self.p.embed_dim = self.p.k_w * \
 | 
			
		||||
            self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
 | 
			
		||||
 | 
			
		||||
@@ -280,11 +278,9 @@ class Main(object):
 | 
			
		||||
            if self.p.train_strategy == 'one_to_x':
 | 
			
		||||
                triple, label, neg_ent, sub_samp = [
 | 
			
		||||
                    _.to(self.device) for _ in batch]
 | 
			
		||||
                print(triple.shape)
 | 
			
		||||
                return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
 | 
			
		||||
            else:
 | 
			
		||||
                triple, label = [_.to(self.device) for _ in batch]
 | 
			
		||||
                print(triple.shape)
 | 
			
		||||
                return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
 | 
			
		||||
        else:
 | 
			
		||||
            triple, label = [_.to(self.device) for _ in batch]
 | 
			
		||||
@@ -483,7 +479,6 @@ class Main(object):
 | 
			
		||||
 | 
			
		||||
            sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
 | 
			
		||||
                batch, 'train')
 | 
			
		||||
            print(nt_rel)
 | 
			
		||||
 | 
			
		||||
            pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
 | 
			
		||||
            loss = self.model.loss(pred, label, sub_samp)
 | 
			
		||||
 
 | 
			
		||||
@@ -570,7 +570,6 @@ class FouriER(torch.nn.Module):
 | 
			
		||||
        z = z.mean([-2, -1])
 | 
			
		||||
 | 
			
		||||
        nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
 | 
			
		||||
        print(nt_rel)
 | 
			
		||||
        comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
 | 
			
		||||
        y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
 | 
			
		||||
        y_1 = self.bn0(y_1)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user