only add fro norm once
This commit is contained in:
		@@ -135,7 +135,7 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
            _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
            bs = input_ids.shape[0]
 | 
			
		||||
            mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label)
 | 
			
		||||
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        label = batch.pop("label")
 | 
			
		||||
@@ -173,7 +173,10 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
        if self.args.bce:
 | 
			
		||||
            loss += self.loss_fn(mask_logits, labels)
 | 
			
		||||
        else:
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label)
 | 
			
		||||
 | 
			
		||||
        if self.smoothing is not None and self.smoothing != 0.0:
 | 
			
		||||
            loss += self.frobenius_norm_loss()
 | 
			
		||||
 | 
			
		||||
        if batch_idx == 0:
 | 
			
		||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user