add hyperparam to control frobenius
This commit is contained in:
		@@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
            self.cross_entropy_loss = nn.CrossEntropyLoss()
 | 
			
		||||
            self.smoothing = args.label_smoothing
 | 
			
		||||
            self.loss_fn = self.label_smoothed_cross_entropy
 | 
			
		||||
            self.frobenius_reg = args.weight_decay
 | 
			
		||||
            # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
			
		||||
        else:
 | 
			
		||||
            self.loss_fn = nn.CrossEntropyLoss()
 | 
			
		||||
@@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label)
 | 
			
		||||
 | 
			
		||||
        if self.smoothing is not None and self.smoothing != 0.0:
 | 
			
		||||
            loss += self.frobenius_norm_loss()
 | 
			
		||||
            loss += self.frobenius_reg * self.frobenius_norm_loss()
 | 
			
		||||
 | 
			
		||||
        if batch_idx == 0:
 | 
			
		||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user