@@ -9,7 +9,9 @@ from layers import *
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from timm . models . layers import DropPath , trunc_normal_
from timm . models . registry import register_model
from timm . models . layers . helpers import to_2tuple
from timm . layers . helpers import to_2tuple
from typing import *
import math
class ConvE ( torch . nn . Module ) :
@@ -526,6 +528,22 @@ class FouriER(torch.nn.Module):
self . network = nn . ModuleList ( network )
self . norm = norm_layer ( embed_dims [ - 1 ] )
self . graph_type = ' Spatial '
N = ( image_h / / patch_size ) * * 2
if self . graph_type in [ " Spatial " , " Mixed " ] :
# Create a range tensor of node indices
indices = torch . arange ( N )
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices . view ( - 1 , 1 ) . expand ( - 1 , N )
col_indices = indices . view ( 1 , - 1 ) . expand ( N , - 1 )
# Compute the adjacency matrix
row1 , col1 = row_indices / / int ( math . sqrt ( N ) ) , row_indices % int ( math . sqrt ( N ) )
row2 , col2 = col_indices / / int ( math . sqrt ( N ) ) , col_indices % int ( math . sqrt ( N ) )
graph = ( ( abs ( row1 - row2 ) < = 1 ) . float ( ) * ( abs ( col1 - col2 ) < = 1 ) . float ( ) )
graph = graph - torch . eye ( N )
self . spatial_graph = graph . cuda ( ) # comment .to("cuda") if the environment is cpu
self . class_token = False
self . token_scale = False
self . head = nn . Linear (
embed_dims [ - 1 ] , num_classes ) if num_classes > 0 \
else nn . Identity ( )
@@ -543,8 +561,45 @@ class FouriER(torch.nn.Module):
def forward_tokens ( self , x ) :
outs = [ ]
B , C , H , W = x . shape
N = H * W
if self . graph_type in [ " Semantic " , " Mixed " ] :
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self . class_token :
x_normed = x [ : , 1 : ] / x [ : , 1 : ] . norm ( dim = - 1 , keepdim = True )
else :
x_normed = x / x . norm ( dim = - 1 , keepdim = True )
x_cossim = x_normed @ x_normed . transpose ( - 1 , - 2 )
threshold = torch . kthvalue ( x_cossim , N - 1 - self . num_neighbours , dim = - 1 , keepdim = True ) [ 0 ] # B,H,1,1
semantic_graph = torch . where ( x_cossim > = threshold , 1.0 , 0.0 )
if self . class_token :
semantic_graph = semantic_graph - torch . eye ( N - 1 , device = semantic_graph . device ) . unsqueeze ( 0 )
else :
semantic_graph = semantic_graph - torch . eye ( N , device = semantic_graph . device ) . unsqueeze ( 0 )
if self . graph_type == " None " :
graph = None
else :
if self . graph_type == " Spatial " :
graph = self . spatial_graph . unsqueeze ( 0 ) . expand ( B , - 1 , - 1 ) #.to(x.device)
elif self . graph_type == " Semantic " :
graph = semantic_graph
elif self . graph_type == " Mixed " :
# Integrate the spatial graph and semantic graph
spatial_graph = self . spatial_graph . unsqueeze ( 0 ) . expand ( B , - 1 , - 1 ) . to ( x . device )
graph = torch . bitwise_or ( semantic_graph . int ( ) , spatial_graph . int ( ) ) . float ( )
# Symmetrically normalize the graph
degree = graph . sum ( - 1 ) # B, N
degree = torch . diag_embed ( degree * * ( - 1 / 2 ) )
graph = degree @ graph @ degree
for idx , block in enumerate ( self . network ) :
x = block ( x )
try :
x = block ( x , graph )
except :
x = block ( x )
# output only the features of last layer for image classification
return x
@@ -703,10 +758,443 @@ def basic_blocks(dim, index, layers,
use_layer_scale = use_layer_scale ,
layer_scale_init_value = layer_scale_init_value ,
) )
blocks = nn . Sequential ( * blocks )
blocks = SeqModel ( * blocks )
return blocks
def window_partition ( x , window_size ) :
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B , C , H , W = x . shape
x = x . view ( B , H / / window_size , window_size , W / / window_size , window_size , C )
windows = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 ) . contiguous ( ) . view ( - 1 , window_size , window_size , C )
return windows
class WindowAttention ( nn . Module ) :
r """ Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__ ( self , dim , window_size , num_heads , qkv_bias = True , attn_drop = 0. , proj_drop = 0. ,
pretrained_window_size = [ 0 , 0 ] ) :
super ( ) . __init__ ( )
self . dim = dim
self . window_size = window_size # Wh, Ww
self . pretrained_window_size = pretrained_window_size
self . num_heads = num_heads
self . logit_scale = nn . Parameter ( torch . log ( 10 * torch . ones ( ( num_heads , 1 , 1 ) ) ) , requires_grad = True )
# mlp to generate continuous relative position bias
self . cpb_mlp = nn . Sequential ( nn . Linear ( 2 , 512 , bias = True ) ,
nn . ReLU ( inplace = True ) ,
nn . Linear ( 512 , num_heads , bias = False ) )
# get relative_coords_table
relative_coords_h = torch . arange ( - ( self . window_size [ 0 ] - 1 ) , self . window_size [ 0 ] , dtype = torch . float32 )
relative_coords_w = torch . arange ( - ( self . window_size [ 1 ] - 1 ) , self . window_size [ 1 ] , dtype = torch . float32 )
relative_coords_table = torch . stack (
torch . meshgrid ( [ relative_coords_h ,
relative_coords_w ] ) ) . permute ( 1 , 2 , 0 ) . contiguous ( ) . unsqueeze ( 0 ) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size [ 0 ] > 0 :
relative_coords_table [ : , : , : , 0 ] / = ( pretrained_window_size [ 0 ] - 1 )
relative_coords_table [ : , : , : , 1 ] / = ( pretrained_window_size [ 1 ] - 1 )
else :
relative_coords_table [ : , : , : , 0 ] / = ( self . window_size [ 0 ] - 1 )
relative_coords_table [ : , : , : , 1 ] / = ( self . window_size [ 1 ] - 1 )
relative_coords_table * = 8 # normalize to -8, 8
relative_coords_table = torch . sign ( relative_coords_table ) * torch . log2 (
torch . abs ( relative_coords_table ) + 1.0 ) / np . log2 ( 8 )
self . register_buffer ( " relative_coords_table " , relative_coords_table )
# get pair-wise relative position index for each token inside the window
coords_h = torch . arange ( self . window_size [ 0 ] )
coords_w = torch . arange ( self . window_size [ 1 ] )
coords = torch . stack ( torch . meshgrid ( [ coords_h , coords_w ] ) ) # 2, Wh, Ww
coords_flatten = torch . flatten ( coords , 1 ) # 2, Wh*Ww
relative_coords = coords_flatten [ : , : , None ] - coords_flatten [ : , None , : ] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords . permute ( 1 , 2 , 0 ) . contiguous ( ) # Wh*Ww, Wh*Ww, 2
relative_coords [ : , : , 0 ] + = self . window_size [ 0 ] - 1 # shift to start from 0
relative_coords [ : , : , 1 ] + = self . window_size [ 1 ] - 1
relative_coords [ : , : , 0 ] * = 2 * self . window_size [ 1 ] - 1
relative_position_index = relative_coords . sum ( - 1 ) # Wh*Ww, Wh*Ww
self . register_buffer ( " relative_position_index " , relative_position_index )
self . qkv = nn . Linear ( dim , dim * 3 , bias = False )
if qkv_bias :
self . q_bias = nn . Parameter ( torch . zeros ( dim ) )
self . v_bias = nn . Parameter ( torch . zeros ( dim ) )
else :
self . q_bias = None
self . v_bias = None
self . attn_drop = nn . Dropout ( attn_drop )
self . proj = nn . Linear ( dim , dim )
self . proj_drop = nn . Dropout ( proj_drop )
self . softmax = nn . Softmax ( dim = - 1 )
def forward ( self , x , mask = None ) :
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_ , N , C = x . shape
qkv_bias = None
if self . q_bias is not None :
qkv_bias = torch . cat ( ( self . q_bias , torch . zeros_like ( self . v_bias , requires_grad = False ) , self . v_bias ) )
qkv = F . linear ( input = x , weight = self . qkv . weight , bias = qkv_bias )
qkv = qkv . reshape ( B_ , N , 3 , self . num_heads , - 1 ) . permute ( 2 , 0 , 3 , 1 , 4 )
q , k , v = qkv [ 0 ] , qkv [ 1 ] , qkv [ 2 ] # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = ( F . normalize ( q , dim = - 1 ) @ F . normalize ( k , dim = - 1 ) . transpose ( - 2 , - 1 ) )
logit_scale = torch . clamp ( self . logit_scale , max = torch . log ( torch . tensor ( 1. / 0.01 ) ) . cuda ( ) ) . exp ( )
attn = attn * logit_scale
relative_position_bias_table = self . cpb_mlp ( self . relative_coords_table ) . view ( - 1 , self . num_heads )
relative_position_bias = relative_position_bias_table [ self . relative_position_index . view ( - 1 ) ] . view (
self . window_size [ 0 ] * self . window_size [ 1 ] , self . window_size [ 0 ] * self . window_size [ 1 ] , - 1 ) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias . permute ( 2 , 0 , 1 ) . contiguous ( ) # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch . sigmoid ( relative_position_bias )
attn = attn + relative_position_bias . unsqueeze ( 0 )
if mask is not None :
nW = mask . shape [ 0 ]
attn = attn . view ( B_ / / nW , nW , self . num_heads , N , N ) + mask . unsqueeze ( 1 ) . unsqueeze ( 0 )
attn = attn . view ( - 1 , self . num_heads , N , N )
attn = self . softmax ( attn )
else :
attn = self . softmax ( attn )
attn = self . attn_drop ( attn )
x = ( attn @ v ) . transpose ( 1 , 2 ) . reshape ( B_ , N , C )
x = self . proj ( x )
x = self . proj_drop ( x )
return x
def extra_repr ( self ) - > str :
return f ' dim= { self . dim } , window_size= { self . window_size } , ' \
f ' pretrained_window_size= { self . pretrained_window_size } , num_heads= { self . num_heads } '
def flops ( self , N ) :
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops + = N * self . dim * 3 * self . dim
# attn = (q @ k.transpose(-2, -1))
flops + = self . num_heads * N * ( self . dim / / self . num_heads ) * N
# x = (attn @ v)
flops + = self . num_heads * N * N * ( self . dim / / self . num_heads )
# x = self.proj(x)
flops + = N * self . dim * self . dim
return flops
def window_reverse ( windows , window_size , H , W ) :
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int ( windows . shape [ 0 ] / ( H * W / window_size / window_size ) )
x = windows . view ( B , H / / window_size , W / / window_size , window_size , window_size , - 1 )
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 ) . contiguous ( ) . view ( B , - 1 , H , W )
return x
class SeqModel ( nn . Sequential ) :
def forward ( self , * inputs ) :
for module in self . _modules . values ( ) :
if type ( inputs ) == tuple :
inputs = module ( * inputs )
else :
inputs = module ( inputs )
return inputs
def propagate ( x : torch . Tensor , weight : torch . Tensor ,
index_kept : torch . Tensor , index_prop : torch . Tensor ,
standard : str = " None " , alpha : Optional [ float ] = 0 ,
token_scales : Optional [ torch . Tensor ] = None ,
cls_token = True ) :
"""
Propagate tokens based on the selection results.
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,
excluding the [CLS] token. weight could be a pre-defined
graph of the current feature map (by default) or the
attention map (need to manually modify the Block Module).
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
- standard: str: the method applied to propagate the tokens, including " None " , " Mean " and
" GraphProp "
- alpha: float: the coefficient of propagated features
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
is None by default. If it is not None, then token_scales
represents the scales of each token and should sum up to N.
Return:
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
"""
B , N , C = x . shape
# Step 1: divide tokens
if cls_token :
x_cls = x [ : , 0 : 1 ] # B, 1, C
x_kept = x . gather ( dim = 1 , index = index_kept . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , C ) ) # B, N-1-num_prop, C
x_prop = x . gather ( dim = 1 , index = index_prop . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , C ) ) # B, num_prop, C
# Step 2: divide token_scales if it is not None
if token_scales is not None :
if cls_token :
token_scales_cls = token_scales [ : , 0 : 1 ] # B, 1
token_scales_kept = token_scales . gather ( dim = 1 , index = index_kept ) # B, N-1-num_prop
token_scales_prop = token_scales . gather ( dim = 1 , index = index_prop ) # B, num_prop
# Step 3: propagate tokens
if standard == " None " :
"""
No further propagation
"""
pass
elif standard == " Mean " :
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop . mean ( 1 , keepdim = True ) # B, 1, C
# Concatenate the average token
x_kept = torch . cat ( ( x_kept , x_prop ) , dim = 1 ) # B, N-num_prop, C
elif standard == " GraphProp " :
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None , " The graph weight is needed for graph propagation "
# Step 3.1: divide propagation weights.
if cls_token :
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight . gather ( dim = 1 , index = index_kept . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , N - 1 ) ) # B, N-1-num_prop, N-1
weight_prop = weight . gather ( dim = 2 , index = index_prop . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, num_prop
weight = weight . gather ( dim = 2 , index = index_kept . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, N-1-num_prop
else :
weight = weight . gather ( dim = 1 , index = index_kept . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , N ) ) # B, N-1-num_prop, N-1
weight_prop = weight . gather ( dim = 2 , index = index_prop . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, num_prop
weight = weight . gather ( dim = 2 , index = index_kept . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce= " sum " ,
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None
if token_scales is not None :
if cls_token :
token_scales_cls = token_scales [ : , 0 : 1 ] # B, 1
token_scales = token_scales [ : , 1 : ]
token_scales_kept = token_scales . gather ( dim = 1 , index = index_kept ) # B, N-1-num_prop
token_scales_prop = token_scales . gather ( dim = 1 , index = index_prop ) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop . unsqueeze ( - 1 ) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop . squeeze ( - 1 ) # B, N-1-num_prop
if cls_token :
token_scales = torch . cat ( ( token_scales_cls , token_scales ) , dim = 1 ) # B, N-num_prop
else :
assert False , " Propagation method \' %f \' has not been supported yet. " % standard
if cls_token :
# Step 4: concatenate the [CLS] token and generate returned value
x = torch . cat ( ( x_cls , x_kept ) , dim = 1 ) # B, N-num_prop, C
else :
x = x_kept
return x , weight , token_scales
def select ( weight : torch . Tensor , standard : str = " None " , num_prop : int = 0 , cls_token = True ) :
"""
Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
attention map of tokens at the moment.
- standard: str: the method applied to select the tokens
- num_prop: int: the number of tokens to be propagated
Return:
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
- index_prop: Tensor([B, num_prop]): the index of propagated tokens
"""
assert len ( weight . shape ) == 4 , " Selection methods on tensors other than the attention map haven ' t been supported yet. "
B , H , N1 , N2 = weight . shape
assert N1 == N2 , " Selection methods on tensors other than the attention map haven ' t been supported yet. "
N = N1
assert num_prop > = 0 , " The number of propagated/pruned tokens must be non-negative. "
if cls_token :
if standard == " CLSAttnMean " :
token_rank = weight [ : , : , 0 , 1 : ] . mean ( 1 )
elif standard == " CLSAttnMax " :
token_rank = weight [ : , : , 0 , 1 : ] . max ( 1 ) [ 0 ]
elif standard == " IMGAttnMean " :
token_rank = weight [ : , : , : , 1 : ] . sum ( - 2 ) . mean ( 1 )
elif standard == " IMGAttnMax " :
token_rank = weight [ : , : , : , 1 : ] . sum ( - 2 ) . max ( 1 ) [ 0 ]
elif standard == " DiagAttnMean " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . mean ( 1 )
elif standard == " DiagAttnMax " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . max ( 1 ) [ 0 ]
elif standard == " MixedAttnMean " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . mean ( 1 )
token_rank_2 = weight [ : , : , : , 1 : ] . sum ( - 2 ) . mean ( 1 )
token_rank = token_rank_1 * token_rank_2
elif standard == " MixedAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . max ( 1 ) [ 0 ]
token_rank_2 = weight [ : , : , : , 1 : ] . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 * token_rank_2
elif standard == " SumAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . max ( 1 ) [ 0 ]
token_rank_2 = weight [ : , : , : , 1 : ] . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 + token_rank_2
elif standard == " CosSimMean " :
weight = weight [ : , : , 1 : , : ] . mean ( 1 )
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " CosSimMax " :
weight = weight [ : , : , 1 : , : ] . max ( 1 ) [ 0 ]
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " Random " :
token_rank = torch . randn ( ( B , N - 1 ) , device = weight . device )
else :
print ( " Type \' " , standard , " \' selection not supported. " )
assert False
token_rank = torch . argsort ( token_rank , dim = 1 , descending = True ) # B, N-1
index_kept = token_rank [ : , : - num_prop ] + 1 # B, N-1-num_prop
index_prop = token_rank [ : , - num_prop : ] + 1 # B, num_prop
else :
if standard == " IMGAttnMean " :
token_rank = weight . sum ( - 2 ) . mean ( 1 )
elif standard == " IMGAttnMax " :
token_rank = weight . sum ( - 2 ) . max ( 1 ) [ 0 ]
elif standard == " DiagAttnMean " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . mean ( 1 )
elif standard == " DiagAttnMax " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . max ( 1 ) [ 0 ]
elif standard == " MixedAttnMean " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . mean ( 1 )
token_rank_2 = weight . sum ( - 2 ) . mean ( 1 )
token_rank = token_rank_1 * token_rank_2
elif standard == " MixedAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . max ( 1 ) [ 0 ]
token_rank_2 = weight . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 * token_rank_2
elif standard == " SumAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . max ( 1 ) [ 0 ]
token_rank_2 = weight . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 + token_rank_2
elif standard == " CosSimMean " :
weight = weight . mean ( 1 )
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " CosSimMax " :
weight = weight . max ( 1 ) [ 0 ]
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " Random " :
token_rank = torch . randn ( ( B , N - 1 ) , device = weight . device )
else :
print ( " Type \' " , standard , " \' selection not supported. " )
assert False
token_rank = torch . argsort ( token_rank , dim = 1 , descending = True ) # B, N-1
index_kept = token_rank [ : , : - num_prop ] # B, N-1-num_prop
index_prop = token_rank [ : , - num_prop : ] # B, num_prop
return index_kept , index_prop
class PoolFormerBlock ( nn . Module ) :
"""
@@ -731,7 +1219,10 @@ class PoolFormerBlock(nn.Module):
self . norm1 = norm_layer ( dim )
#self.token_mixer = Pooling(pool_size=pool_size)
self . token_mixer = FNetBlock ( )
# self.token_mixer = FNetBlock()
self . window_size = 4
self . attn_mask = None
self . token_mixer = WindowAttention ( dim = dim , window_size = to_2tuple ( self . window_size ) , num_heads = 4 )
self . norm2 = norm_layer ( dim )
mlp_hidden_dim = int ( dim * mlp_ratio )
self . mlp = Mlp ( in_features = dim , hidden_features = mlp_hidden_dim ,
@@ -747,16 +1238,29 @@ class PoolFormerBlock(nn.Module):
self . layer_scale_2 = nn . Parameter (
layer_scale_init_value * torch . ones ( ( dim ) ) , requires_grad = True )
def forward ( self , x ) :
def forward ( self , x , weight , token_scales = None ) :
B , C , H , W = x . shape
x_windows = window_partition ( x , self . window_size )
x_windows = x_windows . view ( - 1 , self . window_size * self . window_size , C )
attn_windows = self . token_mixer ( x_windows , mask = self . attn_mask )
attn_windows = attn_windows . view ( - 1 , self . window_size , self . window_size , C )
x_attn = window_reverse ( attn_windows , self . window_size , H , W )
index_kept , index_prop = select ( x_attn , standard = " MixedAttnMax " , num_prop = 0 ,
cls_token = False )
original_shape = x_attn . shape
x_attn = x_attn . view ( - 1 , self . window_size * self . window_size , C )
x_attn , weight , token_scales = propagate ( x_attn , weight , index_kept , index_prop , standard = " GraphProp " ,
alpha = 0.1 , token_scales = token_scales , cls_token = False )
x_attn = x_attn . view ( * original_shape )
if self . use_layer_scale :
x = x + self . drop_path (
self . layer_scale_1 . unsqueeze ( - 1 ) . unsqueeze ( - 1 )
* self . token_mixer ( self . norm1 ( x ) ) )
* x_attn )
x = x + self . drop_path (
self . layer_scale_2 . unsqueeze ( - 1 ) . unsqueeze ( - 1 )
* self . mlp ( self . norm2 ( x ) ) )
else :
x = x + self . drop_path ( self . token_mixer ( self . norm1 ( x ) ) )
x = x + self . drop_path ( x_attn )
x = x + self . drop_path ( self . mlp ( self . norm2 ( x ) ) )
return x
class PatchEmbed ( nn . Module ) :