From b5ebc32ead4f8b590b0e5193099b07acc23acecf Mon Sep 17 00:00:00 2001 From: Cong Thanh Vu Date: Sun, 12 Feb 2023 11:53:41 +0000 Subject: [PATCH] only add fro norm once --- lit_models/transformer.py | 7 +++++-- scripts/fb15k-237/fb15k-237.sh | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 407b9cf..99653b8 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -135,7 +135,7 @@ class TransformerLitModel(BaseLitModel): _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True) bs = input_ids.shape[0] mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed] - loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss() + loss += self.loss_fn(mask_logits, label) labels = batch.pop("labels") label = batch.pop("label") @@ -173,7 +173,10 @@ class TransformerLitModel(BaseLitModel): if self.args.bce: loss += self.loss_fn(mask_logits, labels) else: - loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss() + loss += self.loss_fn(mask_logits, label) + + if self.smoothing is not None and self.smoothing != 0.0: + loss += self.frobenius_norm_loss() if batch_idx == 0: print('\n'.join(self.decode(batch['input_ids'][:4]))) diff --git a/scripts/fb15k-237/fb15k-237.sh b/scripts/fb15k-237/fb15k-237.sh index eb8c31a..696fd4e 100755 --- a/scripts/fb15k-237/fb15k-237.sh +++ b/scripts/fb15k-237/fb15k-237.sh @@ -1,9 +1,9 @@ -nohup python -u main.py --gpus "2," --max_epochs=16 --num_workers=32 \ +nohup python -u main.py --gpus "3," --max_epochs=16 --num_workers=32 \ --model_name_or_path bert-base-uncased \ --accumulate_grad_batches 1 \ --model_class BertKGC \ --batch_size 16 \ - --checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \ + --checkpoint /root/kg_374/Relphormer_instance_1/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \ --pretrain 0 \ --bce 0 \ --check_val_every_n_epoch 1 \