test
This commit is contained in:
		
							
								
								
									
										6
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.py
									
									
									
									
									
								
							@@ -417,8 +417,8 @@ class Main(object):
 | 
				
			|||||||
            obj_pred = []
 | 
					            obj_pred = []
 | 
				
			||||||
            obj_pred_score = []
 | 
					            obj_pred_score = []
 | 
				
			||||||
            for step, batch in enumerate(train_iter):
 | 
					            for step, batch in enumerate(train_iter):
 | 
				
			||||||
                sub, rel, obj, label = self.read_batch(batch, split)
 | 
					                sub, rel, obj, nt_rel, label = self.read_batch(batch, split)
 | 
				
			||||||
                pred = self.model.forward(sub, rel, None, 'one_to_n')
 | 
					                pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n')
 | 
				
			||||||
                b_range = torch.arange(pred.size()[0], device=self.device)
 | 
					                b_range = torch.arange(pred.size()[0], device=self.device)
 | 
				
			||||||
                target_pred = pred[b_range, obj]
 | 
					                target_pred = pred[b_range, obj]
 | 
				
			||||||
                pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
 | 
					                pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
 | 
				
			||||||
@@ -691,7 +691,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                collate_fn=TrainDataset.collate_fn
 | 
					                collate_fn=TrainDataset.collate_fn
 | 
				
			||||||
            ))
 | 
					            ))
 | 
				
			||||||
        for step, batch in enumerate(dataloader):
 | 
					        for step, batch in enumerate(dataloader):
 | 
				
			||||||
            sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
 | 
					            sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch(
 | 
				
			||||||
                batch, 'train')
 | 
					                batch, 'train')
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            if (neg_ent is None):
 | 
					            if (neg_ent is None):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user