add grid search
This commit is contained in:
		
							
								
								
									
										2
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								main.py
									
									
									
									
									
								
							@@ -659,6 +659,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
        grid = GridSearchCV(estimator=estimator, param_grid=paramsGrid, n_jobs=-1, cv=1)
 | 
					        grid = GridSearchCV(estimator=estimator, param_grid=paramsGrid, n_jobs=-1, cv=1)
 | 
				
			||||||
        data = np.array(model.triples['train'])
 | 
					        data = np.array(model.triples['train'])
 | 
				
			||||||
        data = data[np.random.choice(np.arange(len(data)), size=int(len(data) * 0.2), replace=False)]
 | 
					        data = data[np.random.choice(np.arange(len(data)), size=int(len(data) * 0.2), replace=False)]
 | 
				
			||||||
 | 
					        print(data[0])
 | 
				
			||||||
        dataloader = iter(DataLoader(
 | 
					        dataloader = iter(DataLoader(
 | 
				
			||||||
                TrainDataset(data, model.p),
 | 
					                TrainDataset(data, model.p),
 | 
				
			||||||
                batch_size=len(data),
 | 
					                batch_size=len(data),
 | 
				
			||||||
@@ -667,6 +668,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                collate_fn=TrainDataset.collate_fn
 | 
					                collate_fn=TrainDataset.collate_fn
 | 
				
			||||||
            ))
 | 
					            ))
 | 
				
			||||||
        for step, batch in dataloader:
 | 
					        for step, batch in dataloader:
 | 
				
			||||||
 | 
					            print(batch[0])
 | 
				
			||||||
            print(batch.shape)
 | 
					            print(batch.shape)
 | 
				
			||||||
            sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
 | 
					            sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
 | 
				
			||||||
                batch, 'train')
 | 
					                batch, 'train')
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user