add grid search
This commit is contained in:
parent
5f1518cfd9
commit
dfec7ff331
21
main.py
21
main.py
@ -625,7 +625,7 @@ if __name__ == "__main__":
|
|||||||
default='./config/', help='Config directory')
|
default='./config/', help='Config directory')
|
||||||
|
|
||||||
parser.add_argument('--test_only', action='store_true', default=False)
|
parser.add_argument('--test_only', action='store_true', default=False)
|
||||||
parser.add_argument('--filtered', action='store_true', default=False)
|
parser.add_argument('--grid_search', action='store_true', default=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -635,6 +635,25 @@ if __name__ == "__main__":
|
|||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
model = Main(args)
|
model = Main(args)
|
||||||
|
|
||||||
|
if (args.grid_search):
|
||||||
|
from sklearn.model_selection import GridSearchCV
|
||||||
|
from skorch import NeuralNet
|
||||||
|
|
||||||
|
estimator = NeuralNet(
|
||||||
|
module=FouriER,
|
||||||
|
criterion=torch.nn.BCELoss,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
max_epochs=100,
|
||||||
|
batch_size=128,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
paramsGrid = {
|
||||||
|
'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
|
||||||
|
}
|
||||||
|
|
||||||
|
grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
|
||||||
if (args.test_only):
|
if (args.test_only):
|
||||||
save_path = os.path.join('./torch_saved', args.name)
|
save_path = os.path.join('./torch_saved', args.name)
|
||||||
model.load_model(save_path)
|
model.load_model(save_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user