add grid search

This commit is contained in:
thanhvc3 2023-05-17 13:38:54 +07:00
parent e5a343b0c5
commit 9d182abadb

View File

@ -670,7 +670,7 @@ if __name__ == "__main__":
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
batch, 'train')
dataset = np.stack([sub, rel, neg_ent, np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1)
dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent.cpu(), np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1)
search = grid.fit(dataset, label)
print("BEST SCORE: ", search.best_score_)
print("BEST PARAMS: ", search.best_params_)