try to add attn

This commit is contained in:
Cong Thanh Vu 2024-06-16 19:09:47 +07:00
parent f8e969cbd1
commit c2b17ec1ba
3 changed files with 275 additions and 22 deletions

View File

@ -12407,3 +12407,233 @@
12406 Carry out roadside bombing[65] 12406 Carry out roadside bombing[65]
12407 Appeal for target to allow international involvement (non-mediation)[1] 12407 Appeal for target to allow international involvement (non-mediation)[1]
12408 Reject request for change in leadership[179] 12408 Reject request for change in leadership[179]
12409 Criticize or denounce
12410 Express intent to meet or negotiate
12411 Consult
12412 Make an appeal or request
12413 Abduct, hijack, or take hostage
12414 Praise or endorse
12415 Engage in negotiation
12416 Use unconventional violence
12417 Make statement
12418 Arrest, detain, or charge with legal action
12419 Use conventional military force
12420 Complain officially
12421 Impose administrative sanctions
12422 Express intent to cooperate
12423 Make a visit
12424 Appeal for de-escalation of military engagement
12425 Sign formal agreement
12426 Attempt to assassinate
12427 Host a visit
12428 Increase military alert status
12429 Impose embargo, boycott, or sanctions
12430 Provide economic aid
12431 Demonstrate or rally
12432 Express intent to engage in diplomatic cooperation (such as policy support)
12433 Appeal for intelligence
12434 Demand
12435 Carry out suicide bombing
12436 Threaten
12437 Express intent to provide material aid
12438 Grant diplomatic recognition
12439 Meet at a 'third' location
12440 Accuse
12441 Investigate
12442 Reject
12443 Appeal for diplomatic cooperation (such as policy support)
12444 Engage in symbolic act
12445 Defy norms, law
12446 Consider policy option
12447 Provide aid
12448 Sexually assault
12449 Make empathetic comment
12450 Bring lawsuit against
12451 Impose blockade, restrict movement
12452 Make pessimistic comment
12453 Protest violently, riot
12454 Reduce or break diplomatic relations
12455 Grant asylum
12456 Engage in diplomatic cooperation
12457 Make optimistic comment
12458 Torture
12459 Refuse to yield
12460 Appeal for change in leadership
12461 Cooperate militarily
12462 Mobilize or increase armed forces
12463 fight with small arms and light weapons
12464 Ease administrative sanctions
12465 Appeal for political reform
12466 Return, release person(s)
12467 Discuss by telephone
12468 Demonstrate for leadership change
12469 Impose restrictions on political freedoms
12470 Reduce relations
12471 Investigate crime, corruption
12472 Engage in material cooperation
12473 Appeal to others to meet or negotiate
12474 Provide humanitarian aid
12475 Use tactics of violent repression
12476 Occupy territory
12477 Demand humanitarian aid
12478 Threaten non-force
12479 Express intent to cooperate economically
12480 Conduct suicide, car, or other non-military bombing
12481 Demand diplomatic cooperation (such as policy support)
12482 Demand meeting, negotiation
12483 Deny responsibility
12484 Express intent to change institutions, regime
12485 Give ultimatum
12486 Appeal for judicial cooperation
12487 Rally support on behalf of
12488 Obstruct passage, block
12489 Share intelligence or information
12490 Expel or deport individuals
12491 Confiscate property
12492 Accuse of aggression
12493 Physically assault
12494 Retreat or surrender militarily
12495 Veto
12496 Kill by physical assault
12497 Assassinate
12498 Appeal for change in institutions, regime
12499 Forgive
12500 Reject proposal to meet, discuss, or negotiate
12501 Express intent to provide humanitarian aid
12502 Appeal for release of persons or property
12503 Acknowledge or claim responsibility
12504 Ease economic sanctions, boycott, embargo
12505 Express intent to cooperate militarily
12506 Cooperate economically
12507 Express intent to provide economic aid
12508 Mobilize or increase police power
12509 Employ aerial weapons
12510 Accuse of human rights abuses
12511 Conduct strike or boycott
12512 Appeal for policy change
12513 Demonstrate military or police power
12514 Provide military aid
12515 Reject plan, agreement to settle dispute
12516 Yield
12517 Appeal for easing of administrative sanctions
12518 Mediate
12519 Apologize
12520 Express intent to release persons or property
12521 Express intent to de-escalate military engagement
12522 Accede to demands for rights
12523 Demand economic aid
12524 Impose state of emergency or martial law
12525 Receive deployment of peacekeepers
12526 Demand de-escalation of military engagement
12527 Declare truce, ceasefire
12528 Reduce or stop humanitarian assistance
12529 Appeal to others to settle dispute
12530 Reject request for military aid
12531 Threaten with political dissent, protest
12532 Appeal to engage in or accept mediation
12533 Express intent to ease economic sanctions, boycott, or embargo
12534 Coerce
12535 fight with artillery and tanks
12536 Express intent to cooperate on intelligence
12537 Express intent to settle dispute
12538 Express accord
12539 Decline comment
12540 Rally opposition against
12541 Halt negotiations
12542 Demand that target yields
12543 Appeal for military aid
12544 Threaten with military force
12545 Express intent to provide military protection or peacekeeping
12546 Threaten with sanctions, boycott, embargo
12547 Express intent to provide military aid
12548 Demand change in leadership
12549 Appeal for economic aid
12550 Refuse to de-escalate military engagement
12551 Refuse to release persons or property
12552 Increase police alert status
12553 Return, release property
12554 Ease military blockade
12555 Appeal for material cooperation
12556 Express intent to cooperate on judicial matters
12557 Appeal for economic cooperation
12558 Demand settling of dispute
12559 Accuse of crime, corruption
12560 Defend verbally
12561 Provide military protection or peacekeeping
12562 Accuse of espionage, treason
12563 Seize or damage property
12564 Accede to requests or demands for political reform
12565 Appeal for easing of economic sanctions, boycott, or embargo
12566 Threaten to reduce or stop aid
12567 Engage in judicial cooperation
12568 Appeal to yield
12569 Demand military aid
12570 Refuse to ease administrative sanctions
12571 Demand release of persons or property
12572 Accede to demands for change in leadership
12573 Appeal for humanitarian aid
12574 Threaten with repression
12575 Demand change in institutions, regime
12576 Demonstrate for policy change
12577 Appeal for aid
12578 Appeal for rights
12579 Engage in violent protest for rights
12580 Express intent to mediate
12581 Expel or withdraw peacekeepers
12582 Appeal for military protection or peacekeeping
12583 Engage in mass killings
12584 Accuse of war crimes
12585 Reject military cooperation
12586 Threaten to halt negotiations
12587 Ban political parties or politicians
12588 Express intent to change leadership
12589 Demand material cooperation
12590 Express intent to institute political reform
12591 Demand easing of administrative sanctions
12592 Express intent to engage in material cooperation
12593 Reduce or stop economic assistance
12594 Express intent to ease administrative sanctions
12595 Demand intelligence cooperation
12596 Ease curfew
12597 Receive inspectors
12598 Demand rights
12599 Demand political reform
12600 Demand judicial cooperation
12601 Engage in political dissent
12602 Detonate nuclear weapons
12603 Violate ceasefire
12604 Express intent to accept mediation
12605 Refuse to ease economic sanctions, boycott, or embargo
12606 Demand mediation
12607 Obstruct passage to demand leadership change
12608 Express intent to yield
12609 Conduct hunger strike
12610 Threaten to halt mediation
12611 Reject judicial cooperation
12612 Reduce or stop military assistance
12613 Ease political dissent
12614 Threaten to reduce or break relations
12615 Demobilize armed forces
12616 Use as human shield
12617 Demand policy change
12618 Accede to demands for change in institutions, regime
12619 Reject economic cooperation
12620 Reject material cooperation
12621 Halt mediation
12622 Accede to demands for change in policy
12623 Investigate war crimes
12624 Threaten with administrative sanctions
12625 Reduce or stop material aid
12626 Destroy property
12627 Express intent to change policy
12628 Use chemical, biological, or radiological weapons
12629 Reject request for military protection or peacekeeping
12630 Demand material aid
12631 Engage in mass expulsion
12632 Investigate human rights abuses
12633 Carry out car bombing
12634 Expel or withdraw
12635 Ease state of emergency or martial law
12636 Carry out roadside bombing
12637 Appeal for target to allow international involvement (non-mediation)
12638 Reject request for change in leadership

43
main.py
View File

@ -111,47 +111,48 @@ class Main(object):
for split in ['train', 'test', 'valid']: for split in ['train', 'test', 'valid']:
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)): for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t')) sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t'))
sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj] nt_rel = rel.split('[')[0]
self.data[split].append((sub, rel, obj)) sub, rel, obj, nt_rel = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], self.rel2id[nt_rel]
self.data[split].append((sub, rel, obj, nt_rel))
if split == 'train': if split == 'train':
sr2o[(sub, rel)].add(obj) sr2o[(sub, rel, nt_rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel)].add(sub) sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
self.data = dict(self.data) self.data = dict(self.data)
self.sr2o = {k: list(v) for k, v in sr2o.items()} self.sr2o = {k: list(v) for k, v in sr2o.items()}
for split in ['test', 'valid']: for split in ['test', 'valid']:
for sub, rel, obj in self.data[split]: for sub, rel, obj, nt_rel in self.data[split]:
sr2o[(sub, rel)].add(obj) sr2o[(sub, rel, nt_rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel)].add(sub) sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
self.sr2o_all = {k: list(v) for k, v in sr2o.items()} self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
self.triples = ddict(list) self.triples = ddict(list)
if self.p.train_strategy == 'one_to_n': if self.p.train_strategy == 'one_to_n':
for (sub, rel), obj in self.sr2o.items(): for (sub, rel, nt_rel), obj in self.sr2o.items():
self.triples['train'].append( self.triples['train'].append(
{'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1}) {'triple': (sub, rel, -1, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': 1})
else: else:
for sub, rel, obj in self.data['train']: for sub, rel, obj, nt_rel in self.data['train']:
rel_inv = rel + self.p.num_rel rel_inv = rel + self.p.num_rel
sub_samp = len(self.sr2o[(sub, rel)]) + \ sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
len(self.sr2o[(obj, rel_inv)]) len(self.sr2o[(obj, rel_inv)])
sub_samp = np.sqrt(1/sub_samp) sub_samp = np.sqrt(1/sub_samp)
self.triples['train'].append({'triple': ( self.triples['train'].append({'triple': (
sub, rel, obj), 'label': self.sr2o[(sub, rel)], 'sub_samp': sub_samp}) sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
self.triples['train'].append({'triple': ( self.triples['train'].append({'triple': (
obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp}) obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel)], 'sub_samp': sub_samp})
for split in ['test', 'valid']: for split in ['test', 'valid']:
for sub, rel, obj in self.data[split]: for sub, rel, obj, nt_rel in self.data[split]:
rel_inv = rel + self.p.num_rel rel_inv = rel + self.p.num_rel
self.triples['{}_{}'.format(split, 'tail')].append( self.triples['{}_{}'.format(split, 'tail')].append(
{'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]}) {'triple': (sub, rel, obj, nt_rel), 'label': self.sr2o_all[(sub, rel, nt_rel)]})
self.triples['{}_{}'.format(split, 'head')].append( self.triples['{}_{}'.format(split, 'head')].append(
{'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]}) {'triple': (obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o_all[(obj, rel_inv, nt_rel + self.p.num_rel)]})
self.triples = dict(self.triples) self.triples = dict(self.triples)
@ -275,13 +276,13 @@ class Main(object):
if self.p.train_strategy == 'one_to_x': if self.p.train_strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [ triple, label, neg_ent, sub_samp = [
_.to(self.device) for _ in batch] _.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
else: else:
triple, label = [_.to(self.device) for _ in batch] triple, label = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
else: else:
triple, label = [_.to(self.device) for _ in batch] triple, label = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], label return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label
def save_model(self, save_path): def save_model(self, save_path):
""" """
@ -474,10 +475,10 @@ class Main(object):
for step, batch in enumerate(train_iter): for step, batch in enumerate(train_iter):
self.optimizer.zero_grad() self.optimizer.zero_grad()
sub, rel, obj, label, neg_ent, sub_samp = self.read_batch( sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp) loss = self.model.loss(pred, label, sub_samp)
loss.backward() loss.backward()

View File

@ -466,6 +466,10 @@ class FouriER(torch.nn.Module):
self.p.ent_vec_dim, image_h*image_w) self.p.ent_vec_dim, image_h*image_w)
torch.nn.init.xavier_normal_(self.ent_fusion.weight) torch.nn.init.xavier_normal_(self.ent_fusion.weight)
self.ent_attn = torch.nn.Linear(
128, 128)
torch.nn.init.xavier_normal_(self.ent_attn.weight)
self.rel_fusion = torch.nn.Linear( self.rel_fusion = torch.nn.Linear(
self.p.rel_vec_dim, image_h*image_w) self.p.rel_vec_dim, image_h*image_w)
torch.nn.init.xavier_normal_(self.rel_fusion.weight) torch.nn.init.xavier_normal_(self.rel_fusion.weight)
@ -547,8 +551,15 @@ class FouriER(torch.nn.Module):
x = block(x) x = block(x)
# output only the features of last layer for image classification # output only the features of last layer for image classification
return x return x
def fuse_attention(self, s_embedding, l_embedding):
w1 = self.ent_attn(torch.tanh(s_embedding))
w2 = self.ent_attn(torch.tanh(l_embedding))
aff = F.softmax(torch.cat((w1,w2),1), 1)
en_embedding = aff[:,0].unsqueeze(1) * s_embedding + aff[:, 1].unsqueeze(1) * l_embedding
return en_embedding
def forward(self, sub, rel, neg_ents, strategy='one_to_x'): def forward(self, sub, rel, nt_rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_fusion(self.ent_embed(sub)) sub_emb = self.ent_fusion(self.ent_embed(sub))
rel_emb = self.rel_fusion(self.rel_embed(rel)) rel_emb = self.rel_fusion(self.rel_embed(rel))
comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
@ -557,6 +568,17 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
y_1 = self.bn0(y_1)
z_1 = self.forward_embeddings(y_1)
z_1 = self.forward_tokens(z_1)
z_1 = z_1.mean([-2, -1])
z = self.fuse_attention(z, z_1)
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)