Relphormer baseline
This commit is contained in:
2
pretrain/lit_models/__init__.py
Normal file
2
pretrain/lit_models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .transformer import *
|
||||
from .base import *
|
97
pretrain/lit_models/base.py
Normal file
97
pretrain/lit_models/base.py
Normal file
@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
OPTIMIZER = "AdamW"
|
||||
LR = 5e-5
|
||||
LOSS = "cross_entropy"
|
||||
ONE_CYCLE_TOTAL_STEPS = 100
|
||||
|
||||
class Config(dict):
|
||||
def __getattr__(self, name):
|
||||
return self.get(name)
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
self[name] = val
|
||||
|
||||
|
||||
class BaseLitModel(pl.LightningModule):
|
||||
"""
|
||||
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
|
||||
"""
|
||||
|
||||
def __init__(self, model, args: argparse.Namespace = None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.args = Config(vars(args)) if args is not None else {}
|
||||
|
||||
optimizer = self.args.get("optimizer", OPTIMIZER)
|
||||
self.optimizer_class = getattr(torch.optim, optimizer)
|
||||
self.lr = self.args.get("lr", LR)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
|
||||
parser.add_argument("--lr", type=float, default=LR)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||||
return parser
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
|
||||
if self.one_cycle_max_lr is None:
|
||||
return optimizer
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
self.log("train_loss", loss)
|
||||
self.train_acc(logits, y)
|
||||
self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
self.log("val_loss", loss, prog_bar=True)
|
||||
self.val_acc(logits, y)
|
||||
self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
||||
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
self.test_acc(logits, y)
|
||||
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
|
||||
|
||||
@property
|
||||
def num_training_steps(self) -> int:
|
||||
"""Total training steps inferred from datamodule and devices."""
|
||||
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
|
||||
dataset_size = self.trainer.limit_train_batches
|
||||
elif isinstance(self.trainer.limit_train_batches, float):
|
||||
# limit_train_batches is a percentage of batches
|
||||
dataset_size = len(self.trainer.datamodule.train_dataloader())
|
||||
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
|
||||
else:
|
||||
dataset_size = len(self.trainer.datamodule.train_dataloader())
|
||||
|
||||
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
|
||||
if self.trainer.tpu_cores:
|
||||
num_devices = max(num_devices, self.trainer.tpu_cores)
|
||||
|
||||
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
|
||||
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
|
||||
|
||||
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
|
||||
return self.trainer.max_steps
|
||||
return max_estimated_steps
|
||||
|
503
pretrain/lit_models/transformer.py
Normal file
503
pretrain/lit_models/transformer.py
Normal file
@ -0,0 +1,503 @@
|
||||
from logging import debug
|
||||
import random
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import pickle
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import json
|
||||
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
|
||||
|
||||
from .base import BaseLitModel
|
||||
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
|
||||
|
||||
from functools import partial
|
||||
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
|
||||
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
def multilabel_categorical_crossentropy(y_pred, y_true):
|
||||
y_pred = (1 - 2 * y_true) * y_pred
|
||||
y_pred_neg = y_pred - y_true * 1e12
|
||||
y_pred_pos = y_pred - (1 - y_true) * 1e12
|
||||
zeros = torch.zeros_like(y_pred[..., :1])
|
||||
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
|
||||
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
|
||||
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
|
||||
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
|
||||
return (neg_loss + pos_loss).mean()
|
||||
|
||||
def decode(output_ids, tokenizer):
|
||||
return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
||||
|
||||
class TransformerLitModel(BaseLitModel):
|
||||
def __init__(self, model, args, tokenizer=None, data_config={}):
|
||||
super().__init__(model, args)
|
||||
self.save_hyperparameters(args)
|
||||
if args.bce:
|
||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
elif args.label_smoothing != 0.0:
|
||||
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
||||
else:
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
self.best_acc = 0
|
||||
self.first = True
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.__dict__.update(data_config)
|
||||
# resize the word embedding layer
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
self.decode = partial(decode, tokenizer=self.tokenizer)
|
||||
if args.pretrain:
|
||||
self._freaze_attention()
|
||||
elif "ind" in args.data_dir:
|
||||
# for inductive setting, use feeaze the word embedding
|
||||
self._freaze_word_embedding()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
# embed();exit()
|
||||
# print(self.optimizers().param_groups[1]['lr'])
|
||||
labels = batch.pop("labels")
|
||||
label = batch.pop("label")
|
||||
pos = batch.pop("pos")
|
||||
try:
|
||||
en = batch.pop("en")
|
||||
rel = batch.pop("rel")
|
||||
except KeyError:
|
||||
pass
|
||||
input_ids = batch['input_ids']
|
||||
logits = self.model(**batch, return_dict=True).logits
|
||||
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bs = input_ids.shape[0]
|
||||
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
|
||||
|
||||
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
|
||||
if self.args.bce:
|
||||
loss = self.loss_fn(mask_logits, labels)
|
||||
else:
|
||||
loss = self.loss_fn(mask_logits, label)
|
||||
|
||||
if batch_idx == 0:
|
||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def _eval(self, batch, batch_idx, ):
|
||||
labels = batch.pop("labels")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
pos = batch.pop('pos')
|
||||
my_keys = list(batch.keys())
|
||||
for k in my_keys:
|
||||
if k not in ["input_ids", "attention_mask", "token_type_ids"]:
|
||||
batch.pop(k)
|
||||
logits = self.model(**batch, return_dict=True).logits[:, :, self.entity_id_st:self.entity_id_ed]
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
logits = logits[torch.arange(bsz), mask_idx]
|
||||
# get the entity ranks
|
||||
# filter the entity
|
||||
assert labels[0][label[0]], "correct ids must in filiter!"
|
||||
labels[torch.arange(bsz), label] = 0
|
||||
assert logits.shape == labels.shape
|
||||
logits += labels * -100 # mask entityj
|
||||
# for i in range(bsz):
|
||||
# logits[i][labels]
|
||||
|
||||
_, outputs = torch.sort(logits, dim=1, descending=True)
|
||||
_, outputs = torch.sort(outputs, dim=1)
|
||||
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
return dict(ranks = np.array(ranks))
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
result = self._eval(batch, batch_idx)
|
||||
return result
|
||||
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
ranks = np.concatenate([_['ranks'] for _ in outputs])
|
||||
total_ranks = ranks.shape[0]
|
||||
|
||||
if not self.args.pretrain:
|
||||
l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
|
||||
r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
|
||||
self.log("Eval/lhits10", (l_ranks<=10).mean())
|
||||
self.log("Eval/rhits10", (r_ranks<=10).mean())
|
||||
|
||||
hits20 = (ranks<=20).mean()
|
||||
hits10 = (ranks<=10).mean()
|
||||
hits3 = (ranks<=3).mean()
|
||||
hits1 = (ranks<=1).mean()
|
||||
|
||||
self.log("Eval/hits10", hits10)
|
||||
self.log("Eval/hits20", hits20)
|
||||
self.log("Eval/hits3", hits3)
|
||||
self.log("Eval/hits1", hits1)
|
||||
self.log("Eval/mean_rank", ranks.mean())
|
||||
self.log("Eval/mrr", (1. / ranks).mean())
|
||||
self.log("hits10", hits10, prog_bar=True)
|
||||
self.log("hits1", hits1, prog_bar=True)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
# ranks = self._eval(batch, batch_idx)
|
||||
result = self._eval(batch, batch_idx)
|
||||
# self.log("Test/ranks", np.mean(ranks))
|
||||
|
||||
return result
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
ranks = np.concatenate([_['ranks'] for _ in outputs])
|
||||
|
||||
hits20 = (ranks<=20).mean()
|
||||
hits10 = (ranks<=10).mean()
|
||||
hits3 = (ranks<=3).mean()
|
||||
hits1 = (ranks<=1).mean()
|
||||
|
||||
|
||||
self.log("Test/hits10", hits10)
|
||||
self.log("Test/hits20", hits20)
|
||||
self.log("Test/hits3", hits3)
|
||||
self.log("Test/hits1", hits1)
|
||||
self.log("Test/mean_rank", ranks.mean())
|
||||
self.log("Test/mrr", (1. / ranks).mean())
|
||||
|
||||
def configure_optimizers(self):
|
||||
no_decay_param = ["bias", "LayerNorm.weight"]
|
||||
|
||||
optimizer_group_parameters = [
|
||||
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
|
||||
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
|
||||
]
|
||||
|
||||
optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler":{
|
||||
'scheduler': scheduler,
|
||||
'interval': 'step', # or 'epoch'
|
||||
'frequency': 1,
|
||||
}
|
||||
}
|
||||
|
||||
def _freaze_attention(self):
|
||||
for k, v in self.model.named_parameters():
|
||||
if "word" not in k:
|
||||
v.requires_grad = False
|
||||
else:
|
||||
print(k)
|
||||
|
||||
def _freaze_word_embedding(self):
|
||||
for k, v in self.model.named_parameters():
|
||||
if "word" in k:
|
||||
print(k)
|
||||
v.requires_grad = False
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser = BaseLitModel.add_to_argparse(parser)
|
||||
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
|
||||
parser.add_argument("--bce", type=int, default=0, help="")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
import faiss
|
||||
import os
|
||||
class GetEntityEmbeddingLitModel(TransformerLitModel):
|
||||
def __init__(self, model, args, tokenizer, data_config={}):
|
||||
super().__init__(model, args, tokenizer, data_config)
|
||||
|
||||
self.faissid2entityid = {}
|
||||
# self.index = faiss.IndexFlatL2(d) # build the index
|
||||
|
||||
d, measure = self.model.config.hidden_size, faiss.METRIC_L2
|
||||
# param = 'HNSW64'
|
||||
# self.index = faiss.index_factory(d, param, measure)
|
||||
self.index = faiss.IndexFlatL2(d) # build the index
|
||||
# print(self.index.is_trained) # 此时输出为True
|
||||
# index.add(xb)
|
||||
self.cnt_batch = 0
|
||||
self.total_embedding = []
|
||||
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
labels = batch.pop("labels")
|
||||
mask_idx = batch.pop("pos")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
# last layer
|
||||
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
|
||||
# _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
|
||||
# use normalize or not ?
|
||||
# entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
|
||||
self.total_embedding.append(entity_embedding)
|
||||
# self.index.add(np.array(entity_embedding, dtype=np.float32))
|
||||
for i, l in zip(range(bsz), label):
|
||||
self.faissid2entityid[i+self.cnt_batch] = l.cpu()
|
||||
self.cnt_batch += bsz
|
||||
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
self.total_embedding = np.concatenate(self.total_embedding, axis=0)
|
||||
# self.index.train(self.total_embedding)
|
||||
print(faiss.MatrixStats(self.total_embedding).comments)
|
||||
self.index.add(self.total_embedding)
|
||||
faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
|
||||
with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
|
||||
pickle.dump(self.faissid2entityid, file)
|
||||
|
||||
with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
|
||||
pickle.dump(self.total_embedding, file)
|
||||
# print(f"number of entity embedding : {len(self.faissid2entityid)}")
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser = TransformerLitModel.add_to_argparse(parser)
|
||||
parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
|
||||
return parser
|
||||
|
||||
class UseEntityEmbeddingLitModel(TransformerLitModel):
|
||||
def __init__(self, model, args, tokenizer, data_config={}):
|
||||
super().__init__(model, args, tokenizer, data_config)
|
||||
|
||||
self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
|
||||
self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
|
||||
|
||||
|
||||
self.dis2logits = distance2logits_2
|
||||
|
||||
def _eval(self, batch, batch_idx, ):
|
||||
labels = batch.pop("labels")
|
||||
pos = batch.pop("pos")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
|
||||
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
|
||||
topk = 200
|
||||
D, I = self.index.search(mask_embedding, topk)
|
||||
labels[torch.arange(bsz), label] = 0
|
||||
|
||||
# logits = torch.zeros_like(labels)
|
||||
# D = torch.softmax(torch.exp(-1. * torch.tensor(D)), dim=-1)
|
||||
# for i in range(bsz):
|
||||
# for j in range(topk):
|
||||
# logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
|
||||
# # get the entity ranks
|
||||
# # filter the entity
|
||||
# assert labels[0][label[0]], "correct ids must in filiter!"
|
||||
# labels[torch.arange(bsz), label] = 0
|
||||
# assert logits.shape == labels.shape
|
||||
# logits += labels * -100 # mask entityj
|
||||
|
||||
# _, outputs = torch.sort(logits, dim=1, descending=True)
|
||||
# _, outputs = torch.sort(outputs, dim=1)
|
||||
# ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
entity_logits = torch.full(labels.shape, -100.).to(self.device)
|
||||
D = self.dis2logits(D)
|
||||
for i in range(bsz):
|
||||
for j in range(topk):
|
||||
# filter entity in labels
|
||||
if I[i][j] not in self.faissid2entityid:
|
||||
print(I[i][j])
|
||||
break
|
||||
# assert I[i][j] in self.faissid2entityid, print(I[i][j])
|
||||
if labels[i][self.faissid2entityid[I[i][j]]]: continue
|
||||
if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
|
||||
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
|
||||
# no added together
|
||||
# else:
|
||||
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
|
||||
# get the entity ranks
|
||||
# filter the entity
|
||||
|
||||
assert entity_logits.shape == labels.shape
|
||||
|
||||
_, outputs = torch.sort(entity_logits, dim=1, descending=True)
|
||||
_, outputs = torch.sort(outputs, dim=1)
|
||||
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
return dict(ranks = np.array(ranks))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser = TransformerLitModel.add_to_argparse(parser)
|
||||
parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
|
||||
parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
|
||||
return parser
|
||||
|
||||
|
||||
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
||||
def __init__(self, model, args, tokenizer, data_config={}):
|
||||
super().__init__(model, args, tokenizer, data_config=data_config)
|
||||
self.dis2logits = distance2logits_2
|
||||
self.id2entity = {}
|
||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||
cnt = 0
|
||||
for line in file.readlines():
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity[cnt] = e
|
||||
cnt += 1
|
||||
self.id2entity_t = {}
|
||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||
for line in file.readlines():
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity_t[e] = d
|
||||
for k, v in self.id2entity.items():
|
||||
self.id2entity[k] = self.id2entity_t[v]
|
||||
def _eval(self, batch, batch_idx, ):
|
||||
labels = batch.pop("labels")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
pos = batch.pop("pos")
|
||||
|
||||
result = self.model(**batch, return_dict=True, output_hidden_states=True)
|
||||
hidden_states = result.hidden_states[-1]
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
|
||||
# mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
|
||||
topk = self.args.knn_topk
|
||||
D, I = self.index.search(mask_embedding, topk)
|
||||
D = torch.from_numpy(D).to(self.device)
|
||||
assert labels[0][label[0]], "correct ids must in filiter!"
|
||||
labels[torch.arange(bsz), label] = 0
|
||||
|
||||
|
||||
mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
|
||||
mask_logits = mask_logits[torch.arange(bsz), mask_idx]
|
||||
entity_logits = torch.full(labels.shape, 1000.).to(self.device)
|
||||
# D = self.dis2logits(D)
|
||||
for i in range(bsz):
|
||||
for j in range(topk):
|
||||
# filter entity in labels
|
||||
if labels[i][self.faissid2entityid[I[i][j]]]: continue
|
||||
if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
|
||||
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
|
||||
# else:
|
||||
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
|
||||
entity_logits = self.dis2logits(entity_logits)
|
||||
# get the entity ranks
|
||||
# filter the entity
|
||||
assert entity_logits.shape == labels.shape
|
||||
assert mask_logits.shape == labels.shape
|
||||
# entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
|
||||
entity_logits = entity_logits + labels* -100.
|
||||
mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
|
||||
# logits = mask_logits
|
||||
logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
|
||||
# logits = entity_logits + mask_logits
|
||||
|
||||
|
||||
knn_topk_logits, knn_topk_id = entity_logits.topk(20)
|
||||
mask_topk_logits, mask_topk_id = mask_logits.topk(20)
|
||||
union_topk = []
|
||||
for i in range(bsz):
|
||||
num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
|
||||
union_topk.append(num_same/ 20.)
|
||||
|
||||
knn_topk_id = knn_topk_id.to("cpu")
|
||||
mask_topk_id = mask_topk_id.to("cpu")
|
||||
mask_topk_logits = mask_topk_logits.to("cpu")
|
||||
knn_topk_logits = knn_topk_logits.to("cpu")
|
||||
label = label.to("cpu")
|
||||
|
||||
|
||||
|
||||
for t in range(bsz):
|
||||
if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
|
||||
print(knn_topk_logits[t], knn_topk_id[t])
|
||||
print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
|
||||
print(mask_topk_logits[t], mask_topk_id[t])
|
||||
print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
|
||||
print(label[t])
|
||||
print()
|
||||
|
||||
_, outputs = torch.sort(logits, dim=1, descending=True)
|
||||
_, outputs = torch.sort(outputs, dim=1)
|
||||
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
|
||||
mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
|
||||
ranks = np.concatenate([_['ranks'] for _ in outputs])
|
||||
num_same = np.concatenate([_['num_same'] for _ in outputs])
|
||||
results_keys = list(outputs[0].keys())
|
||||
results = {}
|
||||
# for k in results_keys:
|
||||
# results.
|
||||
|
||||
self.log("Test/num_same", num_same.mean())
|
||||
|
||||
hits20 = (ranks<=20).mean()
|
||||
hits10 = (ranks<=10).mean()
|
||||
hits3 = (ranks<=3).mean()
|
||||
hits1 = (ranks<=1).mean()
|
||||
|
||||
|
||||
self.log("Test/hits10", hits10)
|
||||
self.log("Test/hits20", hits20)
|
||||
self.log("Test/hits3", hits3)
|
||||
self.log("Test/hits1", hits1)
|
||||
self.log("Test/mean_rank", ranks.mean())
|
||||
self.log("Test/mrr", (1. / ranks).mean())
|
||||
|
||||
def add_to_argparse(parser):
|
||||
parser = TransformerLitModel.add_to_argparse(parser)
|
||||
parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
|
||||
parser.add_argument("--knn_topk", type=int, default=100, help="")
|
||||
|
||||
return parser
|
||||
|
||||
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
|
||||
combine_probs = torch.stack([vocab_p, knn_p], dim=0)
|
||||
coeffs = torch.ones_like(combine_probs)
|
||||
coeffs[0] = np.log(1 - coeff)
|
||||
coeffs[1] = np.log(coeff)
|
||||
curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
|
||||
|
||||
return curr_prob
|
||||
|
||||
def distance2logits(D):
|
||||
return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
|
||||
|
||||
def distance2logits_2(D, n=10):
|
||||
if not isinstance(D, torch.Tensor):
|
||||
D = torch.tensor(D)
|
||||
if torch.sum(D) != 0.0:
|
||||
distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
|
||||
return distances
|
66
pretrain/lit_models/utils.py
Normal file
66
pretrain/lit_models/utils.py
Normal file
@ -0,0 +1,66 @@
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
def rank_score(ranks):
|
||||
# prepare the dataset
|
||||
len_samples = len(ranks)
|
||||
hits10 = [0] * len_samples
|
||||
hits5 = [0] * len_samples
|
||||
hits1 = [0] * len_samples
|
||||
mrr = []
|
||||
|
||||
|
||||
for idx, rank in enumerate(ranks):
|
||||
if rank <= 10:
|
||||
hits10[idx] = 1.
|
||||
if rank <= 5:
|
||||
hits5[idx] = 1.
|
||||
if rank <= 1:
|
||||
hits1[idx] = 1.
|
||||
mrr.append(1./rank)
|
||||
|
||||
|
||||
return np.mean(hits10), np.mean(hits5), np.mean(hits1), np.mean(mrr)
|
||||
|
||||
def acc(logits, labels):
|
||||
preds = np.argmax(logits, axis=-1)
|
||||
return (preds == labels).mean()
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
class LabelSmoothSoftmaxCEV1(nn.Module):
|
||||
'''
|
||||
This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
|
||||
'''
|
||||
|
||||
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
|
||||
super(LabelSmoothSoftmaxCEV1, self).__init__()
|
||||
self.lb_smooth = lb_smooth
|
||||
self.reduction = reduction
|
||||
self.lb_ignore = ignore_index
|
||||
self.log_softmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, logits, label):
|
||||
'''
|
||||
args: logits: tensor of shape (N, C, H, W)
|
||||
args: label: tensor of shape(N, H, W)
|
||||
'''
|
||||
# overcome ignored label
|
||||
with torch.no_grad():
|
||||
num_classes = logits.size(1)
|
||||
label = label.clone().detach()
|
||||
ignore = label == self.lb_ignore
|
||||
n_valid = (ignore == 0).sum()
|
||||
label[ignore] = 0
|
||||
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
|
||||
label = torch.empty_like(logits).fill_(
|
||||
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
|
||||
|
||||
logs = self.log_softmax(logits)
|
||||
loss = -torch.sum(logs * label, dim=1)
|
||||
loss[ignore] = 0
|
||||
if self.reduction == 'mean':
|
||||
loss = loss.sum() / n_valid
|
||||
if self.reduction == 'sum':
|
||||
loss = loss.sum()
|
||||
|
||||
return loss
|
Reference in New Issue
Block a user