Compare commits
1 Commits
main
...
negative_s
Author | SHA1 | Date | |
---|---|---|---|
6cc55301ad |
36
.vscode/launch.json
vendored
Normal file
36
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: Current File",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"args": [
|
||||||
|
"--gpus", "1,",
|
||||||
|
"--max_epochs=16",
|
||||||
|
"--num_workers=32",
|
||||||
|
"--model_name_or_path", "bert-base-uncased",
|
||||||
|
"--accumulate_grad_batches", "1",
|
||||||
|
"--model_class", "BertKGC",
|
||||||
|
"--batch_size", "32",
|
||||||
|
"--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt",
|
||||||
|
"--pretrain", "0",
|
||||||
|
"--bce", "0",
|
||||||
|
"--check_val_every_n_epoch", "1",
|
||||||
|
"--data_dir", "dataset/FB15k-237",
|
||||||
|
"--eval_batch_size", "128",
|
||||||
|
"--max_seq_length", "128",
|
||||||
|
"--lr", "3e-5",
|
||||||
|
"--max_triplet", "64",
|
||||||
|
"--add_attn_bias", "True",
|
||||||
|
"--use_global_node", "True",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -73,9 +73,54 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
||||||
|
def create_negatives(self, batch):
|
||||||
|
negativeBatches = []
|
||||||
|
label = batch['label']
|
||||||
|
|
||||||
|
for i in range(label.shape[0]):
|
||||||
|
newBatch = {}
|
||||||
|
newBatch['attention_mask'] = None
|
||||||
|
newBatch['input_ids'] = torch.clone(batch['input_ids'])
|
||||||
|
newBatch['label'] = torch.zeros_like(batch['label'])
|
||||||
|
negativeBatches.append(newBatch)
|
||||||
|
|
||||||
|
entity_idx = []
|
||||||
|
self_label = []
|
||||||
|
for idx, l in enumerate(label):
|
||||||
|
decoded = self.decode([batch['input_ids'][idx]])[0].split(' ')
|
||||||
|
for j in range(1, len(decoded)):
|
||||||
|
if (decoded[j].startswith("[ENTITY_")):
|
||||||
|
entity_idx.append(j)
|
||||||
|
self_label.append(batch['input_ids'][idx][j])
|
||||||
|
break
|
||||||
|
|
||||||
|
for idx, lbl in enumerate(label):
|
||||||
|
for i in range(label.shape[0]):
|
||||||
|
if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl):
|
||||||
|
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl
|
||||||
|
else:
|
||||||
|
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i]
|
||||||
|
|
||||||
|
return negativeBatches
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||||
# embed();exit()
|
# embed();exit()
|
||||||
# print(self.optimizers().param_groups[1]['lr'])
|
# print(self.optimizers().param_groups[1]['lr'])
|
||||||
|
|
||||||
|
negativeBatches = self.create_negatives(batch)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
|
||||||
|
for negativeBatch in negativeBatches:
|
||||||
|
label = negativeBatch.pop("label")
|
||||||
|
input_ids = batch['input_ids']
|
||||||
|
|
||||||
|
logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits
|
||||||
|
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
|
||||||
|
loss += self.loss_fn(mask_logits, label)
|
||||||
|
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
label = batch.pop("label")
|
label = batch.pop("label")
|
||||||
pos = batch.pop("pos")
|
pos = batch.pop("pos")
|
||||||
@ -110,9 +155,9 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
|
|
||||||
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
|
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
|
||||||
if self.args.bce:
|
if self.args.bce:
|
||||||
loss = self.loss_fn(mask_logits, labels)
|
loss += self.loss_fn(mask_logits, labels)
|
||||||
else:
|
else:
|
||||||
loss = self.loss_fn(mask_logits, label)
|
loss += self.loss_fn(mask_logits, label)
|
||||||
|
|
||||||
if batch_idx == 0:
|
if batch_idx == 0:
|
||||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||||
|
2
main.py
2
main.py
@ -120,7 +120,7 @@ def main():
|
|||||||
callbacks = [early_callback, model_checkpoint]
|
callbacks = [early_callback, model_checkpoint]
|
||||||
|
|
||||||
# args.weights_summary = "full" # Print full summary of the model
|
# args.weights_summary = "full" # Print full summary of the model
|
||||||
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp")
|
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
|
||||||
|
|
||||||
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
||||||
trainer.fit(lit_model, datamodule=data)
|
trainer.fit(lit_model, datamodule=data)
|
||||||
|
6
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
6
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
@ -1,13 +1,13 @@
|
|||||||
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
|
nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \
|
||||||
--model_name_or_path bert-base-uncased \
|
--model_name_or_path bert-base-uncased \
|
||||||
--accumulate_grad_batches 1 \
|
--accumulate_grad_batches 1 \
|
||||||
--model_class BertKGC \
|
--model_class BertKGC \
|
||||||
--batch_size 128 \
|
--batch_size 64 \
|
||||||
--pretrain 1 \
|
--pretrain 1 \
|
||||||
--bce 0 \
|
--bce 0 \
|
||||||
--check_val_every_n_epoch 1 \
|
--check_val_every_n_epoch 1 \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--data_dir /kg_374/Relphormer/dataset/FB15k-237 \
|
--data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \
|
||||||
--eval_batch_size 256 \
|
--eval_batch_size 256 \
|
||||||
--max_seq_length 64 \
|
--max_seq_length 64 \
|
||||||
--lr 1e-4 \
|
--lr 1e-4 \
|
||||||
|
7
scripts/fb15k-237/fb15k-237.sh
Normal file → Executable file
7
scripts/fb15k-237/fb15k-237.sh
Normal file → Executable file
@ -1,13 +1,12 @@
|
|||||||
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
|
nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \
|
||||||
--model_name_or_path bert-base-uncased \
|
--model_name_or_path bert-base-uncased \
|
||||||
--accumulate_grad_batches 1 \
|
--accumulate_grad_batches 1 \
|
||||||
--model_class BertKGC \
|
--model_class BertKGC \
|
||||||
--batch_size 64 \
|
--batch_size 16 \
|
||||||
--checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \
|
--checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
|
||||||
--pretrain 0 \
|
--pretrain 0 \
|
||||||
--bce 0 \
|
--bce 0 \
|
||||||
--check_val_every_n_epoch 1 \
|
--check_val_every_n_epoch 1 \
|
||||||
--overwrite_cache \
|
|
||||||
--data_dir dataset/FB15k-237 \
|
--data_dir dataset/FB15k-237 \
|
||||||
--eval_batch_size 128 \
|
--eval_batch_size 128 \
|
||||||
--max_seq_length 128 \
|
--max_seq_length 128 \
|
||||||
|
Loading…
Reference in New Issue
Block a user