add hyperparam to control frobenius

This commit is contained in:
Cong Thanh Vu 2023-02-13 05:31:32 +00:00
parent b5ebc32ead
commit 937b6fcebe

View File

@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel):
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.smoothing = args.label_smoothing
self.loss_fn = self.label_smoothed_cross_entropy
self.frobenius_reg = args.weight_decay
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
else:
self.loss_fn = nn.CrossEntropyLoss()
@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel):
loss += self.loss_fn(mask_logits, label)
if self.smoothing is not None and self.smoothing != 0.0:
loss += self.frobenius_norm_loss()
loss += self.frobenius_reg * self.frobenius_norm_loss()
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))