test
This commit is contained in:
		
							
								
								
									
										3
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.py
									
									
									
									
									
								
							| @@ -273,13 +273,14 @@ class Main(object): | ||||
|         labels:		The label for each triple | ||||
|         """ | ||||
|         if split == 'train': | ||||
|             print(triple.shape) | ||||
|             if self.p.train_strategy == 'one_to_x': | ||||
|                 triple, label, neg_ent, sub_samp = [ | ||||
|                     _.to(self.device) for _ in batch] | ||||
|                 print(triple.shape) | ||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp | ||||
|             else: | ||||
|                 triple, label = [_.to(self.device) for _ in batch] | ||||
|                 print(triple.shape) | ||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None | ||||
|         else: | ||||
|             triple, label = [_.to(self.device) for _ in batch] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user