try gtp vit
This commit is contained in:
parent
fddea4769f
commit
6fc56b920f
13
models.py
13
models.py
@ -597,7 +597,7 @@ class FouriER(torch.nn.Module):
|
|||||||
|
|
||||||
for idx, block in enumerate(self.network):
|
for idx, block in enumerate(self.network):
|
||||||
try:
|
try:
|
||||||
x = block(x, graph)
|
x = block((x, graph))
|
||||||
except:
|
except:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
# output only the features of last layer for image classification
|
# output only the features of last layer for image classification
|
||||||
@ -758,7 +758,7 @@ def basic_blocks(dim, index, layers,
|
|||||||
use_layer_scale=use_layer_scale,
|
use_layer_scale=use_layer_scale,
|
||||||
layer_scale_init_value=layer_scale_init_value,
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
))
|
))
|
||||||
blocks = nn.Sequential(*blocks)
|
blocks = SeqModel(*blocks)
|
||||||
|
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
@ -923,6 +923,15 @@ def window_reverse(windows, window_size, H, W):
|
|||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class SeqModel(nn.Sequential):
|
||||||
|
def forward(self, *inputs):
|
||||||
|
for module in self._modules.values():
|
||||||
|
if type(inputs) == tuple:
|
||||||
|
inputs = module(*inputs)
|
||||||
|
else:
|
||||||
|
inputs = module(inputs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
def propagate(x: torch.Tensor, weight: torch.Tensor,
|
def propagate(x: torch.Tensor, weight: torch.Tensor,
|
||||||
index_kept: torch.Tensor, index_prop: torch.Tensor,
|
index_kept: torch.Tensor, index_prop: torch.Tensor,
|
||||||
standard: str = "None", alpha: Optional[float] = 0,
|
standard: str = "None", alpha: Optional[float] = 0,
|
||||||
|
Loading…
Reference in New Issue
Block a user