This commit is contained in:
2023-02-12 10:57:29 +00:00
parent c0d0be076f
commit 6cc55301ad
5 changed files with 90 additions and 10 deletions

View File

@ -73,9 +73,54 @@ class TransformerLitModel(BaseLitModel):
def forward(self, x):
return self.model(x)
def create_negatives(self, batch):
negativeBatches = []
label = batch['label']
for i in range(label.shape[0]):
newBatch = {}
newBatch['attention_mask'] = None
newBatch['input_ids'] = torch.clone(batch['input_ids'])
newBatch['label'] = torch.zeros_like(batch['label'])
negativeBatches.append(newBatch)
entity_idx = []
self_label = []
for idx, l in enumerate(label):
decoded = self.decode([batch['input_ids'][idx]])[0].split(' ')
for j in range(1, len(decoded)):
if (decoded[j].startswith("[ENTITY_")):
entity_idx.append(j)
self_label.append(batch['input_ids'][idx][j])
break
for idx, lbl in enumerate(label):
for i in range(label.shape[0]):
if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl):
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl
else:
negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i]
return negativeBatches
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
# embed();exit()
# print(self.optimizers().param_groups[1]['lr'])
negativeBatches = self.create_negatives(batch)
loss = 0
for negativeBatch in negativeBatches:
label = negativeBatch.pop("label")
input_ids = batch['input_ids']
logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
loss += self.loss_fn(mask_logits, label)
labels = batch.pop("labels")
label = batch.pop("label")
pos = batch.pop("pos")
@ -110,9 +155,9 @@ class TransformerLitModel(BaseLitModel):
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
if self.args.bce:
loss = self.loss_fn(mask_logits, labels)
loss += self.loss_fn(mask_logits, labels)
else:
loss = self.loss_fn(mask_logits, label)
loss += self.loss_fn(mask_logits, label)
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))