add grid search
This commit is contained in:
parent
9d182abadb
commit
d912e0a225
7
main.py
7
main.py
@ -669,8 +669,13 @@ if __name__ == "__main__":
|
|||||||
for step, batch in enumerate(dataloader):
|
for step, batch in enumerate(dataloader):
|
||||||
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
|
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
|
||||||
batch, 'train')
|
batch, 'train')
|
||||||
|
|
||||||
|
if (neg_ent is None):
|
||||||
|
neg_ent = np.repeat(None, repeats=len(sub))
|
||||||
|
else:
|
||||||
|
neg_ent = neg_ent.cpu()
|
||||||
|
|
||||||
dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent.cpu(), np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1)
|
dataset = np.stack([sub.cpu(), rel.cpu(), neg_ent, np.repeat(model.p.train_strategy, repeats=len(sub))], axis = 1)
|
||||||
search = grid.fit(dataset, label)
|
search = grid.fit(dataset, label)
|
||||||
print("BEST SCORE: ", search.best_score_)
|
print("BEST SCORE: ", search.best_score_)
|
||||||
print("BEST PARAMS: ", search.best_params_)
|
print("BEST PARAMS: ", search.best_params_)
|
||||||
|
Loading…
Reference in New Issue
Block a user