@@ -1,17 +1,16 @@
import torch
from torch import nn
from torch import nn , einsum
import torch . nn . functional as F
import numpy as np
from functools import partial
from einops . layers . torch import Rearrange , Reduce
from einops import rearrange , repeat
from utils import *
from layers import *
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from timm . models . layers import DropPath , trunc_normal_
from timm . models . registry import register_model
from timm . layers . helpers import to_2tuple
from typing import *
import math
class ConvE ( torch . nn . Module ) :
@@ -528,22 +527,6 @@ class FouriER(torch.nn.Module):
self . network = nn . ModuleList ( network )
self . norm = norm_layer ( embed_dims [ - 1 ] )
self . graph_type = ' Spatial '
N = ( image_h / / patch_size ) * * 2
if self . graph_type in [ " Spatial " , " Mixed " ] :
# Create a range tensor of node indices
indices = torch . arange ( N )
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices . view ( - 1 , 1 ) . expand ( - 1 , N )
col_indices = indices . view ( 1 , - 1 ) . expand ( N , - 1 )
# Compute the adjacency matrix
row1 , col1 = row_indices / / int ( math . sqrt ( N ) ) , row_indices % int ( math . sqrt ( N ) )
row2 , col2 = col_indices / / int ( math . sqrt ( N ) ) , col_indices % int ( math . sqrt ( N ) )
graph = ( ( abs ( row1 - row2 ) < = 1 ) . float ( ) * ( abs ( col1 - col2 ) < = 1 ) . float ( ) )
graph = graph - torch . eye ( N )
self . spatial_graph = graph . cuda ( ) # comment .to("cuda") if the environment is cpu
self . class_token = False
self . token_scale = False
self . head = nn . Linear (
embed_dims [ - 1 ] , num_classes ) if num_classes > 0 \
else nn . Identity ( )
@@ -561,45 +544,8 @@ class FouriER(torch.nn.Module):
def forward_tokens ( self , x ) :
outs = [ ]
B , C , H , W = x . shape
N = H * W
if self . graph_type in [ " Semantic " , " Mixed " ] :
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self . class_token :
x_normed = x [ : , 1 : ] / x [ : , 1 : ] . norm ( dim = - 1 , keepdim = True )
else :
x_normed = x / x . norm ( dim = - 1 , keepdim = True )
x_cossim = x_normed @ x_normed . transpose ( - 1 , - 2 )
threshold = torch . kthvalue ( x_cossim , N - 1 - self . num_neighbours , dim = - 1 , keepdim = True ) [ 0 ] # B,H,1,1
semantic_graph = torch . where ( x_cossim > = threshold , 1.0 , 0.0 )
if self . class_token :
semantic_graph = semantic_graph - torch . eye ( N - 1 , device = semantic_graph . device ) . unsqueeze ( 0 )
else :
semantic_graph = semantic_graph - torch . eye ( N , device = semantic_graph . device ) . unsqueeze ( 0 )
if self . graph_type == " None " :
graph = None
else :
if self . graph_type == " Spatial " :
graph = self . spatial_graph . unsqueeze ( 0 ) . expand ( B , - 1 , - 1 ) #.to(x.device)
elif self . graph_type == " Semantic " :
graph = semantic_graph
elif self . graph_type == " Mixed " :
# Integrate the spatial graph and semantic graph
spatial_graph = self . spatial_graph . unsqueeze ( 0 ) . expand ( B , - 1 , - 1 ) . to ( x . device )
graph = torch . bitwise_or ( semantic_graph . int ( ) , spatial_graph . int ( ) ) . float ( )
# Symmetrically normalize the graph
degree = graph . sum ( - 1 ) # B, N
degree = torch . diag_embed ( degree * * ( - 1 / 2 ) )
graph = degree @ graph @ degree
for idx , block in enumerate ( self . network ) :
try :
x = block ( x , graph )
except :
x = block ( x )
x = block ( x )
# output only the features of last layer for image classification
return x
@@ -612,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self . forward_embeddings ( y )
z = self . forward_tokens ( z )
z = z . mean ( [ - 2 , - 1 ] )
if np . count_nonzero ( np . isnan ( z ) ) > 0 :
print ( " ZZZ " )
z = self . norm ( z )
x = self . head ( z )
x = self . hidden_drop ( x )
@@ -758,7 +706,7 @@ def basic_blocks(dim, index, layers,
use_layer_scale = use_layer_scale ,
layer_scale_init_value = layer_scale_init_value ,
) )
blocks = SeqModel ( * blocks )
blocks = nn . Sequential ( * blocks )
return blocks
@@ -923,278 +871,202 @@ def window_reverse(windows, window_size, H, W):
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 ) . contiguous ( ) . view ( B , - 1 , H , W )
return x
class SeqModel ( nn . Sequential ) :
def forward ( self , * inputs ) :
for module in self . _modules . values ( ) :
if type ( inputs ) == tuple :
inputs = module ( * inputs )
else :
inputs = module ( inputs )
return inputs
def cast_tuple ( val , length = 1 ) :
return val if isinstance ( val , tuple ) else ( ( val , ) * length )
def propagate ( x : torch . Tensor , weight : torch . Tensor ,
index_kept : torch . Tensor , index_prop : torch . Tensor ,
standard : str = " None " , alpha : Optional [ float ] = 0 ,
token_scales : Optional [ torch . Tensor ] = None ,
cls_token = True ) :
"""
Propagate tokens based on the selection results.
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,
excluding the [CLS] token. weight could be a pre-defined
graph of the current feature map (by default) or the
attention map (need to manually modify the Block Module).
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
- standard: str: the method applied to propagate the tokens, including " None " , " Mean " and
" GraphProp "
- alpha: float: the coefficient of propagated features
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
is None by default. If it is not None, then token_scales
represents the scales of each token and should sum up to N.
Return:
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
"""
B , N , C = x . shape
# Step 1: divide tokens
if cls_token :
x_cls = x [ : , 0 : 1 ] # B, 1, C
x_kept = x . gather ( dim = 1 , index = index_kept . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , C ) ) # B, N-1-num_prop, C
x_prop = x . gather ( dim = 1 , index = index_prop . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , C ) ) # B, num_prop, C
# Step 2: divide token_scales if it is not None
if token_scales is not None :
if cls_token :
token_scales_cls = token_scales [ : , 0 : 1 ] # B, 1
token_scales_kept = token_scales . gather ( dim = 1 , index = index_kept ) # B, N-1-num_prop
token_scales_prop = token_scales . gather ( dim = 1 , index = index_prop ) # B, num_prop
# Step 3: propagate tokens
if standard == " None " :
"""
No further propagation
"""
pass
elif standard == " Mean " :
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop . mean ( 1 , keepdim = True ) # B, 1, C
# Concatenate the average token
x_kept = torch . cat ( ( x_kept , x_prop ) , dim = 1 ) # B, N-num_prop, C
elif standard == " GraphProp " :
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None , " The graph weight is needed for graph propagation "
# Step 3.1: divide propagation weights.
if cls_token :
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight . gather ( dim = 1 , index = index_kept . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , N - 1 ) ) # B, N-1-num_prop, N-1
weight_prop = weight . gather ( dim = 2 , index = index_prop . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, num_prop
weight = weight . gather ( dim = 2 , index = index_kept . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, N-1-num_prop
else :
weight = weight . gather ( dim = 1 , index = index_kept . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , N ) ) # B, N-1-num_prop, N-1
weight_prop = weight . gather ( dim = 2 , index = index_prop . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, num_prop
weight = weight . gather ( dim = 2 , index = index_kept . unsqueeze ( 1 ) . expand ( - 1 , weight . shape [ 1 ] , - 1 ) ) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce= " sum " ,
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None
if token_scales is not None :
if cls_token :
token_scales_cls = token_scales [ : , 0 : 1 ] # B, 1
token_scales = token_scales [ : , 1 : ]
token_scales_kept = token_scales . gather ( dim = 1 , index = index_kept ) # B, N-1-num_prop
token_scales_prop = token_scales . gather ( dim = 1 , index = index_prop ) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop . unsqueeze ( - 1 ) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop . squeeze ( - 1 ) # B, N-1-num_prop
if cls_token :
token_scales = torch . cat ( ( token_scales_cls , token_scales ) , dim = 1 ) # B, N-num_prop
else :
assert False , " Propagation method \' %f \' has not been supported yet. " % standard
if cls_token :
# Step 4: concatenate the [CLS] token and generate returned value
x = torch . cat ( ( x_cls , x_kept ) , dim = 1 ) # B, N-num_prop, C
else :
x = x_kept
return x , weight , token_scales
# helper classes
class ChanLayerNorm ( nn . Module ) :
def __init__ ( self , dim , eps = 1e-5 ) :
super ( ) . __init__ ( )
self . eps = eps
self . g = nn . Parameter ( torch . ones ( 1 , dim , 1 , 1 ) )
self . b = nn . Parameter ( torch . zeros ( 1 , dim , 1 , 1 ) )
def forward ( self , x ) :
var = torch . var ( x , dim = 1 , unbiased = False , keepdim = True )
mean = torch . mean ( x , dim = 1 , keepdim = True )
return ( x - mean ) / ( var + self . eps ) . sqrt ( ) * self . g + self . b
def select ( weight : torch . Tensor , standard : str = " None " , num_prop : int = 0 , cls_token = True ) :
"""
Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
attention map of tokens at the moment.
- standard: str: the method applied to select the tokens
- num_prop: int: the number of tokens to be propagated
Return:
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
- index_prop: Tensor([B, num_prop]): the index of propagated tokens
"""
assert len ( weight . shape ) == 4 , " Selection methods on tensors other than the attention map haven ' t been supported yet. "
B , H , N1 , N2 = weight . shape
assert N1 == N2 , " Selection methods on tensors other than the attention map haven ' t been supported yet. "
N = N1
assert num_prop > = 0 , " The number of propagated/pruned tokens must be non-negative. "
if cls_token :
if standard == " CLSAttnMean " :
token_rank = weight [ : , : , 0 , 1 : ] . mean ( 1 )
elif standard == " CLSAttnMax " :
token_rank = weight [ : , : , 0 , 1 : ] . max ( 1 ) [ 0 ]
elif standard == " IMGAttnMean " :
token_rank = weight [ : , : , : , 1 : ] . sum ( - 2 ) . mean ( 1 )
elif standard == " IMGAttnMax " :
token_rank = weight [ : , : , : , 1 : ] . sum ( - 2 ) . max ( 1 ) [ 0 ]
elif standard == " DiagAttnMean " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . mean ( 1 )
elif standard == " DiagAttnMax " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . max ( 1 ) [ 0 ]
elif standard == " MixedAttnMean " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . mean ( 1 )
token_rank_2 = weight [ : , : , : , 1 : ] . sum ( - 2 ) . mean ( 1 )
token_rank = token_rank_1 * token_rank_2
elif standard == " MixedAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . max ( 1 ) [ 0 ]
token_rank_2 = weight [ : , : , : , 1 : ] . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 * token_rank_2
elif standard == " SumAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) [ : , : , 1 : ] . max ( 1 ) [ 0 ]
token_rank_2 = weight [ : , : , : , 1 : ] . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 + token_rank_2
elif standard == " CosSimMean " :
weight = weight [ : , : , 1 : , : ] . mean ( 1 )
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " CosSimMax " :
weight = weight [ : , : , 1 : , : ] . max ( 1 ) [ 0 ]
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " Random " :
token_rank = torch . randn ( ( B , N - 1 ) , device = weight . device )
else :
print ( " Type \' " , standard , " \' selection not supported. " )
assert False
token_rank = torch . argsort ( token_rank , dim = 1 , descending = True ) # B, N-1
index_kept = token_rank [ : , : - num_prop ] + 1 # B, N-1-num_prop
index_prop = token_rank [ : , - num_prop : ] + 1 # B, num_prop
else :
if standard == " IMGAttnMean " :
token_rank = weight . sum ( - 2 ) . mean ( 1 )
elif standard == " IMGAttnMax " :
token_rank = weight . sum ( - 2 ) . max ( 1 ) [ 0 ]
elif standard == " DiagAttnMean " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . mean ( 1 )
elif standard == " DiagAttnMax " :
token_rank = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . max ( 1 ) [ 0 ]
elif standard == " MixedAttnMean " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . mean ( 1 )
token_rank_2 = weight . sum ( - 2 ) . mean ( 1 )
token_rank = token_rank_1 * token_rank_2
elif standard == " MixedAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . max ( 1 ) [ 0 ]
token_rank_2 = weight . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 * token_rank_2
elif standard == " SumAttnMax " :
token_rank_1 = torch . diagonal ( weight , dim1 = - 2 , dim2 = - 1 ) . max ( 1 ) [ 0 ]
token_rank_2 = weight . sum ( - 2 ) . max ( 1 ) [ 0 ]
token_rank = token_rank_1 + token_rank_2
elif standard == " CosSimMean " :
weight = weight . mean ( 1 )
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " CosSimMax " :
weight = weight . max ( 1 ) [ 0 ]
weight = weight / weight . norm ( dim = - 1 , keepdim = True )
token_rank = - ( weight @ weight . transpose ( - 1 , - 2 ) ) . sum ( - 1 )
elif standard == " Random " :
token_rank = torch . randn ( ( B , N - 1 ) , device = weight . device )
else :
print ( " Type \' " , standard , " \' selection not supported. " )
assert False
token_rank = torch . argsort ( token_rank , dim = 1 , descending = True ) # B, N-1
index_kept = token_rank [ : , : - num_prop ] # B, N-1-num_prop
index_prop = token_rank [ : , - num_prop : ] # B, num_prop
return index_kept , index_prop
class OverlappingPatchEmbed ( nn . Module ) :
def __init__ ( self , dim_in , dim_out , stride = 2 ) :
super ( ) . __init__ ( )
kernel_size = stride * 2 - 1
padding = kernel_size / / 2
self . conv = nn . Conv2d ( dim_in , dim_out , kernel_size , stride = stride , padding = padding )
def forward ( self , x ) :
return self . conv ( x )
class PEG ( nn . Module ) :
def __init__ ( self , dim , kernel_size = 3 ) :
super ( ) . __init__ ( )
self . proj = nn . Conv2d ( dim , dim , kernel_size = kernel_size , padding = kernel_size / / 2 , groups = dim , stride = 1 )
def forward ( self , x ) :
return self . proj ( x ) + x
# feedforward
class FeedForwardDSSA ( nn . Module ) :
def __init__ ( self , dim , mult = 4 , dropout = 0. ) :
super ( ) . __init__ ( )
inner_dim = int ( dim * mult )
self . net = nn . Sequential (
ChanLayerNorm ( dim ) ,
nn . Conv2d ( dim , inner_dim , 1 ) ,
nn . GELU ( ) ,
nn . Dropout ( dropout ) ,
nn . Conv2d ( inner_dim , dim , 1 ) ,
nn . Dropout ( dropout )
)
def forward ( self , x ) :
return self . net ( x )
# attention
class DSSA ( nn . Module ) :
def __init__ (
self ,
dim ,
heads = 8 ,
dim_head = 32 ,
dropout = 0. ,
window_size = 7
) :
super ( ) . __init__ ( )
self . heads = heads
self . scale = dim_head * * - 0.5
self . window_size = window_size
inner_dim = dim_head * heads
self . norm = ChanLayerNorm ( dim )
self . attend = nn . Sequential (
nn . Softmax ( dim = - 1 ) ,
nn . Dropout ( dropout )
)
self . to_qkv = nn . Conv1d ( dim , inner_dim * 3 , 1 , bias = False )
# window tokens
self . window_tokens = nn . Parameter ( torch . randn ( dim ) )
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self . window_tokens_to_qk = nn . Sequential (
nn . LayerNorm ( dim_head ) ,
nn . GELU ( ) ,
Rearrange ( ' b h n c -> b (h c) n ' ) ,
nn . Conv1d ( inner_dim , inner_dim * 2 , 1 ) ,
Rearrange ( ' b (h c) n -> b h n c ' , h = heads ) ,
)
# window attention
self . window_attend = nn . Sequential (
nn . Softmax ( dim = - 1 ) ,
nn . Dropout ( dropout )
)
self . to_out = nn . Sequential (
nn . Conv2d ( inner_dim , dim , 1 ) ,
nn . Dropout ( dropout )
)
def forward ( self , x ) :
"""
einstein notation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
"""
batch , height , width , heads , wsz = x . shape [ 0 ] , * x . shape [ - 2 : ] , self . heads , self . window_size
assert ( height % wsz ) == 0 and ( width % wsz ) == 0 , f ' height { height } and width { width } must be divisible by window size { wsz } '
num_windows = ( height / / wsz ) * ( width / / wsz )
x = self . norm ( x )
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange ( x , ' b c (h w1) (w w2) -> (b h w) c (w1 w2) ' , w1 = wsz , w2 = wsz )
# add windowing tokens
w = repeat ( self . window_tokens , ' c -> b c 1 ' , b = x . shape [ 0 ] )
x = torch . cat ( ( w , x ) , dim = - 1 )
# project for queries, keys, value
q , k , v = self . to_qkv ( x ) . chunk ( 3 , dim = 1 )
# split out heads
q , k , v = map ( lambda t : rearrange ( t , ' b (h d) ... -> b h (...) d ' , h = heads ) , ( q , k , v ) )
# scale
q = q * self . scale
# similarity
dots = einsum ( ' b h i d, b h j d -> b h i j ' , q , k )
# attention
attn = self . attend ( dots )
# aggregate values
out = torch . matmul ( attn , v )
# split out windowed tokens
window_tokens , windowed_fmaps = out [ : , : , 0 ] , out [ : , : , 1 : ]
# early return if there is only 1 window
if num_windows == 1 :
fmap = rearrange ( windowed_fmaps , ' (b x y) h (w1 w2) d -> b (h d) (x w1) (y w2) ' , x = height / / wsz , y = width / / wsz , w1 = wsz , w2 = wsz )
return self . to_out ( fmap )
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange ( window_tokens , ' (b x y) h d -> b h (x y) d ' , x = height / / wsz , y = width / / wsz )
windowed_fmaps = rearrange ( windowed_fmaps , ' (b x y) h n d -> b h (x y) n d ' , x = height / / wsz , y = width / / wsz )
# windowed queries and keys (preceded by prenorm activation)
w_q , w_k = self . window_tokens_to_qk ( window_tokens ) . chunk ( 2 , dim = - 1 )
# scale
w_q = w_q * self . scale
# similarities
w_dots = einsum ( ' b h i d, b h j d -> b h i j ' , w_q , w_k )
w_attn = self . window_attend ( w_dots )
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum ( ' b h i j, b h j w d -> b h i w d ' , w_attn , windowed_fmaps )
# fold back the windows and then combine heads for aggregation
fmap = rearrange ( aggregated_windowed_fmap , ' b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2) ' , x = height / / wsz , y = width / / wsz , w1 = wsz , w2 = wsz )
return self . to_out ( fmap )
class PoolFormerBlock ( nn . Module ) :
"""
@@ -1221,8 +1093,13 @@ class PoolFormerBlock(nn.Module):
#self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock()
self . window_size = 4
self . attn_heads = 4
self . attn_mask = None
self . token_mixer = WindowAttention ( dim = dim , window_size = to_2tuple ( self . window_size ) , num_heads = 4 )
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self . token_mixer = nn . ModuleList ( [
DSSA ( dim , heads = self . attn_heads , window_size = self . window_size ) ,
FeedForwardDSSA ( dim )
] )
self . norm2 = norm_layer ( dim )
mlp_hidden_dim = int ( dim * mlp_ratio )
self . mlp = Mlp ( in_features = dim , hidden_features = mlp_hidden_dim ,
@@ -1238,20 +1115,14 @@ class PoolFormerBlock(nn.Module):
self . layer_scale_2 = nn . Parameter (
layer_scale_init_value * torch . ones ( ( dim ) ) , requires_grad = True )
def forward ( self , x , weight , token_scales = None ) :
def forward ( self , x ) :
B , C , H , W = x . shape
x_windows = window_partition ( x , self . window_size )
x_windows = x_windows . view ( - 1 , self . window_size * self . window_size , C )
attn_windows = self . token_mixer ( x_windows , mask = self . attn_mask )
attn_windows = attn_windows . view ( - 1 , self . window_size , self . window_size , C )
x_attn = window_reverse ( attn_windows , self . window_size , H , W )
index_kept , index_prop = select ( x_attn , standard = " MixedAttnMax " , num_prop = 0 ,
cls_token = False )
original_shape = x_attn . shape
x_attn = x_attn . view ( - 1 , self . window_size * self . window_size , C )
x_attn , weight , token_scales = propagate ( x_attn , weight , index_kept , index_prop , standard = " GraphProp " ,
alpha = 0.1 , token_scales = token_scales , cls_token = False )
x_attn = x_attn . view ( * original_shape )
# x_windows = window_partition(x, self.window_size)
# x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
# attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# x_attn = window_reverse(attn_windows, self.window_size, H, W)
x_attn = self . token_mixer ( x )
if self . use_layer_scale :
x = x + self . drop_path (
self . layer_scale_1 . unsqueeze ( - 1 ) . unsqueeze ( - 1 )
@@ -1262,6 +1133,9 @@ class PoolFormerBlock(nn.Module):
else :
x = x + self . drop_path ( x_attn )
x = x + self . drop_path ( self . mlp ( self . norm2 ( x ) ) )
if np . count_nonzero ( np . isnan ( x ) ) > 0 :
print ( " PFBlock " )
return x
class PatchEmbed ( nn . Module ) :
"""
@@ -1347,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self . bias . unsqueeze ( - 1 ) . unsqueeze ( - 1 )
return x
class FeedForward ( nn . Module ) :
class FeedForwardFNet ( nn . Module ) :
def __init__ ( self , dim , hidden_dim , dropout = 0. ) :
super ( ) . __init__ ( )
self . net = nn . Sequential (
@@ -1383,7 +1257,7 @@ class FNet(nn.Module):
for _ in range ( depth ) :
self . layers . append ( nn . ModuleList ( [
PreNorm ( dim , FNetBlock ( ) ) ,
PreNorm ( dim , FeedForward ( dim , mlp_dim , dropout = dropout ) )
PreNorm ( dim , FeedForwardFNet ( dim , mlp_dim , dropout = dropout ) )
] ) )
def forward ( self , x ) :
for attn , ff in self . layers :