Compare commits
1 Commits
negative_s
...
negative_s
Author | SHA1 | Date | |
---|---|---|---|
6cc55301ad |
@ -48,11 +48,7 @@ class TransformerLitModel(BaseLitModel):
|
||||
if args.bce:
|
||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
elif args.label_smoothing != 0.0:
|
||||
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
||||
self.smoothing = args.label_smoothing
|
||||
self.loss_fn = self.label_smoothed_cross_entropy
|
||||
self.frobenius_reg = args.weight_decay
|
||||
# self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
||||
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
||||
else:
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
self.best_acc = 0
|
||||
@ -73,19 +69,6 @@ class TransformerLitModel(BaseLitModel):
|
||||
self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
|
||||
self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
|
||||
|
||||
def label_smoothed_cross_entropy(self, logits, labels):
|
||||
num_classes = logits.size(1)
|
||||
one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
|
||||
one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (num_classes - 1)
|
||||
loss = self.cross_entropy_loss(logits, labels)
|
||||
return loss
|
||||
|
||||
def frobenius_norm_loss(self):
|
||||
frobenius_norm = 0.0
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'bias' not in name:
|
||||
frobenius_norm += torch.norm(param, p='fro')
|
||||
return frobenius_norm
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
@ -176,9 +159,6 @@ class TransformerLitModel(BaseLitModel):
|
||||
else:
|
||||
loss += self.loss_fn(mask_logits, label)
|
||||
|
||||
if self.smoothing is not None and self.smoothing != 0.0:
|
||||
loss += self.frobenius_reg * self.frobenius_norm_loss()
|
||||
|
||||
if batch_idx == 0:
|
||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
nohup python -u main.py --gpus "3," --max_epochs=16 --num_workers=32 \
|
||||
nohup python -u main.py --gpus "1," --max_epochs=16 --num_workers=32 \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--accumulate_grad_batches 1 \
|
||||
--model_class BertKGC \
|
||||
--batch_size 16 \
|
||||
--checkpoint /root/kg_374/Relphormer_instance_1/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
|
||||
--checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \
|
||||
--pretrain 0 \
|
||||
--bce 0 \
|
||||
--check_val_every_n_epoch 1 \
|
||||
|
Reference in New Issue
Block a user