Relphormer baseline
This commit is contained in:
		
							
								
								
									
										2
									
								
								lit_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								lit_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .transformer import *
 | 
			
		||||
from .base import *
 | 
			
		||||
							
								
								
									
										97
									
								
								lit_models/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								lit_models/base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Dict, Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OPTIMIZER = "AdamW"
 | 
			
		||||
LR = 5e-5
 | 
			
		||||
LOSS = "cross_entropy"
 | 
			
		||||
ONE_CYCLE_TOTAL_STEPS = 100
 | 
			
		||||
 | 
			
		||||
class Config(dict):
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        return self.get(name)
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name, val):
 | 
			
		||||
        self[name] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseLitModel(pl.LightningModule):
 | 
			
		||||
    """
 | 
			
		||||
    Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model, args: argparse.Namespace = None):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.args = Config(vars(args)) if args is not None else {}
 | 
			
		||||
 | 
			
		||||
        optimizer = self.args.get("optimizer", OPTIMIZER)
 | 
			
		||||
        self.optimizer_class = getattr(torch.optim, optimizer)
 | 
			
		||||
        self.lr = self.args.get("lr", LR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
 | 
			
		||||
        parser.add_argument("--lr", type=float, default=LR)
 | 
			
		||||
        parser.add_argument("--weight_decay", type=float, default=0.01)
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
 | 
			
		||||
        if self.one_cycle_max_lr is None:
 | 
			
		||||
            return optimizer
 | 
			
		||||
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
 | 
			
		||||
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.model(x)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        logits = self(x)
 | 
			
		||||
        loss = self.loss_fn(logits, y)
 | 
			
		||||
        self.log("train_loss", loss)
 | 
			
		||||
        self.train_acc(logits, y)
 | 
			
		||||
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        logits = self(x)
 | 
			
		||||
        loss = self.loss_fn(logits, y)
 | 
			
		||||
        self.log("val_loss", loss, prog_bar=True)
 | 
			
		||||
        self.val_acc(logits, y)
 | 
			
		||||
        self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
 | 
			
		||||
 | 
			
		||||
    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        logits = self(x)
 | 
			
		||||
        self.test_acc(logits, y)
 | 
			
		||||
        self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_training_steps(self) -> int:
 | 
			
		||||
        """Total training steps inferred from datamodule and devices."""
 | 
			
		||||
        if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
 | 
			
		||||
            dataset_size = self.trainer.limit_train_batches
 | 
			
		||||
        elif isinstance(self.trainer.limit_train_batches, float):
 | 
			
		||||
            # limit_train_batches is a percentage of batches
 | 
			
		||||
            dataset_size = len(self.trainer.datamodule.train_dataloader())
 | 
			
		||||
            dataset_size = int(dataset_size * self.trainer.limit_train_batches)
 | 
			
		||||
        else:
 | 
			
		||||
            dataset_size = len(self.trainer.datamodule.train_dataloader())
 | 
			
		||||
 | 
			
		||||
        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
 | 
			
		||||
        if self.trainer.tpu_cores:
 | 
			
		||||
            num_devices = max(num_devices, self.trainer.tpu_cores)
 | 
			
		||||
 | 
			
		||||
        effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
 | 
			
		||||
        max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
 | 
			
		||||
 | 
			
		||||
        if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
 | 
			
		||||
            return self.trainer.max_steps
 | 
			
		||||
        return max_estimated_steps
 | 
			
		||||
    
 | 
			
		||||
							
								
								
									
										521
									
								
								lit_models/transformer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										521
									
								
								lit_models/transformer.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,521 @@
 | 
			
		||||
from logging import debug
 | 
			
		||||
import random
 | 
			
		||||
from turtle import distance
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
import pickle
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
import json
 | 
			
		||||
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
 | 
			
		||||
 | 
			
		||||
from .base import BaseLitModel
 | 
			
		||||
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
 | 
			
		||||
 | 
			
		||||
from functools import partial
 | 
			
		||||
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
 | 
			
		||||
 | 
			
		||||
from typing import Callable, Iterable, List
 | 
			
		||||
 | 
			
		||||
def pad_distance(pad_length, distance):
 | 
			
		||||
    pad = nn.ConstantPad2d(padding=(0, pad_length, 0, pad_length), value=float('-inf'))
 | 
			
		||||
    distance = pad(distance)
 | 
			
		||||
    return distance
 | 
			
		||||
 | 
			
		||||
def lmap(f: Callable, x: Iterable) -> List:
 | 
			
		||||
    """list(map(f, x))"""
 | 
			
		||||
    return list(map(f, x))
 | 
			
		||||
 | 
			
		||||
def multilabel_categorical_crossentropy(y_pred, y_true):
 | 
			
		||||
    y_pred = (1 - 2 * y_true) * y_pred
 | 
			
		||||
    y_pred_neg = y_pred - y_true * 1e12
 | 
			
		||||
    y_pred_pos = y_pred - (1 - y_true) * 1e12
 | 
			
		||||
    zeros = torch.zeros_like(y_pred[..., :1])
 | 
			
		||||
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
 | 
			
		||||
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
 | 
			
		||||
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
 | 
			
		||||
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
 | 
			
		||||
    return (neg_loss + pos_loss).mean()
 | 
			
		||||
 | 
			
		||||
def decode(output_ids, tokenizer):
 | 
			
		||||
    return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
 | 
			
		||||
 | 
			
		||||
class TransformerLitModel(BaseLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer=None, data_config={}):
 | 
			
		||||
        super().__init__(model, args)
 | 
			
		||||
        self.save_hyperparameters(args)
 | 
			
		||||
        if args.bce:
 | 
			
		||||
            self.loss_fn = torch.nn.BCEWithLogitsLoss()
 | 
			
		||||
        elif args.label_smoothing != 0.0:
 | 
			
		||||
            self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
			
		||||
        else:
 | 
			
		||||
            self.loss_fn = nn.CrossEntropyLoss()
 | 
			
		||||
        self.best_acc = 0
 | 
			
		||||
        self.first = True
 | 
			
		||||
        
 | 
			
		||||
        self.tokenizer = tokenizer
 | 
			
		||||
        self.num_heads = 12
 | 
			
		||||
        self.__dict__.update(data_config)
 | 
			
		||||
        # resize the word embedding layer
 | 
			
		||||
        self.model.resize_token_embeddings(len(self.tokenizer))
 | 
			
		||||
        self.decode = partial(decode, tokenizer=self.tokenizer)
 | 
			
		||||
        if args.pretrain:
 | 
			
		||||
            self._freaze_attention()
 | 
			
		||||
        elif "ind" in args.data_dir:
 | 
			
		||||
            # for inductive setting, use feeaze the word embedding
 | 
			
		||||
            self._freaze_word_embedding()
 | 
			
		||||
        
 | 
			
		||||
        self.spatial_pos_encoder = nn.Embedding(5, self.num_heads, padding_idx=0)
 | 
			
		||||
        self.graph_token_virtual_distance = nn.Embedding(1, self.num_heads)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.model(x)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        # embed();exit()
 | 
			
		||||
        # print(self.optimizers().param_groups[1]['lr'])
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        label = batch.pop("label")
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
        try:
 | 
			
		||||
            en = batch.pop("en")
 | 
			
		||||
            rel = batch.pop("rel")
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
 | 
			
		||||
        distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
 | 
			
		||||
        distance = batch.pop("distance_attention")
 | 
			
		||||
        graph_attn_bias = torch.zeros(input_ids.size(0), input_ids.size(1), input_ids.size(1), device='cuda')
 | 
			
		||||
        graph_attn_bias[:, 1:, 1:][distance_attention == float('-inf')] = float('-inf')
 | 
			
		||||
        graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
 | 
			
		||||
        distance_attention = self.spatial_pos_encoder(distance_attention.long()).permute(0, 3, 1, 2)
 | 
			
		||||
        graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + distance_attention
 | 
			
		||||
 | 
			
		||||
        if self.args.use_global_node:
 | 
			
		||||
            t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
 | 
			
		||||
            graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
 | 
			
		||||
            graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
 | 
			
		||||
        
 | 
			
		||||
        if self.args.add_attn_bias:
 | 
			
		||||
            logits = self.model(**batch, return_dict=True, distance_attention=graph_attn_bias).logits
 | 
			
		||||
        else:
 | 
			
		||||
            logits = self.model(**batch, return_dict=True, distance_attention=None).logits
 | 
			
		||||
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bs = input_ids.shape[0]
 | 
			
		||||
        mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
 | 
			
		||||
        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
 | 
			
		||||
        if self.args.bce:
 | 
			
		||||
            loss = self.loss_fn(mask_logits, labels)
 | 
			
		||||
        else:
 | 
			
		||||
            loss = self.loss_fn(mask_logits, label)
 | 
			
		||||
 | 
			
		||||
        if batch_idx == 0:
 | 
			
		||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def _eval(self, batch, batch_idx, ):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
        pos = batch.pop('pos')
 | 
			
		||||
        distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
 | 
			
		||||
        distance = batch.pop("distance_attention")
 | 
			
		||||
        graph_attn_bias = torch.zeros(input_ids.size(0), input_ids.size(1), input_ids.size(1), device='cuda')
 | 
			
		||||
        graph_attn_bias[:, 1:, 1:][distance_attention == float('-inf')] = float('-inf')
 | 
			
		||||
        graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
 | 
			
		||||
        distance_attention = self.spatial_pos_encoder(distance_attention.long()).permute(0, 3, 1, 2)
 | 
			
		||||
        graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + distance_attention
 | 
			
		||||
        # distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance), distance) for i, distance in enumerate(batch['distance_attention'])])
 | 
			
		||||
        # distance = batch.pop("distance_attention")
 | 
			
		||||
        # distance_attention = self.spatial_pos_encoder(distance_attention.long()).permute(0, 3, 1, 2)
 | 
			
		||||
        my_keys = list(batch.keys())
 | 
			
		||||
        for k in my_keys:
 | 
			
		||||
            if k not in ["input_ids", "attention_mask", "token_type_ids"]:
 | 
			
		||||
                batch.pop(k)
 | 
			
		||||
        
 | 
			
		||||
        if self.args.add_attn_bias:
 | 
			
		||||
            logits = self.model(**batch, return_dict=True, distance_attention=graph_attn_bias).logits[:, :, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
        else:
 | 
			
		||||
            logits = self.model(**batch, return_dict=True, distance_attention=None).logits[:, :, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        logits = logits[torch.arange(bsz), mask_idx]
 | 
			
		||||
        # get the entity ranks
 | 
			
		||||
        # filter the entity
 | 
			
		||||
        assert labels[0][label[0]], "correct ids must in filiter!"
 | 
			
		||||
        labels[torch.arange(bsz), label] = 0
 | 
			
		||||
        assert logits.shape == labels.shape
 | 
			
		||||
        logits += labels * -100 # mask entityj
 | 
			
		||||
        # for i in range(bsz):
 | 
			
		||||
        #     logits[i][labels]
 | 
			
		||||
 | 
			
		||||
        _, outputs = torch.sort(logits, dim=1, descending=True)
 | 
			
		||||
        _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return dict(ranks = np.array(ranks))
 | 
			
		||||
 | 
			
		||||
    def validation_step(self, batch, batch_idx):
 | 
			
		||||
        result = self._eval(batch, batch_idx)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def validation_epoch_end(self, outputs) -> None:
 | 
			
		||||
        ranks = np.concatenate([_['ranks'] for _ in outputs])
 | 
			
		||||
        total_ranks = ranks.shape[0]
 | 
			
		||||
 | 
			
		||||
        if not self.args.pretrain:
 | 
			
		||||
            l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
 | 
			
		||||
            r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
 | 
			
		||||
            self.log("Eval/lhits10", (l_ranks<=10).mean())
 | 
			
		||||
            self.log("Eval/rhits10", (r_ranks<=10).mean())
 | 
			
		||||
 | 
			
		||||
        hits20 = (ranks<=20).mean()
 | 
			
		||||
        hits10 = (ranks<=10).mean()
 | 
			
		||||
        hits3 = (ranks<=3).mean()
 | 
			
		||||
        hits1 = (ranks<=1).mean()
 | 
			
		||||
 | 
			
		||||
        self.log("Eval/hits10", hits10)
 | 
			
		||||
        self.log("Eval/hits20", hits20)
 | 
			
		||||
        self.log("Eval/hits3", hits3)
 | 
			
		||||
        self.log("Eval/hits1", hits1)
 | 
			
		||||
        self.log("Eval/mean_rank", ranks.mean())
 | 
			
		||||
        self.log("Eval/mrr", (1. / ranks).mean())
 | 
			
		||||
        self.log("hits10", hits10, prog_bar=True)
 | 
			
		||||
        self.log("hits1", hits1, prog_bar=True)
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        # ranks = self._eval(batch, batch_idx)
 | 
			
		||||
        result = self._eval(batch, batch_idx)
 | 
			
		||||
        # self.log("Test/ranks", np.mean(ranks))
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
        ranks = np.concatenate([_['ranks'] for _ in outputs])
 | 
			
		||||
 | 
			
		||||
        hits20 = (ranks<=20).mean()
 | 
			
		||||
        hits10 = (ranks<=10).mean()
 | 
			
		||||
        hits3 = (ranks<=3).mean()
 | 
			
		||||
        hits1 = (ranks<=1).mean()
 | 
			
		||||
 | 
			
		||||
        self.log("Test/hits10", hits10)
 | 
			
		||||
        self.log("Test/hits20", hits20)
 | 
			
		||||
        self.log("Test/hits3", hits3)
 | 
			
		||||
        self.log("Test/hits1", hits1)
 | 
			
		||||
        self.log("Test/mean_rank", ranks.mean())
 | 
			
		||||
        self.log("Test/mrr", (1. / ranks).mean())
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        no_decay_param = ["bias", "LayerNorm.weight"]
 | 
			
		||||
 | 
			
		||||
        optimizer_group_parameters = [
 | 
			
		||||
            {"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
 | 
			
		||||
            {"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
 | 
			
		||||
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
 | 
			
		||||
        return {
 | 
			
		||||
            "optimizer": optimizer, 
 | 
			
		||||
            "lr_scheduler":{
 | 
			
		||||
                'scheduler': scheduler,
 | 
			
		||||
                'interval': 'step',  # or 'epoch'
 | 
			
		||||
                'frequency': 1,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    
 | 
			
		||||
    def _freaze_attention(self):
 | 
			
		||||
        for k, v in self.model.named_parameters():
 | 
			
		||||
            if "word" not in k:
 | 
			
		||||
                v.requires_grad = False
 | 
			
		||||
            else:
 | 
			
		||||
                print(k)
 | 
			
		||||
    
 | 
			
		||||
    def _freaze_word_embedding(self):
 | 
			
		||||
        for k, v in self.model.named_parameters():
 | 
			
		||||
            if "word" in k:
 | 
			
		||||
                print(k)
 | 
			
		||||
                v.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = BaseLitModel.add_to_argparse(parser)
 | 
			
		||||
 | 
			
		||||
        parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
 | 
			
		||||
        parser.add_argument("--bce", type=int, default=0, help="")
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import faiss
 | 
			
		||||
import os
 | 
			
		||||
class GetEntityEmbeddingLitModel(TransformerLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer, data_config={}):
 | 
			
		||||
        super().__init__(model, args, tokenizer, data_config)
 | 
			
		||||
 | 
			
		||||
        self.faissid2entityid = {}
 | 
			
		||||
        # self.index = faiss.IndexFlatL2(d)   # build the index
 | 
			
		||||
 | 
			
		||||
        d, measure = self.model.config.hidden_size, faiss.METRIC_L2   
 | 
			
		||||
        # param =  'HNSW64' 
 | 
			
		||||
        # self.index = faiss.index_factory(d, param, measure)  
 | 
			
		||||
        self.index = faiss.IndexFlatL2(d)   # build the index
 | 
			
		||||
        # print(self.index.is_trained)                          # 此时输出为True 
 | 
			
		||||
        # index.add(xb)
 | 
			
		||||
        self.cnt_batch = 0
 | 
			
		||||
        self.total_embedding = []
 | 
			
		||||
 | 
			
		||||
    def test_step(self, batch, batch_idx):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        mask_idx = batch.pop("pos")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
        # last layer 
 | 
			
		||||
        hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
 | 
			
		||||
        # _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
 | 
			
		||||
        # use normalize or not ?
 | 
			
		||||
        # entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
 | 
			
		||||
        self.total_embedding.append(entity_embedding)
 | 
			
		||||
        # self.index.add(np.array(entity_embedding, dtype=np.float32))
 | 
			
		||||
        for i, l in zip(range(bsz), label):
 | 
			
		||||
            self.faissid2entityid[i+self.cnt_batch] = l.cpu()
 | 
			
		||||
        self.cnt_batch += bsz
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
        self.total_embedding = np.concatenate(self.total_embedding, axis=0)
 | 
			
		||||
        # self.index.train(self.total_embedding)
 | 
			
		||||
        print(faiss.MatrixStats(self.total_embedding).comments)
 | 
			
		||||
        self.index.add(self.total_embedding)
 | 
			
		||||
        faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
 | 
			
		||||
        with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
 | 
			
		||||
            pickle.dump(self.faissid2entityid, file)
 | 
			
		||||
 | 
			
		||||
        with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
 | 
			
		||||
            pickle.dump(self.total_embedding, file)
 | 
			
		||||
        # print(f"number of  entity embedding : {len(self.faissid2entityid)}")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = TransformerLitModel.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
class UseEntityEmbeddingLitModel(TransformerLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer, data_config={}):
 | 
			
		||||
        super().__init__(model, args, tokenizer, data_config)
 | 
			
		||||
 | 
			
		||||
        self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
 | 
			
		||||
        self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self.dis2logits = distance2logits_2
 | 
			
		||||
    
 | 
			
		||||
    def _eval(self, batch, batch_idx, ):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
 | 
			
		||||
        topk = 200
 | 
			
		||||
        D, I = self.index.search(mask_embedding, topk)
 | 
			
		||||
        labels[torch.arange(bsz), label] = 0
 | 
			
		||||
 | 
			
		||||
        entity_logits = torch.full(labels.shape, -100.).to(self.device)
 | 
			
		||||
        D = self.dis2logits(D)
 | 
			
		||||
        for i in range(bsz):
 | 
			
		||||
            for j in range(topk):
 | 
			
		||||
                # filter entity in labels
 | 
			
		||||
                if I[i][j] not in self.faissid2entityid: 
 | 
			
		||||
                    print(I[i][j])
 | 
			
		||||
                    break
 | 
			
		||||
                # assert I[i][j] in self.faissid2entityid, print(I[i][j])
 | 
			
		||||
                if labels[i][self.faissid2entityid[I[i][j]]]: continue
 | 
			
		||||
                if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
 | 
			
		||||
                    entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
 | 
			
		||||
                # no added together
 | 
			
		||||
                # else:
 | 
			
		||||
                #     entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
 | 
			
		||||
        # get the entity ranks
 | 
			
		||||
        # filter the entity
 | 
			
		||||
 | 
			
		||||
        assert entity_logits.shape == labels.shape
 | 
			
		||||
 | 
			
		||||
        _, outputs = torch.sort(entity_logits, dim=1, descending=True)
 | 
			
		||||
        _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return dict(ranks = np.array(ranks))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = TransformerLitModel.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
 | 
			
		||||
        parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer, data_config={}):
 | 
			
		||||
        super().__init__(model, args, tokenizer, data_config=data_config)
 | 
			
		||||
        self.dis2logits = distance2logits_2
 | 
			
		||||
        self.id2entity = {}
 | 
			
		||||
        with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
 | 
			
		||||
            cnt = 0
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
                e, d = line.strip().split("\t")
 | 
			
		||||
                self.id2entity[cnt] = e
 | 
			
		||||
                cnt += 1
 | 
			
		||||
        self.id2entity_t = {}
 | 
			
		||||
        with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
                e, d = line.strip().split("\t")
 | 
			
		||||
                self.id2entity_t[e] = d
 | 
			
		||||
        for k, v in self.id2entity.items():
 | 
			
		||||
            self.id2entity[k] = self.id2entity_t[v]
 | 
			
		||||
    def _eval(self, batch, batch_idx, ):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
 | 
			
		||||
        result = self.model(**batch, return_dict=True, output_hidden_states=True)
 | 
			
		||||
        hidden_states = result.hidden_states[-1]
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
 | 
			
		||||
        # mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
 | 
			
		||||
        topk = self.args.knn_topk
 | 
			
		||||
        D, I = self.index.search(mask_embedding, topk)
 | 
			
		||||
        D = torch.from_numpy(D).to(self.device)
 | 
			
		||||
        assert labels[0][label[0]], "correct ids must in filiter!"
 | 
			
		||||
        labels[torch.arange(bsz), label] = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
        mask_logits = mask_logits[torch.arange(bsz), mask_idx]
 | 
			
		||||
        entity_logits = torch.full(labels.shape, 1000.).to(self.device)
 | 
			
		||||
        # D = self.dis2logits(D)
 | 
			
		||||
        for i in range(bsz):
 | 
			
		||||
            for j in range(topk):
 | 
			
		||||
                # filter entity in labels
 | 
			
		||||
                if labels[i][self.faissid2entityid[I[i][j]]]: continue
 | 
			
		||||
                if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
 | 
			
		||||
                    entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
 | 
			
		||||
                # else:
 | 
			
		||||
                #     entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
 | 
			
		||||
        entity_logits = self.dis2logits(entity_logits)
 | 
			
		||||
        # get the entity ranks
 | 
			
		||||
        # filter the entity
 | 
			
		||||
        assert entity_logits.shape == labels.shape
 | 
			
		||||
        assert mask_logits.shape == labels.shape
 | 
			
		||||
        # entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
 | 
			
		||||
        entity_logits = entity_logits + labels* -100.
 | 
			
		||||
        mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
 | 
			
		||||
        # logits = mask_logits
 | 
			
		||||
        logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
 | 
			
		||||
        # logits = entity_logits + mask_logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        knn_topk_logits, knn_topk_id  = entity_logits.topk(20)
 | 
			
		||||
        mask_topk_logits, mask_topk_id  = mask_logits.topk(20)
 | 
			
		||||
        union_topk = []
 | 
			
		||||
        for i in range(bsz):
 | 
			
		||||
            num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
 | 
			
		||||
            union_topk.append(num_same/ 20.)
 | 
			
		||||
        
 | 
			
		||||
        knn_topk_id = knn_topk_id.to("cpu")
 | 
			
		||||
        mask_topk_id = mask_topk_id.to("cpu")
 | 
			
		||||
        mask_topk_logits = mask_topk_logits.to("cpu")
 | 
			
		||||
        knn_topk_logits = knn_topk_logits.to("cpu")
 | 
			
		||||
        label = label.to("cpu")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        for t in range(bsz):
 | 
			
		||||
            if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
 | 
			
		||||
                print(knn_topk_logits[t], knn_topk_id[t])
 | 
			
		||||
                print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
 | 
			
		||||
                print(mask_topk_logits[t], mask_topk_id[t])
 | 
			
		||||
                print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
 | 
			
		||||
                print(label[t])
 | 
			
		||||
                print()
 | 
			
		||||
 | 
			
		||||
        _, outputs = torch.sort(logits, dim=1, descending=True)
 | 
			
		||||
        _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
 | 
			
		||||
            mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
 | 
			
		||||
    
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
 | 
			
		||||
        ranks = np.concatenate([_['ranks'] for _ in outputs])
 | 
			
		||||
        num_same = np.concatenate([_['num_same'] for _ in outputs])
 | 
			
		||||
        results_keys = list(outputs[0].keys())
 | 
			
		||||
        results = {}
 | 
			
		||||
        # for k in results_keys:
 | 
			
		||||
        #     results.
 | 
			
		||||
 | 
			
		||||
        self.log("Test/num_same", num_same.mean())
 | 
			
		||||
 | 
			
		||||
        hits20 = (ranks<=20).mean()
 | 
			
		||||
        hits10 = (ranks<=10).mean()
 | 
			
		||||
        hits3 = (ranks<=3).mean()
 | 
			
		||||
        hits1 = (ranks<=1).mean()
 | 
			
		||||
 | 
			
		||||
       
 | 
			
		||||
        self.log("Test/hits10", hits10)
 | 
			
		||||
        self.log("Test/hits20", hits20)
 | 
			
		||||
        self.log("Test/hits3", hits3)
 | 
			
		||||
        self.log("Test/hits1", hits1)
 | 
			
		||||
        self.log("Test/mean_rank", ranks.mean())
 | 
			
		||||
        self.log("Test/mrr", (1. / ranks).mean())
 | 
			
		||||
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = TransformerLitModel.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
 | 
			
		||||
        parser.add_argument("--knn_topk", type=int, default=100, help="")
 | 
			
		||||
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
 | 
			
		||||
    combine_probs = torch.stack([vocab_p, knn_p], dim=0)
 | 
			
		||||
    coeffs = torch.ones_like(combine_probs)
 | 
			
		||||
    coeffs[0] = np.log(1 - coeff)
 | 
			
		||||
    coeffs[1] = np.log(coeff)
 | 
			
		||||
    curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
 | 
			
		||||
 | 
			
		||||
    return curr_prob
 | 
			
		||||
 | 
			
		||||
def distance2logits(D):
 | 
			
		||||
    return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
 | 
			
		||||
 | 
			
		||||
def distance2logits_2(D, n=10):
 | 
			
		||||
    if not isinstance(D, torch.Tensor):
 | 
			
		||||
        D = torch.tensor(D)
 | 
			
		||||
    if torch.sum(D) != 0.0:
 | 
			
		||||
        distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
 | 
			
		||||
    return distances
 | 
			
		||||
							
								
								
									
										66
									
								
								lit_models/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								lit_models/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,66 @@
 | 
			
		||||
import json
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
def rank_score(ranks):
 | 
			
		||||
	# prepare the dataset
 | 
			
		||||
	len_samples = len(ranks)
 | 
			
		||||
	hits10 = [0] * len_samples
 | 
			
		||||
	hits5 = [0] * len_samples
 | 
			
		||||
	hits1 = [0] * len_samples
 | 
			
		||||
	mrr = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
	for idx, rank in enumerate(ranks):
 | 
			
		||||
		if rank <= 10:
 | 
			
		||||
			hits10[idx] = 1.
 | 
			
		||||
			if rank <= 5:
 | 
			
		||||
				hits5[idx] = 1.
 | 
			
		||||
				if rank <= 1:
 | 
			
		||||
					hits1[idx] = 1.
 | 
			
		||||
		mrr.append(1./rank)
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	return np.mean(hits10), np.mean(hits5), np.mean(hits1), np.mean(mrr)
 | 
			
		||||
 | 
			
		||||
def acc(logits, labels):
 | 
			
		||||
    preds = np.argmax(logits, axis=-1)
 | 
			
		||||
    return (preds == labels).mean()
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch
 | 
			
		||||
class LabelSmoothSoftmaxCEV1(nn.Module):
 | 
			
		||||
    '''
 | 
			
		||||
    This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
 | 
			
		||||
        super(LabelSmoothSoftmaxCEV1, self).__init__()
 | 
			
		||||
        self.lb_smooth = lb_smooth
 | 
			
		||||
        self.reduction = reduction
 | 
			
		||||
        self.lb_ignore = ignore_index
 | 
			
		||||
        self.log_softmax = nn.LogSoftmax(dim=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, logits, label):
 | 
			
		||||
        '''
 | 
			
		||||
        args: logits: tensor of shape (N, C, H, W)
 | 
			
		||||
        args: label: tensor of shape(N, H, W)
 | 
			
		||||
        '''
 | 
			
		||||
        # overcome ignored label
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            num_classes = logits.size(1)
 | 
			
		||||
            label = label.clone().detach()
 | 
			
		||||
            ignore = label == self.lb_ignore
 | 
			
		||||
            n_valid = (ignore == 0).sum()
 | 
			
		||||
            label[ignore] = 0
 | 
			
		||||
            lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
 | 
			
		||||
            label = torch.empty_like(logits).fill_(
 | 
			
		||||
                lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
 | 
			
		||||
 | 
			
		||||
        logs = self.log_softmax(logits)
 | 
			
		||||
        loss = -torch.sum(logs * label, dim=1)
 | 
			
		||||
        loss[ignore] = 0
 | 
			
		||||
        if self.reduction == 'mean':
 | 
			
		||||
            loss = loss.sum() / n_valid
 | 
			
		||||
        if self.reduction == 'sum':
 | 
			
		||||
            loss = loss.sum()
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
		Reference in New Issue
	
	Block a user