Compare commits
	
		
			16 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 47bc661a91 | ||
|  | 3b6db89be1 | ||
|  | 352f5f9da9 | ||
|  | b9273b6696 | ||
|  | d0e4630dd6 | ||
|  | 08a3780ba6 | ||
|  | 6fc56b920f | ||
|  | fddea4769f | ||
|  | d9209a7ef1 | ||
|  | 0f986d7517 | ||
|  | 4daa40527b | ||
|  | 541c4fa2b3 | ||
|  | 68a94bd1e2 | ||
|  | b01e504874 | ||
|  | 23c44d3582 | ||
|  | 41a5c7b05a | 
| @@ -12407,233 +12407,3 @@ | ||||
| 12406	Carry out roadside bombing[65] | ||||
| 12407	Appeal for target to allow international involvement (non-mediation)[1] | ||||
| 12408	Reject request for change in leadership[179] | ||||
| 12409	Criticize or denounce | ||||
| 12410	Express intent to meet or negotiate | ||||
| 12411	Consult | ||||
| 12412	Make an appeal or request | ||||
| 12413	Abduct, hijack, or take hostage | ||||
| 12414	Praise or endorse | ||||
| 12415	Engage in negotiation | ||||
| 12416	Use unconventional violence | ||||
| 12417	Make statement | ||||
| 12418	Arrest, detain, or charge with legal action | ||||
| 12419	Use conventional military force | ||||
| 12420	Complain officially | ||||
| 12421	Impose administrative sanctions | ||||
| 12422	Express intent to cooperate | ||||
| 12423	Make a visit | ||||
| 12424	Appeal for de-escalation of military engagement | ||||
| 12425	Sign formal agreement | ||||
| 12426	Attempt to assassinate | ||||
| 12427	Host a visit | ||||
| 12428	Increase military alert status | ||||
| 12429	Impose embargo, boycott, or sanctions | ||||
| 12430	Provide economic aid | ||||
| 12431	Demonstrate or rally | ||||
| 12432	Express intent to engage in diplomatic cooperation (such as policy support) | ||||
| 12433	Appeal for intelligence | ||||
| 12434	Demand | ||||
| 12435	Carry out suicide bombing | ||||
| 12436	Threaten | ||||
| 12437	Express intent to provide material aid | ||||
| 12438	Grant diplomatic recognition | ||||
| 12439	Meet at a 'third' location | ||||
| 12440	Accuse | ||||
| 12441	Investigate | ||||
| 12442	Reject | ||||
| 12443	Appeal for diplomatic cooperation (such as policy support) | ||||
| 12444	Engage in symbolic act | ||||
| 12445	Defy norms, law | ||||
| 12446	Consider policy option | ||||
| 12447	Provide aid | ||||
| 12448	Sexually assault | ||||
| 12449	Make empathetic comment | ||||
| 12450	Bring lawsuit against | ||||
| 12451	Impose blockade, restrict movement | ||||
| 12452	Make pessimistic comment | ||||
| 12453	Protest violently, riot | ||||
| 12454	Reduce or break diplomatic relations | ||||
| 12455	Grant asylum | ||||
| 12456	Engage in diplomatic cooperation | ||||
| 12457	Make optimistic comment | ||||
| 12458	Torture | ||||
| 12459	Refuse to yield | ||||
| 12460	Appeal for change in leadership | ||||
| 12461	Cooperate militarily | ||||
| 12462	Mobilize or increase armed forces | ||||
| 12463	fight with small arms and light weapons | ||||
| 12464	Ease administrative sanctions | ||||
| 12465	Appeal for political reform | ||||
| 12466	Return, release person(s) | ||||
| 12467	Discuss by telephone | ||||
| 12468	Demonstrate for leadership change | ||||
| 12469	Impose restrictions on political freedoms | ||||
| 12470	Reduce relations | ||||
| 12471	Investigate crime, corruption | ||||
| 12472	Engage in material cooperation | ||||
| 12473	Appeal to others to meet or negotiate | ||||
| 12474	Provide humanitarian aid | ||||
| 12475	Use tactics of violent repression | ||||
| 12476	Occupy territory | ||||
| 12477	Demand humanitarian aid | ||||
| 12478	Threaten non-force | ||||
| 12479	Express intent to cooperate economically | ||||
| 12480	Conduct suicide, car, or other non-military bombing | ||||
| 12481	Demand diplomatic cooperation (such as policy support) | ||||
| 12482	Demand meeting, negotiation | ||||
| 12483	Deny responsibility | ||||
| 12484	Express intent to change institutions, regime | ||||
| 12485	Give ultimatum | ||||
| 12486	Appeal for judicial cooperation | ||||
| 12487	Rally support on behalf of | ||||
| 12488	Obstruct passage, block | ||||
| 12489	Share intelligence or information | ||||
| 12490	Expel or deport individuals | ||||
| 12491	Confiscate property | ||||
| 12492	Accuse of aggression | ||||
| 12493	Physically assault | ||||
| 12494	Retreat or surrender militarily | ||||
| 12495	Veto | ||||
| 12496	Kill by physical assault | ||||
| 12497	Assassinate | ||||
| 12498	Appeal for change in institutions, regime | ||||
| 12499	Forgive | ||||
| 12500	Reject proposal to meet, discuss, or negotiate | ||||
| 12501	Express intent to provide humanitarian aid | ||||
| 12502	Appeal for release of persons or property | ||||
| 12503	Acknowledge or claim responsibility | ||||
| 12504	Ease economic sanctions, boycott, embargo | ||||
| 12505	Express intent to cooperate militarily | ||||
| 12506	Cooperate economically | ||||
| 12507	Express intent to provide economic aid | ||||
| 12508	Mobilize or increase police power | ||||
| 12509	Employ aerial weapons | ||||
| 12510	Accuse of human rights abuses | ||||
| 12511	Conduct strike or boycott | ||||
| 12512	Appeal for policy change | ||||
| 12513	Demonstrate military or police power | ||||
| 12514	Provide military aid | ||||
| 12515	Reject plan, agreement to settle dispute | ||||
| 12516	Yield | ||||
| 12517	Appeal for easing of administrative sanctions | ||||
| 12518	Mediate | ||||
| 12519	Apologize | ||||
| 12520	Express intent to release persons or property | ||||
| 12521	Express intent to de-escalate military engagement | ||||
| 12522	Accede to demands for rights | ||||
| 12523	Demand economic aid | ||||
| 12524	Impose state of emergency or martial law | ||||
| 12525	Receive deployment of peacekeepers | ||||
| 12526	Demand de-escalation of military engagement | ||||
| 12527	Declare truce, ceasefire | ||||
| 12528	Reduce or stop humanitarian assistance | ||||
| 12529	Appeal to others to settle dispute | ||||
| 12530	Reject request for military aid | ||||
| 12531	Threaten with political dissent, protest | ||||
| 12532	Appeal to engage in or accept mediation | ||||
| 12533	Express intent to ease economic sanctions, boycott, or embargo | ||||
| 12534	Coerce | ||||
| 12535	fight with artillery and tanks | ||||
| 12536	Express intent to cooperate on intelligence | ||||
| 12537	Express intent to settle dispute | ||||
| 12538	Express accord | ||||
| 12539	Decline comment | ||||
| 12540	Rally opposition against | ||||
| 12541	Halt negotiations | ||||
| 12542	Demand that target yields | ||||
| 12543	Appeal for military aid | ||||
| 12544	Threaten with military force | ||||
| 12545	Express intent to provide military protection or peacekeeping | ||||
| 12546	Threaten with sanctions, boycott, embargo | ||||
| 12547	Express intent to provide military aid | ||||
| 12548	Demand change in leadership | ||||
| 12549	Appeal for economic aid | ||||
| 12550	Refuse to de-escalate military engagement | ||||
| 12551	Refuse to release persons or property | ||||
| 12552	Increase police alert status | ||||
| 12553	Return, release property | ||||
| 12554	Ease military blockade | ||||
| 12555	Appeal for material cooperation | ||||
| 12556	Express intent to cooperate on judicial matters | ||||
| 12557	Appeal for economic cooperation | ||||
| 12558	Demand settling of dispute | ||||
| 12559	Accuse of crime, corruption | ||||
| 12560	Defend verbally | ||||
| 12561	Provide military protection or peacekeeping | ||||
| 12562	Accuse of espionage, treason | ||||
| 12563	Seize or damage property | ||||
| 12564	Accede to requests or demands for political reform | ||||
| 12565	Appeal for easing of economic sanctions, boycott, or embargo | ||||
| 12566	Threaten to reduce or stop aid | ||||
| 12567	Engage in judicial cooperation | ||||
| 12568	Appeal to yield | ||||
| 12569	Demand military aid | ||||
| 12570	Refuse to ease administrative sanctions | ||||
| 12571	Demand release of persons or property | ||||
| 12572	Accede to demands for change in leadership | ||||
| 12573	Appeal for humanitarian aid | ||||
| 12574	Threaten with repression | ||||
| 12575	Demand change in institutions, regime | ||||
| 12576	Demonstrate for policy change | ||||
| 12577	Appeal for aid | ||||
| 12578	Appeal for rights | ||||
| 12579	Engage in violent protest for rights | ||||
| 12580	Express intent to mediate | ||||
| 12581	Expel or withdraw peacekeepers | ||||
| 12582	Appeal for military protection or peacekeeping | ||||
| 12583	Engage in mass killings | ||||
| 12584	Accuse of war crimes | ||||
| 12585	Reject military cooperation | ||||
| 12586	Threaten to halt negotiations | ||||
| 12587	Ban political parties or politicians | ||||
| 12588	Express intent to change leadership | ||||
| 12589	Demand material cooperation | ||||
| 12590	Express intent to institute political reform | ||||
| 12591	Demand easing of administrative sanctions | ||||
| 12592	Express intent to engage in material cooperation | ||||
| 12593	Reduce or stop economic assistance | ||||
| 12594	Express intent to ease administrative sanctions | ||||
| 12595	Demand intelligence cooperation | ||||
| 12596	Ease curfew | ||||
| 12597	Receive inspectors | ||||
| 12598	Demand rights | ||||
| 12599	Demand political reform | ||||
| 12600	Demand judicial cooperation | ||||
| 12601	Engage in political dissent | ||||
| 12602	Detonate nuclear weapons | ||||
| 12603	Violate ceasefire | ||||
| 12604	Express intent to accept mediation | ||||
| 12605	Refuse to ease economic sanctions, boycott, or embargo | ||||
| 12606	Demand mediation | ||||
| 12607	Obstruct passage to demand leadership change | ||||
| 12608	Express intent to yield | ||||
| 12609	Conduct hunger strike | ||||
| 12610	Threaten to halt mediation | ||||
| 12611	Reject judicial cooperation | ||||
| 12612	Reduce or stop military assistance | ||||
| 12613	Ease political dissent | ||||
| 12614	Threaten to reduce or break relations | ||||
| 12615	Demobilize armed forces | ||||
| 12616	Use as human shield | ||||
| 12617	Demand policy change | ||||
| 12618	Accede to demands for change in institutions, regime | ||||
| 12619	Reject economic cooperation | ||||
| 12620	Reject material cooperation | ||||
| 12621	Halt mediation | ||||
| 12622	Accede to demands for change in policy | ||||
| 12623	Investigate war crimes | ||||
| 12624	Threaten with administrative sanctions | ||||
| 12625	Reduce or stop material aid | ||||
| 12626	Destroy property | ||||
| 12627	Express intent to change policy | ||||
| 12628	Use chemical, biological, or radiological weapons | ||||
| 12629	Reject request for military protection or peacekeeping | ||||
| 12630	Demand material aid | ||||
| 12631	Engage in mass expulsion | ||||
| 12632	Investigate human rights abuses | ||||
| 12633	Carry out car bombing | ||||
| 12634	Expel or withdraw | ||||
| 12635	Ease state of emergency or martial law | ||||
| 12636	Carry out roadside bombing | ||||
| 12637	Appeal for target to allow international involvement (non-mediation) | ||||
| 12638	Reject request for change in leadership | ||||
| @@ -421,27 +421,3 @@ | ||||
| 420	P551[36-69] | ||||
| 421	P579[0-15] | ||||
| 422	P102[54-62] | ||||
| 423	P131 | ||||
| 424	P1435 | ||||
| 425	P39 | ||||
| 426	P54 | ||||
| 427	P31 | ||||
| 428	P463 | ||||
| 429	P512 | ||||
| 430	P190 | ||||
| 431	P150 | ||||
| 432	P1376 | ||||
| 433	P166 | ||||
| 434	P2962 | ||||
| 435	P108 | ||||
| 436	P17 | ||||
| 437	P793 | ||||
| 438	P69 | ||||
| 439	P26 | ||||
| 440	P579 | ||||
| 441	P1411 | ||||
| 442	P6 | ||||
| 443	P1346 | ||||
| 444	P102 | ||||
| 445	P27 | ||||
| 446	P551 | ||||
|   | ||||
							
								
								
									
										53
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								main.py
									
									
									
									
									
								
							| @@ -91,11 +91,9 @@ class Main(object): | ||||
|         for line in open('./data/{}/{}'.format(self.p.dataset, "relations.dict")): | ||||
|             id, rel = map(str.lower, line.strip().split('\t')) | ||||
|             self.rel2id[rel] = int(id) | ||||
|             rel_set.add(rel) | ||||
|  | ||||
|         # self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)} | ||||
|         # self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)} | ||||
|          | ||||
|         self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) | ||||
|                            for idx, rel in enumerate(rel_set)}) | ||||
|  | ||||
| @@ -113,48 +111,47 @@ class Main(object): | ||||
|         for split in ['train', 'test', 'valid']: | ||||
|             for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)): | ||||
|                 sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t')) | ||||
|                 nt_rel = rel.split('[')[0] | ||||
|                 sub, rel, obj, nt_rel = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], self.rel2id[nt_rel] | ||||
|                 self.data[split].append((sub, rel, obj, nt_rel)) | ||||
|                 sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj] | ||||
|                 self.data[split].append((sub, rel, obj)) | ||||
|  | ||||
|                 if split == 'train': | ||||
|                     sr2o[(sub, rel, nt_rel)].add(obj) | ||||
|                     sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub) | ||||
|                     sr2o[(sub, rel)].add(obj) | ||||
|                     sr2o[(obj, rel+self.p.num_rel)].add(sub) | ||||
|         self.data = dict(self.data) | ||||
|  | ||||
|         self.sr2o = {k: list(v) for k, v in sr2o.items()} | ||||
|         for split in ['test', 'valid']: | ||||
|             for sub, rel, obj, nt_rel in self.data[split]: | ||||
|                 sr2o[(sub, rel, nt_rel)].add(obj) | ||||
|                 sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub) | ||||
|             for sub, rel, obj in self.data[split]: | ||||
|                 sr2o[(sub, rel)].add(obj) | ||||
|                 sr2o[(obj, rel+self.p.num_rel)].add(sub) | ||||
|  | ||||
|         self.sr2o_all = {k: list(v) for k, v in sr2o.items()} | ||||
|  | ||||
|         self.triples = ddict(list) | ||||
|  | ||||
|         if self.p.train_strategy == 'one_to_n': | ||||
|             for (sub, rel, nt_rel), obj in self.sr2o.items(): | ||||
|             for (sub, rel), obj in self.sr2o.items(): | ||||
|                 self.triples['train'].append( | ||||
|                     {'triple': (sub, rel, -1, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': 1}) | ||||
|                     {'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1}) | ||||
|         else: | ||||
|             for sub, rel, obj, nt_rel in self.data['train']: | ||||
|             for sub, rel, obj in self.data['train']: | ||||
|                 rel_inv = rel + self.p.num_rel | ||||
|                 sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \ | ||||
|                     len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)]) | ||||
|                 sub_samp = len(self.sr2o[(sub, rel)]) + \ | ||||
|                     len(self.sr2o[(obj, rel_inv)]) | ||||
|                 sub_samp = np.sqrt(1/sub_samp) | ||||
|  | ||||
|                 self.triples['train'].append({'triple': ( | ||||
|                     sub, rel, obj, nt_rel),     'label': self.sr2o[(sub, rel, nt_rel)],     'sub_samp': sub_samp}) | ||||
|                     sub, rel, obj),     'label': self.sr2o[(sub, rel)],     'sub_samp': sub_samp}) | ||||
|                 self.triples['train'].append({'triple': ( | ||||
|                     obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp}) | ||||
|                     obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp}) | ||||
|  | ||||
|         for split in ['test', 'valid']: | ||||
|             for sub, rel, obj, nt_rel in self.data[split]: | ||||
|             for sub, rel, obj in self.data[split]: | ||||
|                 rel_inv = rel + self.p.num_rel | ||||
|                 self.triples['{}_{}'.format(split, 'tail')].append( | ||||
|                     {'triple': (sub, rel, obj, nt_rel), 	   'label': self.sr2o_all[(sub, rel, nt_rel)]}) | ||||
|                     {'triple': (sub, rel, obj), 	   'label': self.sr2o_all[(sub, rel)]}) | ||||
|                 self.triples['{}_{}'.format(split, 'head')].append( | ||||
|                     {'triple': (obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o_all[(obj, rel_inv, nt_rel + self.p.num_rel)]}) | ||||
|                     {'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]}) | ||||
|  | ||||
|         self.triples = dict(self.triples) | ||||
|  | ||||
| @@ -278,13 +275,13 @@ class Main(object): | ||||
|             if self.p.train_strategy == 'one_to_x': | ||||
|                 triple, label, neg_ent, sub_samp = [ | ||||
|                     _.to(self.device) for _ in batch] | ||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp | ||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp | ||||
|             else: | ||||
|                 triple, label = [_.to(self.device) for _ in batch] | ||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None | ||||
|                 return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None | ||||
|         else: | ||||
|             triple, label = [_.to(self.device) for _ in batch] | ||||
|             return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label | ||||
|             return triple[:, 0], triple[:, 1], triple[:, 2], label | ||||
|  | ||||
|     def save_model(self, save_path): | ||||
|         """ | ||||
| @@ -419,8 +416,8 @@ class Main(object): | ||||
|             obj_pred = [] | ||||
|             obj_pred_score = [] | ||||
|             for step, batch in enumerate(train_iter): | ||||
|                 sub, rel, obj, nt_rel, label = self.read_batch(batch, split) | ||||
|                 pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n') | ||||
|                 sub, rel, obj, label = self.read_batch(batch, split) | ||||
|                 pred = self.model.forward(sub, rel, None, 'one_to_n') | ||||
|                 b_range = torch.arange(pred.size()[0], device=self.device) | ||||
|                 target_pred = pred[b_range, obj] | ||||
|                 pred = torch.where(label.byte(), torch.zeros_like(pred), pred) | ||||
| @@ -477,10 +474,10 @@ class Main(object): | ||||
|         for step, batch in enumerate(train_iter): | ||||
|             self.optimizer.zero_grad() | ||||
|  | ||||
|             sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch( | ||||
|             sub, rel, obj, label, neg_ent, sub_samp = self.read_batch( | ||||
|                 batch, 'train') | ||||
|  | ||||
|             pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy) | ||||
|             pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) | ||||
|             loss = self.model.loss(pred, label, sub_samp) | ||||
|  | ||||
|             loss.backward() | ||||
| @@ -693,7 +690,7 @@ if __name__ == "__main__": | ||||
|                 collate_fn=TrainDataset.collate_fn | ||||
|             )) | ||||
|         for step, batch in enumerate(dataloader): | ||||
|             sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch( | ||||
|             sub, rel, obj, label, neg_ent, sub_samp = model.read_batch( | ||||
|                 batch, 'train') | ||||
|              | ||||
|             if (neg_ent is None): | ||||
|   | ||||
							
								
								
									
										363
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										363
									
								
								models.py
									
									
									
									
									
								
							| @@ -10,6 +10,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||||
| from timm.models.layers import DropPath, trunc_normal_ | ||||
| from timm.models.registry import register_model | ||||
| from timm.layers.helpers import to_2tuple | ||||
| from typing import * | ||||
| import math | ||||
|  | ||||
|  | ||||
| class ConvE(torch.nn.Module): | ||||
| @@ -466,10 +468,6 @@ class FouriER(torch.nn.Module): | ||||
|             self.p.ent_vec_dim, image_h*image_w) | ||||
|         torch.nn.init.xavier_normal_(self.ent_fusion.weight) | ||||
|  | ||||
|         self.ent_attn = torch.nn.Linear( | ||||
|             128, 128) | ||||
|         torch.nn.init.xavier_normal_(self.ent_attn.weight) | ||||
|  | ||||
|         self.rel_fusion = torch.nn.Linear( | ||||
|             self.p.rel_vec_dim, image_h*image_w) | ||||
|         torch.nn.init.xavier_normal_(self.rel_fusion.weight) | ||||
| @@ -530,6 +528,22 @@ class FouriER(torch.nn.Module): | ||||
|  | ||||
|         self.network = nn.ModuleList(network) | ||||
|         self.norm = norm_layer(embed_dims[-1]) | ||||
|         self.graph_type = 'Spatial' | ||||
|         N = (image_h // patch_size)**2 | ||||
|         if self.graph_type in ["Spatial", "Mixed"]: | ||||
|             # Create a range tensor of node indices | ||||
|             indices = torch.arange(N) | ||||
|             # Reshape the indices tensor to create a grid of row and column indices | ||||
|             row_indices = indices.view(-1, 1).expand(-1, N) | ||||
|             col_indices = indices.view(1, -1).expand(N, -1) | ||||
|             # Compute the adjacency matrix | ||||
|             row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N)) | ||||
|             row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N)) | ||||
|             graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float()) | ||||
|             graph = graph - torch.eye(N) | ||||
|             self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu | ||||
|         self.class_token = False | ||||
|         self.token_scale = False | ||||
|         self.head = nn.Linear( | ||||
|                 embed_dims[-1], num_classes) if num_classes > 0 \ | ||||
|                 else nn.Identity() | ||||
| @@ -547,19 +561,49 @@ class FouriER(torch.nn.Module): | ||||
|  | ||||
|     def forward_tokens(self, x): | ||||
|         outs = [] | ||||
|         B, C, H, W = x.shape | ||||
|         N = H*W | ||||
|         if self.graph_type in ["Semantic", "Mixed"]: | ||||
|             # Generate the semantic graph w.r.t. the cosine similarity between tokens | ||||
|             # Compute cosine similarity | ||||
|             if self.class_token: | ||||
|                 x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True) | ||||
|             else: | ||||
|                 x_normed = x / x.norm(dim=-1, keepdim=True) | ||||
|             x_cossim = x_normed @ x_normed.transpose(-1, -2) | ||||
|             threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1  | ||||
|             semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0) | ||||
|             if self.class_token: | ||||
|                 semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0) | ||||
|             else: | ||||
|                 semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0) | ||||
|          | ||||
|         if self.graph_type == "None": | ||||
|             graph = None | ||||
|         else: | ||||
|             if self.graph_type == "Spatial": | ||||
|                 graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device) | ||||
|             elif self.graph_type == "Semantic": | ||||
|                 graph = semantic_graph | ||||
|             elif self.graph_type == "Mixed": | ||||
|                 # Integrate the spatial graph and semantic graph | ||||
|                 spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device) | ||||
|                 graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float() | ||||
|              | ||||
|             # Symmetrically normalize the graph | ||||
|             degree = graph.sum(-1) # B, N | ||||
|             degree = torch.diag_embed(degree**(-1/2)) | ||||
|             graph = degree @ graph @ degree | ||||
|  | ||||
|         for idx, block in enumerate(self.network): | ||||
|             try: | ||||
|                 x = block(x, graph) | ||||
|             except: | ||||
|                 x = block(x) | ||||
|         # output only the features of last layer for image classification | ||||
|         return x | ||||
|  | ||||
|     def fuse_attention(self, s_embedding, l_embedding): | ||||
|         w1 = self.ent_attn(torch.tanh(s_embedding)) | ||||
|         w2 = self.ent_attn(torch.tanh(l_embedding)) | ||||
|         aff = F.softmax(torch.cat((w1,w2),1), 1) | ||||
|         en_embedding = aff[:,0].unsqueeze(1) * s_embedding + aff[:, 1].unsqueeze(1) * l_embedding | ||||
|         return en_embedding | ||||
|  | ||||
|     def forward(self, sub, rel, nt_rel, neg_ents, strategy='one_to_x'): | ||||
|     def forward(self, sub, rel, neg_ents, strategy='one_to_x'): | ||||
|         sub_emb = self.ent_fusion(self.ent_embed(sub)) | ||||
|         rel_emb = self.rel_fusion(self.rel_embed(rel)) | ||||
|         comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) | ||||
| @@ -568,17 +612,6 @@ class FouriER(torch.nn.Module): | ||||
|         z = self.forward_embeddings(y) | ||||
|         z = self.forward_tokens(z) | ||||
|         z = z.mean([-2, -1]) | ||||
|  | ||||
|         nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel)) | ||||
|         comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) | ||||
|         y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w) | ||||
|         y_1 = self.bn0(y_1) | ||||
|         z_1 = self.forward_embeddings(y_1) | ||||
|         z_1 = self.forward_tokens(z_1) | ||||
|         z_1 = z_1.mean([-2, -1]) | ||||
|  | ||||
|         z = self.fuse_attention(z, z_1) | ||||
|  | ||||
|         z = self.norm(z) | ||||
|         x = self.head(z) | ||||
|         x = self.hidden_drop(x) | ||||
| @@ -725,7 +758,7 @@ def basic_blocks(dim, index, layers, | ||||
|             use_layer_scale=use_layer_scale,  | ||||
|             layer_scale_init_value=layer_scale_init_value,  | ||||
|             )) | ||||
|     blocks = nn.Sequential(*blocks) | ||||
|     blocks = SeqModel(*blocks) | ||||
|  | ||||
|     return blocks | ||||
|  | ||||
| @@ -890,6 +923,279 @@ def window_reverse(windows, window_size, H, W): | ||||
|     x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) | ||||
|     return x | ||||
|  | ||||
| class SeqModel(nn.Sequential): | ||||
| 	def forward(self, *inputs): | ||||
| 		for module in self._modules.values(): | ||||
| 			if type(inputs) == tuple: | ||||
| 				inputs = module(*inputs) | ||||
| 			else: | ||||
| 				inputs = module(inputs) | ||||
| 		return inputs | ||||
|  | ||||
| def propagate(x: torch.Tensor, weight: torch.Tensor,  | ||||
|               index_kept: torch.Tensor, index_prop: torch.Tensor,  | ||||
|               standard: str = "None", alpha: Optional[float] = 0,  | ||||
|               token_scales: Optional[torch.Tensor] = None, | ||||
|               cls_token=True): | ||||
|     """ | ||||
|     Propagate tokens based on the selection results. | ||||
|     ================================================ | ||||
|     Args: | ||||
|         - x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token. | ||||
|          | ||||
|         - weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,  | ||||
|                                          excluding the [CLS] token. weight could be a pre-defined  | ||||
|                                          graph of the current feature map (by default) or the | ||||
|                                          attention map (need to manually modify the Block Module). | ||||
|                                           | ||||
|         - index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X | ||||
|          | ||||
|         - index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X | ||||
|          | ||||
|         - standard: str: the method applied to propagate the tokens, including "None", "Mean" and  | ||||
|                          "GraphProp" | ||||
|          | ||||
|         - alpha: float: the coefficient of propagated features | ||||
|          | ||||
|         - token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales | ||||
|                                         is None by default. If it is not None, then token_scales  | ||||
|                                         represents the scales of each token and should sum up to N. | ||||
|          | ||||
|     Return: | ||||
|         - x: Tensor([B, N-1-num_prop, C]): the feature map after propagation | ||||
|          | ||||
|         - weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation | ||||
|          | ||||
|         - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation | ||||
|     """ | ||||
|      | ||||
|     B, N, C = x.shape | ||||
|      | ||||
|     # Step 1: divide tokens | ||||
|     if cls_token: | ||||
|         x_cls = x[:, 0:1] # B, 1, C | ||||
|     x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C | ||||
|     x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C | ||||
|      | ||||
|     # Step 2: divide token_scales if it is not None | ||||
|     if token_scales is not None: | ||||
|         if cls_token: | ||||
|             token_scales_cls = token_scales[:, 0:1] # B, 1 | ||||
|         token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop | ||||
|         token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop | ||||
|      | ||||
|     # Step 3: propagate tokens | ||||
|     if standard == "None": | ||||
|         """ | ||||
|         No further propagation | ||||
|         """ | ||||
|         pass | ||||
|          | ||||
|     elif standard == "Mean": | ||||
|         """ | ||||
|         Calculate the mean of all the propagated tokens, | ||||
|         and concatenate the result token back to kept tokens. | ||||
|         """ | ||||
|         # naive average | ||||
|         x_prop = x_prop.mean(1, keepdim=True) # B, 1, C | ||||
|         # Concatenate the average token  | ||||
|         x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C | ||||
|              | ||||
|     elif standard == "GraphProp": | ||||
|         """ | ||||
|         Propagate all the propagated token to kept token | ||||
|         with respect to the weights and token scales. | ||||
|         """ | ||||
|         assert weight is not None, "The graph weight is needed for graph propagation" | ||||
|          | ||||
|         # Step 3.1: divide propagation weights. | ||||
|         if cls_token: | ||||
|             index_kept = index_kept - 1 # since weights do not include the [CLS] token | ||||
|             index_prop = index_prop - 1 # since weights do not include the [CLS] token | ||||
|             weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1 | ||||
|             weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop | ||||
|             weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop | ||||
|         else: | ||||
|             weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1 | ||||
|             weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop | ||||
|             weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop | ||||
|          | ||||
|         # Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens | ||||
|         # Simple implementation | ||||
|         x_prop = weight_prop @ x_prop # B, N-1-num_prop, C | ||||
|         x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C | ||||
|          | ||||
|         """ scatter_reduce implementation for batched inputs | ||||
|         # Get the non-zero values | ||||
|         non_zero_indices = torch.nonzero(weight_prop, as_tuple=True) | ||||
|         non_zero_values = weight_prop[non_zero_indices] | ||||
|          | ||||
|         # Sparse multiplication | ||||
|         batch_indices, row_indices, col_indices = non_zero_indices | ||||
|         sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :] | ||||
|         reduce_indices = batch_indices * x_kept.shape[1] + row_indices | ||||
|          | ||||
|         x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,  | ||||
|                                                       index=reduce_indices[:, None],  | ||||
|                                                       src=sparse_matmul,  | ||||
|                                                       reduce="sum", | ||||
|                                                       include_self=True) | ||||
|         x_kept = x_kept.reshape(B, -1, C) | ||||
|         """ | ||||
|          | ||||
|         # Step 3.3: calculate the scale of each token if token_scales is not None | ||||
|         if token_scales is not None: | ||||
|             if cls_token: | ||||
|                 token_scales_cls = token_scales[:, 0:1] # B, 1 | ||||
|                 token_scales = token_scales[:, 1:] | ||||
|             token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop | ||||
|             token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop | ||||
|             token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1 | ||||
|             token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop | ||||
|             if cls_token: | ||||
|                 token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop | ||||
|     else: | ||||
|         assert False, "Propagation method \'%f\' has not been supported yet." % standard | ||||
|      | ||||
|      | ||||
|     if cls_token: | ||||
|         # Step 4: concatenate the [CLS] token and generate returned value | ||||
|         x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C | ||||
|     else: | ||||
|         x = x_kept | ||||
|     return x, weight, token_scales | ||||
|  | ||||
|  | ||||
|  | ||||
| def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True): | ||||
|     """ | ||||
|     Select image tokens to be propagated. The [CLS] token will be ignored.  | ||||
|     ====================================================================== | ||||
|     Args: | ||||
|         - weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the | ||||
|                                         attention map of tokens at the moment. | ||||
|          | ||||
|         - standard: str: the method applied to select the tokens | ||||
|          | ||||
|         - num_prop: int: the number of tokens to be propagated | ||||
|          | ||||
|     Return: | ||||
|         - index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens  | ||||
|          | ||||
|         - index_prop: Tensor([B, num_prop]): the index of propagated tokens | ||||
|     """ | ||||
|      | ||||
|     assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet." | ||||
|     B, H, N1, N2 = weight.shape | ||||
|     assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet." | ||||
|     N = N1 | ||||
|     assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative." | ||||
|      | ||||
|     if cls_token: | ||||
|         if standard == "CLSAttnMean": | ||||
|             token_rank = weight[:,:,0,1:].mean(1) | ||||
|              | ||||
|         elif standard == "CLSAttnMax": | ||||
|             token_rank = weight[:,:,0,1:].max(1)[0] | ||||
|                  | ||||
|         elif standard == "IMGAttnMean": | ||||
|             token_rank = weight[:,:,:,1:].sum(-2).mean(1) | ||||
|          | ||||
|         elif standard == "IMGAttnMax": | ||||
|             token_rank = weight[:,:,:,1:].sum(-2).max(1)[0] | ||||
|                  | ||||
|         elif standard == "DiagAttnMean": | ||||
|             token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) | ||||
|              | ||||
|         elif standard == "DiagAttnMax": | ||||
|             token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] | ||||
|              | ||||
|         elif standard == "MixedAttnMean": | ||||
|             token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) | ||||
|             token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1) | ||||
|             token_rank = token_rank_1 * token_rank_2 | ||||
|              | ||||
|         elif standard == "MixedAttnMax": | ||||
|             token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] | ||||
|             token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] | ||||
|             token_rank = token_rank_1 * token_rank_2 | ||||
|              | ||||
|         elif standard == "SumAttnMax": | ||||
|             token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] | ||||
|             token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] | ||||
|             token_rank = token_rank_1 + token_rank_2 | ||||
|              | ||||
|         elif standard == "CosSimMean": | ||||
|             weight = weight[:,:,1:,:].mean(1) | ||||
|             weight = weight / weight.norm(dim=-1, keepdim=True) | ||||
|             token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) | ||||
|          | ||||
|         elif standard == "CosSimMax": | ||||
|             weight = weight[:,:,1:,:].max(1)[0] | ||||
|             weight = weight / weight.norm(dim=-1, keepdim=True) | ||||
|             token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) | ||||
|              | ||||
|         elif standard == "Random": | ||||
|             token_rank = torch.randn((B, N-1), device=weight.device) | ||||
|                  | ||||
|         else: | ||||
|             print("Type\'", standard, "\' selection not supported.") | ||||
|             assert False | ||||
|          | ||||
|         token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 | ||||
|         index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop | ||||
|         index_prop = token_rank[:, -num_prop:]+1 # B, num_prop | ||||
|              | ||||
|     else: | ||||
|         if standard == "IMGAttnMean": | ||||
|             token_rank = weight.sum(-2).mean(1) | ||||
|          | ||||
|         elif standard == "IMGAttnMax": | ||||
|             token_rank = weight.sum(-2).max(1)[0] | ||||
|                  | ||||
|         elif standard == "DiagAttnMean": | ||||
|             token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) | ||||
|              | ||||
|         elif standard == "DiagAttnMax": | ||||
|             token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] | ||||
|              | ||||
|         elif standard == "MixedAttnMean": | ||||
|             token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) | ||||
|             token_rank_2 = weight.sum(-2).mean(1) | ||||
|             token_rank = token_rank_1 * token_rank_2 | ||||
|              | ||||
|         elif standard == "MixedAttnMax": | ||||
|             token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] | ||||
|             token_rank_2 = weight.sum(-2).max(1)[0] | ||||
|             token_rank = token_rank_1 * token_rank_2 | ||||
|              | ||||
|         elif standard == "SumAttnMax": | ||||
|             token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] | ||||
|             token_rank_2 = weight.sum(-2).max(1)[0] | ||||
|             token_rank = token_rank_1 + token_rank_2 | ||||
|              | ||||
|         elif standard == "CosSimMean": | ||||
|             weight = weight.mean(1) | ||||
|             weight = weight / weight.norm(dim=-1, keepdim=True) | ||||
|             token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) | ||||
|          | ||||
|         elif standard == "CosSimMax": | ||||
|             weight = weight.max(1)[0] | ||||
|             weight = weight / weight.norm(dim=-1, keepdim=True) | ||||
|             token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) | ||||
|              | ||||
|         elif standard == "Random": | ||||
|             token_rank = torch.randn((B, N-1), device=weight.device) | ||||
|                  | ||||
|         else: | ||||
|             print("Type\'", standard, "\' selection not supported.") | ||||
|             assert False | ||||
|          | ||||
|         token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 | ||||
|         index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop | ||||
|         index_prop = token_rank[:, -num_prop:] # B, num_prop | ||||
|     return index_kept, index_prop | ||||
|  | ||||
| class PoolFormerBlock(nn.Module): | ||||
|     """ | ||||
|     Implementation of one PoolFormer block. | ||||
| @@ -932,13 +1238,20 @@ class PoolFormerBlock(nn.Module): | ||||
|             self.layer_scale_2 = nn.Parameter( | ||||
|                 layer_scale_init_value * torch.ones((dim)), requires_grad=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|     def forward(self, x, weight, token_scales = None): | ||||
|         B, C, H, W = x.shape | ||||
|         x_windows = window_partition(x, self.window_size) | ||||
|         x_windows = x_windows.view(-1, self.window_size * self.window_size, C) | ||||
|         attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) | ||||
|         attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) | ||||
|         x_attn = window_reverse(attn_windows, self.window_size, H, W) | ||||
|         index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, | ||||
|                                             cls_token=False) | ||||
|         original_shape = x_attn.shape | ||||
|         x_attn = x_attn.view(-1, self.window_size * self.window_size, C) | ||||
|         x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp", | ||||
|                                         alpha=0.1, token_scales=token_scales, cls_token=False) | ||||
|         x_attn = x_attn.view(*original_shape) | ||||
|         if self.use_layer_scale: | ||||
|             x = x + self.drop_path( | ||||
|                 self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user