Thesis/layers.py

22 lines
622 B
Python
Raw Normal View History

2023-05-04 08:49:41 +00:00
import torch
from torch import nn
from torch.nn import functional as F
class PreNormResidual(torch.nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = torch.nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor=4, dropout=0., dense=torch.nn.Linear):
inner_dim = int(dim * expansion_factor)
return torch.nn.Sequential(
dense(dim, inner_dim),
torch.nn.GELU(),
torch.nn.Dropout(dropout*2),
dense(inner_dim, dim),
torch.nn.Dropout(dropout)
)