6 Commits

Author SHA1 Message Date
9075b53be6 try sep vit 2024-04-28 11:04:26 +07:00
ab5c1d0b4b try sep vit 2024-04-28 11:00:08 +07:00
3243b1d963 try sep vit 2024-04-28 09:42:39 +07:00
37b01708b4 try sep vit 2024-04-28 01:44:33 +07:00
a246d2bb64 try sep vit 2024-04-28 01:01:31 +07:00
4a962a02ad try sep vit 2024-04-27 21:57:24 +07:00
2 changed files with 223 additions and 10 deletions

View File

@ -478,7 +478,11 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
loss = self.model.loss(pred, label, sub_samp) try:
loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()

227
models.py
View File

@ -1,9 +1,10 @@
import torch import torch
from torch import nn from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -557,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@ -868,6 +871,203 @@ def window_reverse(windows, window_size, H, W):
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x return x
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForwardDSSA(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
"""
einstein notation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
"""
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
# add windowing tokens
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
x = torch.cat((w, x), dim = -1)
# project for queries, keys, value
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# scale
q = q * self.scale
# similarity
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# split out windowed tokens
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
# early return if there is only 1 window
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
# windowed queries and keys (preceded by prenorm activation)
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
# scale
w_q = w_q * self.scale
# similarities
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
w_attn = self.window_attend(w_dots)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
# fold back the windows and then combine heads for aggregation
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
Implementation of one PoolFormer block. Implementation of one PoolFormer block.
@ -893,8 +1093,13 @@ class PoolFormerBlock(nn.Module):
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4 self.window_size = 4
self.attn_heads = 4
self.attn_mask = None self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -912,11 +1117,12 @@ class PoolFormerBlock(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) # x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) # attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W) # x_attn = window_reverse(attn_windows, self.window_size, H, W)
x_attn = self.token_mixer(x)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
@ -927,6 +1133,9 @@ class PoolFormerBlock(nn.Module):
else: else:
x = x + self.drop_path(x_attn) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@ -1012,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@ -1048,7 +1257,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: