only add fro norm once

This commit is contained in:
2023-02-12 11:53:41 +00:00
parent f73bb5f1fd
commit b5ebc32ead
2 changed files with 7 additions and 4 deletions

View File

@ -135,7 +135,7 @@ class TransformerLitModel(BaseLitModel):
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
loss += self.loss_fn(mask_logits, label)
labels = batch.pop("labels")
label = batch.pop("label")
@ -173,7 +173,10 @@ class TransformerLitModel(BaseLitModel):
if self.args.bce:
loss += self.loss_fn(mask_logits, labels)
else:
loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
loss += self.loss_fn(mask_logits, label)
if self.smoothing is not None and self.smoothing != 0.0:
loss += self.frobenius_norm_loss()
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))