test
This commit is contained in:
		@@ -28,8 +28,6 @@ class TrainDataset(Dataset):
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        ele = self.triples[idx]
 | 
			
		||||
        if (idx == 0):
 | 
			
		||||
            print(ele)
 | 
			
		||||
        triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(
 | 
			
		||||
            ele['label']), np.float32(ele['sub_samp'])
 | 
			
		||||
        trp_label = self.get_label(label)
 | 
			
		||||
@@ -51,7 +49,6 @@ class TrainDataset(Dataset):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def collate_fn(data):
 | 
			
		||||
        triple = torch.stack([_[0] for _ in data], dim=0)
 | 
			
		||||
        print(triple)
 | 
			
		||||
        trp_label = torch.stack([_[1] for _ in data], dim=0)
 | 
			
		||||
 | 
			
		||||
        if not data[0][2] is None:							# one_to_x
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user