22 lines
622 B
Python
22 lines
622 B
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
class PreNormResidual(torch.nn.Module):
|
|
def __init__(self, dim, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
self.norm = torch.nn.LayerNorm(dim)
|
|
|
|
def forward(self, x):
|
|
return self.fn(self.norm(x)) + x
|
|
|
|
def FeedForward(dim, expansion_factor=4, dropout=0., dense=torch.nn.Linear):
|
|
inner_dim = int(dim * expansion_factor)
|
|
return torch.nn.Sequential(
|
|
dense(dim, inner_dim),
|
|
torch.nn.GELU(),
|
|
torch.nn.Dropout(dropout*2),
|
|
dense(inner_dim, dim),
|
|
torch.nn.Dropout(dropout)
|
|
) |