add grid search

This commit is contained in:
thanhvc3 2023-05-17 13:10:36 +07:00
parent dfec7ff331
commit 9aa85307af
2 changed files with 30 additions and 3 deletions

24
main.py
View File

@ -650,10 +650,30 @@ if __name__ == "__main__":
)
paramsGrid = {
'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
'optimizer__lr': [0.0001, 0.0003, 0.001],
'optimizer__weight_decay': [1e-4, 1e-5, 1e-6],
'module__hid_drop': [0.2, 0.5, 0.7],
'module__embed_dim': [300, 400, 500],
}
grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
grid = GridSearchCV(estimator=estimator, param_grid=paramsGrid, n_jobs=-1, cv=1)
data = np.array(model.triples['train'])
data = np.random.sample(0.2)
dataloader = iter(DataLoader(
TrainDataset(data, model.p),
batch_size=len(data),
shuffle=True,
num_workers=max(0, model.p.num_workers),
collate_fn=TrainDataset.collate_fn
))
for step, batch in dataloader:
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
batch, 'train')
dataset = np.stack([sub, rel, neg_ent, np.repeat(model.p.train_strategy)], axis = 1)
search = grid.fit(dataset, label)
print("BEST SCORE: ", search.best_score_)
print("BEST PARAMS: ", search.best_params_)
if (args.test_only):
save_path = os.path.join('./torch_saved', args.name)
model.load_model(save_path)

View File

@ -437,9 +437,16 @@ class TuckER(torch.nn.Module):
class FouriER(torch.nn.Module):
def __init__(self, params):
def __init__(self, params, hid_drop = None, embed_dim = None):
super(FouriER, self).__init__()
if hid_drop is not None:
self.p.hid_drop = hid_drop
if embed_dim is not None:
self.p.ent_vec_dim = embed_dim
self.p.rel_vec_dim = embed_dim
self.p.embed_dim = embed_dim
self.p = params
image_h, image_w = self.p.image_h, self.p.image_w