add label smoothing for cross entropy and add fro reg
This commit is contained in:
		@@ -48,7 +48,10 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
        if args.bce:
 | 
			
		||||
            self.loss_fn = torch.nn.BCEWithLogitsLoss()
 | 
			
		||||
        elif args.label_smoothing != 0.0:
 | 
			
		||||
            self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
			
		||||
            self.cross_entropy_loss = nn.CrossEntropyLoss()
 | 
			
		||||
            self.smoothing = args.label_smoothing
 | 
			
		||||
            self.loss_fn = self.label_smoothed_cross_entropy
 | 
			
		||||
            # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
			
		||||
        else:
 | 
			
		||||
            self.loss_fn = nn.CrossEntropyLoss()
 | 
			
		||||
        self.best_acc = 0
 | 
			
		||||
@@ -69,13 +72,71 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
        self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
 | 
			
		||||
        self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
 | 
			
		||||
        
 | 
			
		||||
    def label_smoothed_cross_entropy(self, logits, labels):
 | 
			
		||||
        num_classes = logits.size(1)
 | 
			
		||||
        one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
 | 
			
		||||
        one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (num_classes - 1)
 | 
			
		||||
        loss = self.cross_entropy_loss(logits, labels)
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def frobenius_norm_loss(self):
 | 
			
		||||
        frobenius_norm = 0.0
 | 
			
		||||
        for name, param in self.model.named_parameters():
 | 
			
		||||
            if 'bias' not in name:
 | 
			
		||||
                frobenius_norm += torch.norm(param, p='fro')
 | 
			
		||||
        return frobenius_norm
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.model(x)
 | 
			
		||||
 | 
			
		||||
    def create_negatives(self, batch):
 | 
			
		||||
        negativeBatches = []
 | 
			
		||||
        label = batch['label']
 | 
			
		||||
 | 
			
		||||
        for i in range(label.shape[0]):
 | 
			
		||||
            newBatch = {}
 | 
			
		||||
            newBatch['attention_mask'] = None
 | 
			
		||||
            newBatch['input_ids'] = torch.clone(batch['input_ids'])
 | 
			
		||||
            newBatch['label'] = torch.zeros_like(batch['label'])
 | 
			
		||||
            negativeBatches.append(newBatch)
 | 
			
		||||
 | 
			
		||||
        entity_idx = []
 | 
			
		||||
        self_label = []
 | 
			
		||||
        for idx, l in enumerate(label):
 | 
			
		||||
            decoded = self.decode([batch['input_ids'][idx]])[0].split(' ')
 | 
			
		||||
            for j in range(1, len(decoded)):
 | 
			
		||||
                if (decoded[j].startswith("[ENTITY_")):
 | 
			
		||||
                    entity_idx.append(j)
 | 
			
		||||
                    self_label.append(batch['input_ids'][idx][j])   
 | 
			
		||||
                    break
 | 
			
		||||
        
 | 
			
		||||
        for idx, lbl in enumerate(label):
 | 
			
		||||
            for i in range(label.shape[0]):
 | 
			
		||||
                if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl):
 | 
			
		||||
                    negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl
 | 
			
		||||
                else:
 | 
			
		||||
                    negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i]
 | 
			
		||||
                
 | 
			
		||||
        return negativeBatches
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        # embed();exit()
 | 
			
		||||
        # print(self.optimizers().param_groups[1]['lr'])
 | 
			
		||||
 | 
			
		||||
        negativeBatches = self.create_negatives(batch)
 | 
			
		||||
 | 
			
		||||
        loss = 0
 | 
			
		||||
 | 
			
		||||
        for negativeBatch in negativeBatches:
 | 
			
		||||
            label = negativeBatch.pop("label")
 | 
			
		||||
            input_ids = batch['input_ids']
 | 
			
		||||
            
 | 
			
		||||
            logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits
 | 
			
		||||
            _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
            bs = input_ids.shape[0]
 | 
			
		||||
            mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
 | 
			
		||||
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        label = batch.pop("label")
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
@@ -110,9 +171,9 @@ class TransformerLitModel(BaseLitModel):
 | 
			
		||||
 | 
			
		||||
        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
 | 
			
		||||
        if self.args.bce:
 | 
			
		||||
            loss = self.loss_fn(mask_logits, labels)
 | 
			
		||||
            loss += self.loss_fn(mask_logits, labels)
 | 
			
		||||
        else:
 | 
			
		||||
            loss = self.loss_fn(mask_logits, label)
 | 
			
		||||
            loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
 | 
			
		||||
 | 
			
		||||
        if batch_idx == 0:
 | 
			
		||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user