try gtp vit
This commit is contained in:
parent
0f986d7517
commit
d9209a7ef1
@ -596,7 +596,10 @@ class FouriER(torch.nn.Module):
|
||||
graph = degree @ graph @ degree
|
||||
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
if (isinstance(block, PoolFormerBlock)):
|
||||
x = block(x, graph)
|
||||
else:
|
||||
x = block(x)
|
||||
# output only the features of last layer for image classification
|
||||
return x
|
||||
|
||||
@ -1226,7 +1229,7 @@ class PoolFormerBlock(nn.Module):
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, weight):
|
||||
B, C, H, W = x.shape
|
||||
x_windows = window_partition(x, self.window_size)
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
||||
|
Loading…
Reference in New Issue
Block a user