From 6fc56b920fb259cfe14b9d31a351ae7a4fb30fb7 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 15:24:05 +0700 Subject: [PATCH] try gtp vit --- models.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index db6be39..ffab680 100644 --- a/models.py +++ b/models.py @@ -597,7 +597,7 @@ class FouriER(torch.nn.Module): for idx, block in enumerate(self.network): try: - x = block(x, graph) + x = block((x, graph)) except: x = block(x) # output only the features of last layer for image classification @@ -758,7 +758,7 @@ def basic_blocks(dim, index, layers, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, )) - blocks = nn.Sequential(*blocks) + blocks = SeqModel(*blocks) return blocks @@ -923,6 +923,15 @@ def window_reverse(windows, window_size, H, W): x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) return x +class SeqModel(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + def propagate(x: torch.Tensor, weight: torch.Tensor, index_kept: torch.Tensor, index_prop: torch.Tensor, standard: str = "None", alpha: Optional[float] = 0,