try gtp vit

This commit is contained in:
thanhvc3 2024-04-28 12:01:40 +07:00
parent 23c44d3582
commit b01e504874

View File

@ -527,6 +527,9 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
self.class_token = False
self.token_scale = False
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
@ -578,10 +581,6 @@ class FouriER(torch.nn.Module):
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
if self.token_scale:
token_scales = self.token_scales.unsqueeze(0).expand(B,-1).to(x.device)
else:
token_scales = None
for idx, block in enumerate(self.network):
x = block(x)
# output only the features of last layer for image classification