try gtp vit

This commit is contained in:
thanhvc3 2024-04-28 15:10:09 +07:00
parent 4daa40527b
commit 0f986d7517

View File

@ -596,7 +596,7 @@ class FouriER(torch.nn.Module):
graph = degree @ graph @ degree
for idx, block in enumerate(self.network):
x = block(x, graph)
x = block(x)
# output only the features of last layer for image classification
return x
@ -1226,7 +1226,7 @@ class PoolFormerBlock(nn.Module):
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, graph):
def forward(self, x):
B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)