6 Commits

Author SHA1 Message Date
9075b53be6 try sep vit 2024-04-28 11:04:26 +07:00
ab5c1d0b4b try sep vit 2024-04-28 11:00:08 +07:00
3243b1d963 try sep vit 2024-04-28 09:42:39 +07:00
37b01708b4 try sep vit 2024-04-28 01:44:33 +07:00
a246d2bb64 try sep vit 2024-04-28 01:01:31 +07:00
4a962a02ad try sep vit 2024-04-27 21:57:24 +07:00
2 changed files with 221 additions and 343 deletions

View File

@ -478,7 +478,11 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp) try:
loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()

558
models.py
View File

@ -1,17 +1,16 @@
import torch import torch
from torch import nn from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module): class ConvE(torch.nn.Module):
@ -528,22 +527,6 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1]) self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear( self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \ embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity() else nn.Identity()
@ -561,45 +544,8 @@ class FouriER(torch.nn.Module):
def forward_tokens(self, x): def forward_tokens(self, x):
outs = [] outs = []
B, C, H, W = x.shape
N = H*W
if self.graph_type in ["Semantic", "Mixed"]:
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self.class_token:
x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
else:
x_normed = x / x.norm(dim=-1, keepdim=True)
x_cossim = x_normed @ x_normed.transpose(-1, -2)
threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1
semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
if self.class_token:
semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
else:
semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
if self.graph_type == "None":
graph = None
else:
if self.graph_type == "Spatial":
graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
elif self.graph_type == "Semantic":
graph = semantic_graph
elif self.graph_type == "Mixed":
# Integrate the spatial graph and semantic graph
spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
# Symmetrically normalize the graph
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network): for idx, block in enumerate(self.network):
try: x = block(x)
x = block(x, graph)
except:
x = block(x)
# output only the features of last layer for image classification # output only the features of last layer for image classification
return x return x
@ -612,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@ -758,7 +706,7 @@ def basic_blocks(dim, index, layers,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
)) ))
blocks = SeqModel(*blocks) blocks = nn.Sequential(*blocks)
return blocks return blocks
@ -923,278 +871,202 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x return x
class SeqModel(nn.Sequential): def cast_tuple(val, length = 1):
def forward(self, *inputs): return val if isinstance(val, tuple) else ((val,) * length)
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
def propagate(x: torch.Tensor, weight: torch.Tensor, # helper classes
index_kept: torch.Tensor, index_prop: torch.Tensor,
standard: str = "None", alpha: Optional[float] = 0,
token_scales: Optional[torch.Tensor] = None,
cls_token=True):
"""
Propagate tokens based on the selection results.
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,
excluding the [CLS] token. weight could be a pre-defined
graph of the current feature map (by default) or the
attention map (need to manually modify the Block Module).
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
- standard: str: the method applied to propagate the tokens, including "None", "Mean" and
"GraphProp"
- alpha: float: the coefficient of propagated features
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
is None by default. If it is not None, then token_scales
represents the scales of each token and should sum up to N.
Return:
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
"""
B, N, C = x.shape
# Step 1: divide tokens
if cls_token:
x_cls = x[:, 0:1] # B, 1, C
x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
# Step 2: divide token_scales if it is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
# Step 3: propagate tokens
if standard == "None":
"""
No further propagation
"""
pass
elif standard == "Mean":
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop.mean(1, keepdim=True) # B, 1, C
# Concatenate the average token
x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C
elif standard == "GraphProp":
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None, "The graph weight is needed for graph propagation"
# Step 3.1: divide propagation weights.
if cls_token:
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
else:
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce="sum",
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales = token_scales[:, 1:]
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
if cls_token:
token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
else:
assert False, "Propagation method \'%f\' has not been supported yet." % standard
if cls_token:
# Step 4 concatenate the [CLS] token and generate returned value
x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
else:
x = x_kept
return x, weight, token_scales
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True): class OverlappingPatchEmbed(nn.Module):
""" def __init__(self, dim_in, dim_out, stride = 2):
Select image tokens to be propagated. The [CLS] token will be ignored. super().__init__()
====================================================================== kernel_size = stride * 2 - 1
Args: padding = kernel_size // 2
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
attention map of tokens at the moment.
def forward(self, x):
- standard: str: the method applied to select the tokens return self.conv(x)
- num_prop: int: the number of tokens to be propagated class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
Return: super().__init__()
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
- index_prop: Tensor([B, num_prop]): the index of propagated tokens def forward(self, x):
""" return self.proj(x) + x
assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet." # feedforward
B, H, N1, N2 = weight.shape
assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet." class FeedForwardDSSA(nn.Module):
N = N1 def __init__(self, dim, mult = 4, dropout = 0.):
assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative." super().__init__()
inner_dim = int(dim * mult)
if cls_token: self.net = nn.Sequential(
if standard == "CLSAttnMean": ChanLayerNorm(dim),
token_rank = weight[:,:,0,1:].mean(1) nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
elif standard == "CLSAttnMax": nn.Dropout(dropout),
token_rank = weight[:,:,0,1:].max(1)[0] nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
elif standard == "IMGAttnMean": )
token_rank = weight[:,:,:,1:].sum(-2).mean(1) def forward(self, x):
return self.net(x)
elif standard == "IMGAttnMax":
token_rank = weight[:,:,:,1:].sum(-2).max(1)[0] # attention
elif standard == "DiagAttnMean": class DSSA(nn.Module):
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) def __init__(
self,
elif standard == "DiagAttnMax": dim,
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] heads = 8,
dim_head = 32,
elif standard == "MixedAttnMean": dropout = 0.,
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) window_size = 7
token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1) ):
token_rank = token_rank_1 * token_rank_2 super().__init__()
self.heads = heads
elif standard == "MixedAttnMax": self.scale = dim_head ** -0.5
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] self.window_size = window_size
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] inner_dim = dim_head * heads
token_rank = token_rank_1 * token_rank_2
self.norm = ChanLayerNorm(dim)
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] self.attend = nn.Sequential(
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] nn.Softmax(dim = -1),
token_rank = token_rank_1 + token_rank_2 nn.Dropout(dropout)
)
elif standard == "CosSimMean":
weight = weight[:,:,1:,:].mean(1) self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) # window tokens
elif standard == "CosSimMax": self.window_tokens = nn.Parameter(torch.randn(dim))
weight = weight[:,:,1:,:].max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True) # prenorm and non-linearity for window tokens
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) # then projection to queries and keys for window tokens
elif standard == "Random": self.window_tokens_to_qk = nn.Sequential(
token_rank = torch.randn((B, N-1), device=weight.device) nn.LayerNorm(dim_head),
nn.GELU(),
else: Rearrange('b h n c -> b (h c) n'),
print("Type\'", standard, "\' selection not supported.") nn.Conv1d(inner_dim, inner_dim * 2, 1),
assert False Rearrange('b (h c) n -> b h n c', h = heads),
)
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop # window attention
index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
self.window_attend = nn.Sequential(
else: nn.Softmax(dim = -1),
if standard == "IMGAttnMean": nn.Dropout(dropout)
token_rank = weight.sum(-2).mean(1) )
elif standard == "IMGAttnMax": self.to_out = nn.Sequential(
token_rank = weight.sum(-2).max(1)[0] nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
elif standard == "DiagAttnMean": )
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
def forward(self, x):
elif standard == "DiagAttnMax": """
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] einstein notation
elif standard == "MixedAttnMean": b - batch
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) c - channels
token_rank_2 = weight.sum(-2).mean(1) w1 - window size (height)
token_rank = token_rank_1 * token_rank_2 w2 - also window size (width)
i - sequence dimension (source)
elif standard == "MixedAttnMax": j - sequence dimension (target dimension to be reduced)
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] h - heads
token_rank_2 = weight.sum(-2).max(1)[0] x - height of feature map divided by window size
token_rank = token_rank_1 * token_rank_2 y - width of feature map divided by window size
"""
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
token_rank_2 = weight.sum(-2).max(1)[0] assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
token_rank = token_rank_1 + token_rank_2 num_windows = (height // wsz) * (width // wsz)
elif standard == "CosSimMean": x = self.norm(x)
weight = weight.mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True) # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
elif standard == "CosSimMax":
weight = weight.max(1)[0] # add windowing tokens
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
x = torch.cat((w, x), dim = -1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device) # project for queries, keys, value
else: q, k, v = self.to_qkv(x).chunk(3, dim = 1)
print("Type\'", standard, "\' selection not supported.")
assert False # split out heads
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:] # B, num_prop # scale
return index_kept, index_prop
q = q * self.scale
# similarity
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# split out windowed tokens
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
# early return if there is only 1 window
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
# windowed queries and keys (preceded by prenorm activation)
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
# scale
w_q = w_q * self.scale
# similarities
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
w_attn = self.window_attend(w_dots)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
# fold back the windows and then combine heads for aggregation
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
@ -1221,8 +1093,13 @@ class PoolFormerBlock(nn.Module):
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4 self.window_size = 4
self.attn_heads = 4
self.attn_mask = None self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -1238,20 +1115,14 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, weight, token_scales = None): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) # x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) # attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W) # x_attn = window_reverse(attn_windows, self.window_size, H, W)
index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, x_attn = self.token_mixer(x)
cls_token=False)
original_shape = x_attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
@ -1262,6 +1133,9 @@ class PoolFormerBlock(nn.Module):
else: else:
x = x + self.drop_path(x_attn) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@ -1347,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@ -1383,7 +1257,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: