try sep vit
This commit is contained in:
		
							
								
								
									
										11
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								models.py
									
									
									
									
									
								
							@@ -906,7 +906,7 @@ class PEG(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# feedforward
 | 
					# feedforward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FeedForward(nn.Module):
 | 
					class FeedForwardDSSA(nn.Module):
 | 
				
			||||||
    def __init__(self, dim, mult = 4, dropout = 0.):
 | 
					    def __init__(self, dim, mult = 4, dropout = 0.):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        inner_dim = int(dim * mult)
 | 
					        inner_dim = int(dim * mult)
 | 
				
			||||||
@@ -1094,7 +1094,10 @@ class PoolFormerBlock(nn.Module):
 | 
				
			|||||||
        self.attn_heads = 4
 | 
					        self.attn_heads = 4
 | 
				
			||||||
        self.attn_mask = None
 | 
					        self.attn_mask = None
 | 
				
			||||||
        # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
 | 
					        # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
 | 
				
			||||||
        self.token_mixer = DSSA(dim, heads=self.attn_heads, window_size=self.window_size)
 | 
					        self.token_mixer = nn.ModuleList([
 | 
				
			||||||
 | 
					            DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
 | 
				
			||||||
 | 
					            FeedForwardDSSA(dim)
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
        self.norm2 = norm_layer(dim)
 | 
					        self.norm2 = norm_layer(dim)
 | 
				
			||||||
        mlp_hidden_dim = int(dim * mlp_ratio)
 | 
					        mlp_hidden_dim = int(dim * mlp_ratio)
 | 
				
			||||||
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
 | 
					        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
 | 
				
			||||||
@@ -1213,7 +1216,7 @@ class LayerNormChannel(nn.Module):
 | 
				
			|||||||
            + self.bias.unsqueeze(-1).unsqueeze(-1)
 | 
					            + self.bias.unsqueeze(-1).unsqueeze(-1)
 | 
				
			||||||
        return x
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FeedForward(nn.Module):
 | 
					class FeedForwardFNet(nn.Module):
 | 
				
			||||||
    def __init__(self, dim, hidden_dim, dropout = 0.):
 | 
					    def __init__(self, dim, hidden_dim, dropout = 0.):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.net = nn.Sequential(
 | 
					        self.net = nn.Sequential(
 | 
				
			||||||
@@ -1249,7 +1252,7 @@ class FNet(nn.Module):
 | 
				
			|||||||
        for _ in range(depth):
 | 
					        for _ in range(depth):
 | 
				
			||||||
            self.layers.append(nn.ModuleList([
 | 
					            self.layers.append(nn.ModuleList([
 | 
				
			||||||
                PreNorm(dim, FNetBlock()),
 | 
					                PreNorm(dim, FNetBlock()),
 | 
				
			||||||
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
 | 
					                PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
 | 
				
			||||||
            ]))
 | 
					            ]))
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        for attn, ff in self.layers:
 | 
					        for attn, ff in self.layers:
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user