first version of negative sampling
This commit is contained in:
		@@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
 | 
			
		||||
                                                  PreTrainedTokenizerBase)
 | 
			
		||||
 | 
			
		||||
from .base_data_module import BaseDataModule
 | 
			
		||||
from .processor import KGProcessor, get_dataset
 | 
			
		||||
from .processor import KGProcessor, get_dataset, getNegativeEntityId
 | 
			
		||||
import transformers
 | 
			
		||||
transformers.logging.set_verbosity_error()
 | 
			
		||||
 | 
			
		||||
@@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq:
 | 
			
		||||
                if isinstance(l, int): 
 | 
			
		||||
                    new_labels[i][l] = 1
 | 
			
		||||
                else:
 | 
			
		||||
                    for j in l:
 | 
			
		||||
                        new_labels[i][j] = 1
 | 
			
		||||
                    if (l[0] != getNegativeEntityId()):
 | 
			
		||||
                        for j in l:
 | 
			
		||||
                            new_labels[i][j] = 1
 | 
			
		||||
            labels = new_labels
 | 
			
		||||
 | 
			
		||||
        features = self.tokenizer.pad(
 | 
			
		||||
 
 | 
			
		||||
@@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
 | 
			
		||||
    global ent2id
 | 
			
		||||
    global ent2token
 | 
			
		||||
    global rel2id
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
 | 
			
		||||
    head_filter_entities = head
 | 
			
		||||
    tail_filter_entities = tail
 | 
			
		||||
@@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
 | 
			
		||||
    ent2id = ent2id_
 | 
			
		||||
    ent2token = ent2token_
 | 
			
		||||
    rel2id = rel2id_
 | 
			
		||||
    negativeEntity = ent2id['[NEG]']
 | 
			
		||||
 | 
			
		||||
def delete_init(ent2text_):
 | 
			
		||||
    global ent2text
 | 
			
		||||
    ent2text = ent2text_
 | 
			
		||||
 | 
			
		||||
def getEntityIdByName(name):
 | 
			
		||||
    global ent2id
 | 
			
		||||
    return ent2id[name]
 | 
			
		||||
 | 
			
		||||
def getNegativeEntityId():
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
    return negativeEntity
 | 
			
		||||
 | 
			
		||||
class KGProcessor(DataProcessor):
 | 
			
		||||
    """Processor for knowledge graph data set."""
 | 
			
		||||
@@ -377,6 +386,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        """Gets all entities in the knowledge graph."""
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            lines = f.readlines()
 | 
			
		||||
            lines.append('[NEG]\t')
 | 
			
		||||
            entities = []
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                entities.append(line.strip().split("\t")[0])
 | 
			
		||||
@@ -403,6 +413,7 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        ent2text_with_type = {}
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            ent_lines = f.readlines()
 | 
			
		||||
            ent_lines.append('[NEG]\t')
 | 
			
		||||
            for line in ent_lines:
 | 
			
		||||
                temp = line.strip().split('\t')
 | 
			
		||||
                try:
 | 
			
		||||
 
 | 
			
		||||
@@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
 | 
			
		||||
        self.id2entity = {}
 | 
			
		||||
        with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
 | 
			
		||||
            cnt = 0
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
            lines = file.readlines()
 | 
			
		||||
            lines.append('[NEG]\t')
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                e, d = line.strip().split("\t")
 | 
			
		||||
                self.id2entity[cnt] = e
 | 
			
		||||
                cnt += 1
 | 
			
		||||
        self.id2entity_t = {}
 | 
			
		||||
        with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
            lines = file.readlines()
 | 
			
		||||
            lines.append('[NEG]\t')
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                e, d = line.strip().split("\t")
 | 
			
		||||
                self.id2entity_t[e] = d
 | 
			
		||||
        for k, v in self.id2entity.items():
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										0
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										0
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
		Reference in New Issue
	
	Block a user