add hyperparam to control frobenius
This commit is contained in:
parent
b5ebc32ead
commit
937b6fcebe
@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
||||||
self.smoothing = args.label_smoothing
|
self.smoothing = args.label_smoothing
|
||||||
self.loss_fn = self.label_smoothed_cross_entropy
|
self.loss_fn = self.label_smoothed_cross_entropy
|
||||||
|
self.frobenius_reg = args.weight_decay
|
||||||
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
||||||
else:
|
else:
|
||||||
self.loss_fn = nn.CrossEntropyLoss()
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
loss += self.loss_fn(mask_logits, label)
|
loss += self.loss_fn(mask_logits, label)
|
||||||
|
|
||||||
if self.smoothing is not None and self.smoothing != 0.0:
|
if self.smoothing is not None and self.smoothing != 0.0:
|
||||||
loss += self.frobenius_norm_loss()
|
loss += self.frobenius_reg * self.frobenius_norm_loss()
|
||||||
|
|
||||||
if batch_idx == 0:
|
if batch_idx == 0:
|
||||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||||
|
Loading…
Reference in New Issue
Block a user