try swin
This commit is contained in:
		@@ -800,7 +800,7 @@ class WindowAttention(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # cosine attention
 | 
					        # cosine attention
 | 
				
			||||||
        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
 | 
					        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
 | 
				
			||||||
        logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
 | 
					        logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp()
 | 
				
			||||||
        attn = attn * logit_scale
 | 
					        attn = attn * logit_scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
 | 
					        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user