26 Commits

Author SHA1 Message Date
9075b53be6 try sep vit 2024-04-28 11:04:26 +07:00
ab5c1d0b4b try sep vit 2024-04-28 11:00:08 +07:00
3243b1d963 try sep vit 2024-04-28 09:42:39 +07:00
37b01708b4 try sep vit 2024-04-28 01:44:33 +07:00
a246d2bb64 try sep vit 2024-04-28 01:01:31 +07:00
4a962a02ad try sep vit 2024-04-27 21:57:24 +07:00
f8e969cbd1 try swin 2024-04-27 11:52:23 +07:00
ae0f43ab4d try swin 2024-04-27 11:51:35 +07:00
dda7f13dbd try swin 2024-04-27 11:49:07 +07:00
1dd423edf0 try swin 2024-04-27 11:48:25 +07:00
a1bf2d7389 try swin 2024-04-27 11:46:32 +07:00
c31588cc5f try swin 2024-04-27 11:45:24 +07:00
c03e24f4c2 try swin 2024-04-27 11:43:15 +07:00
a47a60f6a1 try swin 2024-04-27 11:40:27 +07:00
ba388148d4 try swin 2024-04-27 11:27:38 +07:00
1b816fed50 try swin 2024-04-27 11:24:57 +07:00
32962bf421 try swin 2024-04-27 11:23:28 +07:00
b9efe68d3c try swin 2024-04-27 11:12:52 +07:00
465f98bef8 try swin 2024-04-27 11:08:46 +07:00
d4ac470c54 try swin 2024-04-27 11:07:48 +07:00
28a8352044 try swin 2024-04-27 10:59:11 +07:00
b77c79708e try swin 2024-04-27 10:56:10 +07:00
22d44d1a99 try swin 2024-04-27 10:32:08 +07:00
63ccb4ec75 try swin 2024-04-27 10:26:58 +07:00
6ec566505f try swin 2024-04-27 10:18:48 +07:00
30805a0af9 try swin 2024-04-27 10:04:41 +07:00
3 changed files with 410 additions and 22 deletions

30
main.py
View File

@ -20,6 +20,7 @@ from data_loader import TrainDataset, TestDataset
from utils import get_logger, get_combined_results, set_gpu, prepare_env, set_seed from utils import get_logger, get_combined_results, set_gpu, prepare_env, set_seed
from models import ComplEx, ConvE, HypER, InteractE, FouriER, TuckER from models import ComplEx, ConvE, HypER, InteractE, FouriER, TuckER
import traceback
class Main(object): class Main(object):
@ -477,7 +478,11 @@ class Main(object):
batch, 'train') batch, 'train')
pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy) pred = self.model.forward(sub, rel, neg_ent, self.p.train_strategy)
try:
loss = self.model.loss(pred, label, sub_samp) loss = self.model.loss(pred, label, sub_samp)
except Exception as e:
print(pred)
raise e
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
@ -715,16 +720,19 @@ if __name__ == "__main__":
model.load_model(save_path) model.load_model(save_path)
model.evaluate('test') model.evaluate('test')
else: else:
while True:
try:
model = Main(args, logger) model = Main(args, logger)
model.fit() model.fit()
except Exception as e: # while True:
print(e) # try:
try: # model = Main(args, logger)
del model # model.fit()
except Exception: # except Exception as e:
pass # print(e)
time.sleep(30) # traceback.print_exc()
continue # try:
break # del model
# except Exception:
# pass
# time.sleep(30)
# continue
# break

392
models.py
View File

@ -1,15 +1,16 @@
import torch import torch
from torch import nn from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from functools import partial from functools import partial
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, repeat
from utils import * from utils import *
from layers import * from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
class ConvE(torch.nn.Module): class ConvE(torch.nn.Module):
@ -557,6 +558,8 @@ class FouriER(torch.nn.Module):
z = self.forward_embeddings(y) z = self.forward_embeddings(y)
z = self.forward_tokens(z) z = self.forward_tokens(z)
z = z.mean([-2, -1]) z = z.mean([-2, -1])
if np.count_nonzero(np.isnan(z)) > 0:
print("ZZZ")
z = self.norm(z) z = self.norm(z)
x = self.head(z) x = self.head(z)
x = self.hidden_drop(x) x = self.hidden_drop(x)
@ -707,6 +710,363 @@ def basic_blocks(dim, index, layers,
return blocks return blocks
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, C, H, W = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
nn.ReLU(inplace=True),
nn.Linear(512, num_heads, bias=False))
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(
torch.meshgrid([relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
else:
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
self.register_buffer("relative_coords_table", relative_coords_table)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim))
self.v_bias = nn.Parameter(torch.zeros(dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, ' \
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
kernel_size = stride * 2 - 1
padding = kernel_size // 2
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)
def forward(self, x):
return self.conv(x)
class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)
def forward(self, x):
return self.proj(x) + x
# feedforward
class FeedForwardDSSA(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# attention
class DSSA(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
# window tokens
self.window_tokens = nn.Parameter(torch.randn(dim))
# prenorm and non-linearity for window tokens
# then projection to queries and keys for window tokens
self.window_tokens_to_qk = nn.Sequential(
nn.LayerNorm(dim_head),
nn.GELU(),
Rearrange('b h n c -> b (h c) n'),
nn.Conv1d(inner_dim, inner_dim * 2, 1),
Rearrange('b (h c) n -> b h n c', h = heads),
)
# window attention
self.window_attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
"""
einstein notation
b - batch
c - channels
w1 - window size (height)
w2 - also window size (width)
i - sequence dimension (source)
j - sequence dimension (target dimension to be reduced)
h - heads
x - height of feature map divided by window size
y - width of feature map divided by window size
"""
batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
# add windowing tokens
w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
x = torch.cat((w, x), dim = -1)
# project for queries, keys, value
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
# split out heads
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))
# scale
q = q * self.scale
# similarity
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# attention
attn = self.attend(dots)
# aggregate values
out = torch.matmul(attn, v)
# split out windowed tokens
window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]
# early return if there is only 1 window
if num_windows == 1:
fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
# carry out the pointwise attention, the main novelty in the paper
window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)
# windowed queries and keys (preceded by prenorm activation)
w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)
# scale
w_q = w_q * self.scale
# similarities
w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)
w_attn = self.window_attend(w_dots)
# aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before)
aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)
# fold back the windows and then combine heads for aggregation
fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
return self.to_out(fmap)
class PoolFormerBlock(nn.Module): class PoolFormerBlock(nn.Module):
""" """
@ -731,7 +1091,15 @@ class PoolFormerBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
#self.token_mixer = Pooling(pool_size=pool_size) #self.token_mixer = Pooling(pool_size=pool_size)
self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4
self.attn_heads = 4
self.attn_mask = None
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -748,16 +1116,26 @@ class PoolFormerBlock(nn.Module):
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape
# x_windows = window_partition(x, self.window_size)
# x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
# attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# x_attn = window_reverse(attn_windows, self.window_size, H, W)
x_attn = self.token_mixer(x)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x))) * x_attn)
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.mlp(self.norm2(x))) * self.mlp(self.norm2(x)))
else: else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
if np.count_nonzero(np.isnan(x)) > 0:
print("PFBlock")
return x return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" """
@ -843,7 +1221,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@ -879,7 +1257,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers:

View File

@ -2,3 +2,5 @@ torch==1.12.1+cu116
ordered-set==4.1.0 ordered-set==4.1.0
numpy==1.21.5 numpy==1.21.5
einops==0.4.1 einops==0.4.1
pandas
timm==0.9.16