add hyperparam to control frobenius
This commit is contained in:
parent
b5ebc32ead
commit
937b6fcebe
@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel):
|
||||
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
||||
self.smoothing = args.label_smoothing
|
||||
self.loss_fn = self.label_smoothed_cross_entropy
|
||||
self.frobenius_reg = args.weight_decay
|
||||
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
||||
else:
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel):
|
||||
loss += self.loss_fn(mask_logits, label)
|
||||
|
||||
if self.smoothing is not None and self.smoothing != 0.0:
|
||||
loss += self.frobenius_norm_loss()
|
||||
loss += self.frobenius_reg * self.frobenius_norm_loss()
|
||||
|
||||
if batch_idx == 0:
|
||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||
|
Loading…
Reference in New Issue
Block a user