import torch from torch import nn import torch.nn.functional as F import numpy as np from functools import partial from einops.layers.torch import Rearrange, Reduce from utils import * from layers import * from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropPath, trunc_normal_ from timm.models.registry import register_model from timm.models.layers.helpers import to_2tuple class ConvE(torch.nn.Module): def __init__(self, params, ): super(ConvE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.conv1 = torch.nn.Conv2d( self.in_channels, self.out_channels, (self.p.filt_h, self.p.filt_w), 1, 0, bias=True) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (20-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub).view(-1, 1, self.p.k_w, self.p.k_h) rel_emb = self.rel_embed(rel).view(-1, 1, self.p.k_w, self.p.k_h) x = torch.cat([sub_emb, rel_emb], 2) x = self.bn0(x) x = self.inp_drop(x) x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class HypER(torch.nn.Module): def __init__(self, params, ): super(HypER, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (1-self.p.filt_h+1)*(self.p.embed_dim - self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) fc1_length = self.in_channels*self.out_channels*self.p.filt_h*self.p.filt_w self.fc1 = torch.nn.Linear(self.p.embed_dim, fc1_length) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed( sub).view(-1, 1, 1, self.ent_embed.weight.size(1)) rel_emb = self.rel_embed(rel) x = self.bn0(sub_emb) x = self.inp_drop(x) k = self.fc1(rel_emb) k = k.view(-1, self.in_channels, self.out_channels, self.p.filt_h, self.p.filt_w) k = k.view(sub_emb.size(0)*self.in_channels * self.out_channels, 1, self.p.filt_h, self.p.filt_w) x = x.permute(1, 0, 2, 3) x = F.conv2d(x, k, groups=sub_emb.size(0)) x = x.view(sub_emb.size(0), 1, self.out_channels, 1 - self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1) x = x.permute(0, 3, 4, 1, 2) x = torch.sum(x, dim=3) x = x.permute(0, 3, 1, 2).contiguous() x = self.bn1(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class HypE(torch.nn.Module): def __init__(self, params, ): super(HypE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (10-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed( sub).view(-1, 1, 1, self.ent_embed.weight.size(1)) rel_emb = self.rel_embed(rel) x = self.bn0(sub_emb) x = self.inp_drop(x) k = self.fc1(rel_emb) k = k.view(-1, self.in_channels, self.out_channels, self.p.filt_h, self.p.filt_w) k = k.view(sub_emb.size(0)*self.in_channels * self.out_channels, 1, self.p.filt_h, self.p.filt_w) x = x.permute(1, 0, 2, 3) x = F.conv2d(x, k, groups=sub_emb.size(0)) x = x.view(sub_emb.size(0), 1, self.out_channels, 1 - self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1) x = x.permute(0, 3, 4, 1, 2) x = torch.sum(x, dim=3) x = x.permute(0, 3, 1, 2).contiguous() x = self.bn1(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class DistMult(torch.nn.Module): def __init__(self, params, ): super(DistMult, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) rel_emb = self.rel_embed(rel) sub_emb = self.bn0(sub_emb) sub_emb = self.inp_drop(sub_emb) if strategy == 'one_to_n': x = torch.mm(sub_emb * rel_emb, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul((sub_emb * rel_emb).unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class ComplEx(torch.nn.Module): def __init__(self, params, ): super(ComplEx, self).__init__() self.p = params self.ent_embed_real = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed_real.weight) self.ent_embed_imaginary = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed_imaginary.weight) self.rel_embed_real = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed_real.weight) self.rel_embed_imaginary = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed_imaginary.weight) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb_real = self.ent_embed_real(sub) sub_emb_imaginary = self.ent_embed_imaginary(sub) rel_emb_real = self.rel_embed_real(rel) rel_emb_imaginary = self.rel_embed_imaginary(rel) sub_emb_real = self.bn0(sub_emb_real) sub_emb_real = self.inp_drop(sub_emb_real) sub_emb_imaginary = self.bn0(sub_emb_imaginary) sub_emb_imaginary = self.inp_drop(sub_emb_imaginary) if strategy == 'one_to_n': x = torch.mm(sub_emb_real*rel_emb_real, self.ent_embed_real.weight.transpose(1, 0)) +\ torch.mm(sub_emb_real*rel_emb_imaginary, self.ent_embed_imaginary.weight.transpose(1, 0)) +\ torch.mm(sub_emb_imaginary*rel_emb_real, self.ent_embed_imaginary.weight.transpose(1, 0)) -\ torch.mm(sub_emb_imaginary*rel_emb_imaginary, self.ent_embed_real.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: neg_embs_real = self.ent_embed_real(neg_ents) neg_embs_imaginary = self.ent_embed_imaginary(neg_ents) x = (torch.mul((sub_emb_real*rel_emb_real).unsqueeze(1), neg_embs_real) + torch.mul((sub_emb_real*rel_emb_imaginary).unsqueeze(1), neg_embs_imaginary) + torch.mul((sub_emb_imaginary*rel_emb_real).unsqueeze(1), neg_embs_imaginary) - torch.mul((sub_emb_imaginary*rel_emb_imaginary).unsqueeze(1), neg_embs_real)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class TuckER(torch.nn.Module): def __init__(self, params, ): super(TuckER, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.core_W = torch.nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (self.p.embed_dim, self.p.embed_dim, self.p.embed_dim)), dtype=torch.float, device="cuda", requires_grad=True)) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop1 = torch.nn.Dropout(self.p.hid_drop) self.hidden_drop2 = torch.nn.Dropout(self.p.hid_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) sub_emb = self.bn0(sub_emb) sub_emb = self.inp_drop(sub_emb) x = sub_emb.view(-1, 1, sub_emb.size(1)) r = self.rel_embed(rel) W_mat = torch.mm(r, self.core_W.view(r.size(1), -1)) W_mat = W_mat.view(-1, sub_emb.size(1), sub_emb.size(1)) W_mat = self.hidden_drop1(W_mat) x = torch.bmm(x, W_mat) x = x.view(-1, sub_emb.size(1)) x = self.bn1(x) x = self.hidden_drop2(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class FouriER(torch.nn.Module): def __init__(self, params, hid_drop = None, embed_dim = None): super(FouriER, self).__init__() if hid_drop is not None: self.p.hid_drop = hid_drop if embed_dim is not None: self.p.ent_vec_dim = embed_dim self.p.rel_vec_dim = embed_dim self.p.embed_dim = embed_dim self.p = params image_h, image_w = self.p.image_h, self.p.image_w self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.ent_vec_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.rel_vec_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.ent_fusion = torch.nn.Linear( self.p.ent_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.ent_fusion.weight) self.rel_fusion = torch.nn.Linear( self.p.rel_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.rel_fusion.weight) self.bceloss = torch.nn.BCELoss() channels = 2 self.bn0 = torch.nn.BatchNorm2d(channels) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) patch_size = self.p.patch_size assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size' self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size, embed_dim=self.p.embed_dim, stride=4, padding=2) network = [] layers = [4, 4, 12, 4] embed_dims = [self.p.embed_dim, 128, 320, 128] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pool_size=3 act_layer=nn.GELU drop_rate=self.p.drop norm_layer=GroupNorm drop_path_rate=self.p.drop_path use_layer_scale=True layer_scale_init_value=1e-5 down_patch_size=3 down_stride=2 down_pad=1 num_classes=self.p.embed_dim for i in range(len(layers)): stage = basic_blocks(embed_dims[i], i, layers, pool_size=pool_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value) network.append(stage) if i >= len(layers) - 1: break if downsamples[i] or embed_dims[i] != embed_dims[i+1]: # downsampling between two stages network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, padding=down_pad, in_chans=embed_dims[i], embed_dim=embed_dims[i+1] ) ) self.network = nn.ModuleList(network) self.norm = norm_layer(embed_dims[-1]) self.head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 \ else nn.Identity() def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward_embeddings(self, x): x = self.patch_embed(x) return x def forward_tokens(self, x): outs = [] for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification return x def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_fusion(self.ent_embed(sub)) rel_emb = self.rel_fusion(self.rel_embed(rel)) comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) y = comb_emb.view(-1, 2, self.p.image_h, self.p.image_w) y = self.bn0(y) z = self.forward_embeddings(y) z = self.forward_tokens(z) z = z.mean([-2, -1]) z = self.norm(z) x = self.head(z) x = self.hidden_drop(x) x = self.bn1(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class InteractE(torch.nn.Module): """ Proposed method in the paper. Refer Section 6 of the paper for mode details Parameters ---------- params: Hyperparameters of the model chequer_perm: Reshaping to be used by the model Returns ------- The InteractE model instance """ def __init__(self, params, chequer_perm): super(InteractE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.bceloss = torch.nn.BCELoss() self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.p.perm) flat_sz_h = self.p.k_h flat_sz_w = 2*self.p.k_w self.padding = 0 self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt*self.p.perm) self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt*self.p.perm self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim) self.chequer_perm = chequer_perm self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) self.register_parameter('conv_filt', torch.nn.Parameter( torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz))) torch.nn.init.xavier_normal_(self.conv_filt) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def circular_padding_chw(self, batch, padding): upper_pad = batch[..., -padding:, :] lower_pad = batch[..., :padding, :] temp = torch.cat([upper_pad, batch, lower_pad], dim=2) left_pad = temp[..., -padding:] right_pad = temp[..., :padding] padded = torch.cat([left_pad, temp, right_pad], dim=3) return padded def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) rel_emb = self.rel_embed(rel) comb_emb = torch.cat([sub_emb, rel_emb], dim=1) chequer_perm = comb_emb[:, self.chequer_perm] stack_inp = chequer_perm.reshape( (-1, self.p.perm, 2*self.p.k_w, self.p.k_h)) stack_inp = self.bn0(stack_inp) x = self.inp_drop(stack_inp) x = self.circular_padding_chw(x, self.p.ker_sz//2) x = F.conv2d(x, self.conv_filt.repeat(self.p.perm, 1, 1, 1), padding=self.padding, groups=self.p.perm) x = self.bn1(x) x = F.relu(x) x = self.feature_map_drop(x) x = x.view(-1, self.flat_sz) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class GroupNorm(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) def basic_blocks(dim, index, layers, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop_rate=.0, drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5): """ generate PoolFormer blocks for a stage return: PoolFormer blocks """ blocks = [] for block_idx in range(layers[index]): block_dpr = drop_path_rate * ( block_idx + sum(layers[:index])) / (sum(layers) - 1) blocks.append(PoolFormerBlock( dim, pool_size=pool_size, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, )) blocks = nn.Sequential(*blocks) return blocks class PoolFormerBlock(nn.Module): """ Implementation of one PoolFormer block. --dim: embedding dim --pool_size: pooling size --mlp_ratio: mlp expansion ratio --act_layer: activation --norm_layer: normalization --drop: dropout rate --drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382 --use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239 """ def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = norm_layer(dim) #self.token_mixer = Pooling(pool_size=pool_size) self.token_mixer = FNetBlock() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Patch Embedding that is implemented by a layer of conv. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H/stride, W/stride] """ def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = self.norm(x) return x class Pooling(nn.Module): """ Implementation of pooling for PoolFormer --pool_size: pooling size """ def __init__(self, pool_size=3): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size//2, count_include_pad=False) def forward(self, x): return self.pool(x) - x class Mlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LayerNormChannel(nn.Module): """ LayerNorm only for Channel Dimension. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, eps=1e-05): super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ + self.bias.unsqueeze(-1).unsqueeze(-1) return x class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FNetBlock(nn.Module): def __init__(self): super().__init__() def forward(self, x): x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real return x class FNet(nn.Module): def __init__(self, dim, depth, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, FNetBlock()), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x