try gtp vit

This commit is contained in:
thanhvc3 2024-04-28 15:27:41 +07:00
parent d0e4630dd6
commit b9273b6696

View File

@ -1238,7 +1238,7 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True) layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, weight): def forward(self, x, weight, token_scales = None):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) x_windows = x_windows.view(-1, self.window_size * self.window_size, C)