import torch from torch import nn import torch.nn.functional as F import numpy as np from functools import partial from einops.layers.torch import Rearrange, Reduce from utils import * from layers import * from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropPath, trunc_normal_ from timm.models.registry import register_model from timm.layers.helpers import to_2tuple from typing import * import math class ConvE(torch.nn.Module): def __init__(self, params, ): super(ConvE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.conv1 = torch.nn.Conv2d( self.in_channels, self.out_channels, (self.p.filt_h, self.p.filt_w), 1, 0, bias=True) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (20-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub).view(-1, 1, self.p.k_w, self.p.k_h) rel_emb = self.rel_embed(rel).view(-1, 1, self.p.k_w, self.p.k_h) x = torch.cat([sub_emb, rel_emb], 2) x = self.bn0(x) x = self.inp_drop(x) x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class HypER(torch.nn.Module): def __init__(self, params, ): super(HypER, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (1-self.p.filt_h+1)*(self.p.embed_dim - self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) fc1_length = self.in_channels*self.out_channels*self.p.filt_h*self.p.filt_w self.fc1 = torch.nn.Linear(self.p.embed_dim, fc1_length) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed( sub).view(-1, 1, 1, self.ent_embed.weight.size(1)) rel_emb = self.rel_embed(rel) x = self.bn0(sub_emb) x = self.inp_drop(x) k = self.fc1(rel_emb) k = k.view(-1, self.in_channels, self.out_channels, self.p.filt_h, self.p.filt_w) k = k.view(sub_emb.size(0)*self.in_channels * self.out_channels, 1, self.p.filt_h, self.p.filt_w) x = x.permute(1, 0, 2, 3) x = F.conv2d(x, k, groups=sub_emb.size(0)) x = x.view(sub_emb.size(0), 1, self.out_channels, 1 - self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1) x = x.permute(0, 3, 4, 1, 2) x = torch.sum(x, dim=3) x = x.permute(0, 3, 1, 2).contiguous() x = self.bn1(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class HypE(torch.nn.Module): def __init__(self, params, ): super(HypE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.in_channels) self.bn1 = torch.nn.BatchNorm2d(self.out_channels) self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) fc_length = (10-self.p.filt_h+1)*(20-self.p.filt_w+1)*self.out_channels self.fc = torch.nn.Linear(fc_length, self.p.embed_dim) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed( sub).view(-1, 1, 1, self.ent_embed.weight.size(1)) rel_emb = self.rel_embed(rel) x = self.bn0(sub_emb) x = self.inp_drop(x) k = self.fc1(rel_emb) k = k.view(-1, self.in_channels, self.out_channels, self.p.filt_h, self.p.filt_w) k = k.view(sub_emb.size(0)*self.in_channels * self.out_channels, 1, self.p.filt_h, self.p.filt_w) x = x.permute(1, 0, 2, 3) x = F.conv2d(x, k, groups=sub_emb.size(0)) x = x.view(sub_emb.size(0), 1, self.out_channels, 1 - self.p.filt_h+1, sub_emb.size(3)-self.p.filt_w+1) x = x.permute(0, 3, 4, 1, 2) x = torch.sum(x, dim=3) x = x.permute(0, 3, 1, 2).contiguous() x = self.bn1(x) x = self.feature_map_drop(x) x = x.view(sub_emb.size(0), -1) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class DistMult(torch.nn.Module): def __init__(self, params, ): super(DistMult, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) rel_emb = self.rel_embed(rel) sub_emb = self.bn0(sub_emb) sub_emb = self.inp_drop(sub_emb) if strategy == 'one_to_n': x = torch.mm(sub_emb * rel_emb, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul((sub_emb * rel_emb).unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class ComplEx(torch.nn.Module): def __init__(self, params, ): super(ComplEx, self).__init__() self.p = params self.ent_embed_real = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed_real.weight) self.ent_embed_imaginary = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed_imaginary.weight) self.rel_embed_real = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed_real.weight) self.rel_embed_imaginary = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed_imaginary.weight) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb_real = self.ent_embed_real(sub) sub_emb_imaginary = self.ent_embed_imaginary(sub) rel_emb_real = self.rel_embed_real(rel) rel_emb_imaginary = self.rel_embed_imaginary(rel) sub_emb_real = self.bn0(sub_emb_real) sub_emb_real = self.inp_drop(sub_emb_real) sub_emb_imaginary = self.bn0(sub_emb_imaginary) sub_emb_imaginary = self.inp_drop(sub_emb_imaginary) if strategy == 'one_to_n': x = torch.mm(sub_emb_real*rel_emb_real, self.ent_embed_real.weight.transpose(1, 0)) +\ torch.mm(sub_emb_real*rel_emb_imaginary, self.ent_embed_imaginary.weight.transpose(1, 0)) +\ torch.mm(sub_emb_imaginary*rel_emb_real, self.ent_embed_imaginary.weight.transpose(1, 0)) -\ torch.mm(sub_emb_imaginary*rel_emb_imaginary, self.ent_embed_real.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: neg_embs_real = self.ent_embed_real(neg_ents) neg_embs_imaginary = self.ent_embed_imaginary(neg_ents) x = (torch.mul((sub_emb_real*rel_emb_real).unsqueeze(1), neg_embs_real) + torch.mul((sub_emb_real*rel_emb_imaginary).unsqueeze(1), neg_embs_imaginary) + torch.mul((sub_emb_imaginary*rel_emb_real).unsqueeze(1), neg_embs_imaginary) - torch.mul((sub_emb_imaginary*rel_emb_imaginary).unsqueeze(1), neg_embs_real)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class TuckER(torch.nn.Module): def __init__(self, params, ): super(TuckER, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.core_W = torch.nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (self.p.embed_dim, self.p.embed_dim, self.p.embed_dim)), dtype=torch.float, device="cuda", requires_grad=True)) self.bceloss = torch.nn.BCELoss() self.inp_drop = torch.nn.Dropout(self.p.inp_drop) self.hidden_drop1 = torch.nn.Dropout(self.p.hid_drop) self.hidden_drop2 = torch.nn.Dropout(self.p.hid_drop) self.bn0 = torch.nn.BatchNorm1d(self.p.embed_dim) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) sub_emb = self.bn0(sub_emb) sub_emb = self.inp_drop(sub_emb) x = sub_emb.view(-1, 1, sub_emb.size(1)) r = self.rel_embed(rel) W_mat = torch.mm(r, self.core_W.view(r.size(1), -1)) W_mat = W_mat.view(-1, sub_emb.size(1), sub_emb.size(1)) W_mat = self.hidden_drop1(W_mat) x = torch.bmm(x, W_mat) x = x.view(-1, sub_emb.size(1)) x = self.bn1(x) x = self.hidden_drop2(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class FouriER(torch.nn.Module): def __init__(self, params, hid_drop = None, embed_dim = None): super(FouriER, self).__init__() if hid_drop is not None: self.p.hid_drop = hid_drop if embed_dim is not None: self.p.ent_vec_dim = embed_dim self.p.rel_vec_dim = embed_dim self.p.embed_dim = embed_dim self.p = params image_h, image_w = self.p.image_h, self.p.image_w self.in_channels = self.p.in_channels self.out_channels = self.p.out_channels self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.ent_vec_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.rel_vec_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.ent_fusion = torch.nn.Linear( self.p.ent_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.ent_fusion.weight) self.rel_fusion = torch.nn.Linear( self.p.rel_vec_dim, image_h*image_w) torch.nn.init.xavier_normal_(self.rel_fusion.weight) self.bceloss = torch.nn.BCELoss() channels = 2 self.bn0 = torch.nn.BatchNorm2d(channels) self.bn1 = torch.nn.BatchNorm1d(self.p.embed_dim) self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) patch_size = self.p.patch_size assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size' self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size, embed_dim=self.p.embed_dim, stride=4, padding=2) network = [] layers = [4, 4, 12, 4] embed_dims = [self.p.embed_dim, 128, 320, 128] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pool_size=3 act_layer=nn.GELU drop_rate=self.p.drop norm_layer=GroupNorm drop_path_rate=self.p.drop_path use_layer_scale=True layer_scale_init_value=1e-5 down_patch_size=3 down_stride=2 down_pad=1 num_classes=self.p.embed_dim for i in range(len(layers)): stage = basic_blocks(embed_dims[i], i, layers, pool_size=pool_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value) network.append(stage) if i >= len(layers) - 1: break if downsamples[i] or embed_dims[i] != embed_dims[i+1]: # downsampling between two stages network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, padding=down_pad, in_chans=embed_dims[i], embed_dim=embed_dims[i+1] ) ) self.network = nn.ModuleList(network) self.norm = norm_layer(embed_dims[-1]) self.graph_type = 'Spatial' N = (image_h // patch_size)**2 if self.graph_type in ["Spatial", "Mixed"]: # Create a range tensor of node indices indices = torch.arange(N) # Reshape the indices tensor to create a grid of row and column indices row_indices = indices.view(-1, 1).expand(-1, N) col_indices = indices.view(1, -1).expand(N, -1) # Compute the adjacency matrix row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N)) row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N)) graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float()) graph = graph - torch.eye(N) self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu self.class_token = False self.token_scale = False self.head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 \ else nn.Identity() def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def forward_embeddings(self, x): x = self.patch_embed(x) return x def forward_tokens(self, x): outs = [] B, C, H, W = x.shape N = H*W if self.graph_type in ["Semantic", "Mixed"]: # Generate the semantic graph w.r.t. the cosine similarity between tokens # Compute cosine similarity if self.class_token: x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True) else: x_normed = x / x.norm(dim=-1, keepdim=True) x_cossim = x_normed @ x_normed.transpose(-1, -2) threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1 semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0) if self.class_token: semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0) else: semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0) if self.graph_type == "None": graph = None else: if self.graph_type == "Spatial": graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device) elif self.graph_type == "Semantic": graph = semantic_graph elif self.graph_type == "Mixed": # Integrate the spatial graph and semantic graph spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device) graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float() # Symmetrically normalize the graph degree = graph.sum(-1) # B, N degree = torch.diag_embed(degree**(-1/2)) graph = degree @ graph @ degree for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification return x def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_fusion(self.ent_embed(sub)) rel_emb = self.rel_fusion(self.rel_embed(rel)) comb_emb = torch.stack([sub_emb.view(-1, self.p.image_h, self.p.image_w), rel_emb.view(-1, self.p.image_h, self.p.image_w)], dim=1) y = comb_emb.view(-1, 2, self.p.image_h, self.p.image_w) y = self.bn0(y) z = self.forward_embeddings(y) z = self.forward_tokens(z) z = z.mean([-2, -1]) z = self.norm(z) x = self.head(z) x = self.hidden_drop(x) x = self.bn1(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class InteractE(torch.nn.Module): """ Proposed method in the paper. Refer Section 6 of the paper for mode details Parameters ---------- params: Hyperparameters of the model chequer_perm: Reshaping to be used by the model Returns ------- The InteractE model instance """ def __init__(self, params, chequer_perm): super(InteractE, self).__init__() self.p = params self.ent_embed = torch.nn.Embedding( self.p.num_ent, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.ent_embed.weight) self.rel_embed = torch.nn.Embedding( self.p.num_rel*2, self.p.embed_dim, padding_idx=None) torch.nn.init.xavier_normal_(self.rel_embed.weight) self.bceloss = torch.nn.BCELoss() self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop) self.bn0 = torch.nn.BatchNorm2d(self.p.perm) flat_sz_h = self.p.k_h flat_sz_w = 2*self.p.k_w self.padding = 0 self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt*self.p.perm) self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt*self.p.perm self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim) self.chequer_perm = chequer_perm self.register_parameter( 'bias', torch.nn.Parameter(torch.zeros(self.p.num_ent))) self.register_parameter('conv_filt', torch.nn.Parameter( torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz))) torch.nn.init.xavier_normal_(self.conv_filt) def loss(self, pred, true_label=None, sub_samp=None): label_pos = true_label[0] label_neg = true_label[1:] loss = self.bceloss(pred, true_label) return loss def circular_padding_chw(self, batch, padding): upper_pad = batch[..., -padding:, :] lower_pad = batch[..., :padding, :] temp = torch.cat([upper_pad, batch, lower_pad], dim=2) left_pad = temp[..., -padding:] right_pad = temp[..., :padding] padded = torch.cat([left_pad, temp, right_pad], dim=3) return padded def forward(self, sub, rel, neg_ents, strategy='one_to_x'): sub_emb = self.ent_embed(sub) rel_emb = self.rel_embed(rel) comb_emb = torch.cat([sub_emb, rel_emb], dim=1) chequer_perm = comb_emb[:, self.chequer_perm] stack_inp = chequer_perm.reshape( (-1, self.p.perm, 2*self.p.k_w, self.p.k_h)) stack_inp = self.bn0(stack_inp) x = self.inp_drop(stack_inp) x = self.circular_padding_chw(x, self.p.ker_sz//2) x = F.conv2d(x, self.conv_filt.repeat(self.p.perm, 1, 1, 1), padding=self.padding, groups=self.p.perm) x = self.bn1(x) x = F.relu(x) x = self.feature_map_drop(x) x = x.view(-1, self.flat_sz) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) if strategy == 'one_to_n': x = torch.mm(x, self.ent_embed.weight.transpose(1, 0)) x += self.bias.expand_as(x) else: x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1) x += self.bias[neg_ents] pred = torch.sigmoid(x) return pred class GroupNorm(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) def basic_blocks(dim, index, layers, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop_rate=.0, drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5): """ generate PoolFormer blocks for a stage return: PoolFormer blocks """ blocks = [] for block_idx in range(layers[index]): block_dpr = drop_path_rate * ( block_idx + sum(layers[:index])) / (sum(layers) - 1) blocks.append(PoolFormerBlock( dim, pool_size=pool_size, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, )) blocks = nn.Sequential(*blocks) return blocks def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, C, H, W = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 pretrained_window_size (tuple[int]): The height and width of the window in pre-training. """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., pretrained_window_size=[0, 0]): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)) # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_table = torch.stack( torch.meshgrid([relative_coords_h, relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) else: relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / np.log2(8) self.register_buffer("relative_coords_table", relative_coords_table) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(dim)) self.v_bias = nn.Parameter(torch.zeros(dim)) else: self.q_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, ' \ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) return x def propagate(x: torch.Tensor, weight: torch.Tensor, index_kept: torch.Tensor, index_prop: torch.Tensor, standard: str = "None", alpha: Optional[float] = 0, token_scales: Optional[torch.Tensor] = None, cls_token=True): """ Propagate tokens based on the selection results. ================================================ Args: - x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token. - weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens, excluding the [CLS] token. weight could be a pre-defined graph of the current feature map (by default) or the attention map (need to manually modify the Block Module). - index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X - index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X - standard: str: the method applied to propagate the tokens, including "None", "Mean" and "GraphProp" - alpha: float: the coefficient of propagated features - token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales is None by default. If it is not None, then token_scales represents the scales of each token and should sum up to N. Return: - x: Tensor([B, N-1-num_prop, C]): the feature map after propagation - weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation """ B, C, N = x.shape # Step 1: divide tokens if cls_token: x_cls = x[:, 0:1] # B, 1, C x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C # Step 2: divide token_scales if it is not None if token_scales is not None: if cls_token: token_scales_cls = token_scales[:, 0:1] # B, 1 token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop # Step 3: propagate tokens if standard == "None": """ No further propagation """ pass elif standard == "Mean": """ Calculate the mean of all the propagated tokens, and concatenate the result token back to kept tokens. """ # naive average x_prop = x_prop.mean(1, keepdim=True) # B, 1, C # Concatenate the average token x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C elif standard == "GraphProp": """ Propagate all the propagated token to kept token with respect to the weights and token scales. """ assert weight is not None, "The graph weight is needed for graph propagation" # Step 3.1: divide propagation weights. if cls_token: index_kept = index_kept - 1 # since weights do not include the [CLS] token index_prop = index_prop - 1 # since weights do not include the [CLS] token weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1 weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop else: weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1 weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop # Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens # Simple implementation x_prop = weight_prop @ x_prop # B, N-1-num_prop, C x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C """ scatter_reduce implementation for batched inputs # Get the non-zero values non_zero_indices = torch.nonzero(weight_prop, as_tuple=True) non_zero_values = weight_prop[non_zero_indices] # Sparse multiplication batch_indices, row_indices, col_indices = non_zero_indices sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :] reduce_indices = batch_indices * x_kept.shape[1] + row_indices x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0, index=reduce_indices[:, None], src=sparse_matmul, reduce="sum", include_self=True) x_kept = x_kept.reshape(B, -1, C) """ # Step 3.3: calculate the scale of each token if token_scales is not None if token_scales is not None: if cls_token: token_scales_cls = token_scales[:, 0:1] # B, 1 token_scales = token_scales[:, 1:] token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1 token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop if cls_token: token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop else: assert False, "Propagation method \'%f\' has not been supported yet." % standard if cls_token: # Step 4: concatenate the [CLS] token and generate returned value x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C else: x = x_kept return x, weight, token_scales def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True): """ Select image tokens to be propagated. The [CLS] token will be ignored. ====================================================================== Args: - weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the attention map of tokens at the moment. - standard: str: the method applied to select the tokens - num_prop: int: the number of tokens to be propagated Return: - index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens - index_prop: Tensor([B, num_prop]): the index of propagated tokens """ assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet." B, H, N1, N2 = weight.shape assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet." N = N1 assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative." if cls_token: if standard == "CLSAttnMean": token_rank = weight[:,:,0,1:].mean(1) elif standard == "CLSAttnMax": token_rank = weight[:,:,0,1:].max(1)[0] elif standard == "IMGAttnMean": token_rank = weight[:,:,:,1:].sum(-2).mean(1) elif standard == "IMGAttnMax": token_rank = weight[:,:,:,1:].sum(-2).max(1)[0] elif standard == "DiagAttnMean": token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) elif standard == "DiagAttnMax": token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] elif standard == "MixedAttnMean": token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1) token_rank = token_rank_1 * token_rank_2 elif standard == "MixedAttnMax": token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] token_rank = token_rank_1 * token_rank_2 elif standard == "SumAttnMax": token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] token_rank = token_rank_1 + token_rank_2 elif standard == "CosSimMean": weight = weight[:,:,1:,:].mean(1) weight = weight / weight.norm(dim=-1, keepdim=True) token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) elif standard == "CosSimMax": weight = weight[:,:,1:,:].max(1)[0] weight = weight / weight.norm(dim=-1, keepdim=True) token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) elif standard == "Random": token_rank = torch.randn((B, N-1), device=weight.device) else: print("Type\'", standard, "\' selection not supported.") assert False token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop index_prop = token_rank[:, -num_prop:]+1 # B, num_prop else: if standard == "IMGAttnMean": token_rank = weight.sum(-2).mean(1) elif standard == "IMGAttnMax": token_rank = weight.sum(-2).max(1)[0] elif standard == "DiagAttnMean": token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) elif standard == "DiagAttnMax": token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] elif standard == "MixedAttnMean": token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) token_rank_2 = weight.sum(-2).mean(1) token_rank = token_rank_1 * token_rank_2 elif standard == "MixedAttnMax": token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] token_rank_2 = weight.sum(-2).max(1)[0] token_rank = token_rank_1 * token_rank_2 elif standard == "SumAttnMax": token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] token_rank_2 = weight.sum(-2).max(1)[0] token_rank = token_rank_1 + token_rank_2 elif standard == "CosSimMean": weight = weight.mean(1) weight = weight / weight.norm(dim=-1, keepdim=True) token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) elif standard == "CosSimMax": weight = weight.max(1)[0] weight = weight / weight.norm(dim=-1, keepdim=True) token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) elif standard == "Random": token_rank = torch.randn((B, N-1), device=weight.device) else: print("Type\'", standard, "\' selection not supported.") assert False token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop index_prop = token_rank[:, -num_prop:] # B, num_prop return index_kept, index_prop class PoolFormerBlock(nn.Module): """ Implementation of one PoolFormer block. --dim: embedding dim --pool_size: pooling size --mlp_ratio: mlp expansion ratio --act_layer: activation --norm_layer: normalization --drop: dropout rate --drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382 --use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239 """ def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = norm_layer(dim) #self.token_mixer = Pooling(pool_size=pool_size) # self.token_mixer = FNetBlock() self.window_size = 4 self.attn_mask = None self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x, graph): B, C, H, W = x.shape x_windows = window_partition(x, self.window_size) x_windows = x_windows.view(-1, self.window_size * self.window_size, C) attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) x_attn = window_reverse(attn_windows, self.window_size, H, W) index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, cls_token=False) x, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", alpha=0.1, token_scales=token_scales, cls_token=False) if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x_attn) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) else: x = x + self.drop_path(x_attn) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Patch Embedding that is implemented by a layer of conv. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H/stride, W/stride] """ def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = self.norm(x) return x class Pooling(nn.Module): """ Implementation of pooling for PoolFormer --pool_size: pooling size """ def __init__(self, pool_size=3): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size//2, count_include_pad=False) def forward(self, x): return self.pool(x) - x class Mlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LayerNormChannel(nn.Module): """ LayerNorm only for Channel Dimension. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, eps=1e-05): super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ + self.bias.unsqueeze(-1).unsqueeze(-1) return x class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FNetBlock(nn.Module): def __init__(self): super().__init__() def forward(self, x): x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real return x class FNet(nn.Module): def __init__(self, dim, depth, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, FNetBlock()), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x