diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..9a5e950 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,36 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--gpus", "1,", + "--max_epochs=16", + "--num_workers=32", + "--model_name_or_path", "bert-base-uncased", + "--accumulate_grad_batches", "1", + "--model_class", "BertKGC", + "--batch_size", "32", + "--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt", + "--pretrain", "0", + "--bce", "0", + "--check_val_every_n_epoch", "1", + "--data_dir", "dataset/FB15k-237", + "--eval_batch_size", "128", + "--max_seq_length", "128", + "--lr", "3e-5", + "--max_triplet", "64", + "--add_attn_bias", "True", + "--use_global_node", "True", + ] + } + ] +} \ No newline at end of file diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 6075a28..407b9cf 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -48,7 +48,10 @@ class TransformerLitModel(BaseLitModel): if args.bce: self.loss_fn = torch.nn.BCEWithLogitsLoss() elif args.label_smoothing != 0.0: - self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing) + self.cross_entropy_loss = nn.CrossEntropyLoss() + self.smoothing = args.label_smoothing + self.loss_fn = self.label_smoothed_cross_entropy + # self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing) else: self.loss_fn = nn.CrossEntropyLoss() self.best_acc = 0 @@ -69,13 +72,71 @@ class TransformerLitModel(BaseLitModel): self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0) self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads) + def label_smoothed_cross_entropy(self, logits, labels): + num_classes = logits.size(1) + one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1) + one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (num_classes - 1) + loss = self.cross_entropy_loss(logits, labels) + return loss + + def frobenius_norm_loss(self): + frobenius_norm = 0.0 + for name, param in self.model.named_parameters(): + if 'bias' not in name: + frobenius_norm += torch.norm(param, p='fro') + return frobenius_norm def forward(self, x): return self.model(x) + def create_negatives(self, batch): + negativeBatches = [] + label = batch['label'] + + for i in range(label.shape[0]): + newBatch = {} + newBatch['attention_mask'] = None + newBatch['input_ids'] = torch.clone(batch['input_ids']) + newBatch['label'] = torch.zeros_like(batch['label']) + negativeBatches.append(newBatch) + + entity_idx = [] + self_label = [] + for idx, l in enumerate(label): + decoded = self.decode([batch['input_ids'][idx]])[0].split(' ') + for j in range(1, len(decoded)): + if (decoded[j].startswith("[ENTITY_")): + entity_idx.append(j) + self_label.append(batch['input_ids'][idx][j]) + break + + for idx, lbl in enumerate(label): + for i in range(label.shape[0]): + if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl): + negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl + else: + negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i] + + return negativeBatches + def training_step(self, batch, batch_idx): # pylint: disable=unused-argument # embed();exit() # print(self.optimizers().param_groups[1]['lr']) + + negativeBatches = self.create_negatives(batch) + + loss = 0 + + for negativeBatch in negativeBatches: + label = negativeBatch.pop("label") + input_ids = batch['input_ids'] + + logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits + _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True) + bs = input_ids.shape[0] + mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed] + loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss() + labels = batch.pop("labels") label = batch.pop("label") pos = batch.pop("pos") @@ -110,9 +171,9 @@ class TransformerLitModel(BaseLitModel): assert mask_idx.shape[0] == bs, "only one mask in sequence!" if self.args.bce: - loss = self.loss_fn(mask_logits, labels) + loss += self.loss_fn(mask_logits, labels) else: - loss = self.loss_fn(mask_logits, label) + loss += self.loss_fn(mask_logits, label) + self.frobenius_norm_loss() if batch_idx == 0: print('\n'.join(self.decode(batch['input_ids'][:4]))) diff --git a/main.py b/main.py index b9c8a71..7f3ffb4 100644 --- a/main.py +++ b/main.py @@ -120,7 +120,7 @@ def main(): callbacks = [early_callback, model_checkpoint] # args.weights_summary = "full" # Print full summary of the model - trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp") + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") if "EntityEmbedding" not in lit_model.__class__.__name__: trainer.fit(lit_model, datamodule=data) diff --git a/pretrain/scripts/pretrain_fb15k-237.sh b/pretrain/scripts/pretrain_fb15k-237.sh old mode 100644 new mode 100755 index 59d24c4..f303195 --- a/pretrain/scripts/pretrain_fb15k-237.sh +++ b/pretrain/scripts/pretrain_fb15k-237.sh @@ -1,13 +1,13 @@ -nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \ +nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \ --model_name_or_path bert-base-uncased \ --accumulate_grad_batches 1 \ --model_class BertKGC \ - --batch_size 128 \ + --batch_size 64 \ --pretrain 1 \ --bce 0 \ --check_val_every_n_epoch 1 \ --overwrite_cache \ - --data_dir /kg_374/Relphormer/dataset/FB15k-237 \ + --data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \ --eval_batch_size 256 \ --max_seq_length 64 \ --lr 1e-4 \ diff --git a/scripts/fb15k-237/fb15k-237.sh b/scripts/fb15k-237/fb15k-237.sh old mode 100644 new mode 100755 index ac9bc99..eb8c31a --- a/scripts/fb15k-237/fb15k-237.sh +++ b/scripts/fb15k-237/fb15k-237.sh @@ -1,13 +1,12 @@ -nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \ +nohup python -u main.py --gpus "2," --max_epochs=16 --num_workers=32 \ --model_name_or_path bert-base-uncased \ --accumulate_grad_batches 1 \ --model_class BertKGC \ - --batch_size 64 \ - --checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \ + --batch_size 16 \ + --checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \ --pretrain 0 \ --bce 0 \ --check_val_every_n_epoch 1 \ - --overwrite_cache \ --data_dir dataset/FB15k-237 \ --eval_batch_size 128 \ --max_seq_length 128 \