6 Commits

Author SHA1 Message Date
thanhvc3
9075b53be6 try sep vit 2024-04-28 11:04:26 +07:00
thanhvc3
ab5c1d0b4b try sep vit 2024-04-28 11:00:08 +07:00
thanhvc3
3243b1d963 try sep vit 2024-04-28 09:42:39 +07:00
thanhvc3
37b01708b4 try sep vit 2024-04-28 01:44:33 +07:00
thanhvc3
a246d2bb64 try sep vit 2024-04-28 01:01:31 +07:00
thanhvc3
4a962a02ad try sep vit 2024-04-27 21:57:24 +07:00
2 changed files with 221 additions and 343 deletions

View File

@@ -478,7 +478,11 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
try:
loss = self.model.loss(pred, label, sub_samp) loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()

486
models.py
View File

@@ -1,17 +1,16 @@
import torch import torch
from torch import nn from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module): class ConvE(torch.nn.Module):
@@ -528,22 +527,6 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1]) self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear( self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \ embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity() else nn.Identity()
@@ -561,44 +544,7 @@ class FouriER(torch.nn.Module):
def forward_tokens(self, x): def forward_tokens(self, x):
outs = [] outs = []
B, C, H, W = x.shape
N = H*W
if self.graph_type in ["Semantic", "Mixed"]:
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self.class_token:
x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
else:
x_normed = x / x.norm(dim=-1, keepdim=True)
x_cossim = x_normed @ x_normed.transpose(-1, -2)
threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1
semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
if self.class_token:
semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
else:
semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
if self.graph_type == "None":
graph = None
else:
if self.graph_type == "Spatial":
graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
elif self.graph_type == "Semantic":
graph = semantic_graph
elif self.graph_type == "Mixed":
# Integrate the spatial graph and semantic graph
spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
# Symmetrically normalize the graph
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network): for idx, block in enumerate(self.network):
try:
x = block(x, graph)
except:
x = block(x) x = block(x)
# output only the features of last layer for image classification # output only the features of last layer for image classification
return x return x
@@ -612,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@@ -758,7 +706,7 @@ def basic_blocks(dim, index, layers,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
)) ))
blocks = SeqModel(*blocks) blocks = nn.Sequential(*blocks)
return blocks return blocks
@@ -923,278 +871,202 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x return x
class SeqModel(nn.Sequential): def cast_tuple(val, length = 1):
def forward(self, *inputs): return val if isinstance(val, tuple) else ((val,) * length)
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
def propagate(x: torch.Tensor, weight: torch.Tensor, # helper classes
index_kept: torch.Tensor, index_prop: torch.Tensor,
standard: str = "None", alpha: Optional[float] = 0, class ChanLayerNorm(nn.Module):
token_scales: Optional[torch.Tensor] = None, def __init__(self, dim, eps = 1e-5):
cls_token=True): super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForwardDSSA(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
""" """
Propagate tokens based on the selection results. einstein notation
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens, b - batch
excluding the [CLS] token. weight could be a pre-defined c - channels
graph of the current feature map (by default) or the w1 - window size (height)
attention map (need to manually modify the Block Module). w2 - also window size (width)
i - sequence dimension (source)
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X j - sequence dimension (target dimension to be reduced)
h - heads
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X x - height of feature map divided by window size
y - width of feature map divided by window size
- standard: str: the method applied to propagate the tokens, including "None", "Mean" and
"GraphProp"
- alpha: float: the coefficient of propagated features
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
is None by default. If it is not None, then token_scales
represents the scales of each token and should sum up to N.
Return:
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
""" """
B, N, C = x.shape batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
# Step 1: divide tokens x = self.norm(x)
if cls_token:
x_cls = x[:, 0:1] # B, 1, C
x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
# Step 2: divide token_scales if it is not None # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
# Step 3: propagate tokens x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
if standard == "None":
"""
No further propagation
"""
pass
elif standard == "Mean": # add windowing tokens
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop.mean(1, keepdim=True) # B, 1, C
# Concatenate the average token
x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C
elif standard == "GraphProp": w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
""" x = torch.cat((w, x), dim = -1)
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None, "The graph weight is needed for graph propagation"
# Step 3.1: divide propagation weights. # project for queries, keys, value
if cls_token:
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
else:
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs # split out heads
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0, # scale
index=reduce_indices[:, None],
src=sparse_matmul,
reduce="sum",
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None q = q * self.scale
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales = token_scales[:, 1:]
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
if cls_token:
token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
else:
assert False, "Propagation method \'%f\' has not been supported yet." % standard
# similarity
if cls_token: dots = einsum('b h i d, b h j d -> b h i j', q, k)
# Step 4 concatenate the [CLS] token and generate returned value
x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
else:
x = x_kept
return x, weight, token_scales
# attention
attn = self.attend(dots)
def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True): # aggregate values
"""
Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
attention map of tokens at the moment.
- standard: str: the method applied to select the tokens out = torch.matmul(attn, v)
- num_prop: int: the number of tokens to be propagated # split out windowed tokens
Return: window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
- index_prop: Tensor([B, num_prop]): the index of propagated tokens # early return if there is only 1 window
"""
assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet." if num_windows == 1:
B, H, N1, N2 = weight.shape fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet." return self.to_out(fmap)
N = N1
assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative."
if cls_token: # carry out the pointwise attention, the main novelty in the paper
if standard == "CLSAttnMean":
token_rank = weight[:,:,0,1:].mean(1)
elif standard == "CLSAttnMax": window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
token_rank = weight[:,:,0,1:].max(1)[0] windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
elif standard == "IMGAttnMean": # windowed queries and keys (preceded by prenorm activation)
token_rank = weight[:,:,:,1:].sum(-2).mean(1)
elif standard == "IMGAttnMax": w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
token_rank = weight[:,:,:,1:].sum(-2).max(1)[0]
elif standard == "DiagAttnMean": # scale
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
elif standard == "DiagAttnMax": w_q = w_q * self.scale
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
elif standard == "MixedAttnMean": # similarities
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax": w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax": w_attn = self.window_attend(w_dots)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean": # aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
weight = weight[:,:,1:,:].mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax": aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
weight = weight[:,:,1:,:].max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random": # fold back the windows and then combine heads for aggregation
token_rank = torch.randn((B, N-1), device=weight.device)
else: fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
print("Type\'", standard, "\' selection not supported.") return self.to_out(fmap)
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
else:
if standard == "IMGAttnMean":
token_rank = weight.sum(-2).mean(1)
elif standard == "IMGAttnMax":
token_rank = weight.sum(-2).max(1)[0]
elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
elif standard == "DiagAttnMax":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
token_rank_2 = weight.sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean":
weight = weight.mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax":
weight = weight.max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:] # B, num_prop
return index_kept, index_prop
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
@@ -1221,8 +1093,13 @@ class PoolFormerBlock(nn.Module):
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4 self.window_size = 4
self.attn_heads = 4
self.attn_mask = None self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@@ -1238,20 +1115,14 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, weight, token_scales = None): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) # x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) # attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W) # x_attn = window_reverse(attn_windows, self.window_size, H, W)
index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, x_attn = self.token_mixer(x)
cls_token=False)
original_shape = x_attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
@@ -1262,6 +1133,9 @@ class PoolFormerBlock(nn.Module):
else: else:
x = x + self.drop_path(x_attn) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@@ -1347,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@@ -1383,7 +1257,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: