add grid search
This commit is contained in:
		
							
								
								
									
										21
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								main.py
									
									
									
									
									
								
							@@ -625,7 +625,7 @@ if __name__ == "__main__":
 | 
			
		||||
                        default='./config/', help='Config directory')
 | 
			
		||||
    
 | 
			
		||||
    parser.add_argument('--test_only', action='store_true', default=False)
 | 
			
		||||
    parser.add_argument('--filtered', action='store_true', default=False)
 | 
			
		||||
    parser.add_argument('--grid_search', action='store_true', default=False)
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
@@ -635,6 +635,25 @@ if __name__ == "__main__":
 | 
			
		||||
    set_seed(args.seed)
 | 
			
		||||
 | 
			
		||||
    model = Main(args)
 | 
			
		||||
 | 
			
		||||
    if (args.grid_search):
 | 
			
		||||
        from sklearn.model_selection import GridSearchCV
 | 
			
		||||
        from skorch import NeuralNet
 | 
			
		||||
 | 
			
		||||
        estimator = NeuralNet(
 | 
			
		||||
            module=FouriER,
 | 
			
		||||
            criterion=torch.nn.BCELoss,
 | 
			
		||||
            optimizer=torch.optim.Adam,
 | 
			
		||||
            max_epochs=100,
 | 
			
		||||
            batch_size=128,
 | 
			
		||||
            verbose=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        paramsGrid = {
 | 
			
		||||
            'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
 | 
			
		||||
    if (args.test_only):
 | 
			
		||||
        save_path = os.path.join('./torch_saved', args.name)
 | 
			
		||||
        model.load_model(save_path)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user