From 937b6fcebe344601fd260836b288b5343d889a1e Mon Sep 17 00:00:00 2001 From: Cong Thanh Vu Date: Mon, 13 Feb 2023 05:31:32 +0000 Subject: [PATCH] add hyperparam to control frobenius --- lit_models/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 99653b8..41826e5 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel): self.cross_entropy_loss = nn.CrossEntropyLoss() self.smoothing = args.label_smoothing self.loss_fn = self.label_smoothed_cross_entropy + self.frobenius_reg = args.weight_decay # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing) else: self.loss_fn = nn.CrossEntropyLoss() @@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel): loss += self.loss_fn(mask_logits, label) if self.smoothing is not None and self.smoothing != 0.0: - loss += self.frobenius_norm_loss() + loss += self.frobenius_reg * self.frobenius_norm_loss() if batch_idx == 0: print('\n'.join(self.decode(batch['input_ids'][:4])))