diff --git a/main.py b/main.py index 625b6e8..d5d9329 100644 --- a/main.py +++ b/main.py @@ -675,8 +675,14 @@ if __name__ == "__main__": else: neg_ent = neg_ent.cpu() - dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent, np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1) - search = grid.fit(dataset, label) + inputs = [] + for i in len(sub): + input = {} + input['sub'] = sub[i] + input['rel'] = rel[i] + input['neg_ents'] = neg_ent[i] + inputs.append(input) + search = grid.fit(inputs, label) print("BEST SCORE: ", search.best_score_) print("BEST PARAMS: ", search.best_params_) if (args.test_only):