add grid search

This commit is contained in:
Cong Thanh Vu 2023-05-17 05:38:03 +00:00
parent 5f1518cfd9
commit dfec7ff331

21
main.py
View File

@ -625,7 +625,7 @@ if __name__ == "__main__":
default='./config/', help='Config directory') default='./config/', help='Config directory')
parser.add_argument('--test_only', action='store_true', default=False) parser.add_argument('--test_only', action='store_true', default=False)
parser.add_argument('--filtered', action='store_true', default=False) parser.add_argument('--grid_search', action='store_true', default=False)
args = parser.parse_args() args = parser.parse_args()
@ -635,6 +635,25 @@ if __name__ == "__main__":
set_seed(args.seed) set_seed(args.seed)
model = Main(args) model = Main(args)
if (args.grid_search):
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNet
estimator = NeuralNet(
module=FouriER,
criterion=torch.nn.BCELoss,
optimizer=torch.optim.Adam,
max_epochs=100,
batch_size=128,
verbose=False
)
paramsGrid = {
'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
}
grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
if (args.test_only): if (args.test_only):
save_path = os.path.join('./torch_saved', args.name) save_path = os.path.join('./torch_saved', args.name)
model.load_model(save_path) model.load_model(save_path)