test
This commit is contained in:
		
							
								
								
									
										3
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.py
									
									
									
									
									
								
							| @@ -273,13 +273,14 @@ class Main(object): | |||||||
|         labels:		The label for each triple |         labels:		The label for each triple | ||||||
|         """ |         """ | ||||||
|         if split == 'train': |         if split == 'train': | ||||||
|             print(triple.shape) |  | ||||||
|             if self.p.train_strategy == 'one_to_x': |             if self.p.train_strategy == 'one_to_x': | ||||||
|                 triple, label, neg_ent, sub_samp = [ |                 triple, label, neg_ent, sub_samp = [ | ||||||
|                     _.to(self.device) for _ in batch] |                     _.to(self.device) for _ in batch] | ||||||
|  |                 print(triple.shape) | ||||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp |                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp | ||||||
|             else: |             else: | ||||||
|                 triple, label = [_.to(self.device) for _ in batch] |                 triple, label = [_.to(self.device) for _ in batch] | ||||||
|  |                 print(triple.shape) | ||||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None |                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None | ||||||
|         else: |         else: | ||||||
|             triple, label = [_.to(self.device) for _ in batch] |             triple, label = [_.to(self.device) for _ in batch] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user