16 Commits

Author SHA1 Message Date
47bc661a91 try gtp vit 2024-04-28 15:40:31 +07:00
3b6db89be1 try gtp vit 2024-04-28 15:35:24 +07:00
352f5f9da9 try gtp vit 2024-04-28 15:31:58 +07:00
b9273b6696 try gtp vit 2024-04-28 15:27:41 +07:00
d0e4630dd6 try gtp vit 2024-04-28 15:26:24 +07:00
08a3780ba6 try gtp vit 2024-04-28 15:25:44 +07:00
6fc56b920f try gtp vit 2024-04-28 15:24:05 +07:00
fddea4769f try gtp vit 2024-04-28 15:17:40 +07:00
d9209a7ef1 try gtp vit 2024-04-28 15:14:27 +07:00
0f986d7517 try gtp vit 2024-04-28 15:10:09 +07:00
4daa40527b try gtp vit 2024-04-28 15:08:08 +07:00
541c4fa2b3 try gtp vit 2024-04-28 12:05:57 +07:00
68a94bd1e2 try gtp vit 2024-04-28 12:05:04 +07:00
b01e504874 try gtp vit 2024-04-28 12:01:40 +07:00
23c44d3582 try gtp vit 2024-04-28 11:59:09 +07:00
41a5c7b05a try gtp vit 2024-04-28 11:57:17 +07:00
4 changed files with 364 additions and 308 deletions

View File

@ -12407,233 +12407,3 @@
12406 Carry out roadside bombing[65]
12407 Appeal for target to allow international involvement (non-mediation)[1]
12408 Reject request for change in leadership[179]
12409 Criticize or denounce
12410 Express intent to meet or negotiate
12411 Consult
12412 Make an appeal or request
12413 Abduct, hijack, or take hostage
12414 Praise or endorse
12415 Engage in negotiation
12416 Use unconventional violence
12417 Make statement
12418 Arrest, detain, or charge with legal action
12419 Use conventional military force
12420 Complain officially
12421 Impose administrative sanctions
12422 Express intent to cooperate
12423 Make a visit
12424 Appeal for de-escalation of military engagement
12425 Sign formal agreement
12426 Attempt to assassinate
12427 Host a visit
12428 Increase military alert status
12429 Impose embargo, boycott, or sanctions
12430 Provide economic aid
12431 Demonstrate or rally
12432 Express intent to engage in diplomatic cooperation (such as policy support)
12433 Appeal for intelligence
12434 Demand
12435 Carry out suicide bombing
12436 Threaten
12437 Express intent to provide material aid
12438 Grant diplomatic recognition
12439 Meet at a 'third' location
12440 Accuse
12441 Investigate
12442 Reject
12443 Appeal for diplomatic cooperation (such as policy support)
12444 Engage in symbolic act
12445 Defy norms, law
12446 Consider policy option
12447 Provide aid
12448 Sexually assault
12449 Make empathetic comment
12450 Bring lawsuit against
12451 Impose blockade, restrict movement
12452 Make pessimistic comment
12453 Protest violently, riot
12454 Reduce or break diplomatic relations
12455 Grant asylum
12456 Engage in diplomatic cooperation
12457 Make optimistic comment
12458 Torture
12459 Refuse to yield
12460 Appeal for change in leadership
12461 Cooperate militarily
12462 Mobilize or increase armed forces
12463 fight with small arms and light weapons
12464 Ease administrative sanctions
12465 Appeal for political reform
12466 Return, release person(s)
12467 Discuss by telephone
12468 Demonstrate for leadership change
12469 Impose restrictions on political freedoms
12470 Reduce relations
12471 Investigate crime, corruption
12472 Engage in material cooperation
12473 Appeal to others to meet or negotiate
12474 Provide humanitarian aid
12475 Use tactics of violent repression
12476 Occupy territory
12477 Demand humanitarian aid
12478 Threaten non-force
12479 Express intent to cooperate economically
12480 Conduct suicide, car, or other non-military bombing
12481 Demand diplomatic cooperation (such as policy support)
12482 Demand meeting, negotiation
12483 Deny responsibility
12484 Express intent to change institutions, regime
12485 Give ultimatum
12486 Appeal for judicial cooperation
12487 Rally support on behalf of
12488 Obstruct passage, block
12489 Share intelligence or information
12490 Expel or deport individuals
12491 Confiscate property
12492 Accuse of aggression
12493 Physically assault
12494 Retreat or surrender militarily
12495 Veto
12496 Kill by physical assault
12497 Assassinate
12498 Appeal for change in institutions, regime
12499 Forgive
12500 Reject proposal to meet, discuss, or negotiate
12501 Express intent to provide humanitarian aid
12502 Appeal for release of persons or property
12503 Acknowledge or claim responsibility
12504 Ease economic sanctions, boycott, embargo
12505 Express intent to cooperate militarily
12506 Cooperate economically
12507 Express intent to provide economic aid
12508 Mobilize or increase police power
12509 Employ aerial weapons
12510 Accuse of human rights abuses
12511 Conduct strike or boycott
12512 Appeal for policy change
12513 Demonstrate military or police power
12514 Provide military aid
12515 Reject plan, agreement to settle dispute
12516 Yield
12517 Appeal for easing of administrative sanctions
12518 Mediate
12519 Apologize
12520 Express intent to release persons or property
12521 Express intent to de-escalate military engagement
12522 Accede to demands for rights
12523 Demand economic aid
12524 Impose state of emergency or martial law
12525 Receive deployment of peacekeepers
12526 Demand de-escalation of military engagement
12527 Declare truce, ceasefire
12528 Reduce or stop humanitarian assistance
12529 Appeal to others to settle dispute
12530 Reject request for military aid
12531 Threaten with political dissent, protest
12532 Appeal to engage in or accept mediation
12533 Express intent to ease economic sanctions, boycott, or embargo
12534 Coerce
12535 fight with artillery and tanks
12536 Express intent to cooperate on intelligence
12537 Express intent to settle dispute
12538 Express accord
12539 Decline comment
12540 Rally opposition against
12541 Halt negotiations
12542 Demand that target yields
12543 Appeal for military aid
12544 Threaten with military force
12545 Express intent to provide military protection or peacekeeping
12546 Threaten with sanctions, boycott, embargo
12547 Express intent to provide military aid
12548 Demand change in leadership
12549 Appeal for economic aid
12550 Refuse to de-escalate military engagement
12551 Refuse to release persons or property
12552 Increase police alert status
12553 Return, release property
12554 Ease military blockade
12555 Appeal for material cooperation
12556 Express intent to cooperate on judicial matters
12557 Appeal for economic cooperation
12558 Demand settling of dispute
12559 Accuse of crime, corruption
12560 Defend verbally
12561 Provide military protection or peacekeeping
12562 Accuse of espionage, treason
12563 Seize or damage property
12564 Accede to requests or demands for political reform
12565 Appeal for easing of economic sanctions, boycott, or embargo
12566 Threaten to reduce or stop aid
12567 Engage in judicial cooperation
12568 Appeal to yield
12569 Demand military aid
12570 Refuse to ease administrative sanctions
12571 Demand release of persons or property
12572 Accede to demands for change in leadership
12573 Appeal for humanitarian aid
12574 Threaten with repression
12575 Demand change in institutions, regime
12576 Demonstrate for policy change
12577 Appeal for aid
12578 Appeal for rights
12579 Engage in violent protest for rights
12580 Express intent to mediate
12581 Expel or withdraw peacekeepers
12582 Appeal for military protection or peacekeeping
12583 Engage in mass killings
12584 Accuse of war crimes
12585 Reject military cooperation
12586 Threaten to halt negotiations
12587 Ban political parties or politicians
12588 Express intent to change leadership
12589 Demand material cooperation
12590 Express intent to institute political reform
12591 Demand easing of administrative sanctions
12592 Express intent to engage in material cooperation
12593 Reduce or stop economic assistance
12594 Express intent to ease administrative sanctions
12595 Demand intelligence cooperation
12596 Ease curfew
12597 Receive inspectors
12598 Demand rights
12599 Demand political reform
12600 Demand judicial cooperation
12601 Engage in political dissent
12602 Detonate nuclear weapons
12603 Violate ceasefire
12604 Express intent to accept mediation
12605 Refuse to ease economic sanctions, boycott, or embargo
12606 Demand mediation
12607 Obstruct passage to demand leadership change
12608 Express intent to yield
12609 Conduct hunger strike
12610 Threaten to halt mediation
12611 Reject judicial cooperation
12612 Reduce or stop military assistance
12613 Ease political dissent
12614 Threaten to reduce or break relations
12615 Demobilize armed forces
12616 Use as human shield
12617 Demand policy change
12618 Accede to demands for change in institutions, regime
12619 Reject economic cooperation
12620 Reject material cooperation
12621 Halt mediation
12622 Accede to demands for change in policy
12623 Investigate war crimes
12624 Threaten with administrative sanctions
12625 Reduce or stop material aid
12626 Destroy property
12627 Express intent to change policy
12628 Use chemical, biological, or radiological weapons
12629 Reject request for military protection or peacekeeping
12630 Demand material aid
12631 Engage in mass expulsion
12632 Investigate human rights abuses
12633 Carry out car bombing
12634 Expel or withdraw
12635 Ease state of emergency or martial law
12636 Carry out roadside bombing
12637 Appeal for target to allow international involvement (non-mediation)
12638 Reject request for change in leadership

View File

@ -421,27 +421,3 @@
420 P551[36-69]
421 P579[0-15]
422 P102[54-62]
423 P131
424 P1435
425 P39
426 P54
427 P31
428 P463
429 P512
430 P190
431 P150
432 P1376
433 P166
434 P2962
435 P108
436 P17
437 P793
438 P69
439 P26
440 P579
441 P1411
442 P6
443 P1346
444 P102
445 P27
446 P551

53
main.py
View File

@ -91,11 +91,9 @@ class Main(object):
for line in open('./data/{}/{}'.format(self.p.dataset, "relations.dict")):
id, rel = map(str.lower, line.strip().split('\t'))
self.rel2id[rel] = int(id)
rel_set.add(rel)
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
for idx, rel in enumerate(rel_set)})
@ -113,48 +111,47 @@ class Main(object):
for split in ['train', 'test', 'valid']:
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t'))
nt_rel = rel.split('[')[0]
sub, rel, obj, nt_rel = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], self.rel2id[nt_rel]
self.data[split].append((sub, rel, obj, nt_rel))
sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
self.data[split].append((sub, rel, obj))
if split == 'train':
sr2o[(sub, rel, nt_rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel)].add(sub)
self.data = dict(self.data)
self.sr2o = {k: list(v) for k, v in sr2o.items()}
for split in ['test', 'valid']:
for sub, rel, obj, nt_rel in self.data[split]:
sr2o[(sub, rel, nt_rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
for sub, rel, obj in self.data[split]:
sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel)].add(sub)
self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
self.triples = ddict(list)
if self.p.train_strategy == 'one_to_n':
for (sub, rel, nt_rel), obj in self.sr2o.items():
for (sub, rel), obj in self.sr2o.items():
self.triples['train'].append(
{'triple': (sub, rel, -1, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': 1})
{'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})
else:
for sub, rel, obj, nt_rel in self.data['train']:
for sub, rel, obj in self.data['train']:
rel_inv = rel + self.p.num_rel
sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)])
sub_samp = len(self.sr2o[(sub, rel)]) + \
len(self.sr2o[(obj, rel_inv)])
sub_samp = np.sqrt(1/sub_samp)
self.triples['train'].append({'triple': (
sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
sub, rel, obj), 'label': self.sr2o[(sub, rel)], 'sub_samp': sub_samp})
self.triples['train'].append({'triple': (
obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp})
obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp})
for split in ['test', 'valid']:
for sub, rel, obj, nt_rel in self.data[split]:
for sub, rel, obj in self.data[split]:
rel_inv = rel + self.p.num_rel
self.triples['{}_{}'.format(split, 'tail')].append(
{'triple': (sub, rel, obj, nt_rel), 'label': self.sr2o_all[(sub, rel, nt_rel)]})
{'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
self.triples['{}_{}'.format(split, 'head')].append(
{'triple': (obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o_all[(obj, rel_inv, nt_rel + self.p.num_rel)]})
{'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})
self.triples = dict(self.triples)
@ -278,13 +275,13 @@ class Main(object):
if self.p.train_strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [
_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp
else:
triple, label = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None
else:
triple, label = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label
return triple[:, 0], triple[:, 1], triple[:, 2], label
def save_model(self, save_path):
"""
@ -419,8 +416,8 @@ class Main(object):
obj_pred = []
obj_pred_score = []
for step, batch in enumerate(train_iter):
sub, rel, obj, nt_rel, label = self.read_batch(batch, split)
pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n')
sub, rel, obj, label = self.read_batch(batch, split)
pred = self.model.forward(sub, rel, None, 'one_to_n')
b_range = torch.arange(pred.size()[0], device=self.device)
target_pred = pred[b_range, obj]
pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
@ -477,10 +474,10 @@ class Main(object):
for step, batch in enumerate(train_iter):
self.optimizer.zero_grad()
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
sub, rel, obj, label, neg_ent, sub_samp = self.read_batch(
batch, 'train')
pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp)
loss.backward()
@ -693,7 +690,7 @@ if __name__ == "__main__":
collate_fn=TrainDataset.collate_fn
))
for step, batch in enumerate(dataloader):
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch(
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
batch, 'train')
if (neg_ent is None):

363
models.py
View File

@ -10,6 +10,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module):
@ -466,10 +468,6 @@ class FouriER(torch.nn.Module):
self.p.ent_vec_dim, image_h*image_w)
torch.nn.init.xavier_normal_(self.ent_fusion.weight)
self.ent_attn = torch.nn.Linear(
128, 128)
torch.nn.init.xavier_normal_(self.ent_attn.weight)
self.rel_fusion = torch.nn.Linear(
self.p.rel_vec_dim, image_h*image_w)
torch.nn.init.xavier_normal_(self.rel_fusion.weight)
@ -530,6 +528,22 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
@ -547,19 +561,49 @@ class FouriER(torch.nn.Module):
def forward_tokens(self, x):
outs = []
B, C, H, W = x.shape
N = H*W
if self.graph_type in ["Semantic", "Mixed"]:
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self.class_token:
x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
else:
x_normed = x / x.norm(dim=-1, keepdim=True)
x_cossim = x_normed @ x_normed.transpose(-1, -2)
threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1
semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
if self.class_token:
semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
else:
semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
if self.graph_type == "None":
graph = None
else:
if self.graph_type == "Spatial":
graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
elif self.graph_type == "Semantic":
graph = semantic_graph
elif self.graph_type == "Mixed":
# Integrate the spatial graph and semantic graph
spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
# Symmetrically normalize the graph
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network):
try:
x = block(x, graph)
except:
x = block(x)
# output only the features of last layer for image classification
return x
def fuse_attention(self, s_embedding, l_embedding):
w1 = self.ent_attn(torch.tanh(s_embedding))
w2 = self.ent_attn(torch.tanh(l_embedding))
aff = F.softmax(torch.cat((w1,w2),1), 1)
en_embedding = aff[:,0].unsqueeze(1) * s_embedding + aff[:, 1].unsqueeze(1) * l_embedding
return en_embedding
def forward(self, sub, rel, nt_rel, neg_ents, strategy='one_to_x'):
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_fusion(self.ent_embed(sub))
rel_emb = self.rel_fusion(self.rel_embed(rel))
comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
@ -568,17 +612,6 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y)
z = self.forward_tokens(z)
z = z.mean([-2, -1])
nt_rel_emb = self.rel_fusion(self.rel_embed(nt_rel))
comb_emb_1 = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), nt_rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
y_1 = comb_emb_1.view(-1, 2, self.p.image_h, self.p.image_w)
y_1 = self.bn0(y_1)
z_1 = self.forward_embeddings(y_1)
z_1 = self.forward_tokens(z_1)
z_1 = z_1.mean([-2, -1])
z = self.fuse_attention(z, z_1)
z = self.norm(z)
x = self.head(z)
x = self.hidden_drop(x)
@ -725,7 +758,7 @@ def basic_blocks(dim, index, layers,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
blocks = nn.Sequential(*blocks)
blocks = SeqModel(*blocks)
return blocks
@ -890,6 +923,279 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x
class SeqModel(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
def propagate(x: torch.Tensor, weight: torch.Tensor,
index_kept: torch.Tensor, index_prop: torch.Tensor,
standard: str = "None", alpha: Optional[float] = 0,
token_scales: Optional[torch.Tensor] = None,
cls_token=True):
"""
Propagate tokens based on the selection results.
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,
excluding the [CLS] token. weight could be a pre-defined
graph of the current feature map (by default) or the
attention map (need to manually modify the Block Module).
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
- standard: str: the method applied to propagate the tokens, including "None", "Mean" and
"GraphProp"
- alpha: float: the coefficient of propagated features
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
is None by default. If it is not None, then token_scales
represents the scales of each token and should sum up to N.
Return:
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
"""
B, N, C = x.shape
# Step 1: divide tokens
if cls_token:
x_cls = x[:, 0:1] # B, 1, C
x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
# Step 2: divide token_scales if it is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
# Step 3: propagate tokens
if standard == "None":
"""
No further propagation
"""
pass
elif standard == "Mean":
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop.mean(1, keepdim=True) # B, 1, C
# Concatenate the average token
x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C
elif standard == "GraphProp":
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None, "The graph weight is needed for graph propagation"
# Step 3.1: divide propagation weights.
if cls_token:
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
else:
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce="sum",
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales = token_scales[:, 1:]
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
if cls_token:
token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
else:
assert False, "Propagation method \'%f\' has not been supported yet." % standard
if cls_token:
# Step 4 concatenate the [CLS] token and generate returned value
x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
else:
x = x_kept
return x, weight, token_scales
def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True):
"""
Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
attention map of tokens at the moment.
- standard: str: the method applied to select the tokens
- num_prop: int: the number of tokens to be propagated
Return:
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
- index_prop: Tensor([B, num_prop]): the index of propagated tokens
"""
assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet."
B, H, N1, N2 = weight.shape
assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet."
N = N1
assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative."
if cls_token:
if standard == "CLSAttnMean":
token_rank = weight[:,:,0,1:].mean(1)
elif standard == "CLSAttnMax":
token_rank = weight[:,:,0,1:].max(1)[0]
elif standard == "IMGAttnMean":
token_rank = weight[:,:,:,1:].sum(-2).mean(1)
elif standard == "IMGAttnMax":
token_rank = weight[:,:,:,1:].sum(-2).max(1)[0]
elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
elif standard == "DiagAttnMax":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean":
weight = weight[:,:,1:,:].mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax":
weight = weight[:,:,1:,:].max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
else:
if standard == "IMGAttnMean":
token_rank = weight.sum(-2).mean(1)
elif standard == "IMGAttnMax":
token_rank = weight.sum(-2).max(1)[0]
elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
elif standard == "DiagAttnMax":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
token_rank_2 = weight.sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean":
weight = weight.mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax":
weight = weight.max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:] # B, num_prop
return index_kept, index_prop
class PoolFormerBlock(nn.Module):
"""
Implementation of one PoolFormer block.
@ -932,13 +1238,20 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x):
def forward(self, x, weight, token_scales = None):
B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W)
index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
cls_token=False)
original_shape = x_attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)