add grid search
This commit is contained in:
		
							
								
								
									
										21
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								main.py
									
									
									
									
									
								
							@@ -625,7 +625,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                        default='./config/', help='Config directory')
 | 
					                        default='./config/', help='Config directory')
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    parser.add_argument('--test_only', action='store_true', default=False)
 | 
					    parser.add_argument('--test_only', action='store_true', default=False)
 | 
				
			||||||
    parser.add_argument('--filtered', action='store_true', default=False)
 | 
					    parser.add_argument('--grid_search', action='store_true', default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -635,6 +635,25 @@ if __name__ == "__main__":
 | 
				
			|||||||
    set_seed(args.seed)
 | 
					    set_seed(args.seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = Main(args)
 | 
					    model = Main(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (args.grid_search):
 | 
				
			||||||
 | 
					        from sklearn.model_selection import GridSearchCV
 | 
				
			||||||
 | 
					        from skorch import NeuralNet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        estimator = NeuralNet(
 | 
				
			||||||
 | 
					            module=FouriER,
 | 
				
			||||||
 | 
					            criterion=torch.nn.BCELoss,
 | 
				
			||||||
 | 
					            optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					            max_epochs=100,
 | 
				
			||||||
 | 
					            batch_size=128,
 | 
				
			||||||
 | 
					            verbose=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        paramsGrid = {
 | 
				
			||||||
 | 
					            'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
 | 
				
			||||||
    if (args.test_only):
 | 
					    if (args.test_only):
 | 
				
			||||||
        save_path = os.path.join('./torch_saved', args.name)
 | 
					        save_path = os.path.join('./torch_saved', args.name)
 | 
				
			||||||
        model.load_model(save_path)
 | 
					        model.load_model(save_path)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user