try gtp vit
This commit is contained in:
parent
23c44d3582
commit
b01e504874
@ -527,6 +527,9 @@ class FouriER(torch.nn.Module):
|
|||||||
|
|
||||||
self.network = nn.ModuleList(network)
|
self.network = nn.ModuleList(network)
|
||||||
self.norm = norm_layer(embed_dims[-1])
|
self.norm = norm_layer(embed_dims[-1])
|
||||||
|
self.graph_type = 'Spatial'
|
||||||
|
self.class_token = False
|
||||||
|
self.token_scale = False
|
||||||
self.head = nn.Linear(
|
self.head = nn.Linear(
|
||||||
embed_dims[-1], num_classes) if num_classes > 0 \
|
embed_dims[-1], num_classes) if num_classes > 0 \
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
@ -578,10 +581,6 @@ class FouriER(torch.nn.Module):
|
|||||||
degree = torch.diag_embed(degree**(-1/2))
|
degree = torch.diag_embed(degree**(-1/2))
|
||||||
graph = degree @ graph @ degree
|
graph = degree @ graph @ degree
|
||||||
|
|
||||||
if self.token_scale:
|
|
||||||
token_scales = self.token_scales.unsqueeze(0).expand(B,-1).to(x.device)
|
|
||||||
else:
|
|
||||||
token_scales = None
|
|
||||||
for idx, block in enumerate(self.network):
|
for idx, block in enumerate(self.network):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
# output only the features of last layer for image classification
|
# output only the features of last layer for image classification
|
||||||
|
Loading…
Reference in New Issue
Block a user