From 6cc55301ad512f5e63e6250f521aafb9b11ecad7 Mon Sep 17 00:00:00 2001 From: Cong Thanh Vu Date: Sun, 12 Feb 2023 10:57:29 +0000 Subject: [PATCH] apply ns --- .vscode/launch.json | 36 +++++++++++++++++++ lit_models/transformer.py | 49 ++++++++++++++++++++++++-- main.py | 2 +- pretrain/scripts/pretrain_fb15k-237.sh | 6 ++-- scripts/fb15k-237/fb15k-237.sh | 7 ++-- 5 files changed, 90 insertions(+), 10 deletions(-) create mode 100644 .vscode/launch.json mode change 100644 => 100755 pretrain/scripts/pretrain_fb15k-237.sh mode change 100644 => 100755 scripts/fb15k-237/fb15k-237.sh diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..9a5e950 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,36 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--gpus", "1,", + "--max_epochs=16", + "--num_workers=32", + "--model_name_or_path", "bert-base-uncased", + "--accumulate_grad_batches", "1", + "--model_class", "BertKGC", + "--batch_size", "32", + "--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt", + "--pretrain", "0", + "--bce", "0", + "--check_val_every_n_epoch", "1", + "--data_dir", "dataset/FB15k-237", + "--eval_batch_size", "128", + "--max_seq_length", "128", + "--lr", "3e-5", + "--max_triplet", "64", + "--add_attn_bias", "True", + "--use_global_node", "True", + ] + } + ] +} \ No newline at end of file diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 6075a28..eaddef0 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -73,9 +73,54 @@ class TransformerLitModel(BaseLitModel): def forward(self, x): return self.model(x) + def create_negatives(self, batch): + negativeBatches = [] + label = batch['label'] + + for i in range(label.shape[0]): + newBatch = {} + newBatch['attention_mask'] = None + newBatch['input_ids'] = torch.clone(batch['input_ids']) + newBatch['label'] = torch.zeros_like(batch['label']) + negativeBatches.append(newBatch) + + entity_idx = [] + self_label = [] + for idx, l in enumerate(label): + decoded = self.decode([batch['input_ids'][idx]])[0].split(' ') + for j in range(1, len(decoded)): + if (decoded[j].startswith("[ENTITY_")): + entity_idx.append(j) + self_label.append(batch['input_ids'][idx][j]) + break + + for idx, lbl in enumerate(label): + for i in range(label.shape[0]): + if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl): + negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl + else: + negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i] + + return negativeBatches + def training_step(self, batch, batch_idx): # pylint: disable=unused-argument # embed();exit() # print(self.optimizers().param_groups[1]['lr']) + + negativeBatches = self.create_negatives(batch) + + loss = 0 + + for negativeBatch in negativeBatches: + label = negativeBatch.pop("label") + input_ids = batch['input_ids'] + + logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits + _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True) + bs = input_ids.shape[0] + mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed] + loss += self.loss_fn(mask_logits, label) + labels = batch.pop("labels") label = batch.pop("label") pos = batch.pop("pos") @@ -110,9 +155,9 @@ class TransformerLitModel(BaseLitModel): assert mask_idx.shape[0] == bs, "only one mask in sequence!" if self.args.bce: - loss = self.loss_fn(mask_logits, labels) + loss += self.loss_fn(mask_logits, labels) else: - loss = self.loss_fn(mask_logits, label) + loss += self.loss_fn(mask_logits, label) if batch_idx == 0: print('\n'.join(self.decode(batch['input_ids'][:4]))) diff --git a/main.py b/main.py index b9c8a71..7f3ffb4 100644 --- a/main.py +++ b/main.py @@ -120,7 +120,7 @@ def main(): callbacks = [early_callback, model_checkpoint] # args.weights_summary = "full" # Print full summary of the model - trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp") + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") if "EntityEmbedding" not in lit_model.__class__.__name__: trainer.fit(lit_model, datamodule=data) diff --git a/pretrain/scripts/pretrain_fb15k-237.sh b/pretrain/scripts/pretrain_fb15k-237.sh old mode 100644 new mode 100755 index 59d24c4..f303195 --- a/pretrain/scripts/pretrain_fb15k-237.sh +++ b/pretrain/scripts/pretrain_fb15k-237.sh @@ -1,13 +1,13 @@ -nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \ +nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \ --model_name_or_path bert-base-uncased \ --accumulate_grad_batches 1 \ --model_class BertKGC \ - --batch_size 128 \ + --batch_size 64 \ --pretrain 1 \ --bce 0 \ --check_val_every_n_epoch 1 \ --overwrite_cache \ - --data_dir /kg_374/Relphormer/dataset/FB15k-237 \ + --data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \ --eval_batch_size 256 \ --max_seq_length 64 \ --lr 1e-4 \ diff --git a/scripts/fb15k-237/fb15k-237.sh b/scripts/fb15k-237/fb15k-237.sh old mode 100644 new mode 100755 index ac9bc99..a64f7d5 --- a/scripts/fb15k-237/fb15k-237.sh +++ b/scripts/fb15k-237/fb15k-237.sh @@ -1,13 +1,12 @@ -nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \ +nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \ --model_name_or_path bert-base-uncased \ --accumulate_grad_batches 1 \ --model_class BertKGC \ - --batch_size 64 \ - --checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \ + --batch_size 16 \ + --checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \ --pretrain 0 \ --bce 0 \ --check_val_every_n_epoch 1 \ - --overwrite_cache \ --data_dir dataset/FB15k-237 \ --eval_batch_size 128 \ --max_seq_length 128 \