diff --git a/models.py b/models.py index ca11dc6..3624eb4 100644 --- a/models.py +++ b/models.py @@ -565,12 +565,12 @@ class FouriER(torch.nn.Module): if downsamples[i] or embed_dims[i] != embed_dims[i+1]: # downsampling between two stages network.append( - # PatchEmbed( - # patch_size=down_patch_size, stride=down_stride, - # padding=down_pad, - # in_chans=embed_dims[i], embed_dim=embed_dims[i+1] - # ) - PatchMerging(dim=embed_dims[i+1]) + PatchEmbed( + patch_size=down_patch_size, stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], embed_dim=embed_dims[i+1] + ) + # PatchMerging(dim=embed_dims[i+1]) ) self.network = nn.ModuleList(network)