Compare commits

..

No commits in common. "negative_sampling_fro_reg_cross_entropy_loss" and "main" have entirely different histories.

5 changed files with 11 additions and 111 deletions

36
.vscode/launch.json vendored
View File

@ -1,36 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--gpus", "1,",
"--max_epochs=16",
"--num_workers=32",
"--model_name_or_path", "bert-base-uncased",
"--accumulate_grad_batches", "1",
"--model_class", "BertKGC",
"--batch_size", "32",
"--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt",
"--pretrain", "0",
"--bce", "0",
"--check_val_every_n_epoch", "1",
"--data_dir", "dataset/FB15k-237",
"--eval_batch_size", "128",
"--max_seq_length", "128",
"--lr", "3e-5",
"--max_triplet", "64",
"--add_attn_bias", "True",
"--use_global_node", "True",
]
}
]
}

View File

@ -48,11 +48,7 @@ class TransformerLitModel(BaseLitModel):
if args.bce:
self.loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.label_smoothing != 0.0:
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.smoothing = args.label_smoothing
self.loss_fn = self.label_smoothed_cross_entropy
self.frobenius_reg = args.weight_decay
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
else:
self.loss_fn = nn.CrossEntropyLoss()
self.best_acc = 0
@ -73,71 +69,13 @@ class TransformerLitModel(BaseLitModel):
self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
def label_smoothed_cross_entropy(self, logits, labels):
num_classes = logits.size(1)
one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (num_classes - 1)
loss = self.cross_entropy_loss(logits, labels)
return loss
def frobenius_norm_loss(self):
frobenius_norm = 0.0
for name, param in self.model.named_parameters():
if 'bias' not in name:
frobenius_norm += torch.norm(param, p='fro')
return frobenius_norm
def forward(self, x):
return self.model(x)
def create_negatives(self, batch):
negativeBatches = []
label = batch['label']
for i in range(label.shape[0]):
newBatch = {}
newBatch['attention_mask'] = None
newBatch['input_ids'] = torch.clone(batch['input_ids'])
newBatch['label'] = torch.zeros_like(batch['label'])
negativeBatches.append(newBatch)
entity_idx = []
self_label = []
for idx, l in enumerate(label):
decoded = self.decode([batch['input_ids'][idx]])[0].split(' ')
for j in range(1, len(decoded)):
if (decoded[j].startswith("[ENTITY_")):
entity_idx.append(j)
self_label.append(batch['input_ids'][idx][j])
break
for idx, lbl in enumerate(label):
for i in range(label.shape[0]):
if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl):
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl
else:
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i]
return negativeBatches
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
# embed();exit()
# print(self.optimizers().param_groups[1]['lr'])
negativeBatches = self.create_negatives(batch)
loss = 0
for negativeBatch in negativeBatches:
label = negativeBatch.pop("label")
input_ids = batch['input_ids']
logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
loss += self.loss_fn(mask_logits, label)
labels = batch.pop("labels")
label = batch.pop("label")
pos = batch.pop("pos")
@ -172,12 +110,9 @@ class TransformerLitModel(BaseLitModel):
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
if self.args.bce:
loss += self.loss_fn(mask_logits, labels)
loss = self.loss_fn(mask_logits, labels)
else:
loss += self.loss_fn(mask_logits, label)
if self.smoothing is not None and self.smoothing != 0.0:
loss += self.frobenius_reg * self.frobenius_norm_loss()
loss = self.loss_fn(mask_logits, label)
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))

View File

@ -120,7 +120,7 @@ def main():
callbacks = [early_callback, model_checkpoint]
# args.weights_summary = "full" # Print full summary of the model
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp")
if "EntityEmbedding" not in lit_model.__class__.__name__:
trainer.fit(lit_model, datamodule=data)

6
pretrain/scripts/pretrain_fb15k-237.sh Executable file → Normal file
View File

@ -1,13 +1,13 @@
nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
--model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \
--model_class BertKGC \
--batch_size 64 \
--batch_size 128 \
--pretrain 1 \
--bce 0 \
--check_val_every_n_epoch 1 \
--overwrite_cache \
--data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \
--data_dir /kg_374/Relphormer/dataset/FB15k-237 \
--eval_batch_size 256 \
--max_seq_length 64 \
--lr 1e-4 \

7
scripts/fb15k-237/fb15k-237.sh Executable file → Normal file
View File

@ -1,12 +1,13 @@
nohup python -u main.py --gpus "3," --max_epochs=16 --num_workers=32 \
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
--model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \
--model_class BertKGC \
--batch_size 16 \
--checkpoint /root/kg_374/Relphormer_instance_1/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
--batch_size 64 \
--checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \
--pretrain 0 \
--bce 0 \
--check_val_every_n_epoch 1 \
--overwrite_cache \
--data_dir dataset/FB15k-237 \
--eval_batch_size 128 \
--max_seq_length 128 \