import torch from torch import nn from torch.nn import functional as F class PreNormResidual(torch.nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = torch.nn.LayerNorm(dim) def forward(self, x): return self.fn(self.norm(x)) + x def FeedForward(dim, expansion_factor=4, dropout=0., dense=torch.nn.Linear): inner_dim = int(dim * expansion_factor) return torch.nn.Sequential( dense(dim, inner_dim), torch.nn.GELU(), torch.nn.Dropout(dropout*2), dense(inner_dim, dim), torch.nn.Dropout(dropout) )