only add fro norm once
This commit is contained in:
parent
f73bb5f1fd
commit
b5ebc32ead
@ -135,7 +135,7 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
|
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
|
||||||
loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
|
loss += self.loss_fn(mask_logits, label)
|
||||||
|
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
label = batch.pop("label")
|
label = batch.pop("label")
|
||||||
@ -173,7 +173,10 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
if self.args.bce:
|
if self.args.bce:
|
||||||
loss += self.loss_fn(mask_logits, labels)
|
loss += self.loss_fn(mask_logits, labels)
|
||||||
else:
|
else:
|
||||||
loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
|
loss += self.loss_fn(mask_logits, label)
|
||||||
|
|
||||||
|
if self.smoothing is not None and self.smoothing != 0.0:
|
||||||
|
loss += self.frobenius_norm_loss()
|
||||||
|
|
||||||
if batch_idx == 0:
|
if batch_idx == 0:
|
||||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
nohup python -u main.py --gpus "2," --max_epochs=16 --num_workers=32 \
|
nohup python -u main.py --gpus "3," --max_epochs=16 --num_workers=32 \
|
||||||
--model_name_or_path bert-base-uncased \
|
--model_name_or_path bert-base-uncased \
|
||||||
--accumulate_grad_batches 1 \
|
--accumulate_grad_batches 1 \
|
||||||
--model_class BertKGC \
|
--model_class BertKGC \
|
||||||
--batch_size 16 \
|
--batch_size 16 \
|
||||||
--checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
|
--checkpoint /root/kg_374/Relphormer_instance_1/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
|
||||||
--pretrain 0 \
|
--pretrain 0 \
|
||||||
--bce 0 \
|
--bce 0 \
|
||||||
--check_val_every_n_epoch 1 \
|
--check_val_every_n_epoch 1 \
|
||||||
|
Loading…
Reference in New Issue
Block a user