add grid search
This commit is contained in:
parent
5f1518cfd9
commit
dfec7ff331
21
main.py
21
main.py
@ -625,7 +625,7 @@ if __name__ == "__main__":
|
||||
default='./config/', help='Config directory')
|
||||
|
||||
parser.add_argument('--test_only', action='store_true', default=False)
|
||||
parser.add_argument('--filtered', action='store_true', default=False)
|
||||
parser.add_argument('--grid_search', action='store_true', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -635,6 +635,25 @@ if __name__ == "__main__":
|
||||
set_seed(args.seed)
|
||||
|
||||
model = Main(args)
|
||||
|
||||
if (args.grid_search):
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
from skorch import NeuralNet
|
||||
|
||||
estimator = NeuralNet(
|
||||
module=FouriER,
|
||||
criterion=torch.nn.BCELoss,
|
||||
optimizer=torch.optim.Adam,
|
||||
max_epochs=100,
|
||||
batch_size=128,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
paramsGrid = {
|
||||
'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
|
||||
}
|
||||
|
||||
grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
|
||||
if (args.test_only):
|
||||
save_path = os.path.join('./torch_saved', args.name)
|
||||
model.load_model(save_path)
|
||||
|
Loading…
Reference in New Issue
Block a user