add label smoothing for cross entropy and add fro reg
This commit is contained in:
		
							
								
								
									
										36
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					    // Use IntelliSense to learn about possible attributes.
 | 
				
			||||||
 | 
					    // Hover to view descriptions of existing attributes.
 | 
				
			||||||
 | 
					    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
 | 
				
			||||||
 | 
					    "version": "0.2.0",
 | 
				
			||||||
 | 
					    "configurations": [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "name": "Python: Current File",
 | 
				
			||||||
 | 
					            "type": "python",
 | 
				
			||||||
 | 
					            "request": "launch",
 | 
				
			||||||
 | 
					            "program": "${file}",
 | 
				
			||||||
 | 
					            "console": "integratedTerminal",
 | 
				
			||||||
 | 
					            "justMyCode": true,
 | 
				
			||||||
 | 
					            "args": [
 | 
				
			||||||
 | 
					                "--gpus", "1,", 
 | 
				
			||||||
 | 
					                "--max_epochs=16",  
 | 
				
			||||||
 | 
					                "--num_workers=32", 
 | 
				
			||||||
 | 
					                "--model_name_or_path",  "bert-base-uncased",
 | 
				
			||||||
 | 
					                "--accumulate_grad_batches", "1", 
 | 
				
			||||||
 | 
					                "--model_class", "BertKGC",
 | 
				
			||||||
 | 
					                "--batch_size", "32",
 | 
				
			||||||
 | 
					                "--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt",
 | 
				
			||||||
 | 
					                "--pretrain", "0",
 | 
				
			||||||
 | 
					                "--bce", "0",
 | 
				
			||||||
 | 
					                "--check_val_every_n_epoch", "1",
 | 
				
			||||||
 | 
					                "--data_dir", "dataset/FB15k-237", 
 | 
				
			||||||
 | 
					                "--eval_batch_size", "128",
 | 
				
			||||||
 | 
					                "--max_seq_length", "128",
 | 
				
			||||||
 | 
					                "--lr", "3e-5",
 | 
				
			||||||
 | 
					                "--max_triplet", "64",
 | 
				
			||||||
 | 
					                "--add_attn_bias", "True",
 | 
				
			||||||
 | 
					                "--use_global_node", "True",
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -48,7 +48,10 @@ class TransformerLitModel(BaseLitModel):
 | 
				
			|||||||
        if args.bce:
 | 
					        if args.bce:
 | 
				
			||||||
            self.loss_fn = torch.nn.BCEWithLogitsLoss()
 | 
					            self.loss_fn = torch.nn.BCEWithLogitsLoss()
 | 
				
			||||||
        elif args.label_smoothing != 0.0:
 | 
					        elif args.label_smoothing != 0.0:
 | 
				
			||||||
            self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
					            self.cross_entropy_loss = nn.CrossEntropyLoss()
 | 
				
			||||||
 | 
					            self.smoothing = args.label_smoothing
 | 
				
			||||||
 | 
					            self.loss_fn = self.label_smoothed_cross_entropy
 | 
				
			||||||
 | 
					            # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.loss_fn = nn.CrossEntropyLoss()
 | 
					            self.loss_fn = nn.CrossEntropyLoss()
 | 
				
			||||||
        self.best_acc = 0
 | 
					        self.best_acc = 0
 | 
				
			||||||
@@ -69,13 +72,71 @@ class TransformerLitModel(BaseLitModel):
 | 
				
			|||||||
        self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
 | 
					        self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
 | 
				
			||||||
        self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
 | 
					        self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					    def label_smoothed_cross_entropy(self, logits, labels):
 | 
				
			||||||
 | 
					        num_classes = logits.size(1)
 | 
				
			||||||
 | 
					        one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
 | 
				
			||||||
 | 
					        one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (num_classes - 1)
 | 
				
			||||||
 | 
					        loss = self.cross_entropy_loss(logits, labels)
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def frobenius_norm_loss(self):
 | 
				
			||||||
 | 
					        frobenius_norm = 0.0
 | 
				
			||||||
 | 
					        for name, param in self.model.named_parameters():
 | 
				
			||||||
 | 
					            if 'bias' not in name:
 | 
				
			||||||
 | 
					                frobenius_norm += torch.norm(param, p='fro')
 | 
				
			||||||
 | 
					        return frobenius_norm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        return self.model(x)
 | 
					        return self.model(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_negatives(self, batch):
 | 
				
			||||||
 | 
					        negativeBatches = []
 | 
				
			||||||
 | 
					        label = batch['label']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(label.shape[0]):
 | 
				
			||||||
 | 
					            newBatch = {}
 | 
				
			||||||
 | 
					            newBatch['attention_mask'] = None
 | 
				
			||||||
 | 
					            newBatch['input_ids'] = torch.clone(batch['input_ids'])
 | 
				
			||||||
 | 
					            newBatch['label'] = torch.zeros_like(batch['label'])
 | 
				
			||||||
 | 
					            negativeBatches.append(newBatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entity_idx = []
 | 
				
			||||||
 | 
					        self_label = []
 | 
				
			||||||
 | 
					        for idx, l in enumerate(label):
 | 
				
			||||||
 | 
					            decoded = self.decode([batch['input_ids'][idx]])[0].split(' ')
 | 
				
			||||||
 | 
					            for j in range(1, len(decoded)):
 | 
				
			||||||
 | 
					                if (decoded[j].startswith("[ENTITY_")):
 | 
				
			||||||
 | 
					                    entity_idx.append(j)
 | 
				
			||||||
 | 
					                    self_label.append(batch['input_ids'][idx][j])   
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        for idx, lbl in enumerate(label):
 | 
				
			||||||
 | 
					            for i in range(label.shape[0]):
 | 
				
			||||||
 | 
					                if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl):
 | 
				
			||||||
 | 
					                    negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i]
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        return negativeBatches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
					    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
				
			||||||
        # embed();exit()
 | 
					        # embed();exit()
 | 
				
			||||||
        # print(self.optimizers().param_groups[1]['lr'])
 | 
					        # print(self.optimizers().param_groups[1]['lr'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        negativeBatches = self.create_negatives(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for negativeBatch in negativeBatches:
 | 
				
			||||||
 | 
					            label = negativeBatch.pop("label")
 | 
				
			||||||
 | 
					            input_ids = batch['input_ids']
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits
 | 
				
			||||||
 | 
					            _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
				
			||||||
 | 
					            bs = input_ids.shape[0]
 | 
				
			||||||
 | 
					            mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
 | 
				
			||||||
 | 
					            loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        labels = batch.pop("labels")
 | 
					        labels = batch.pop("labels")
 | 
				
			||||||
        label = batch.pop("label")
 | 
					        label = batch.pop("label")
 | 
				
			||||||
        pos = batch.pop("pos")
 | 
					        pos = batch.pop("pos")
 | 
				
			||||||
@@ -110,9 +171,9 @@ class TransformerLitModel(BaseLitModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
 | 
					        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
 | 
				
			||||||
        if self.args.bce:
 | 
					        if self.args.bce:
 | 
				
			||||||
            loss = self.loss_fn(mask_logits, labels)
 | 
					            loss += self.loss_fn(mask_logits, labels)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            loss = self.loss_fn(mask_logits, label)
 | 
					            loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if batch_idx == 0:
 | 
					        if batch_idx == 0:
 | 
				
			||||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
					            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								main.py
									
									
									
									
									
								
							@@ -120,7 +120,7 @@ def main():
 | 
				
			|||||||
    callbacks = [early_callback, model_checkpoint]
 | 
					    callbacks = [early_callback, model_checkpoint]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # args.weights_summary = "full"  # Print full summary of the model
 | 
					    # args.weights_summary = "full"  # Print full summary of the model
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp")
 | 
					    trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    if "EntityEmbedding" not in lit_model.__class__.__name__:
 | 
					    if "EntityEmbedding" not in lit_model.__class__.__name__:
 | 
				
			||||||
        trainer.fit(lit_model, datamodule=data)
 | 
					        trainer.fit(lit_model, datamodule=data)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										6
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										6
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							@@ -1,13 +1,13 @@
 | 
				
			|||||||
nohup python -u main.py --gpus "1" --max_epochs=16  --num_workers=32 \
 | 
					nohup python -u main.py --gpus "1," --max_epochs=16  --num_workers=32 \
 | 
				
			||||||
   --model_name_or_path  bert-base-uncased \
 | 
					   --model_name_or_path  bert-base-uncased \
 | 
				
			||||||
   --accumulate_grad_batches 1 \
 | 
					   --accumulate_grad_batches 1 \
 | 
				
			||||||
   --model_class BertKGC \
 | 
					   --model_class BertKGC \
 | 
				
			||||||
   --batch_size 128 \
 | 
					   --batch_size 64 \
 | 
				
			||||||
   --pretrain 1 \
 | 
					   --pretrain 1 \
 | 
				
			||||||
   --bce 0 \
 | 
					   --bce 0 \
 | 
				
			||||||
   --check_val_every_n_epoch 1 \
 | 
					   --check_val_every_n_epoch 1 \
 | 
				
			||||||
   --overwrite_cache \
 | 
					   --overwrite_cache \
 | 
				
			||||||
   --data_dir /kg_374/Relphormer/dataset/FB15k-237 \
 | 
					   --data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \
 | 
				
			||||||
   --eval_batch_size 256 \
 | 
					   --eval_batch_size 256 \
 | 
				
			||||||
   --max_seq_length 64 \
 | 
					   --max_seq_length 64 \
 | 
				
			||||||
   --lr 1e-4 \
 | 
					   --lr 1e-4 \
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										7
									
								
								scripts/fb15k-237/fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										7
									
								
								scripts/fb15k-237/fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							@@ -1,13 +1,12 @@
 | 
				
			|||||||
nohup python -u main.py --gpus "1" --max_epochs=16  --num_workers=32 \
 | 
					nohup python -u main.py --gpus "2," --max_epochs=16  --num_workers=32 \
 | 
				
			||||||
   --model_name_or_path  bert-base-uncased \
 | 
					   --model_name_or_path  bert-base-uncased \
 | 
				
			||||||
   --accumulate_grad_batches 1 \
 | 
					   --accumulate_grad_batches 1 \
 | 
				
			||||||
   --model_class BertKGC \
 | 
					   --model_class BertKGC \
 | 
				
			||||||
   --batch_size 64 \
 | 
					   --batch_size 16 \
 | 
				
			||||||
   --checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \
 | 
					   --checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
 | 
				
			||||||
   --pretrain 0 \
 | 
					   --pretrain 0 \
 | 
				
			||||||
   --bce 0 \
 | 
					   --bce 0 \
 | 
				
			||||||
   --check_val_every_n_epoch 1 \
 | 
					   --check_val_every_n_epoch 1 \
 | 
				
			||||||
   --overwrite_cache \
 | 
					 | 
				
			||||||
   --data_dir dataset/FB15k-237 \
 | 
					   --data_dir dataset/FB15k-237 \
 | 
				
			||||||
   --eval_batch_size 128 \
 | 
					   --eval_batch_size 128 \
 | 
				
			||||||
   --max_seq_length 128 \
 | 
					   --max_seq_length 128 \
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user