first version of negative sampling
This commit is contained in:
		@@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
 | 
			
		||||
                                                  PreTrainedTokenizerBase)
 | 
			
		||||
 | 
			
		||||
from .base_data_module import BaseDataModule
 | 
			
		||||
from .processor import KGProcessor, get_dataset
 | 
			
		||||
from .processor import KGProcessor, get_dataset, getNegativeEntityId
 | 
			
		||||
import transformers
 | 
			
		||||
transformers.logging.set_verbosity_error()
 | 
			
		||||
 | 
			
		||||
@@ -105,8 +105,9 @@ class DataCollatorForSeq2Seq:
 | 
			
		||||
                if isinstance(l, int): 
 | 
			
		||||
                    new_labels[i][l] = 1
 | 
			
		||||
                else:
 | 
			
		||||
                    for j in l:
 | 
			
		||||
                        new_labels[i][j] = 1
 | 
			
		||||
                    if (l[0] != getNegativeEntityId()):
 | 
			
		||||
                        for j in l:
 | 
			
		||||
                            new_labels[i][j] = 1
 | 
			
		||||
            labels = new_labels
 | 
			
		||||
 | 
			
		||||
        features = self.tokenizer.pad(
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ import contextlib
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from collections import Counter
 | 
			
		||||
from multiprocessing import Pool
 | 
			
		||||
from multiprocessing import Pool, synchronize
 | 
			
		||||
from torch._C import HOIST_CONV_PACKED_PARAMS
 | 
			
		||||
from torch.utils.data import Dataset, Sampler, IterableDataset
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
@@ -235,6 +235,31 @@ class DataProcessor(object):
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
from collections import deque
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _LiveState(type):
 | 
			
		||||
    _instances = {}
 | 
			
		||||
    _lock = threading.Lock()
 | 
			
		||||
 | 
			
		||||
    def __call__(cls, *args, **kwargs):
 | 
			
		||||
        if cls not in cls._instances:
 | 
			
		||||
            with cls._lock:
 | 
			
		||||
                if cls not in cls._instances:
 | 
			
		||||
                    cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs)
 | 
			
		||||
        
 | 
			
		||||
        return cls._instances[cls]
 | 
			
		||||
 | 
			
		||||
class LiveState(metaclass=_LiveState):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._pool_size = 16
 | 
			
		||||
        self._deq = deque(maxlen=self._pool_size)
 | 
			
		||||
    def put(self, item):
 | 
			
		||||
        self._deq.append(item)
 | 
			
		||||
    def get(self):
 | 
			
		||||
        return list(self._deq)
 | 
			
		||||
 | 
			
		||||
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
 | 
			
		||||
    """
 | 
			
		||||
@@ -364,6 +389,53 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32):
 | 
			
		||||
            InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list))
 | 
			
		||||
        examples.append(
 | 
			
		||||
            InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list))
 | 
			
		||||
        
 | 
			
		||||
        liveState = LiveState()
 | 
			
		||||
        _prev = liveState.get()
 | 
			
		||||
 | 
			
		||||
        if (set_type == "train" and len(_prev) > 0):
 | 
			
		||||
 | 
			
		||||
            for prev_ent in _prev:
 | 
			
		||||
 | 
			
		||||
                z = head_filter_entities["\t".join([prev_ent,line[1]])]
 | 
			
		||||
                if (len(z) == 0):
 | 
			
		||||
                    z.append('[NEG]')
 | 
			
		||||
                    z.append(line[2])
 | 
			
		||||
                    z.append(line[0])
 | 
			
		||||
                masked_neg_seq = set()
 | 
			
		||||
                masked_neg_seq_id = set()
 | 
			
		||||
 | 
			
		||||
                masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \
 | 
			
		||||
                    random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet)
 | 
			
		||||
 | 
			
		||||
                for item in masked_neg_graph_list:
 | 
			
		||||
                    masked_neg_seq.add(item[0])
 | 
			
		||||
                    masked_neg_seq.add(item[1])
 | 
			
		||||
                    masked_neg_seq.add(item[2])
 | 
			
		||||
                    masked_neg_seq_id.add(ent2id[item[0]])
 | 
			
		||||
                    masked_neg_seq_id.add(rel2id[item[1]])
 | 
			
		||||
                    masked_neg_seq_id.add(ent2id[item[2]])
 | 
			
		||||
                    
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({line[0]})
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({line[2]})
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({line[1]})
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference(prev_ent)
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference(prev_ent)
 | 
			
		||||
                # examples.append(
 | 
			
		||||
                #     InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
 | 
			
		||||
                # examples.append(
 | 
			
		||||
                #     InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
 | 
			
		||||
                examples.append(
 | 
			
		||||
                    InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list))
 | 
			
		||||
                examples.append(    
 | 
			
		||||
                    InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        liveState.put(line[0])
 | 
			
		||||
        liveState.put(line[2])
 | 
			
		||||
    return examples
 | 
			
		||||
 | 
			
		||||
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_):
 | 
			
		||||
@@ -377,6 +449,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
 | 
			
		||||
    global masked_head_neighbor
 | 
			
		||||
    global masked_tail_neighbor
 | 
			
		||||
    global rel2token
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
 | 
			
		||||
    head_filter_entities = head
 | 
			
		||||
    tail_filter_entities = tail
 | 
			
		||||
@@ -388,11 +461,20 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
 | 
			
		||||
    masked_head_neighbor = masked_head_neighbor_
 | 
			
		||||
    masked_tail_neighbor = masked_tail_neighbor_
 | 
			
		||||
    rel2token = rel2token_
 | 
			
		||||
    negativeEntity = ent2id['[NEG]']
 | 
			
		||||
 | 
			
		||||
def delete_init(ent2text_):
 | 
			
		||||
    global ent2text
 | 
			
		||||
    ent2text = ent2text_
 | 
			
		||||
 | 
			
		||||
def getEntityIdByName(name):
 | 
			
		||||
    global ent2id
 | 
			
		||||
    return ent2id[name]
 | 
			
		||||
 | 
			
		||||
def getNegativeEntityId():
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
    return negativeEntity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KGProcessor(DataProcessor):
 | 
			
		||||
    """Processor for knowledge graph data set."""
 | 
			
		||||
@@ -443,6 +525,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        """Gets all entities in the knowledge graph."""
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            lines = f.readlines()
 | 
			
		||||
            lines.append('[NEG]\t')
 | 
			
		||||
            entities = []
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                entities.append(line.strip().split("\t")[0])
 | 
			
		||||
@@ -469,6 +552,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        ent2text_with_type = {}
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            ent_lines = f.readlines()
 | 
			
		||||
            ent_lines.append('[NEG]\t')
 | 
			
		||||
            for line in ent_lines:
 | 
			
		||||
                temp = line.strip().split('\t')
 | 
			
		||||
                try:
 | 
			
		||||
@@ -579,6 +663,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        else:
 | 
			
		||||
            annotate_ = partial(
 | 
			
		||||
                solve,
 | 
			
		||||
                set_type=set_type,
 | 
			
		||||
                pretrain=self.args.pretrain,
 | 
			
		||||
                max_triplet=self.args.max_triplet
 | 
			
		||||
            )
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user