try swin
This commit is contained in:
parent
465f98bef8
commit
b9efe68d3c
@ -800,7 +800,7 @@ class WindowAttention(nn.Module):
|
|||||||
|
|
||||||
# cosine attention
|
# cosine attention
|
||||||
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||||||
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
|
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp()
|
||||||
attn = attn * logit_scale
|
attn = attn * logit_scale
|
||||||
|
|
||||||
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
||||||
|
Loading…
Reference in New Issue
Block a user