try modify swin
This commit is contained in:
		
							
								
								
									
										55
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										55
									
								
								models.py
									
									
									
									
									
								
							@@ -435,6 +435,50 @@ class TuckER(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return pred
 | 
					        return pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PatchMerging(nn.Module):
 | 
				
			||||||
 | 
					    r""" Patch Merging Layer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        input_resolution (tuple[int]): Resolution of input feature.
 | 
				
			||||||
 | 
					        dim (int): Number of input channels.
 | 
				
			||||||
 | 
					        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, dim, norm_layer=nn.LayerNorm):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.dim = dim
 | 
				
			||||||
 | 
					        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
 | 
				
			||||||
 | 
					        self.norm = norm_layer(2 * dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        x: B, C, H, W
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        B, C, H, W = x.shape
 | 
				
			||||||
 | 
					        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = x.view(B, H, W, C)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
 | 
				
			||||||
 | 
					        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
 | 
				
			||||||
 | 
					        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
 | 
				
			||||||
 | 
					        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
 | 
				
			||||||
 | 
					        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
 | 
				
			||||||
 | 
					        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.reduction(x)
 | 
				
			||||||
 | 
					        x = self.norm(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self) -> str:
 | 
				
			||||||
 | 
					        return f"input_resolution={self.input_resolution}, dim={self.dim}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def flops(self):
 | 
				
			||||||
 | 
					        H, W = self.input_resolution
 | 
				
			||||||
 | 
					        flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
 | 
				
			||||||
 | 
					        flops += H * W * self.dim // 2
 | 
				
			||||||
 | 
					        return flops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FouriER(torch.nn.Module):
 | 
					class FouriER(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, params, hid_drop = None, embed_dim = None):
 | 
					    def __init__(self, params, hid_drop = None, embed_dim = None):
 | 
				
			||||||
@@ -519,11 +563,12 @@ class FouriER(torch.nn.Module):
 | 
				
			|||||||
            if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
 | 
					            if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
 | 
				
			||||||
                # downsampling between two stages
 | 
					                # downsampling between two stages
 | 
				
			||||||
                network.append(
 | 
					                network.append(
 | 
				
			||||||
                    PatchEmbed(
 | 
					                    # PatchEmbed(
 | 
				
			||||||
                        patch_size=down_patch_size, stride=down_stride, 
 | 
					                    #     patch_size=down_patch_size, stride=down_stride, 
 | 
				
			||||||
                        padding=down_pad, 
 | 
					                    #     padding=down_pad, 
 | 
				
			||||||
                        in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
 | 
					                    #     in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
 | 
				
			||||||
                        )
 | 
					                    #     )
 | 
				
			||||||
 | 
					                    PatchMerging(dim=embed_dims[i+1])
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.network = nn.ModuleList(network)
 | 
					        self.network = nn.ModuleList(network)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user