6 Commits

Author SHA1 Message Date
9075b53be6 try sep vit 2024-04-28 11:04:26 +07:00
ab5c1d0b4b try sep vit 2024-04-28 11:00:08 +07:00
3243b1d963 try sep vit 2024-04-28 09:42:39 +07:00
37b01708b4 try sep vit 2024-04-28 01:44:33 +07:00
a246d2bb64 try sep vit 2024-04-28 01:01:31 +07:00
4a962a02ad try sep vit 2024-04-27 21:57:24 +07:00
2 changed files with 235 additions and 106 deletions

View File

@ -478,7 +478,11 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp) try:
loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()

335
models.py
View File

@ -1,9 +1,10 @@
import torch import torch
from torch import nn from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -435,50 +436,6 @@ class TuckER(torch.nn.Module):
return pred return pred
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
"""
x: B, C, H, W
"""
B, C, H, W = x.shape
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
return flops
class FouriER(torch.nn.Module): class FouriER(torch.nn.Module):
def __init__(self, params, hid_drop = None, embed_dim = None): def __init__(self, params, hid_drop = None, embed_dim = None):
@ -532,10 +489,9 @@ class FouriER(torch.nn.Module):
self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size, self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size,
embed_dim=self.p.embed_dim, stride=4, padding=2) embed_dim=self.p.embed_dim, stride=4, padding=2)
network = [] network = []
layers = [2, 2, 6, 2] layers = [4, 4, 12, 4]
embed_dims = [self.p.embed_dim, 320, 256, 128] embed_dims = [self.p.embed_dim, 128, 320, 128]
mlp_ratios = [4, 4, 8, 12] mlp_ratios = [4, 4, 4, 4]
num_heads = [4, 4, 4, 4]
downsamples = [True, True, True, True] downsamples = [True, True, True, True]
pool_size=3 pool_size=3
act_layer=nn.GELU act_layer=nn.GELU
@ -547,7 +503,6 @@ class FouriER(torch.nn.Module):
down_patch_size=3 down_patch_size=3
down_stride=2 down_stride=2
down_pad=1 down_pad=1
window_size = 4
num_classes=self.p.embed_dim num_classes=self.p.embed_dim
for i in range(len(layers)): for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers, stage = basic_blocks(embed_dims[i], i, layers,
@ -556,9 +511,7 @@ class FouriER(torch.nn.Module):
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path_rate=drop_path_rate, drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value)
num_heads=num_heads[i], input_resolution=(image_h // (2**i), image_w // (2**i)),
window_size=window_size, shift_size=0)
network.append(stage) network.append(stage)
if i >= len(layers) - 1: if i >= len(layers) - 1:
break break
@ -570,7 +523,6 @@ class FouriER(torch.nn.Module):
padding=down_pad, padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i+1] in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
) )
# PatchMerging(dim=embed_dims[i+1])
) )
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)
@ -606,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@ -736,7 +690,7 @@ def basic_blocks(dim, index, layers,
pool_size=3, mlp_ratio=4., pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm, act_layer=nn.GELU, norm_layer=GroupNorm,
drop_rate=.0, drop_path_rate=0., drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5, num_heads = 4, input_resolution = None, window_size = 4, shift_size = 2): use_layer_scale=True, layer_scale_init_value=1e-5):
""" """
generate PoolFormer blocks for a stage generate PoolFormer blocks for a stage
return: PoolFormer blocks return: PoolFormer blocks
@ -751,8 +705,6 @@ def basic_blocks(dim, index, layers,
drop=drop_rate, drop_path=block_dpr, drop=drop_rate, drop_path=block_dpr,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
num_heads=num_heads, input_resolution = input_resolution,
window_size=window_size, shift_size=shift_size
)) ))
blocks = nn.Sequential(*blocks) blocks = nn.Sequential(*blocks)
@ -872,12 +824,9 @@ class WindowAttention(nn.Module):
attn = attn + relative_position_bias.unsqueeze(0) attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None: if mask is not None:
try: nW = mask.shape[0]
nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N)
attn = attn.view(-1, self.num_heads, N, N)
except:
pass
attn = self.softmax(attn) attn = self.softmax(attn)
else: else:
attn = self.softmax(attn) attn = self.softmax(attn)
@ -922,6 +871,203 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x return x
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForwardDSSA(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
"""
einstein notation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
"""
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
# add windowing tokens
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
x = torch.cat((w, x), dim = -1)
# project for queries, keys, value
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# scale
q = q * self.scale
# similarity
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# split out windowed tokens
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
# early return if there is only 1 window
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
# windowed queries and keys (preceded by prenorm activation)
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
# scale
w_q = w_q * self.scale
# similarities
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
w_attn = self.window_attend(w_dots)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
# fold back the windows and then combine heads for aggregation
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
Implementation of one PoolFormer block. Implementation of one PoolFormer block.
@ -938,18 +1084,22 @@ class PoolFormerBlock(nn.Module):
""" """
def __init__(self, dim, pool_size=3, mlp_ratio=4., def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm, act_layer=nn.GELU, norm_layer=GroupNorm,
drop=0., drop_path=0., num_heads=4, drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5, input_resolution = None, window_size = 4, shift_size = 2): use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = window_size self.window_size = 4
self.shift_size = shift_size self.attn_heads = 4
self.input_resolution = input_resolution self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, attn_drop=0.2, proj_drop=0.1) # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -964,43 +1114,15 @@ class PoolFormerBlock(nn.Module):
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, :, h, w] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) # x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) # attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W) # x_attn = window_reverse(attn_windows, self.window_size, H, W)
if self.shift_size > 0: x_attn = self.token_mixer(x)
x = torch.roll(x_attn, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = x_attn
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
@ -1011,6 +1133,9 @@ class PoolFormerBlock(nn.Module):
else: else:
x = x + self.drop_path(x_attn) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@ -1096,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@ -1132,7 +1257,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: