From c2b17ec1ba3f8ab5533d4b44cece1a9e058d9a05 Mon Sep 17 00:00:00 2001 From: Cong Thanh Vu Date: Sun, 16 Jun 2024 19:09:47 +0700 Subject: [PATCH] try to add attn --- data/icews14/relations.dict | 230 ++++++++++++++++++++++++++++++++++++ main.py | 43 +++---- models.py | 24 +++- 3 files changed, 275 insertions(+), 22 deletions(-) diff --git a/data/icews14/relations.dict b/data/icews14/relations.dict index cb17ee6..75c6433 100644 --- a/data/icews14/relations.dict +++ b/data/icews14/relations.dict @@ -12407,3 +12407,233 @@ 12406 Carry out roadside bombing[65] 12407 Appeal for target to allow international involvement (non-mediation)[1] 12408 Reject request for change in leadership[179] +12409 Criticize or denounce +12410 Express intent to meet or negotiate +12411 Consult +12412 Make an appeal or request +12413 Abduct, hijack, or take hostage +12414 Praise or endorse +12415 Engage in negotiation +12416 Use unconventional violence +12417 Make statement +12418 Arrest, detain, or charge with legal action +12419 Use conventional military force +12420 Complain officially +12421 Impose administrative sanctions +12422 Express intent to cooperate +12423 Make a visit +12424 Appeal for de-escalation of military engagement +12425 Sign formal agreement +12426 Attempt to assassinate +12427 Host a visit +12428 Increase military alert status +12429 Impose embargo, boycott, or sanctions +12430 Provide economic aid +12431 Demonstrate or rally +12432 Express intent to engage in diplomatic cooperation (such as policy support) +12433 Appeal for intelligence +12434 Demand +12435 Carry out suicide bombing +12436 Threaten +12437 Express intent to provide material aid +12438 Grant diplomatic recognition +12439 Meet at a 'third' location +12440 Accuse +12441 Investigate +12442 Reject +12443 Appeal for diplomatic cooperation (such as policy support) +12444 Engage in symbolic act +12445 Defy norms, law +12446 Consider policy option +12447 Provide aid +12448 Sexually assault +12449 Make empathetic comment +12450 Bring lawsuit against +12451 Impose blockade, restrict movement +12452 Make pessimistic comment +12453 Protest violently, riot +12454 Reduce or break diplomatic relations +12455 Grant asylum +12456 Engage in diplomatic cooperation +12457 Make optimistic comment +12458 Torture +12459 Refuse to yield +12460 Appeal for change in leadership +12461 Cooperate militarily +12462 Mobilize or increase armed forces +12463 fight with small arms and light weapons +12464 Ease administrative sanctions +12465 Appeal for political reform +12466 Return, release person(s) +12467 Discuss by telephone +12468 Demonstrate for leadership change +12469 Impose restrictions on political freedoms +12470 Reduce relations +12471 Investigate crime, corruption +12472 Engage in material cooperation +12473 Appeal to others to meet or negotiate +12474 Provide humanitarian aid +12475 Use tactics of violent repression +12476 Occupy territory +12477 Demand humanitarian aid +12478 Threaten non-force +12479 Express intent to cooperate economically +12480 Conduct suicide, car, or other non-military bombing +12481 Demand diplomatic cooperation (such as policy support) +12482 Demand meeting, negotiation +12483 Deny responsibility +12484 Express intent to change institutions, regime +12485 Give ultimatum +12486 Appeal for judicial cooperation +12487 Rally support on behalf of +12488 Obstruct passage, block +12489 Share intelligence or information +12490 Expel or deport individuals +12491 Confiscate property +12492 Accuse of aggression +12493 Physically assault +12494 Retreat or surrender militarily +12495 Veto +12496 Kill by physical assault +12497 Assassinate +12498 Appeal for change in institutions, regime +12499 Forgive +12500 Reject proposal to meet, discuss, or negotiate +12501 Express intent to provide humanitarian aid +12502 Appeal for release of persons or property +12503 Acknowledge or claim responsibility +12504 Ease economic sanctions, boycott, embargo +12505 Express intent to cooperate militarily +12506 Cooperate economically +12507 Express intent to provide economic aid +12508 Mobilize or increase police power +12509 Employ aerial weapons +12510 Accuse of human rights abuses +12511 Conduct strike or boycott +12512 Appeal for policy change +12513 Demonstrate military or police power +12514 Provide military aid +12515 Reject plan, agreement to settle dispute +12516 Yield +12517 Appeal for easing of administrative sanctions +12518 Mediate +12519 Apologize +12520 Express intent to release persons or property +12521 Express intent to de-escalate military engagement +12522 Accede to demands for rights +12523 Demand economic aid +12524 Impose state of emergency or martial law +12525 Receive deployment of peacekeepers +12526 Demand de-escalation of military engagement +12527 Declare truce, ceasefire +12528 Reduce or stop humanitarian assistance +12529 Appeal to others to settle dispute +12530 Reject request for military aid +12531 Threaten with political dissent, protest +12532 Appeal to engage in or accept mediation +12533 Express intent to ease economic sanctions, boycott, or embargo +12534 Coerce +12535 fight with artillery and tanks +12536 Express intent to cooperate on intelligence +12537 Express intent to settle dispute +12538 Express accord +12539 Decline comment +12540 Rally opposition against +12541 Halt negotiations +12542 Demand that target yields +12543 Appeal for military aid +12544 Threaten with military force +12545 Express intent to provide military protection or peacekeeping +12546 Threaten with sanctions, boycott, embargo +12547 Express intent to provide military aid +12548 Demand change in leadership +12549 Appeal for economic aid +12550 Refuse to de-escalate military engagement +12551 Refuse to release persons or property +12552 Increase police alert status +12553 Return, release property +12554 Ease military blockade +12555 Appeal for material cooperation +12556 Express intent to cooperate on judicial matters +12557 Appeal for economic cooperation +12558 Demand settling of dispute +12559 Accuse of crime, corruption +12560 Defend verbally +12561 Provide military protection or peacekeeping +12562 Accuse of espionage, treason +12563 Seize or damage property +12564 Accede to requests or demands for political reform +12565 Appeal for easing of economic sanctions, boycott, or embargo +12566 Threaten to reduce or stop aid +12567 Engage in judicial cooperation +12568 Appeal to yield +12569 Demand military aid +12570 Refuse to ease administrative sanctions +12571 Demand release of persons or property +12572 Accede to demands for change in leadership +12573 Appeal for humanitarian aid +12574 Threaten with repression +12575 Demand change in institutions, regime +12576 Demonstrate for policy change +12577 Appeal for aid +12578 Appeal for rights +12579 Engage in violent protest for rights +12580 Express intent to mediate +12581 Expel or withdraw peacekeepers +12582 Appeal for military protection or peacekeeping +12583 Engage in mass killings +12584 Accuse of war crimes +12585 Reject military cooperation +12586 Threaten to halt negotiations +12587 Ban political parties or politicians +12588 Express intent to change leadership +12589 Demand material cooperation +12590 Express intent to institute political reform +12591 Demand easing of administrative sanctions +12592 Express intent to engage in material cooperation +12593 Reduce or stop economic assistance +12594 Express intent to ease administrative sanctions +12595 Demand intelligence cooperation +12596 Ease curfew +12597 Receive inspectors +12598 Demand rights +12599 Demand political reform +12600 Demand judicial cooperation +12601 Engage in political dissent +12602 Detonate nuclear weapons +12603 Violate ceasefire +12604 Express intent to accept mediation +12605 Refuse to ease economic sanctions, boycott, or embargo +12606 Demand mediation +12607 Obstruct passage to demand leadership change +12608 Express intent to yield +12609 Conduct hunger strike +12610 Threaten to halt mediation +12611 Reject judicial cooperation +12612 Reduce or stop military assistance +12613 Ease political dissent +12614 Threaten to reduce or break relations +12615 Demobilize armed forces +12616 Use as human shield +12617 Demand policy change +12618 Accede to demands for change in institutions, regime +12619 Reject economic cooperation +12620 Reject material cooperation +12621 Halt mediation +12622 Accede to demands for change in policy +12623 Investigate war crimes +12624 Threaten with administrative sanctions +12625 Reduce or stop material aid +12626 Destroy property +12627 Express intent to change policy +12628 Use chemical, biological, or radiological weapons +12629 Reject request for military protection or peacekeeping +12630 Demand material aid +12631 Engage in mass expulsion +12632 Investigate human rights abuses +12633 Carry out car bombing +12634 Expel or withdraw +12635 Ease state of emergency or martial law +12636 Carry out roadside bombing +12637 Appeal for target to allow international involvement (non-mediation) +12638 Reject request for change in leadership \ No newline at end of file diff --git a/main.py b/main.py index b9349b9..0cc5d42 100644 --- a/main.py +++ b/main.py @@ -111,47 +111,48 @@ class Main(object): for split in ['train', 'test', 'valid']: for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)): sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t')) - sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj] - self.data[split].append((sub, rel, obj)) + nt_rel = rel.split('[')[0] + sub, rel, obj, nt_rel = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], self.rel2id[nt_rel] + self.data[split].append((sub, rel, obj, nt_rel)) if split == 'train': - sr2o[(sub, rel)].add(obj) - sr2o[(obj, rel+self.p.num_rel)].add(sub) + sr2o[(sub, rel, nt_rel)].add(obj) + sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub) self.data = dict(self.data) self.sr2o = {k: list(v) for k, v in sr2o.items()} for split in ['test', 'valid']: - for sub, rel, obj in self.data[split]: - sr2o[(sub, rel)].add(obj) - sr2o[(obj, rel+self.p.num_rel)].add(sub) + for sub, rel, obj, nt_rel in self.data[split]: + sr2o[(sub, rel, nt_rel)].add(obj) + sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub) self.sr2o_all = {k: list(v) for k, v in sr2o.items()} self.triples = ddict(list) if self.p.train_strategy == 'one_to_n': - for (sub, rel), obj in self.sr2o.items(): + for (sub, rel, nt_rel), obj in self.sr2o.items(): self.triples['train'].append( - {'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1}) + {'triple': (sub, rel, -1, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': 1}) else: - for sub, rel, obj in self.data['train']: + for sub, rel, obj, nt_rel in self.data['train']: rel_inv = rel + self.p.num_rel - sub_samp = len(self.sr2o[(sub, rel)]) + \ + sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \ len(self.sr2o[(obj, rel_inv)]) sub_samp = np.sqrt(1/sub_samp) self.triples['train'].append({'triple': ( - sub, rel, obj), 'label': self.sr2o[(sub, rel)], 'sub_samp': sub_samp}) + sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp}) self.triples['train'].append({'triple': ( - obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp}) + obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel)], 'sub_samp': sub_samp}) for split in ['test', 'valid']: - for sub, rel, obj in self.data[split]: + for sub, rel, obj, nt_rel in self.data[split]: rel_inv = rel + self.p.num_rel self.triples['{}_{}'.format(split, 'tail')].append( - {'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]}) + {'triple': (sub, rel, obj, nt_rel), 'label': self.sr2o_all[(sub, rel, nt_rel)]}) self.triples['{}_{}'.format(split, 'head')].append( - {'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]}) + {'triple': (obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o_all[(obj, rel_inv, nt_rel + self.p.num_rel)]}) self.triples = dict(self.triples) @@ -275,13 +276,13 @@ class Main(object): if self.p.train_strategy == 'one_to_x': triple, label, neg_ent, sub_samp = [ _.to(self.device) for _ in batch] - return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp + return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp else: triple, label = [_.to(self.device) for _ in batch] - return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None + return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None else: triple, label = [_.to(self.device) for _ in batch] - return triple[:, 0], triple[:, 1], triple[:, 2], label + return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label def save_model(self, save_path): """ @@ -474,10 +475,10 @@ class Main(object): for step, batch in enumerate(train_iter): self.optimizer.zero_grad() - sub, rel, obj, label, neg_ent, sub_samp = self.read_batch( + sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch( batch, 'train') - pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) + pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy) loss = self.model.loss(pred, label, sub_samp) loss.backward() diff --git a/models.py b/models.py index b262ffd..f8d4680 100644 --- a/models.py +++ b/models.py @@ -466,6 +466,10 @@ class FouriER(torch.nn.Module): self.p.ent_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.ent_fusion.weight) + self.ent_attn = torch.nn.Linear( + 128, 128) + torch.nn.init.xavier_normal_(self.ent_attn.weight) + self.rel_fusion = torch.nn.Linear( self.p.rel_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.rel_fusion.weight) @@ -547,8 +551,15 @@ class FouriER(torch.nn.Module): x = block(x) # output only the features of last layer for image classification return x + + def fuse_attention(self, s_embedding, l_embedding): + w1 = self.ent_attn(torch.tanh(s_embedding)) + w2 = self.ent_attn(torch.tanh(l_embedding)) + aff = F.softmax(torch.cat((w1,w2),1), 1) + en_embedding = aff[:,0].unsqueeze(1) * s_embedding + aff[:, 1].unsqueeze(1) * l_embedding + return en_embedding - def forward(self, sub, rel, neg_ents, strategy='one_to_x'): + def forward(self, sub, rel, nt_rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_fusion(self.ent_embed(sub)) rel_emb = self.rel_fusion(self.rel_embed(rel)) comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) @@ -557,6 +568,17 @@ class FouriER(torch.nn.Module): z = self.forward_embeddings(y) z = self.forward_tokens(z) z = z.mean([-2, -1]) + + nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel)) + comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) + y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w) + y_1 = self.bn0(y_1) + z_1 = self.forward_embeddings(y_1) + z_1 = self.forward_tokens(z_1) + z_1 = z_1.mean([-2, -1]) + + z = self.fuse_attention(z, z_1) + z = self.norm(z) x = self.head(z) x = self.hidden_drop(x)