From dfec7ff331071bc25f19bd9a9e559fab30910ba2 Mon Sep 17 00:00:00 2001 From: Cong Thanh Vu Date: Wed, 17 May 2023 05:38:03 +0000 Subject: [PATCH] add grid search --- main.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index dd41ee7..fa3e39a 100644 --- a/main.py +++ b/main.py @@ -625,7 +625,7 @@ if __name__ == "__main__": default='./config/', help='Config directory') parser.add_argument('--test_only', action='store_true', default=False) - parser.add_argument('--filtered', action='store_true', default=False) + parser.add_argument('--grid_search', action='store_true', default=False) args = parser.parse_args() @@ -635,6 +635,25 @@ if __name__ == "__main__": set_seed(args.seed) model = Main(args) + + if (args.grid_search): + from sklearn.model_selection import GridSearchCV + from skorch import NeuralNet + + estimator = NeuralNet( + module=FouriER, + criterion=torch.nn.BCELoss, + optimizer=torch.optim.Adam, + max_epochs=100, + batch_size=128, + verbose=False + ) + + paramsGrid = { + 'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3], + } + + grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3) if (args.test_only): save_path = os.path.join('./torch_saved', args.name) model.load_model(save_path)