16 Commits

Author SHA1 Message Date
47bc661a91 try gtp vit 2024-04-28 15:40:31 +07:00
3b6db89be1 try gtp vit 2024-04-28 15:35:24 +07:00
352f5f9da9 try gtp vit 2024-04-28 15:31:58 +07:00
b9273b6696 try gtp vit 2024-04-28 15:27:41 +07:00
d0e4630dd6 try gtp vit 2024-04-28 15:26:24 +07:00
08a3780ba6 try gtp vit 2024-04-28 15:25:44 +07:00
6fc56b920f try gtp vit 2024-04-28 15:24:05 +07:00
fddea4769f try gtp vit 2024-04-28 15:17:40 +07:00
d9209a7ef1 try gtp vit 2024-04-28 15:14:27 +07:00
0f986d7517 try gtp vit 2024-04-28 15:10:09 +07:00
4daa40527b try gtp vit 2024-04-28 15:08:08 +07:00
541c4fa2b3 try gtp vit 2024-04-28 12:05:57 +07:00
68a94bd1e2 try gtp vit 2024-04-28 12:05:04 +07:00
b01e504874 try gtp vit 2024-04-28 12:01:40 +07:00
23c44d3582 try gtp vit 2024-04-28 11:59:09 +07:00
41a5c7b05a try gtp vit 2024-04-28 11:57:17 +07:00
2 changed files with 341 additions and 219 deletions

View File

@ -478,11 +478,7 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
try: loss = self.model.loss(pred, label, sub_samp)
loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()

554
models.py
View File

@ -1,16 +1,17 @@
import torch import torch
from torch import nn, einsum from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module): class ConvE(torch.nn.Module):
@ -527,6 +528,22 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1]) self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear( self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \ embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity() else nn.Identity()
@ -544,8 +561,45 @@ class FouriER(torch.nn.Module):
def forward_tokens(self, x): def forward_tokens(self, x):
outs = [] outs = []
B, C, H, W = x.shape
N = H*W
if self.graph_type in ["Semantic", "Mixed"]:
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self.class_token:
x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
else:
x_normed = x / x.norm(dim=-1, keepdim=True)
x_cossim = x_normed @ x_normed.transpose(-1, -2)
threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1
semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
if self.class_token:
semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
else:
semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
if self.graph_type == "None":
graph = None
else:
if self.graph_type == "Spatial":
graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
elif self.graph_type == "Semantic":
graph = semantic_graph
elif self.graph_type == "Mixed":
# Integrate the spatial graph and semantic graph
spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
# Symmetrically normalize the graph
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network): for idx, block in enumerate(self.network):
x = block(x) try:
x = block(x, graph)
except:
x = block(x)
# output only the features of last layer for image classification # output only the features of last layer for image classification
return x return x
@ -558,8 +612,6 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@ -706,7 +758,7 @@ def basic_blocks(dim, index, layers,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
)) ))
blocks = nn.Sequential(*blocks) blocks = SeqModel(*blocks)
return blocks return blocks
@ -871,202 +923,278 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x return x
def cast_tuple(val, length = 1): class SeqModel(nn.Sequential):
return val if isinstance(val, tuple) else ((val,) * length) def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
# helper classes def propagate(x: torch.Tensor, weight: torch.Tensor,
index_kept: torch.Tensor, index_prop: torch.Tensor,
class ChanLayerNorm(nn.Module): standard: str = "None", alpha: Optional[float] = 0,
def __init__(self, dim, eps = 1e-5): token_scales: Optional[torch.Tensor] = None,
super().__init__() cls_token=True):
self.eps = eps """
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) Propagate tokens based on the selection results.
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) ================================================
Args:
def forward(self, x): - x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True) - weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b excluding the [CLS] token. weight could be a pre-defined
graph of the current feature map (by default) or the
class OverlappingPatchEmbed(nn.Module): attention map (need to manually modify the Block Module).
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__() - index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
kernel_size = stride * 2 - 1
padding = kernel_size // 2 - index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
- standard: str: the method applied to propagate the tokens, including "None", "Mean" and
def forward(self, x): "GraphProp"
return self.conv(x)
- alpha: float: the coefficient of propagated features
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3): - token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
super().__init__() is None by default. If it is not None, then token_scales
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1) represents the scales of each token and should sum up to N.
def forward(self, x): Return:
return self.proj(x) + x - x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
# feedforward - weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
class FeedForwardDSSA(nn.Module): - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
def __init__(self, dim, mult = 4, dropout = 0.): """
super().__init__()
inner_dim = int(dim * mult) B, N, C = x.shape
self.net = nn.Sequential(
ChanLayerNorm(dim), # Step 1: divide tokens
nn.Conv2d(dim, inner_dim, 1), if cls_token:
nn.GELU(), x_cls = x[:, 0:1] # B, 1, C
nn.Dropout(dropout), x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
nn.Conv2d(inner_dim, dim, 1), x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
nn.Dropout(dropout)
) # Step 2: divide token_scales if it is not None
def forward(self, x): if token_scales is not None:
return self.net(x) if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
# attention token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
class DSSA(nn.Module):
def __init__( # Step 3: propagate tokens
self, if standard == "None":
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
""" """
einstein notation No further propagation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
""" """
pass
elif standard == "Mean":
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop.mean(1, keepdim=True) # B, 1, C
# Concatenate the average token
x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C
elif standard == "GraphProp":
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None, "The graph weight is needed for graph propagation"
# Step 3.1: divide propagation weights.
if cls_token:
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
else:
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce="sum",
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales = token_scales[:, 1:]
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
if cls_token:
token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
else:
assert False, "Propagation method \'%f\' has not been supported yet." % standard
if cls_token:
# Step 4 concatenate the [CLS] token and generate returned value
x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
else:
x = x_kept
return x, weight, token_scales
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True):
"""
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz) Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
# add windowing tokens Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0]) attention map of tokens at the moment.
x = torch.cat((w, x), dim = -1)
- standard: str: the method applied to select the tokens
# project for queries, keys, value
- num_prop: int: the number of tokens to be propagated
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
Return:
# split out heads - index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v)) - index_prop: Tensor([B, num_prop]): the index of propagated tokens
"""
# scale
assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet."
q = q * self.scale B, H, N1, N2 = weight.shape
assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet."
# similarity N = N1
assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative."
dots = einsum('b h i d, b h j d -> b h i j', q, k)
if cls_token:
# attention if standard == "CLSAttnMean":
token_rank = weight[:,:,0,1:].mean(1)
attn = self.attend(dots)
elif standard == "CLSAttnMax":
# aggregate values token_rank = weight[:,:,0,1:].max(1)[0]
out = torch.matmul(attn, v) elif standard == "IMGAttnMean":
token_rank = weight[:,:,:,1:].sum(-2).mean(1)
# split out windowed tokens
elif standard == "IMGAttnMax":
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:] token_rank = weight[:,:,:,1:].sum(-2).max(1)[0]
# early return if there is only 1 window elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) elif standard == "DiagAttnMax":
return self.to_out(fmap) token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
# carry out the pointwise attention, the main novelty in the paper elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz) token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz) token_rank = token_rank_1 * token_rank_2
# windowed queries and keys (preceded by prenorm activation) elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1) token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
# scale
elif standard == "SumAttnMax":
w_q = w_q * self.scale token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
# similarities token_rank = token_rank_1 + token_rank_2
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k) elif standard == "CosSimMean":
weight = weight[:,:,1:,:].mean(1)
w_attn = self.window_attend(w_dots) weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
elif standard == "CosSimMax":
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps) weight = weight[:,:,1:,:].max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
# fold back the windows and then combine heads for aggregation token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) elif standard == "Random":
return self.to_out(fmap) token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
else:
if standard == "IMGAttnMean":
token_rank = weight.sum(-2).mean(1)
elif standard == "IMGAttnMax":
token_rank = weight.sum(-2).max(1)[0]
elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
elif standard == "DiagAttnMax":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
token_rank_2 = weight.sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean":
weight = weight.mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax":
weight = weight.max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:] # B, num_prop
return index_kept, index_prop
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
@ -1093,13 +1221,8 @@ class PoolFormerBlock(nn.Module):
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4 self.window_size = 4
self.attn_heads = 4
self.attn_mask = None self.attn_mask = None
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -1115,14 +1238,20 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x): def forward(self, x, weight, token_scales = None):
B, C, H, W = x.shape B, C, H, W = x.shape
# x_windows = window_partition(x, self.window_size) x_windows = window_partition(x, self.window_size)
# x_windows = x_windows.view(-1, self.window_size * self.window_size, C) x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
# attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# x_attn = window_reverse(attn_windows, self.window_size, H, W) x_attn = window_reverse(attn_windows, self.window_size, H, W)
x_attn = self.token_mixer(x) index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
cls_token=False)
original_shape = x_attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
@ -1133,9 +1262,6 @@ class PoolFormerBlock(nn.Module):
else: else:
x = x + self.drop_path(x_attn) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@ -1221,7 +1347,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForwardFNet(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@ -1257,7 +1383,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: