diff --git a/models.py b/models.py index 91e0b0b..223a03e 100644 --- a/models.py +++ b/models.py @@ -596,7 +596,10 @@ class FouriER(torch.nn.Module): graph = degree @ graph @ degree for idx, block in enumerate(self.network): - x = block(x) + if (isinstance(block, PoolFormerBlock)): + x = block(x, graph) + else: + x = block(x) # output only the features of last layer for image classification return x @@ -1226,7 +1229,7 @@ class PoolFormerBlock(nn.Module): self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) - def forward(self, x): + def forward(self, x, weight): B, C, H, W = x.shape x_windows = window_partition(x, self.window_size) x_windows = x_windows.view(-1, self.window_size * self.window_size, C)