3 Commits

2 changed files with 23 additions and 3 deletions

View File

@ -48,7 +48,11 @@ class TransformerLitModel(BaseLitModel):
if args.bce:
self.loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.label_smoothing != 0.0:
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.smoothing = args.label_smoothing
self.loss_fn = self.label_smoothed_cross_entropy
self.frobenius_reg = args.weight_decay
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
else:
self.loss_fn = nn.CrossEntropyLoss()
self.best_acc = 0
@ -69,6 +73,19 @@ class TransformerLitModel(BaseLitModel):
self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
def label_smoothed_cross_entropy(self, logits, labels):
num_classes = logits.size(1)
one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (num_classes - 1)
loss = self.cross_entropy_loss(logits, labels)
return loss
def frobenius_norm_loss(self):
frobenius_norm = 0.0
for name, param in self.model.named_parameters():
if 'bias' not in name:
frobenius_norm += torch.norm(param, p='fro')
return frobenius_norm
def forward(self, x):
return self.model(x)
@ -159,6 +176,9 @@ class TransformerLitModel(BaseLitModel):
else:
loss += self.loss_fn(mask_logits, label)
if self.smoothing is not None and self.smoothing != 0.0:
loss += self.frobenius_reg * self.frobenius_norm_loss()
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))

View File

@ -1,9 +1,9 @@
nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \
nohup python -u main.py --gpus "3," --max_epochs=16 --num_workers=32 \
--model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \
--model_class BertKGC \
--batch_size 16 \
--checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
--checkpoint /root/kg_374/Relphormer_instance_1/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
--pretrain 0 \
--bce 0 \
--check_val_every_n_epoch 1 \