From b01e504874cd25661a4e1828e70be4522185b7e7 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 12:01:40 +0700 Subject: [PATCH] try gtp vit --- models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/models.py b/models.py index 60e6eb8..d8347a7 100644 --- a/models.py +++ b/models.py @@ -527,6 +527,9 @@ class FouriER(torch.nn.Module): self.network = nn.ModuleList(network) self.norm = norm_layer(embed_dims[-1]) + self.graph_type = 'Spatial' + self.class_token = False + self.token_scale = False self.head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 \ else nn.Identity() @@ -578,10 +581,6 @@ class FouriER(torch.nn.Module): degree = torch.diag_embed(degree**(-1/2)) graph = degree @ graph @ degree - if self.token_scale: - token_scales = self.token_scales.unsqueeze(0).expand(B,-1).to(x.device) - else: - token_scales = None for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification