try modify swin
This commit is contained in:
		
							
								
								
									
										55
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										55
									
								
								models.py
									
									
									
									
									
								
							@@ -435,6 +435,50 @@ class TuckER(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        return pred
 | 
			
		||||
 | 
			
		||||
class PatchMerging(nn.Module):
 | 
			
		||||
    r""" Patch Merging Layer.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        input_resolution (tuple[int]): Resolution of input feature.
 | 
			
		||||
        dim (int): Number of input channels.
 | 
			
		||||
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim, norm_layer=nn.LayerNorm):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
 | 
			
		||||
        self.norm = norm_layer(2 * dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        """
 | 
			
		||||
        x: B, C, H, W
 | 
			
		||||
        """
 | 
			
		||||
        B, C, H, W = x.shape
 | 
			
		||||
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
 | 
			
		||||
 | 
			
		||||
        x = x.view(B, H, W, C)
 | 
			
		||||
 | 
			
		||||
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
 | 
			
		||||
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
 | 
			
		||||
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
 | 
			
		||||
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
 | 
			
		||||
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
 | 
			
		||||
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
 | 
			
		||||
 | 
			
		||||
        x = self.reduction(x)
 | 
			
		||||
        x = self.norm(x)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return f"input_resolution={self.input_resolution}, dim={self.dim}"
 | 
			
		||||
 | 
			
		||||
    def flops(self):
 | 
			
		||||
        H, W = self.input_resolution
 | 
			
		||||
        flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
 | 
			
		||||
        flops += H * W * self.dim // 2
 | 
			
		||||
        return flops
 | 
			
		||||
 | 
			
		||||
class FouriER(torch.nn.Module):
 | 
			
		||||
    def __init__(self, params, hid_drop = None, embed_dim = None):
 | 
			
		||||
@@ -519,11 +563,12 @@ class FouriER(torch.nn.Module):
 | 
			
		||||
            if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
 | 
			
		||||
                # downsampling between two stages
 | 
			
		||||
                network.append(
 | 
			
		||||
                    PatchEmbed(
 | 
			
		||||
                        patch_size=down_patch_size, stride=down_stride, 
 | 
			
		||||
                        padding=down_pad, 
 | 
			
		||||
                        in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
 | 
			
		||||
                        )
 | 
			
		||||
                    # PatchEmbed(
 | 
			
		||||
                    #     patch_size=down_patch_size, stride=down_stride, 
 | 
			
		||||
                    #     padding=down_pad, 
 | 
			
		||||
                    #     in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
 | 
			
		||||
                    #     )
 | 
			
		||||
                    PatchMerging(dim=embed_dims[i+1])
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        self.network = nn.ModuleList(network)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user