add grid search

This commit is contained in:
thanhvc3 2023-05-17 14:04:48 +07:00
parent ddbaa2781f
commit f2052c2839

10
main.py
View File

@ -675,8 +675,14 @@ if __name__ == "__main__":
else:
neg_ent = neg_ent.cpu()
dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent, np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1)
search = grid.fit(dataset, label)
inputs = []
for i in len(sub):
input = {}
input['sub'] = sub[i]
input['rel'] = rel[i]
input['neg_ents'] = neg_ent[i]
inputs.append(input)
search = grid.fit(inputs, label)
print("BEST SCORE: ", search.best_score_)
print("BEST PARAMS: ", search.best_params_)
if (args.test_only):