From f2052c2839448666dbff65a15c410607776fea5b Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Wed, 17 May 2023 14:04:48 +0700 Subject: [PATCH] add grid search --- main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 625b6e8..d5d9329 100644 --- a/main.py +++ b/main.py @@ -675,8 +675,14 @@ if __name__ == "__main__": else: neg_ent = neg_ent.cpu() - dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent, np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1) - search = grid.fit(dataset, label) + inputs = [] + for i in len(sub): + input = {} + input['sub'] = sub[i] + input['rel'] = rel[i] + input['neg_ents'] = neg_ent[i] + inputs.append(input) + search = grid.fit(inputs, label) print("BEST SCORE: ", search.best_score_) print("BEST PARAMS: ", search.best_params_) if (args.test_only):