This commit is contained in:
thanhvc3 2024-04-27 11:12:52 +07:00
parent 465f98bef8
commit b9efe68d3c

View File

@ -800,7 +800,7 @@ class WindowAttention(nn.Module):
# cosine attention # cosine attention
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).cuda()).exp()
attn = attn * logit_scale attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)