diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 99653b8..41826e5 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel): self.cross_entropy_loss = nn.CrossEntropyLoss() self.smoothing = args.label_smoothing self.loss_fn = self.label_smoothed_cross_entropy + self.frobenius_reg = args.weight_decay # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing) else: self.loss_fn = nn.CrossEntropyLoss() @@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel): loss += self.loss_fn(mask_logits, label) if self.smoothing is not None and self.smoothing != 0.0: - loss += self.frobenius_norm_loss() + loss += self.frobenius_reg * self.frobenius_norm_loss() if batch_idx == 0: print('\n'.join(self.decode(batch['input_ids'][:4])))