add hyperparam to control frobenius
This commit is contained in:
		@@ -51,6 +51,7 @@ class TransformerLitModel(BaseLitModel):
 | 
				
			|||||||
            self.cross_entropy_loss = nn.CrossEntropyLoss()
 | 
					            self.cross_entropy_loss = nn.CrossEntropyLoss()
 | 
				
			||||||
            self.smoothing = args.label_smoothing
 | 
					            self.smoothing = args.label_smoothing
 | 
				
			||||||
            self.loss_fn = self.label_smoothed_cross_entropy
 | 
					            self.loss_fn = self.label_smoothed_cross_entropy
 | 
				
			||||||
 | 
					            self.frobenius_reg = args.weight_decay
 | 
				
			||||||
            # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
					            # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.loss_fn = nn.CrossEntropyLoss()
 | 
					            self.loss_fn = nn.CrossEntropyLoss()
 | 
				
			||||||
@@ -176,7 +177,7 @@ class TransformerLitModel(BaseLitModel):
 | 
				
			|||||||
            loss += self.loss_fn(mask_logits, label)
 | 
					            loss += self.loss_fn(mask_logits, label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.smoothing is not None and self.smoothing != 0.0:
 | 
					        if self.smoothing is not None and self.smoothing != 0.0:
 | 
				
			||||||
            loss += self.frobenius_norm_loss()
 | 
					            loss += self.frobenius_reg * self.frobenius_norm_loss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if batch_idx == 0:
 | 
					        if batch_idx == 0:
 | 
				
			||||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
					            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user