add grid search
This commit is contained in:
		
							
								
								
									
										10
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								main.py
									
									
									
									
									
								
							@@ -675,8 +675,14 @@ if __name__ == "__main__":
 | 
			
		||||
            else:
 | 
			
		||||
                neg_ent = neg_ent.cpu()
 | 
			
		||||
 | 
			
		||||
            dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent, np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1)
 | 
			
		||||
            search = grid.fit(dataset, label)
 | 
			
		||||
            inputs = []
 | 
			
		||||
            for i in len(sub):
 | 
			
		||||
                input = {}
 | 
			
		||||
                input['sub'] = sub[i]
 | 
			
		||||
                input['rel'] = rel[i]
 | 
			
		||||
                input['neg_ents'] = neg_ent[i]
 | 
			
		||||
                inputs.append(input)
 | 
			
		||||
            search = grid.fit(inputs, label)
 | 
			
		||||
            print("BEST SCORE: ", search.best_score_)
 | 
			
		||||
            print("BEST PARAMS: ", search.best_params_)
 | 
			
		||||
    if (args.test_only):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user