结合代码理解各种注意力机制(二):多头注意力机制

前言

这是注意力机制系列的第二篇,在上一篇文章
结合代码理解各种注意力机制(一):自注意力机制中,我们介绍了自注意力机制。此篇文章我们将在自注意力机制的基础上介绍多头注意力机制。

多头注意力机制

概念

多头注意力机制(Multi-Head Attention)是自注意力机制的扩展,它可以通过不同的子空间,来捕捉更多的信息。

其实,也就是我们可以拥有多组Wq,Wk,Wv,获得多种不同视角的注意力分数,然后将其进行拼接并进行线性变换。

picture 0

通过多组QKV得到多个注意力分数,然后进行concat拼接,再进行线性变换。

picture 1

picture 2

代码实现

在代码实现上,我们可以通过封装上一篇中实现的SelfAttention类,来实现多头注意力机制。

通过使用nn.ModuleList,我们可以将多个自注意力机制的实例封装在一起,然后通过torch.cat进行拼接。拼接完成后,再通过一个nn.Linear建立全连接层进行线性变换,维度为num_heads * d_v, d_v。

MHA类封装:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class MultiHeadAttention(nn.Module):
def __init__(self, d_emb, d_q, d_k, d_v, num_heads):
super().__init__()
self.heads = nn.ModuleList(
[SelfAttention(d_emb=d_emb, d_q=d_q, d_k=d_k, d_v=d_v)
for _ in range(num_heads)]
)
self.fc_concat_output = nn.Linear(num_heads * d_v, d_v)

def forward(self, embed):
concat = torch.cat([head(embed) for head in self.heads], dim = -1)
#print(f"concat.shape:{concat.shape}")
output = self.fc_concat_output(concat)
#print(f"out_shape:{output.shape}")
return output

MHA类使用:

1
2
3
num_heads = 8
mha = MultiHeadAttention(d_emb, d_q, d_k, d_v, num_heads)
result = mha(embedding_sentence)

结果如下:

1
2
concat.shape:torch.Size([9, 128])
out_shape:torch.Size([9, 16])

结合代码理解各种注意力机制(二):多头注意力机制
https://abigail61.github.io/2025/01/26/结合代码理解各种注意力机制(二):多头注意力机制/
作者
Yajing Luo
发布于
2025年1月26日
许可协议