import torch from torch import nn, einsum import torch.nn.functional as F import numpy as np from functools import partial from einops.layers.torch import Rearrange, Reduce from einops import rearrange, repeat from utils import * from layers import * from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropPath, trunc_normal_ from timm.models.registry import register_model from timm.layers.helpers import to_2tuple class ConvE(torch.nn.Module): def __init__(self, params, ): super(ConvE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.conv1 = torch.nn.Conv2d( self.in_channels, self.out_channels, (self.p.filt_h, self.p.filt_w), 1, 0, bias=True) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (20-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub).view(-1, 1, self.p.k_w, self.p.k_h) rel_emb = self.rel_embed(rel).view(-1, 1, self.p.k_w, self.p.k_h) x = torch.cat([sub_emb, rel_emb], 2) x = self.bn0(x) x = self.inp_drop(x) x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class HypER(torch.nn.Module): def __init__(self, params, ): super(HypER, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (1-self.p.filt_h+1)*(self.p.embed_dim - self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) fc1_length = self.in_channels*self.out_channels*self.p.filt_h*self.p.filt_w self.fc1 = torch.nn.Linear(self.p.embed_dim, fc1_length) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed( sub).view(-1, 1, 1, self.ent_embed.weight.size(1)) rel_emb = self.rel_embed(rel) x = self.bn0(sub_emb) x = self.inp_drop(x) k = self.fc1(rel_emb) k = k.view(-1, self.in_channels, self.out_channels, self.p.filt_h, self.p.filt_w) k = k.view(sub_emb.size(0)*self.in_channels * self.out_channels, 1, self.p.filt_h, self.p.filt_w) x = x.permute(1, 0, 2, 3) x = F.conv2d(x, k, groups=sub_emb.size(0)) x = x.view(sub_emb.size(0), 1, self.out_channels, 1 - self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1) x = x.permute(0, 3, 4, 1, 2) x = torch.sum(x, dim=3) x = x.permute(0, 3, 1, 2).contiguous() x = self.bn1(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class HypE(torch.nn.Module): def __init__(self, params, ): super(HypE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (10-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed( sub).view(-1, 1, 1, self.ent_embed.weight.size(1)) rel_emb = self.rel_embed(rel) x = self.bn0(sub_emb) x = self.inp_drop(x) k = self.fc1(rel_emb) k = k.view(-1, self.in_channels, self.out_channels, self.p.filt_h, self.p.filt_w) k = k.view(sub_emb.size(0)*self.in_channels * self.out_channels, 1, self.p.filt_h, self.p.filt_w) x = x.permute(1, 0, 2, 3) x = F.conv2d(x, k, groups=sub_emb.size(0)) x = x.view(sub_emb.size(0), 1, self.out_channels, 1 - self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1) x = x.permute(0, 3, 4, 1, 2) x = torch.sum(x, dim=3) x = x.permute(0, 3, 1, 2).contiguous() x = self.bn1(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class DistMult(torch.nn.Module): def __init__(self, params, ): super(DistMult, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) rel_emb = self.rel_embed(rel) sub_emb = self.bn0(sub_emb) sub_emb = self.inp_drop(sub_emb) if strategy == 'one_to_n': x = torch.mm(sub_emb * rel_emb, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul((sub_emb * rel_emb).unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class ComplEx(torch.nn.Module): def __init__(self, params, ): super(ComplEx, self).__init__() self.p = params self.ent_embed_real = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed_real.weight) self.ent_embed_imaginary = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed_imaginary.weight) self.rel_embed_real = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed_real.weight) self.rel_embed_imaginary = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed_imaginary.weight) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb_real = self.ent_embed_real(sub) sub_emb_imaginary = self.ent_embed_imaginary(sub) rel_emb_real = self.rel_embed_real(rel) rel_emb_imaginary = self.rel_embed_imaginary(rel) sub_emb_real = self.bn0(sub_emb_real) sub_emb_real = self.inp_drop(sub_emb_real) sub_emb_imaginary = self.bn0(sub_emb_imaginary) sub_emb_imaginary = self.inp_drop(sub_emb_imaginary) if strategy == 'one_to_n': x = torch.mm(sub_emb_real*rel_emb_real, self.ent_embed_real.weight.transpose(1, 0)) +\ torch.mm(sub_emb_real*rel_emb_imaginary, self.ent_embed_imaginary.weight.transpose(1, 0)) +\ torch.mm(sub_emb_imaginary*rel_emb_real, self.ent_embed_imaginary.weight.transpose(1, 0)) -\ torch.mm(sub_emb_imaginary*rel_emb_imaginary, self.ent_embed_real.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: neg_embs_real = self.ent_embed_real(neg_ents) neg_embs_imaginary = self.ent_embed_imaginary(neg_ents) x = (torch.mul((sub_emb_real*rel_emb_real).unsqueeze(1), neg_embs_real) + torch.mul((sub_emb_real*rel_emb_imaginary).unsqueeze(1), neg_embs_imaginary) + torch.mul((sub_emb_imaginary*rel_emb_real).unsqueeze(1), neg_embs_imaginary) - torch.mul((sub_emb_imaginary*rel_emb_imaginary).unsqueeze(1), neg_embs_real)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class TuckER(torch.nn.Module): def __init__(self, params, ): super(TuckER, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.core_W = torch.nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (self.p.embed_dim, self.p.embed_dim, self.p.embed_dim)), dtype=torch.float, device="cuda", requires_grad=True)) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop1 = torch.nn.Dropout(self.p.hid_drop) self.hidden_drop2 = torch.nn.Dropout(self.p.hid_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) sub_emb = self.bn0(sub_emb) sub_emb = self.inp_drop(sub_emb) x = sub_emb.view(-1, 1, sub_emb.size(1)) r = self.rel_embed(rel) W_mat = torch.mm(r, self.core_W.view(r.size(1), -1)) W_mat = W_mat.view(-1, sub_emb.size(1), sub_emb.size(1)) W_mat = self.hidden_drop1(W_mat) x = torch.bmm(x, W_mat) x = x.view(-1, sub_emb.size(1)) x = self.bn1(x) x = self.hidden_drop2(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class FouriER(torch.nn.Module): def __init__(self, params, hid_drop = None, embed_dim = None): super(FouriER, self).__init__() if hid_drop is not None: self.p.hid_drop = hid_drop if embed_dim is not None: self.p.ent_vec_dim = embed_dim self.p.rel_vec_dim = embed_dim self.p.embed_dim = embed_dim self.p = params image_h, image_w = self.p.image_h, self.p.image_w self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.ent_vec_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.rel_vec_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.ent_fusion = torch.nn.Linear( self.p.ent_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.ent_fusion.weight) self.rel_fusion = torch.nn.Linear( self.p.rel_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.rel_fusion.weight) self.bceloss = torch.nn.BCELoss() channels = 2 self.bn0 = torch.nn.BatchNorm2d(channels) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) patch_size = self.p.patch_size assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size' self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size, embed_dim=self.p.embed_dim, stride=4, padding=2) network = [] layers = [4, 4, 12, 4] embed_dims = [self.p.embed_dim, 128, 320, 128] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pool_size=3 act_layer=nn.GELU drop_rate=self.p.drop norm_layer=GroupNorm drop_path_rate=self.p.drop_path use_layer_scale=True layer_scale_init_value=1e-5 down_patch_size=3 down_stride=2 down_pad=1 num_classes=self.p.embed_dim for i in range(len(layers)): stage = basic_blocks(embed_dims[i], i, layers, pool_size=pool_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value) network.append(stage) if i >= len(layers) - 1: break if downsamples[i] or embed_dims[i] != embed_dims[i+1]: # downsampling between two stages network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, padding=down_pad, in_chans=embed_dims[i], embed_dim=embed_dims[i+1] ) ) self.network = nn.ModuleList(network) self.norm = norm_layer(embed_dims[-1]) self.head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 \ else nn.Identity() def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward_embeddings(self, x): x = self.patch_embed(x) return x def forward_tokens(self, x): outs = [] for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification return x def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_fusion(self.ent_embed(sub)) rel_emb = self.rel_fusion(self.rel_embed(rel)) comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) y = comb_emb.view(-1, 2, self.p.image_h, self.p.image_w) y = self.bn0(y) z = self.forward_embeddings(y) z = self.forward_tokens(z) z = z.mean([-2, -1]) z = self.norm(z) x = self.head(z) x = self.hidden_drop(x) x = self.bn1(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class InteractE(torch.nn.Module): """ Proposed method in the paper. Refer Section 6 of the paper for mode details Parameters ---------- params: Hyperparameters of the model chequer_perm: Reshaping to be used by the model Returns ------- The InteractE model instance """ def __init__(self, params, chequer_perm): super(InteractE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.bceloss = torch.nn.BCELoss() self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.p.perm) flat_sz_h = self.p.k_h flat_sz_w = 2*self.p.k_w self.padding = 0 self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt*self.p.perm) self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt*self.p.perm self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim) self.chequer_perm = chequer_perm self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) self.register_parameter('conv_filt', torch.nn.Parameter( torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz))) torch.nn.init.xavier_normal_(self.conv_filt) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def circular_padding_chw(self, batch, padding): upper_pad = batch[..., -padding:, :] lower_pad = batch[..., :padding, :] temp = torch.cat([upper_pad, batch, lower_pad], dim=2) left_pad = temp[..., -padding:] right_pad = temp[..., :padding] padded = torch.cat([left_pad, temp, right_pad], dim=3) return padded def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) rel_emb = self.rel_embed(rel) comb_emb = torch.cat([sub_emb, rel_emb], dim=1) chequer_perm = comb_emb[:, self.chequer_perm] stack_inp = chequer_perm.reshape( (-1, self.p.perm, 2*self.p.k_w, self.p.k_h)) stack_inp = self.bn0(stack_inp) x = self.inp_drop(stack_inp) x = self.circular_padding_chw(x, self.p.ker_sz//2) x = F.conv2d(x, self.conv_filt.repeat(self.p.perm, 1, 1, 1), padding=self.padding, groups=self.p.perm) x = self.bn1(x) x = F.relu(x) x = self.feature_map_drop(x) x = x.view(-1, self.flat_sz) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class GroupNorm(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) def basic_blocks(dim, index, layers, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop_rate=.0, drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5): """ generate PoolFormer blocks for a stage return: PoolFormer blocks """ blocks = [] for block_idx in range(layers[index]): block_dpr = drop_path_rate * ( block_idx + sum(layers[:index])) / (sum(layers) - 1) blocks.append(PoolFormerBlock( dim, pool_size=pool_size, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, )) blocks = nn.Sequential(*blocks) return blocks def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, C, H, W = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 pretrained_window_size (tuple[int]): The height and width of the window in pre-training. """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., pretrained_window_size=[0, 0]): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)) # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_table = torch.stack( torch.meshgrid([relative_coords_h, relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) else: relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / np.log2(8) self.register_buffer("relative_coords_table", relative_coords_table) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(dim)) self.v_bias = nn.Parameter(torch.zeros(dim)) else: self.q_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, ' \ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) return x def cast_tuple(val, length = 1): return val if isinstance(val, tuple) else ((val,) * length) # helper classes class ChanLayerNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class OverlappingPatchEmbed(nn.Module): def __init__(self, dim_in, dim_out, stride = 2): super().__init__() kernel_size = stride * 2 - 1 padding = kernel_size // 2 self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding) def forward(self, x): return self.conv(x) class PEG(nn.Module): def __init__(self, dim, kernel_size = 3): super().__init__() self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1) def forward(self, x): return self.proj(x) + x # feedforward class FeedForward(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() inner_dim = int(dim * mult) self.net = nn.Sequential( ChanLayerNorm(dim), nn.Conv2d(dim, inner_dim, 1), nn.GELU(), nn.Dropout(dropout), nn.Conv2d(inner_dim, dim, 1), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) # attention class DSSA(nn.Module): def __init__( self, dim, heads = 8, dim_head = 32, dropout = 0., window_size = 7 ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 self.window_size = window_size inner_dim = dim_head * heads self.norm = ChanLayerNorm(dim) self.attend = nn.Sequential( nn.Softmax(dim = -1), nn.Dropout(dropout) ) self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False) # window tokens self.window_tokens = nn.Parameter(torch.randn(dim)) # prenorm and non-linearity for window tokens # then projection to queries and keys for window tokens self.window_tokens_to_qk = nn.Sequential( nn.LayerNorm(dim_head), nn.GELU(), Rearrange('b h n c -> b (h c) n'), nn.Conv1d(inner_dim, inner_dim * 2, 1), Rearrange('b (h c) n -> b h n c', h = heads), ) # window attention self.window_attend = nn.Sequential( nn.Softmax(dim = -1), nn.Dropout(dropout) ) self.to_out = nn.Sequential( nn.Conv2d(inner_dim, dim, 1), nn.Dropout(dropout) ) def forward(self, x): """ einstein notation b - batch c - channels w1 - window size (height) w2 - also window size (width) i - sequence dimension (source) j - sequence dimension (target dimension to be reduced) h - heads x - height of feature map divided by window size y - width of feature map divided by window size """ batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}' num_windows = (height // wsz) * (width // wsz) x = self.norm(x) # fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz) # add windowing tokens w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0]) x = torch.cat((w, x), dim = -1) # project for queries, keys, value q, k, v = self.to_qkv(x).chunk(3, dim = 1) # split out heads q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v)) # scale q = q * self.scale # similarity dots = einsum('b h i d, b h j d -> b h i j', q, k) # attention attn = self.attend(dots) # aggregate values out = torch.matmul(attn, v) # split out windowed tokens window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:] # early return if there is only 1 window if num_windows == 1: fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) return self.to_out(fmap) # carry out the pointwise attention, the main novelty in the paper window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz) windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz) # windowed queries and keys (preceded by prenorm activation) w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1) # scale w_q = w_q * self.scale # similarities w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k) w_attn = self.window_attend(w_dots) # aggregate the feature maps from the "depthwise" attention step (the most interesting part of the paper, one i haven't seen before) aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps) # fold back the windows and then combine heads for aggregation fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz) return self.to_out(fmap) class PoolFormerBlock(nn.Module): """ Implementation of one PoolFormer block. --dim: embedding dim --pool_size: pooling size --mlp_ratio: mlp expansion ratio --act_layer: activation --norm_layer: normalization --drop: dropout rate --drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382 --use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239 """ def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = norm_layer(dim) #self.token_mixer = Pooling(pool_size=pool_size) # self.token_mixer = FNetBlock() self.window_size = 4 self.attn_heads = 4 self.attn_mask = None # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) self.token_mixer = DSSA(dim, heads=self.attn_heads, window_size=self.window_size) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): B, C, H, W = x.shape # x_windows = window_partition(x, self.window_size) # x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # x_attn = window_reverse(attn_windows, self.window_size, H, W) x_attn = self.token_mixer(x) if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x_attn) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) else: x = x + self.drop_path(x_attn) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Patch Embedding that is implemented by a layer of conv. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H/stride, W/stride] """ def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = self.norm(x) return x class Pooling(nn.Module): """ Implementation of pooling for PoolFormer --pool_size: pooling size """ def __init__(self, pool_size=3): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size//2, count_include_pad=False) def forward(self, x): return self.pool(x) - x class Mlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LayerNormChannel(nn.Module): """ LayerNorm only for Channel Dimension. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, eps=1e-05): super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ + self.bias.unsqueeze(-1).unsqueeze(-1) return x class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FNetBlock(nn.Module): def __init__(self): super().__init__() def forward(self, x): x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real return x class FNet(nn.Module): def __init__(self, dim, depth, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, FNetBlock()), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x