Relphormer baseline

This commit is contained in:
2022-12-26 04:54:46 +00:00
commit c0d0be076f
117 changed files with 1574716 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from .transformer import *
from .base import *

View File

@ -0,0 +1,97 @@
import argparse
import pytorch_lightning as pl
import torch
from typing import Dict, Any
OPTIMIZER = "AdamW"
LR = 5e-5
LOSS = "cross_entropy"
ONE_CYCLE_TOTAL_STEPS = 100
class Config(dict):
def __getattr__(self, name):
return self.get(name)
def __setattr__(self, name, val):
self[name] = val
class BaseLitModel(pl.LightningModule):
"""
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
"""
def __init__(self, model, args: argparse.Namespace = None):
super().__init__()
self.model = model
self.args = Config(vars(args)) if args is not None else {}
optimizer = self.args.get("optimizer", OPTIMIZER)
self.optimizer_class = getattr(torch.optim, optimizer)
self.lr = self.args.get("lr", LR)
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
parser.add_argument("--lr", type=float, default=LR)
parser.add_argument("--weight_decay", type=float, default=0.01)
return parser
def configure_optimizers(self):
optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
if self.one_cycle_max_lr is None:
return optimizer
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log("train_loss", loss)
self.train_acc(logits, y)
self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log("val_loss", loss, prog_bar=True)
self.val_acc(logits, y)
self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.trainer.datamodule.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.trainer.datamodule.train_dataloader())
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps

View File

@ -0,0 +1,503 @@
from logging import debug
import random
import pytorch_lightning as pl
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
from .base import BaseLitModel
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from functools import partial
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
from typing import Callable, Iterable, List
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def multilabel_categorical_crossentropy(y_pred, y_true):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return (neg_loss + pos_loss).mean()
def decode(output_ids, tokenizer):
return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
class TransformerLitModel(BaseLitModel):
def __init__(self, model, args, tokenizer=None, data_config={}):
super().__init__(model, args)
self.save_hyperparameters(args)
if args.bce:
self.loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.label_smoothing != 0.0:
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
else:
self.loss_fn = nn.CrossEntropyLoss()
self.best_acc = 0
self.first = True
self.tokenizer = tokenizer
self.__dict__.update(data_config)
# resize the word embedding layer
self.model.resize_token_embeddings(len(self.tokenizer))
self.decode = partial(decode, tokenizer=self.tokenizer)
if args.pretrain:
self._freaze_attention()
elif "ind" in args.data_dir:
# for inductive setting, use feeaze the word embedding
self._freaze_word_embedding()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
# embed();exit()
# print(self.optimizers().param_groups[1]['lr'])
labels = batch.pop("labels")
label = batch.pop("label")
pos = batch.pop("pos")
try:
en = batch.pop("en")
rel = batch.pop("rel")
except KeyError:
pass
input_ids = batch['input_ids']
logits = self.model(**batch, return_dict=True).logits
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
if self.args.bce:
loss = self.loss_fn(mask_logits, labels)
else:
loss = self.loss_fn(mask_logits, label)
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))
return loss
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
pos = batch.pop('pos')
my_keys = list(batch.keys())
for k in my_keys:
if k not in ["input_ids", "attention_mask", "token_type_ids"]:
batch.pop(k)
logits = self.model(**batch, return_dict=True).logits[:, :, self.entity_id_st:self.entity_id_ed]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
logits = logits[torch.arange(bsz), mask_idx]
# get the entity ranks
# filter the entity
assert labels[0][label[0]], "correct ids must in filiter!"
labels[torch.arange(bsz), label] = 0
assert logits.shape == labels.shape
logits += labels * -100 # mask entityj
# for i in range(bsz):
# logits[i][labels]
_, outputs = torch.sort(logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks))
def validation_step(self, batch, batch_idx):
result = self._eval(batch, batch_idx)
return result
def validation_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
total_ranks = ranks.shape[0]
if not self.args.pretrain:
l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
self.log("Eval/lhits10", (l_ranks<=10).mean())
self.log("Eval/rhits10", (r_ranks<=10).mean())
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Eval/hits10", hits10)
self.log("Eval/hits20", hits20)
self.log("Eval/hits3", hits3)
self.log("Eval/hits1", hits1)
self.log("Eval/mean_rank", ranks.mean())
self.log("Eval/mrr", (1. / ranks).mean())
self.log("hits10", hits10, prog_bar=True)
self.log("hits1", hits1, prog_bar=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
# ranks = self._eval(batch, batch_idx)
result = self._eval(batch, batch_idx)
# self.log("Test/ranks", np.mean(ranks))
return result
def test_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Test/hits10", hits10)
self.log("Test/hits20", hits20)
self.log("Test/hits3", hits3)
self.log("Test/hits1", hits1)
self.log("Test/mean_rank", ranks.mean())
self.log("Test/mrr", (1. / ranks).mean())
def configure_optimizers(self):
no_decay_param = ["bias", "LayerNorm.weight"]
optimizer_group_parameters = [
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
]
optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
return {
"optimizer": optimizer,
"lr_scheduler":{
'scheduler': scheduler,
'interval': 'step', # or 'epoch'
'frequency': 1,
}
}
def _freaze_attention(self):
for k, v in self.model.named_parameters():
if "word" not in k:
v.requires_grad = False
else:
print(k)
def _freaze_word_embedding(self):
for k, v in self.model.named_parameters():
if "word" in k:
print(k)
v.requires_grad = False
@staticmethod
def add_to_argparse(parser):
parser = BaseLitModel.add_to_argparse(parser)
parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
parser.add_argument("--bce", type=int, default=0, help="")
return parser
import faiss
import os
class GetEntityEmbeddingLitModel(TransformerLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config)
self.faissid2entityid = {}
# self.index = faiss.IndexFlatL2(d) # build the index
d, measure = self.model.config.hidden_size, faiss.METRIC_L2
# param = 'HNSW64'
# self.index = faiss.index_factory(d, param, measure)
self.index = faiss.IndexFlatL2(d) # build the index
# print(self.index.is_trained) # 此时输出为True
# index.add(xb)
self.cnt_batch = 0
self.total_embedding = []
def test_step(self, batch, batch_idx):
labels = batch.pop("labels")
mask_idx = batch.pop("pos")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
# last layer
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
# _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
# use normalize or not ?
# entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
self.total_embedding.append(entity_embedding)
# self.index.add(np.array(entity_embedding, dtype=np.float32))
for i, l in zip(range(bsz), label):
self.faissid2entityid[i+self.cnt_batch] = l.cpu()
self.cnt_batch += bsz
def test_epoch_end(self, outputs) -> None:
self.total_embedding = np.concatenate(self.total_embedding, axis=0)
# self.index.train(self.total_embedding)
print(faiss.MatrixStats(self.total_embedding).comments)
self.index.add(self.total_embedding)
faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
pickle.dump(self.faissid2entityid, file)
with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
pickle.dump(self.total_embedding, file)
# print(f"number of entity embedding : {len(self.faissid2entityid)}")
@staticmethod
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
return parser
class UseEntityEmbeddingLitModel(TransformerLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config)
self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
self.dis2logits = distance2logits_2
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
pos = batch.pop("pos")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
topk = 200
D, I = self.index.search(mask_embedding, topk)
labels[torch.arange(bsz), label] = 0
# logits = torch.zeros_like(labels)
# D = torch.softmax(torch.exp(-1. * torch.tensor(D)), dim=-1)
# for i in range(bsz):
# for j in range(topk):
# logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
# # get the entity ranks
# # filter the entity
# assert labels[0][label[0]], "correct ids must in filiter!"
# labels[torch.arange(bsz), label] = 0
# assert logits.shape == labels.shape
# logits += labels * -100 # mask entityj
# _, outputs = torch.sort(logits, dim=1, descending=True)
# _, outputs = torch.sort(outputs, dim=1)
# ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
entity_logits = torch.full(labels.shape, -100.).to(self.device)
D = self.dis2logits(D)
for i in range(bsz):
for j in range(topk):
# filter entity in labels
if I[i][j] not in self.faissid2entityid:
print(I[i][j])
break
# assert I[i][j] in self.faissid2entityid, print(I[i][j])
if labels[i][self.faissid2entityid[I[i][j]]]: continue
if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
# no added together
# else:
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
# get the entity ranks
# filter the entity
assert entity_logits.shape == labels.shape
_, outputs = torch.sort(entity_logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks))
@staticmethod
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
return parser
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config=data_config)
self.dis2logits = distance2logits_2
self.id2entity = {}
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
cnt = 0
for line in file.readlines():
e, d = line.strip().split("\t")
self.id2entity[cnt] = e
cnt += 1
self.id2entity_t = {}
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
for line in file.readlines():
e, d = line.strip().split("\t")
self.id2entity_t[e] = d
for k, v in self.id2entity.items():
self.id2entity[k] = self.id2entity_t[v]
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
pos = batch.pop("pos")
result = self.model(**batch, return_dict=True, output_hidden_states=True)
hidden_states = result.hidden_states[-1]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
# mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
topk = self.args.knn_topk
D, I = self.index.search(mask_embedding, topk)
D = torch.from_numpy(D).to(self.device)
assert labels[0][label[0]], "correct ids must in filiter!"
labels[torch.arange(bsz), label] = 0
mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
mask_logits = mask_logits[torch.arange(bsz), mask_idx]
entity_logits = torch.full(labels.shape, 1000.).to(self.device)
# D = self.dis2logits(D)
for i in range(bsz):
for j in range(topk):
# filter entity in labels
if labels[i][self.faissid2entityid[I[i][j]]]: continue
if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
# else:
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
entity_logits = self.dis2logits(entity_logits)
# get the entity ranks
# filter the entity
assert entity_logits.shape == labels.shape
assert mask_logits.shape == labels.shape
# entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
entity_logits = entity_logits + labels* -100.
mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
# logits = mask_logits
logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
# logits = entity_logits + mask_logits
knn_topk_logits, knn_topk_id = entity_logits.topk(20)
mask_topk_logits, mask_topk_id = mask_logits.topk(20)
union_topk = []
for i in range(bsz):
num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
union_topk.append(num_same/ 20.)
knn_topk_id = knn_topk_id.to("cpu")
mask_topk_id = mask_topk_id.to("cpu")
mask_topk_logits = mask_topk_logits.to("cpu")
knn_topk_logits = knn_topk_logits.to("cpu")
label = label.to("cpu")
for t in range(bsz):
if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
print(knn_topk_logits[t], knn_topk_id[t])
print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
print(mask_topk_logits[t], mask_topk_id[t])
print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
print(label[t])
print()
_, outputs = torch.sort(logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
def test_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
num_same = np.concatenate([_['num_same'] for _ in outputs])
results_keys = list(outputs[0].keys())
results = {}
# for k in results_keys:
# results.
self.log("Test/num_same", num_same.mean())
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Test/hits10", hits10)
self.log("Test/hits20", hits20)
self.log("Test/hits3", hits3)
self.log("Test/hits1", hits1)
self.log("Test/mean_rank", ranks.mean())
self.log("Test/mrr", (1. / ranks).mean())
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
parser.add_argument("--knn_topk", type=int, default=100, help="")
return parser
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
combine_probs = torch.stack([vocab_p, knn_p], dim=0)
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - coeff)
coeffs[1] = np.log(coeff)
curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
return curr_prob
def distance2logits(D):
return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
def distance2logits_2(D, n=10):
if not isinstance(D, torch.Tensor):
D = torch.tensor(D)
if torch.sum(D) != 0.0:
distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
return distances

View File

@ -0,0 +1,66 @@
import json
import numpy as np
def rank_score(ranks):
# prepare the dataset
len_samples = len(ranks)
hits10 = [0] * len_samples
hits5 = [0] * len_samples
hits1 = [0] * len_samples
mrr = []
for idx, rank in enumerate(ranks):
if rank <= 10:
hits10[idx] = 1.
if rank <= 5:
hits5[idx] = 1.
if rank <= 1:
hits1[idx] = 1.
mrr.append(1./rank)
return np.mean(hits10), np.mean(hits5), np.mean(hits1), np.mean(mrr)
def acc(logits, labels):
preds = np.argmax(logits, axis=-1)
return (preds == labels).mean()
import torch.nn as nn
import torch
class LabelSmoothSoftmaxCEV1(nn.Module):
'''
This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
'''
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV1, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, logits, label):
'''
args: logits: tensor of shape (N, C, H, W)
args: label: tensor of shape(N, H, W)
'''
# overcome ignored label
with torch.no_grad():
num_classes = logits.size(1)
label = label.clone().detach()
ignore = label == self.lb_ignore
n_valid = (ignore == 0).sum()
label[ignore] = 0
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
label = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self.log_softmax(logits)
loss = -torch.sum(logs * label, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
if self.reduction == 'sum':
loss = loss.sum()
return loss