6 Commits

Author SHA1 Message Date
thanhvc3
9075b53be6 try sep vit 2024-04-28 11:04:26 +07:00
thanhvc3
ab5c1d0b4b try sep vit 2024-04-28 11:00:08 +07:00
thanhvc3
3243b1d963 try sep vit 2024-04-28 09:42:39 +07:00
thanhvc3
37b01708b4 try sep vit 2024-04-28 01:44:33 +07:00
thanhvc3
a246d2bb64 try sep vit 2024-04-28 01:01:31 +07:00
thanhvc3
4a962a02ad try sep vit 2024-04-27 21:57:24 +07:00
2 changed files with 221 additions and 343 deletions

View File

@@ -478,7 +478,11 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp) try:
loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()

462
models.py
View File

@@ -1,17 +1,16 @@
import torch import torch
from torch import nn from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module): class ConvE(torch.nn.Module):
@@ -528,22 +527,6 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1]) self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear( self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \ embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity() else nn.Identity()
@@ -561,45 +544,8 @@ class FouriER(torch.nn.Module):
def forward_tokens(self, x): def forward_tokens(self, x):
outs = [] outs = []
B, C, H, W = x.shape
N = H*W
if self.graph_type in ["Semantic", "Mixed"]:
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self.class_token:
x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
else:
x_normed = x / x.norm(dim=-1, keepdim=True)
x_cossim = x_normed @ x_normed.transpose(-1, -2)
threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1
semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
if self.class_token:
semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
else:
semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
if self.graph_type == "None":
graph = None
else:
if self.graph_type == "Spatial":
graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
elif self.graph_type == "Semantic":
graph = semantic_graph
elif self.graph_type == "Mixed":
# Integrate the spatial graph and semantic graph
spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
# Symmetrically normalize the graph
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network): for idx, block in enumerate(self.network):
try: x = block(x)
x = block(x, graph)
except:
x = block(x)
# output only the features of last layer for image classification # output only the features of last layer for image classification
return x return x
@@ -612,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@@ -758,7 +706,7 @@ def basic_blocks(dim, index, layers,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
)) ))
blocks = SeqModel(*blocks) blocks = nn.Sequential(*blocks)
return blocks return blocks
@@ -923,278 +871,202 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x return x
class SeqModel(nn.Sequential): def cast_tuple(val, length = 1):
def forward(self, *inputs): return val if isinstance(val, tuple) else ((val,) * length)
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
def propagate(x: torch.Tensor, weight: torch.Tensor, # helper classes
index_kept: torch.Tensor, index_prop: torch.Tensor,
standard: str = "None", alpha: Optional[float] = 0,
token_scales: Optional[torch.Tensor] = None,
cls_token=True):
"""
Propagate tokens based on the selection results.
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens, class ChanLayerNorm(nn.Module):
excluding the [CLS] token. weight could be a pre-defined def __init__(self, dim, eps = 1e-5):
graph of the current feature map (by default) or the super().__init__()
attention map (need to manually modify the Block Module). self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
- standard: str: the method applied to propagate the tokens, including "None", "Mean" and def forward(self, x):
"GraphProp" return self.conv(x)
- alpha: float: the coefficient of propagated features class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales def forward(self, x):
is None by default. If it is not None, then token_scales return self.proj(x) + x
represents the scales of each token and should sum up to N.
Return: # feedforward
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation class FeedForwardDSSA(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation # attention
"""
B, N, C = x.shape class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
# Step 1: divide tokens self.norm = ChanLayerNorm(dim)
if cls_token:
x_cls = x[:, 0:1] # B, 1, C
x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
# Step 2: divide token_scales if it is not None self.attend = nn.Sequential(
if token_scales is not None: nn.Softmax(dim = -1),
if cls_token: nn.Dropout(dropout)
token_scales_cls = token_scales[:, 0:1] # B, 1 )
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
# Step 3: propagate tokens self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
if standard == "None":
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
""" """
No further propagation einstein notation
"""
pass
elif standard == "Mean": b - batch
""" c - channels
Calculate the mean of all the propagated tokens, w1 - window size (height)
and concatenate the result token back to kept tokens. w2 - also window size (width)
""" i - sequence dimension (source)
# naive average j - sequence dimension (target dimension to be reduced)
x_prop = x_prop.mean(1, keepdim=True) # B, 1, C h - heads
# Concatenate the average token x - height of feature map divided by window size
x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C y - width of feature map divided by window size
elif standard == "GraphProp":
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None, "The graph weight is needed for graph propagation"
# Step 3.1: divide propagation weights.
if cls_token:
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
else:
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce="sum",
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
""" """
# Step 3.3: calculate the scale of each token if token_scales is not None batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
if token_scales is not None: assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
if cls_token: num_windows = (height // wsz) * (width // wsz)
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales = token_scales[:, 1:]
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
if cls_token:
token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
else:
assert False, "Propagation method \'%f\' has not been supported yet." % standard
x = self.norm(x)
if cls_token: # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
# Step 4 concatenate the [CLS] token and generate returned value
x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
else:
x = x_kept
return x, weight, token_scales
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
# add windowing tokens
def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True): w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
""" x = torch.cat((w, x), dim = -1)
Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
attention map of tokens at the moment.
- standard: str: the method applied to select the tokens # project for queries, keys, value
- num_prop: int: the number of tokens to be propagated q, k, v = self.to_qkv(x).chunk(3, dim = 1)
Return: # split out heads
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
- index_prop: Tensor([B, num_prop]): the index of propagated tokens q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
"""
assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet." # scale
B, H, N1, N2 = weight.shape
assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet."
N = N1
assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative."
if cls_token: q = q * self.scale
if standard == "CLSAttnMean":
token_rank = weight[:,:,0,1:].mean(1)
elif standard == "CLSAttnMax": # similarity
token_rank = weight[:,:,0,1:].max(1)[0]
elif standard == "IMGAttnMean": dots = einsum('b h i d, b h j d -> b h i j', q, k)
token_rank = weight[:,:,:,1:].sum(-2).mean(1)
elif standard == "IMGAttnMax": # attention
token_rank = weight[:,:,:,1:].sum(-2).max(1)[0]
elif standard == "DiagAttnMean": attn = self.attend(dots)
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
elif standard == "DiagAttnMax": # aggregate values
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
elif standard == "MixedAttnMean": out = torch.matmul(attn, v)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax": # split out windowed tokens
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax": window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean": # early return if there is only 1 window
weight = weight[:,:,1:,:].mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax": if num_windows == 1:
weight = weight[:,:,1:,:].max(1)[0] fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
weight = weight / weight.norm(dim=-1, keepdim=True) return self.to_out(fmap)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random": # carry out the pointwise attention, the main novelty in the paper
token_rank = torch.randn((B, N-1), device=weight.device)
else: window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
print("Type\'", standard, "\' selection not supported.") windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 # windowed queries and keys (preceded by prenorm activation)
index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
else: w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
if standard == "IMGAttnMean":
token_rank = weight.sum(-2).mean(1)
elif standard == "IMGAttnMax": # scale
token_rank = weight.sum(-2).max(1)[0]
elif standard == "DiagAttnMean": w_q = w_q * self.scale
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
elif standard == "DiagAttnMax": # similarities
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
elif standard == "MixedAttnMean": w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
token_rank_2 = weight.sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax": w_attn = self.window_attend(w_dots)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax": # aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean": aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
weight = weight.mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax": # fold back the windows and then combine heads for aggregation
weight = weight.max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random": fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
token_rank = torch.randn((B, N-1), device=weight.device) return self.to_out(fmap)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:] # B, num_prop
return index_kept, index_prop
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
@@ -1221,8 +1093,13 @@ class PoolFormerBlock(nn.Module):
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4 self.window_size = 4
self.attn_heads = 4
self.attn_mask = None self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@@ -1238,20 +1115,14 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, weight, token_scales = None): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) # x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) # attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W) # x_attn = window_reverse(attn_windows, self.window_size, H, W)
index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, x_attn = self.token_mixer(x)
cls_token=False)
original_shape = x_attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
@@ -1262,6 +1133,9 @@ class PoolFormerBlock(nn.Module):
else: else:
x = x + self.drop_path(x_attn) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@@ -1347,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@@ -1383,7 +1257,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: