Thesis/lit_models/transformer.py
2023-02-12 10:57:29 +00:00

566 lines
24 KiB
Python

from logging import debug
import random
from turtle import distance
import pytorch_lightning as pl
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
from .base import BaseLitModel
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from functools import partial
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
from typing import Callable, Iterable, List
def pad_distance(pad_length, distance):
pad = nn.ConstantPad2d(padding=(0, pad_length, 0, pad_length), value=float('-inf'))
distance = pad(distance)
return distance
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def multilabel_categorical_crossentropy(y_pred, y_true):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return (neg_loss + pos_loss).mean()
def decode(output_ids, tokenizer):
return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
class TransformerLitModel(BaseLitModel):
def __init__(self, model, args, tokenizer=None, data_config={}):
super().__init__(model, args)
self.save_hyperparameters(args)
if args.bce:
self.loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.label_smoothing != 0.0:
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
else:
self.loss_fn = nn.CrossEntropyLoss()
self.best_acc = 0
self.first = True
self.tokenizer = tokenizer
self.num_heads = 12
self.__dict__.update(data_config)
# resize the word embedding layer
self.model.resize_token_embeddings(len(self.tokenizer))
self.decode = partial(decode, tokenizer=self.tokenizer)
if args.pretrain:
self._freaze_attention()
elif "ind" in args.data_dir:
# for inductive setting, use feeaze the word embedding
self._freaze_word_embedding()
self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
def forward(self, x):
return self.model(x)
def create_negatives(self, batch):
negativeBatches = []
label = batch['label']
for i in range(label.shape[0]):
newBatch = {}
newBatch['attention_mask'] = None
newBatch['input_ids'] = torch.clone(batch['input_ids'])
newBatch['label'] = torch.zeros_like(batch['label'])
negativeBatches.append(newBatch)
entity_idx = []
self_label = []
for idx, l in enumerate(label):
decoded = self.decode([batch['input_ids'][idx]])[0].split(' ')
for j in range(1, len(decoded)):
if (decoded[j].startswith("[ENTITY_")):
entity_idx.append(j)
self_label.append(batch['input_ids'][idx][j])
break
for idx, lbl in enumerate(label):
for i in range(label.shape[0]):
if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl):
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl
else:
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i]
return negativeBatches
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
# embed();exit()
# print(self.optimizers().param_groups[1]['lr'])
negativeBatches = self.create_negatives(batch)
loss = 0
for negativeBatch in negativeBatches:
label = negativeBatch.pop("label")
input_ids = batch['input_ids']
logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
loss += self.loss_fn(mask_logits, label)
labels = batch.pop("labels")
label = batch.pop("label")
pos = batch.pop("pos")
try:
en = batch.pop("en")
rel = batch.pop("rel")
except KeyError:
pass
input_ids = batch['input_ids']
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
distance = batch.pop("distance_attention")
graph_attn_bias = torch.zeros(input_ids.size(0), input_ids.size(1), input_ids.size(1), device='cuda')
graph_attn_bias[:, 1:, 1:][distance_attention == float('-inf')] = float('-inf')
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
distance_attention = self.spatial_pos_encoder(distance_attention.long()).permute(0, 3, 1, 2)
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + distance_attention
if self.args.use_global_node:
t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
if self.args.add_attn_bias:
logits = self.model(**batch, return_dict=True, distance_attention=graph_attn_bias).logits
else:
logits = self.model(**batch, return_dict=True, distance_attention=None).logits
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
if self.args.bce:
loss += self.loss_fn(mask_logits, labels)
else:
loss += self.loss_fn(mask_logits, label)
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))
return loss
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
pos = batch.pop('pos')
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
distance = batch.pop("distance_attention")
graph_attn_bias = torch.zeros(input_ids.size(0), input_ids.size(1), input_ids.size(1), device='cuda')
graph_attn_bias[:, 1:, 1:][distance_attention == float('-inf')] = float('-inf')
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
distance_attention = self.spatial_pos_encoder(distance_attention.long()).permute(0, 3, 1, 2)
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + distance_attention
# distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance), distance) for i, distance in enumerate(batch['distance_attention'])])
# distance = batch.pop("distance_attention")
# distance_attention = self.spatial_pos_encoder(distance_attention.long()).permute(0, 3, 1, 2)
my_keys = list(batch.keys())
for k in my_keys:
if k not in ["input_ids", "attention_mask", "token_type_ids"]:
batch.pop(k)
if self.args.add_attn_bias:
logits = self.model(**batch, return_dict=True, distance_attention=graph_attn_bias).logits[:, :, self.entity_id_st:self.entity_id_ed]
else:
logits = self.model(**batch, return_dict=True, distance_attention=None).logits[:, :, self.entity_id_st:self.entity_id_ed]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
logits = logits[torch.arange(bsz), mask_idx]
# get the entity ranks
# filter the entity
assert labels[0][label[0]], "correct ids must in filiter!"
labels[torch.arange(bsz), label] = 0
assert logits.shape == labels.shape
logits += labels * -100 # mask entityj
# for i in range(bsz):
# logits[i][labels]
_, outputs = torch.sort(logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks))
def validation_step(self, batch, batch_idx):
result = self._eval(batch, batch_idx)
return result
def validation_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
total_ranks = ranks.shape[0]
if not self.args.pretrain:
l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
self.log("Eval/lhits10", (l_ranks<=10).mean())
self.log("Eval/rhits10", (r_ranks<=10).mean())
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Eval/hits10", hits10)
self.log("Eval/hits20", hits20)
self.log("Eval/hits3", hits3)
self.log("Eval/hits1", hits1)
self.log("Eval/mean_rank", ranks.mean())
self.log("Eval/mrr", (1. / ranks).mean())
self.log("hits10", hits10, prog_bar=True)
self.log("hits1", hits1, prog_bar=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
# ranks = self._eval(batch, batch_idx)
result = self._eval(batch, batch_idx)
# self.log("Test/ranks", np.mean(ranks))
return result
def test_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Test/hits10", hits10)
self.log("Test/hits20", hits20)
self.log("Test/hits3", hits3)
self.log("Test/hits1", hits1)
self.log("Test/mean_rank", ranks.mean())
self.log("Test/mrr", (1. / ranks).mean())
def configure_optimizers(self):
no_decay_param = ["bias", "LayerNorm.weight"]
optimizer_group_parameters = [
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
]
optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
return {
"optimizer": optimizer,
"lr_scheduler":{
'scheduler': scheduler,
'interval': 'step', # or 'epoch'
'frequency': 1,
}
}
def _freaze_attention(self):
for k, v in self.model.named_parameters():
if "word" not in k:
v.requires_grad = False
else:
print(k)
def _freaze_word_embedding(self):
for k, v in self.model.named_parameters():
if "word" in k:
print(k)
v.requires_grad = False
@staticmethod
def add_to_argparse(parser):
parser = BaseLitModel.add_to_argparse(parser)
parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
parser.add_argument("--bce", type=int, default=0, help="")
return parser
import faiss
import os
class GetEntityEmbeddingLitModel(TransformerLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config)
self.faissid2entityid = {}
# self.index = faiss.IndexFlatL2(d) # build the index
d, measure = self.model.config.hidden_size, faiss.METRIC_L2
# param = 'HNSW64'
# self.index = faiss.index_factory(d, param, measure)
self.index = faiss.IndexFlatL2(d) # build the index
# print(self.index.is_trained) # 此时输出为True
# index.add(xb)
self.cnt_batch = 0
self.total_embedding = []
def test_step(self, batch, batch_idx):
labels = batch.pop("labels")
mask_idx = batch.pop("pos")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
# last layer
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
# _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
# use normalize or not ?
# entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
self.total_embedding.append(entity_embedding)
# self.index.add(np.array(entity_embedding, dtype=np.float32))
for i, l in zip(range(bsz), label):
self.faissid2entityid[i+self.cnt_batch] = l.cpu()
self.cnt_batch += bsz
def test_epoch_end(self, outputs) -> None:
self.total_embedding = np.concatenate(self.total_embedding, axis=0)
# self.index.train(self.total_embedding)
print(faiss.MatrixStats(self.total_embedding).comments)
self.index.add(self.total_embedding)
faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
pickle.dump(self.faissid2entityid, file)
with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
pickle.dump(self.total_embedding, file)
# print(f"number of entity embedding : {len(self.faissid2entityid)}")
@staticmethod
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
return parser
class UseEntityEmbeddingLitModel(TransformerLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config)
self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
self.dis2logits = distance2logits_2
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
pos = batch.pop("pos")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
topk = 200
D, I = self.index.search(mask_embedding, topk)
labels[torch.arange(bsz), label] = 0
entity_logits = torch.full(labels.shape, -100.).to(self.device)
D = self.dis2logits(D)
for i in range(bsz):
for j in range(topk):
# filter entity in labels
if I[i][j] not in self.faissid2entityid:
print(I[i][j])
break
# assert I[i][j] in self.faissid2entityid, print(I[i][j])
if labels[i][self.faissid2entityid[I[i][j]]]: continue
if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
# no added together
# else:
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
# get the entity ranks
# filter the entity
assert entity_logits.shape == labels.shape
_, outputs = torch.sort(entity_logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks))
@staticmethod
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
return parser
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config=data_config)
self.dis2logits = distance2logits_2
self.id2entity = {}
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
cnt = 0
for line in file.readlines():
e, d = line.strip().split("\t")
self.id2entity[cnt] = e
cnt += 1
self.id2entity_t = {}
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
for line in file.readlines():
e, d = line.strip().split("\t")
self.id2entity_t[e] = d
for k, v in self.id2entity.items():
self.id2entity[k] = self.id2entity_t[v]
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
pos = batch.pop("pos")
result = self.model(**batch, return_dict=True, output_hidden_states=True)
hidden_states = result.hidden_states[-1]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
# mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
topk = self.args.knn_topk
D, I = self.index.search(mask_embedding, topk)
D = torch.from_numpy(D).to(self.device)
assert labels[0][label[0]], "correct ids must in filiter!"
labels[torch.arange(bsz), label] = 0
mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
mask_logits = mask_logits[torch.arange(bsz), mask_idx]
entity_logits = torch.full(labels.shape, 1000.).to(self.device)
# D = self.dis2logits(D)
for i in range(bsz):
for j in range(topk):
# filter entity in labels
if labels[i][self.faissid2entityid[I[i][j]]]: continue
if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
# else:
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
entity_logits = self.dis2logits(entity_logits)
# get the entity ranks
# filter the entity
assert entity_logits.shape == labels.shape
assert mask_logits.shape == labels.shape
# entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
entity_logits = entity_logits + labels* -100.
mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
# logits = mask_logits
logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
# logits = entity_logits + mask_logits
knn_topk_logits, knn_topk_id = entity_logits.topk(20)
mask_topk_logits, mask_topk_id = mask_logits.topk(20)
union_topk = []
for i in range(bsz):
num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
union_topk.append(num_same/ 20.)
knn_topk_id = knn_topk_id.to("cpu")
mask_topk_id = mask_topk_id.to("cpu")
mask_topk_logits = mask_topk_logits.to("cpu")
knn_topk_logits = knn_topk_logits.to("cpu")
label = label.to("cpu")
for t in range(bsz):
if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
print(knn_topk_logits[t], knn_topk_id[t])
print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
print(mask_topk_logits[t], mask_topk_id[t])
print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
print(label[t])
print()
_, outputs = torch.sort(logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
def test_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
num_same = np.concatenate([_['num_same'] for _ in outputs])
results_keys = list(outputs[0].keys())
results = {}
# for k in results_keys:
# results.
self.log("Test/num_same", num_same.mean())
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Test/hits10", hits10)
self.log("Test/hits20", hits20)
self.log("Test/hits3", hits3)
self.log("Test/hits1", hits1)
self.log("Test/mean_rank", ranks.mean())
self.log("Test/mrr", (1. / ranks).mean())
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
parser.add_argument("--knn_topk", type=int, default=100, help="")
return parser
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
combine_probs = torch.stack([vocab_p, knn_p], dim=0)
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - coeff)
coeffs[1] = np.log(coeff)
curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
return curr_prob
def distance2logits(D):
return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
def distance2logits_2(D, n=10):
if not isinstance(D, torch.Tensor):
D = torch.tensor(D)
if torch.sum(D) != 0.0:
distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
return distances