add grid search
This commit is contained in:
parent
dfec7ff331
commit
9aa85307af
24
main.py
24
main.py
@ -650,10 +650,30 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
paramsGrid = {
|
||||
'optimizer__lr': [0.001, 0.01, 0.1, 0.2, 0.3],
|
||||
'optimizer__lr': [0.0001, 0.0003, 0.001],
|
||||
'optimizer__weight_decay': [1e-4, 1e-5, 1e-6],
|
||||
'module__hid_drop': [0.2, 0.5, 0.7],
|
||||
'module__embed_dim': [300, 400, 500],
|
||||
}
|
||||
|
||||
grid = GridSearchCV(estimator=estimator, param_grid= paramsGrid, n_jobs=-1, cv=3)
|
||||
grid = GridSearchCV(estimator=estimator, param_grid=paramsGrid, n_jobs=-1, cv=1)
|
||||
data = np.array(model.triples['train'])
|
||||
data = np.random.sample(0.2)
|
||||
dataloader = iter(DataLoader(
|
||||
TrainDataset(data, model.p),
|
||||
batch_size=len(data),
|
||||
shuffle=True,
|
||||
num_workers=max(0, model.p.num_workers),
|
||||
collate_fn=TrainDataset.collate_fn
|
||||
))
|
||||
for step, batch in dataloader:
|
||||
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
|
||||
batch, 'train')
|
||||
|
||||
dataset = np.stack([sub, rel, neg_ent, np.repeat(model.p.train_strategy)], axis = 1)
|
||||
search = grid.fit(dataset, label)
|
||||
print("BEST SCORE: ", search.best_score_)
|
||||
print("BEST PARAMS: ", search.best_params_)
|
||||
if (args.test_only):
|
||||
save_path = os.path.join('./torch_saved', args.name)
|
||||
model.load_model(save_path)
|
||||
|
@ -437,9 +437,16 @@ class TuckER(torch.nn.Module):
|
||||
|
||||
|
||||
class FouriER(torch.nn.Module):
|
||||
def __init__(self, params):
|
||||
def __init__(self, params, hid_drop = None, embed_dim = None):
|
||||
super(FouriER, self).__init__()
|
||||
|
||||
if hid_drop is not None:
|
||||
self.p.hid_drop = hid_drop
|
||||
if embed_dim is not None:
|
||||
self.p.ent_vec_dim = embed_dim
|
||||
self.p.rel_vec_dim = embed_dim
|
||||
self.p.embed_dim = embed_dim
|
||||
|
||||
self.p = params
|
||||
|
||||
image_h, image_w = self.p.image_h, self.p.image_w
|
||||
|
Loading…
Reference in New Issue
Block a user