Thesis/models.py
2024-04-28 15:40:31 +07:00

1392 lines
54 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from utils import *
from layers import *
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module):
def __init__(self, params, ):
super(ConvE, self).__init__()
self.p = params
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.in_channels = self.p.in_channels
self.out_channels = self.p.out_channels
self.bceloss = torch.nn.BCELoss()
self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
self.conv1 = torch.nn.Conv2d(
self.in_channels, self.out_channels, (self.p.filt_h, self.p.filt_w), 1, 0, bias=True)
self.bn0 = torch.nn.BatchNorm2d(self.in_channels)
self.bn1 = torch.nn.BatchNorm2d(self.out_channels)
self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
fc_length = (20-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels
self.fc = torch.nn.Linear(fc_length, self.p.embed_dim)
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_embed(sub).view(-1, 1, self.p.k_w, self.p.k_h)
rel_emb = self.rel_embed(rel).view(-1, 1, self.p.k_w, self.p.k_h)
x = torch.cat([sub_emb, rel_emb], 2)
x = self.bn0(x)
x = self.inp_drop(x)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_map_drop(x)
x = x.view(sub_emb.size(0), -1)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
if strategy == 'one_to_n':
x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class HypER(torch.nn.Module):
def __init__(self, params, ):
super(HypER, self).__init__()
self.p = params
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.in_channels = self.p.in_channels
self.out_channels = self.p.out_channels
self.bceloss = torch.nn.BCELoss()
self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
self.bn0 = torch.nn.BatchNorm2d(self.in_channels)
self.bn1 = torch.nn.BatchNorm2d(self.out_channels)
self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
fc_length = (1-self.p.filt_h+1)*(self.p.embed_dim -
self.p.filt_w+1)*self.out_channels
self.fc = torch.nn.Linear(fc_length, self.p.embed_dim)
fc1_length = self.in_channels*self.out_channels*self.p.filt_h*self.p.filt_w
self.fc1 = torch.nn.Linear(self.p.embed_dim, fc1_length)
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_embed(
sub).view(-1, 1, 1, self.ent_embed.weight.size(1))
rel_emb = self.rel_embed(rel)
x = self.bn0(sub_emb)
x = self.inp_drop(x)
k = self.fc1(rel_emb)
k = k.view(-1, self.in_channels, self.out_channels,
self.p.filt_h, self.p.filt_w)
k = k.view(sub_emb.size(0)*self.in_channels *
self.out_channels, 1, self.p.filt_h, self.p.filt_w)
x = x.permute(1, 0, 2, 3)
x = F.conv2d(x, k, groups=sub_emb.size(0))
x = x.view(sub_emb.size(0), 1, self.out_channels, 1 -
self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1)
x = x.permute(0, 3, 4, 1, 2)
x = torch.sum(x, dim=3)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.bn1(x)
x = self.feature_map_drop(x)
x = x.view(sub_emb.size(0), -1)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
if strategy == 'one_to_n':
x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class HypE(torch.nn.Module):
def __init__(self, params, ):
super(HypE, self).__init__()
self.p = params
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.in_channels = self.p.in_channels
self.out_channels = self.p.out_channels
self.bceloss = torch.nn.BCELoss()
self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
self.bn0 = torch.nn.BatchNorm2d(self.in_channels)
self.bn1 = torch.nn.BatchNorm2d(self.out_channels)
self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
fc_length = (10-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels
self.fc = torch.nn.Linear(fc_length, self.p.embed_dim)
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_embed(
sub).view(-1, 1, 1, self.ent_embed.weight.size(1))
rel_emb = self.rel_embed(rel)
x = self.bn0(sub_emb)
x = self.inp_drop(x)
k = self.fc1(rel_emb)
k = k.view(-1, self.in_channels, self.out_channels,
self.p.filt_h, self.p.filt_w)
k = k.view(sub_emb.size(0)*self.in_channels *
self.out_channels, 1, self.p.filt_h, self.p.filt_w)
x = x.permute(1, 0, 2, 3)
x = F.conv2d(x, k, groups=sub_emb.size(0))
x = x.view(sub_emb.size(0), 1, self.out_channels, 1 -
self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1)
x = x.permute(0, 3, 4, 1, 2)
x = torch.sum(x, dim=3)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.bn1(x)
x = self.feature_map_drop(x)
x = x.view(sub_emb.size(0), -1)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
if strategy == 'one_to_n':
x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class DistMult(torch.nn.Module):
def __init__(self, params, ):
super(DistMult, self).__init__()
self.p = params
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.bceloss = torch.nn.BCELoss()
self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_embed(sub)
rel_emb = self.rel_embed(rel)
sub_emb = self.bn0(sub_emb)
sub_emb = self.inp_drop(sub_emb)
if strategy == 'one_to_n':
x = torch.mm(sub_emb * rel_emb,
self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul((sub_emb * rel_emb).unsqueeze(1),
self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class ComplEx(torch.nn.Module):
def __init__(self, params, ):
super(ComplEx, self).__init__()
self.p = params
self.ent_embed_real = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed_real.weight)
self.ent_embed_imaginary = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed_imaginary.weight)
self.rel_embed_real = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed_real.weight)
self.rel_embed_imaginary = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed_imaginary.weight)
self.bceloss = torch.nn.BCELoss()
self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb_real = self.ent_embed_real(sub)
sub_emb_imaginary = self.ent_embed_imaginary(sub)
rel_emb_real = self.rel_embed_real(rel)
rel_emb_imaginary = self.rel_embed_imaginary(rel)
sub_emb_real = self.bn0(sub_emb_real)
sub_emb_real = self.inp_drop(sub_emb_real)
sub_emb_imaginary = self.bn0(sub_emb_imaginary)
sub_emb_imaginary = self.inp_drop(sub_emb_imaginary)
if strategy == 'one_to_n':
x = torch.mm(sub_emb_real*rel_emb_real, self.ent_embed_real.weight.transpose(1, 0)) +\
torch.mm(sub_emb_real*rel_emb_imaginary, self.ent_embed_imaginary.weight.transpose(1, 0)) +\
torch.mm(sub_emb_imaginary*rel_emb_real, self.ent_embed_imaginary.weight.transpose(1, 0)) -\
torch.mm(sub_emb_imaginary*rel_emb_imaginary,
self.ent_embed_real.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
neg_embs_real = self.ent_embed_real(neg_ents)
neg_embs_imaginary = self.ent_embed_imaginary(neg_ents)
x = (torch.mul((sub_emb_real*rel_emb_real).unsqueeze(1), neg_embs_real) +
torch.mul((sub_emb_real*rel_emb_imaginary).unsqueeze(1), neg_embs_imaginary) +
torch.mul((sub_emb_imaginary*rel_emb_real).unsqueeze(1), neg_embs_imaginary) -
torch.mul((sub_emb_imaginary*rel_emb_imaginary).unsqueeze(1), neg_embs_real)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class TuckER(torch.nn.Module):
def __init__(self, params, ):
super(TuckER, self).__init__()
self.p = params
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.core_W = torch.nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (self.p.embed_dim,
self.p.embed_dim, self.p.embed_dim)), dtype=torch.float, device="cuda", requires_grad=True))
self.bceloss = torch.nn.BCELoss()
self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
self.hidden_drop1 = torch.nn.Dropout(self.p.hid_drop)
self.hidden_drop2 = torch.nn.Dropout(self.p.hid_drop)
self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_embed(sub)
sub_emb = self.bn0(sub_emb)
sub_emb = self.inp_drop(sub_emb)
x = sub_emb.view(-1, 1, sub_emb.size(1))
r = self.rel_embed(rel)
W_mat = torch.mm(r, self.core_W.view(r.size(1), -1))
W_mat = W_mat.view(-1, sub_emb.size(1), sub_emb.size(1))
W_mat = self.hidden_drop1(W_mat)
x = torch.bmm(x, W_mat)
x = x.view(-1, sub_emb.size(1))
x = self.bn1(x)
x = self.hidden_drop2(x)
if strategy == 'one_to_n':
x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul(x.unsqueeze(1),
self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class FouriER(torch.nn.Module):
def __init__(self, params, hid_drop = None, embed_dim = None):
super(FouriER, self).__init__()
if hid_drop is not None:
self.p.hid_drop = hid_drop
if embed_dim is not None:
self.p.ent_vec_dim = embed_dim
self.p.rel_vec_dim = embed_dim
self.p.embed_dim = embed_dim
self.p = params
image_h, image_w = self.p.image_h, self.p.image_w
self.in_channels = self.p.in_channels
self.out_channels = self.p.out_channels
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.ent_vec_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.rel_vec_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.ent_fusion = torch.nn.Linear(
self.p.ent_vec_dim, image_h*image_w)
torch.nn.init.xavier_normal_(self.ent_fusion.weight)
self.rel_fusion = torch.nn.Linear(
self.p.rel_vec_dim, image_h*image_w)
torch.nn.init.xavier_normal_(self.rel_fusion.weight)
self.bceloss = torch.nn.BCELoss()
channels = 2
self.bn0 = torch.nn.BatchNorm2d(channels)
self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
patch_size = self.p.patch_size
assert (image_h % patch_size) == 0 and (image_w %
patch_size) == 0, 'image must be divisible by patch size'
self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size,
embed_dim=self.p.embed_dim, stride=4, padding=2)
network = []
layers = [4, 4, 12, 4]
embed_dims = [self.p.embed_dim, 128, 320, 128]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
pool_size=3
act_layer=nn.GELU
drop_rate=self.p.drop
norm_layer=GroupNorm
drop_path_rate=self.p.drop_path
use_layer_scale=True
layer_scale_init_value=1e-5
down_patch_size=3
down_stride=2
down_pad=1
num_classes=self.p.embed_dim
for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers,
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
act_layer=act_layer, norm_layer=norm_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value)
network.append(stage)
if i >= len(layers) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size, stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
)
)
self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
outs = []
B, C, H, W = x.shape
N = H*W
if self.graph_type in ["Semantic", "Mixed"]:
# Generate the semantic graph w.r.t. the cosine similarity between tokens
# Compute cosine similarity
if self.class_token:
x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
else:
x_normed = x / x.norm(dim=-1, keepdim=True)
x_cossim = x_normed @ x_normed.transpose(-1, -2)
threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1
semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
if self.class_token:
semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
else:
semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
if self.graph_type == "None":
graph = None
else:
if self.graph_type == "Spatial":
graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
elif self.graph_type == "Semantic":
graph = semantic_graph
elif self.graph_type == "Mixed":
# Integrate the spatial graph and semantic graph
spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
# Symmetrically normalize the graph
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network):
try:
x = block(x, graph)
except:
x = block(x)
# output only the features of last layer for image classification
return x
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_fusion(self.ent_embed(sub))
rel_emb = self.rel_fusion(self.rel_embed(rel))
comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1)
y = comb_emb.view(-1, 2, self.p.image_h, self.p.image_w)
y = self.bn0(y)
z = self.forward_embeddings(y)
z = self.forward_tokens(z)
z = z.mean([-2, -1])
z = self.norm(z)
x = self.head(z)
x = self.hidden_drop(x)
x = self.bn1(x)
x = F.relu(x)
if strategy == 'one_to_n':
x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul(x.unsqueeze(1),
self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class InteractE(torch.nn.Module):
"""
Proposed method in the paper. Refer Section 6 of the paper for mode details
Parameters
----------
params: Hyperparameters of the model
chequer_perm: Reshaping to be used by the model
Returns
-------
The InteractE model instance
"""
def __init__(self, params, chequer_perm):
super(InteractE, self).__init__()
self.p = params
self.ent_embed = torch.nn.Embedding(
self.p.num_ent, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.ent_embed.weight)
self.rel_embed = torch.nn.Embedding(
self.p.num_rel*2, self.p.embed_dim, padding_idx=None)
torch.nn.init.xavier_normal_(self.rel_embed.weight)
self.bceloss = torch.nn.BCELoss()
self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
self.bn0 = torch.nn.BatchNorm2d(self.p.perm)
flat_sz_h = self.p.k_h
flat_sz_w = 2*self.p.k_w
self.padding = 0
self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt*self.p.perm)
self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt*self.p.perm
self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim)
self.chequer_perm = chequer_perm
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent)))
self.register_parameter('conv_filt', torch.nn.Parameter(
torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz)))
torch.nn.init.xavier_normal_(self.conv_filt)
def loss(self, pred, true_label=None, sub_samp=None):
label_pos = true_label[0]
label_neg = true_label[1:]
loss = self.bceloss(pred, true_label)
return loss
def circular_padding_chw(self, batch, padding):
upper_pad = batch[..., -padding:, :]
lower_pad = batch[..., :padding, :]
temp = torch.cat([upper_pad, batch, lower_pad], dim=2)
left_pad = temp[..., -padding:]
right_pad = temp[..., :padding]
padded = torch.cat([left_pad, temp, right_pad], dim=3)
return padded
def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
sub_emb = self.ent_embed(sub)
rel_emb = self.rel_embed(rel)
comb_emb = torch.cat([sub_emb, rel_emb], dim=1)
chequer_perm = comb_emb[:, self.chequer_perm]
stack_inp = chequer_perm.reshape(
(-1, self.p.perm, 2*self.p.k_w, self.p.k_h))
stack_inp = self.bn0(stack_inp)
x = self.inp_drop(stack_inp)
x = self.circular_padding_chw(x, self.p.ker_sz//2)
x = F.conv2d(x, self.conv_filt.repeat(self.p.perm, 1, 1, 1),
padding=self.padding, groups=self.p.perm)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_map_drop(x)
x = x.view(-1, self.flat_sz)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
if strategy == 'one_to_n':
x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
x += self.bias.expand_as(x)
else:
x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1)
x += self.bias[neg_ents]
pred = torch.sigmoid(x)
return pred
class GroupNorm(nn.GroupNorm):
"""
Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
def basic_blocks(dim, index, layers,
pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm,
drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
"""
generate PoolFormer blocks for a stage
return: PoolFormer blocks
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (
block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PoolFormerBlock(
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
act_layer=act_layer, norm_layer=norm_layer,
drop=drop_rate, drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
blocks = SeqModel(*blocks)
return blocks
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, C, H, W = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
nn.ReLU(inplace=True),
nn.Linear(512, num_heads, bias=False))
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(
torch.meshgrid([relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
else:
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
self.register_buffer("relative_coords_table", relative_coords_table)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim))
self.v_bias = nn.Parameter(torch.zeros(dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, ' \
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
return x
class SeqModel(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
def propagate(x: torch.Tensor, weight: torch.Tensor,
index_kept: torch.Tensor, index_prop: torch.Tensor,
standard: str = "None", alpha: Optional[float] = 0,
token_scales: Optional[torch.Tensor] = None,
cls_token=True):
"""
Propagate tokens based on the selection results.
================================================
Args:
- x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
- weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens,
excluding the [CLS] token. weight could be a pre-defined
graph of the current feature map (by default) or the
attention map (need to manually modify the Block Module).
- index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
- index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
- standard: str: the method applied to propagate the tokens, including "None", "Mean" and
"GraphProp"
- alpha: float: the coefficient of propagated features
- token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
is None by default. If it is not None, then token_scales
represents the scales of each token and should sum up to N.
Return:
- x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
- weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
"""
B, N, C = x.shape
# Step 1: divide tokens
if cls_token:
x_cls = x[:, 0:1] # B, 1, C
x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
# Step 2: divide token_scales if it is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
# Step 3: propagate tokens
if standard == "None":
"""
No further propagation
"""
pass
elif standard == "Mean":
"""
Calculate the mean of all the propagated tokens,
and concatenate the result token back to kept tokens.
"""
# naive average
x_prop = x_prop.mean(1, keepdim=True) # B, 1, C
# Concatenate the average token
x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C
elif standard == "GraphProp":
"""
Propagate all the propagated token to kept token
with respect to the weights and token scales.
"""
assert weight is not None, "The graph weight is needed for graph propagation"
# Step 3.1: divide propagation weights.
if cls_token:
index_kept = index_kept - 1 # since weights do not include the [CLS] token
index_prop = index_prop - 1 # since weights do not include the [CLS] token
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
else:
weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
# Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
# Simple implementation
x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
""" scatter_reduce implementation for batched inputs
# Get the non-zero values
non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
non_zero_values = weight_prop[non_zero_indices]
# Sparse multiplication
batch_indices, row_indices, col_indices = non_zero_indices
sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
reduce_indices = batch_indices * x_kept.shape[1] + row_indices
x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0,
index=reduce_indices[:, None],
src=sparse_matmul,
reduce="sum",
include_self=True)
x_kept = x_kept.reshape(B, -1, C)
"""
# Step 3.3: calculate the scale of each token if token_scales is not None
if token_scales is not None:
if cls_token:
token_scales_cls = token_scales[:, 0:1] # B, 1
token_scales = token_scales[:, 1:]
token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
if cls_token:
token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
else:
assert False, "Propagation method \'%f\' has not been supported yet." % standard
if cls_token:
# Step 4 concatenate the [CLS] token and generate returned value
x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
else:
x = x_kept
return x, weight, token_scales
def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True):
"""
Select image tokens to be propagated. The [CLS] token will be ignored.
======================================================================
Args:
- weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
attention map of tokens at the moment.
- standard: str: the method applied to select the tokens
- num_prop: int: the number of tokens to be propagated
Return:
- index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens
- index_prop: Tensor([B, num_prop]): the index of propagated tokens
"""
assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet."
B, H, N1, N2 = weight.shape
assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet."
N = N1
assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative."
if cls_token:
if standard == "CLSAttnMean":
token_rank = weight[:,:,0,1:].mean(1)
elif standard == "CLSAttnMax":
token_rank = weight[:,:,0,1:].max(1)[0]
elif standard == "IMGAttnMean":
token_rank = weight[:,:,:,1:].sum(-2).mean(1)
elif standard == "IMGAttnMax":
token_rank = weight[:,:,:,1:].sum(-2).max(1)[0]
elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
elif standard == "DiagAttnMax":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean":
weight = weight[:,:,1:,:].mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax":
weight = weight[:,:,1:,:].max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
else:
if standard == "IMGAttnMean":
token_rank = weight.sum(-2).mean(1)
elif standard == "IMGAttnMax":
token_rank = weight.sum(-2).max(1)[0]
elif standard == "DiagAttnMean":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
elif standard == "DiagAttnMax":
token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
elif standard == "MixedAttnMean":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
token_rank_2 = weight.sum(-2).mean(1)
token_rank = token_rank_1 * token_rank_2
elif standard == "MixedAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 * token_rank_2
elif standard == "SumAttnMax":
token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
token_rank_2 = weight.sum(-2).max(1)[0]
token_rank = token_rank_1 + token_rank_2
elif standard == "CosSimMean":
weight = weight.mean(1)
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "CosSimMax":
weight = weight.max(1)[0]
weight = weight / weight.norm(dim=-1, keepdim=True)
token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
elif standard == "Random":
token_rank = torch.randn((B, N-1), device=weight.device)
else:
print("Type\'", standard, "\' selection not supported.")
assert False
token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
index_prop = token_rank[:, -num_prop:] # B, num_prop
return index_kept, index_prop
class PoolFormerBlock(nn.Module):
"""
Implementation of one PoolFormer block.
--dim: embedding dim
--pool_size: pooling size
--mlp_ratio: mlp expansion ratio
--act_layer: activation
--norm_layer: normalization
--drop: dropout rate
--drop path: Stochastic Depth,
refer to https://arxiv.org/abs/1603.09382
--use_layer_scale, --layer_scale_init_value: LayerScale,
refer to https://arxiv.org/abs/2103.17239
"""
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = norm_layer(dim)
#self.token_mixer = Pooling(pool_size=pool_size)
# self.token_mixer = FNetBlock()
self.window_size = 4
self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, weight, token_scales = None):
B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W)
index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
cls_token=False)
original_shape = x_attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* x_attn)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(x_attn)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""
Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, patch_size=16, stride=16, padding=0,
in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
stride=stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class Pooling(nn.Module):
"""
Implementation of pooling for PoolFormer
--pool_size: pooling size
"""
def __init__(self, pool_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
def forward(self, x):
return self.pool(x) - x
class Mlp(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LayerNormChannel(nn.Module):
"""
LayerNorm only for Channel Dimension.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, eps=1e-05):
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
+ self.bias.unsqueeze(-1).unsqueeze(-1)
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FNetBlock(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real
return x
class FNet(nn.Module):
def __init__(self, dim, depth, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x