first version of negative sampling
This commit is contained in:
		@@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
 | 
			
		||||
                                                  PreTrainedTokenizerBase)
 | 
			
		||||
 | 
			
		||||
from .base_data_module import BaseDataModule
 | 
			
		||||
from .processor import KGProcessor, get_dataset
 | 
			
		||||
from .processor import KGProcessor, get_dataset, getNegativeEntityId
 | 
			
		||||
import transformers
 | 
			
		||||
transformers.logging.set_verbosity_error()
 | 
			
		||||
 | 
			
		||||
@@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq:
 | 
			
		||||
                if isinstance(l, int): 
 | 
			
		||||
                    new_labels[i][l] = 1
 | 
			
		||||
                else:
 | 
			
		||||
                    for j in l:
 | 
			
		||||
                        new_labels[i][j] = 1
 | 
			
		||||
                    if (l[0] != getNegativeEntityId()):
 | 
			
		||||
                        for j in l:
 | 
			
		||||
                            new_labels[i][j] = 1
 | 
			
		||||
            labels = new_labels
 | 
			
		||||
 | 
			
		||||
        features = self.tokenizer.pad(
 | 
			
		||||
 
 | 
			
		||||
@@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
 | 
			
		||||
    global ent2id
 | 
			
		||||
    global ent2token
 | 
			
		||||
    global rel2id
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
 | 
			
		||||
    head_filter_entities = head
 | 
			
		||||
    tail_filter_entities = tail
 | 
			
		||||
@@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
 | 
			
		||||
    ent2id = ent2id_
 | 
			
		||||
    ent2token = ent2token_
 | 
			
		||||
    rel2id = rel2id_
 | 
			
		||||
    negativeEntity = ent2id['[NEG]']
 | 
			
		||||
 | 
			
		||||
def delete_init(ent2text_):
 | 
			
		||||
    global ent2text
 | 
			
		||||
    ent2text = ent2text_
 | 
			
		||||
 | 
			
		||||
def getEntityIdByName(name):
 | 
			
		||||
    global ent2id
 | 
			
		||||
    return ent2id[name]
 | 
			
		||||
 | 
			
		||||
def getNegativeEntityId():
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
    return negativeEntity
 | 
			
		||||
 | 
			
		||||
class KGProcessor(DataProcessor):
 | 
			
		||||
    """Processor for knowledge graph data set."""
 | 
			
		||||
@@ -377,6 +386,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        """Gets all entities in the knowledge graph."""
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            lines = f.readlines()
 | 
			
		||||
            lines.append('[NEG]\t')
 | 
			
		||||
            entities = []
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                entities.append(line.strip().split("\t")[0])
 | 
			
		||||
@@ -403,6 +413,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        ent2text_with_type = {}
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            ent_lines = f.readlines()
 | 
			
		||||
            ent_lines.append('[NEG]\t')
 | 
			
		||||
            for line in ent_lines:
 | 
			
		||||
                temp = line.strip().split('\t')
 | 
			
		||||
                try:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user