网易首页 > 网易号 > 正文 申请入驻

Transformer自回归关键技术:掩码注意力原理与PyTorch完整实现

0
分享至

掩码注意力(Causal Attention)是生成式模型的核心技术,它传统自注意力机制有根本的不同,掩码注意力限制模型只能关注当前位置之前的tokens,确保了自回归生成的因果性。

自注意力的掩码

自注意力机制在Transformer编码器和BERT等模型中广泛应用。这种机制的特点是每个token都能访问序列中的所有其他tokens,包括前面和后面的位置。这种双向注意力让模型能够充分利用上下文信息,将静态词嵌入转换为富含语境的动态表示。

而掩码注意力作为解码器的关键组件,人为地阻断了对未来tokens的访问。这种单向约束虽然看起来是限制,实际上正是语言生成任务的核心要求——模型必须基于已有的上下文来预测下一个词,而不能"偷看"答案。

Pytorch实现

实现掩码注意力需要五个关键步骤:

先看基础的类结构定义。这里需要为Query、Key、Value分别创建线性变换层,同时初始化一个上三角掩码矩阵:

import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) #dropout
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))

register_buffer这个方法很关键。它确保掩码矩阵会跟随模型在CPU和GPU之间移动,但不会作为可训练参数参与梯度更新。

然后就是前向传播的第一步,计算注意力分数。这部分和标准自注意力完全一样:

import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) #dropout
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
#attention_score
attention_score= vec_q @ vec_k.transpose(1,2) # 记住我们在处理批量数据

下面就是最关键的掩码操作。在这一步masked_fill_函数会将掩码为True的位置填充为负无穷大,这样在后续softmax操作中这些位置的权重就会变成0:

import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) #dropout
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
#attention_score
attention_score= vec_q @ vec_k.transpose(1,2)
#重要的代码行 #########
attention_score.masked_fill_(mask.bool()[:num_tokens,:num_tokens],-torch.inf)

然后是就是标准的缩放和softmax归一化。这里除法运算中的vec_k.shape[-1]是Key向量的维度,这个缩放因子能够稳定梯度:

import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) #dropout
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
#attention_score
attention_score= vec_q @ vec_k.transpose(1,2)
#重要的代码行 #########
attention_score.masked_fill_(mask.bool()[:num_tokens:num_tokens],-torch.inf)
#通过attention_weight进行缩放
attention_weight=torch.softmax(attention_score/vec_k.shape[-1],dim=-1)

最后加入dropout防止过拟合(也可以不加,现在的模型基本上不会dropout了,但是为了演示,我们可以在这里加入dropout),并与Value向量相乘得到最终的上下文表示:

import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) #dropout
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
#attention_score
attention_score= vec_q @ vec_k.transpose(1,2)
#重要的代码行 #########
attention_score.masked_fill_(mask.bool()[:num_tokens:num_tokens],-torch.inf)
#通过attention_weight进行缩放
attention_weight=torch.softmax(attention_score/vec_k.shape[-1],dim=-1)
drop_out=self.Drop(attention_weight)
return drop_out @ vec_v

最后我们来详细解释一下这行代码:

attention_score.masked_fill_(mask.bool()[:num_tokens,num_tokens],-torch.inf)

整个掩码操作分几个部分:首先计算原始的注意力分数矩阵,然后从预先注册的上三角掩码中切取对应大小的子矩阵。mask.bool()将0/1矩阵转换为布尔型,这样masked_fill_函数就将这些位置填充负无穷。

因为负无穷,所以当这些位置经过softmax函数时,exp(-∞)会趋向于0,从而实现了完全屏蔽未来tokens的效果。切片操作[:num_tokens,num_tokens]处理了不同序列长度的情况,因为上下文窗口是固定的,但实际输入序列长度可能变化。

总结

这种掩码机制让GPT等模型能够逐词生成文本,每次预测都只基于已经生成的内容,这正是自回归语言模型的精髓所在。通过一个上三角掩码矩阵,就能让模型在训练时学会"单向思考",这种设计的巧妙之处在于它完美平衡了计算效率和生成质量。

从技术实现角度来看,整个过程其实就是在标准自注意力基础上加了一步masked_fill_操作。但正是这简单的一步,让模型具备了真正的文本生成能力。相比之下,BERT等双向模型虽然在理解任务上表现出色,但在生成任务上就显得力不从心。

掌握了掩码注意力,你就理解了GPT、LLaMA等主流生成模型的核心工作原理。下次看到这些模型的论文或代码时,相信你会有更深刻的认识。

https://avoid.overfit.cn/post/1eaccf4c67f74b27839e3c5b2372f23c

作者:VIGNESHWARAN

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
湖人98-78火箭,4-2晋级!詹姆斯创3大历史纪录,一战看清6个现实

湖人98-78火箭,4-2晋级!詹姆斯创3大历史纪录,一战看清6个现实

毒舌NBA
2026-05-02 12:19:56
中年后存款到这个数,就不用太焦虑了(不是100万)

中年后存款到这个数,就不用太焦虑了(不是100万)

白浅娱乐聊
2026-05-02 09:13:53
浙江很低调的城市,人口仅543万,GDP却直逼9000亿,凭什么?

浙江很低调的城市,人口仅543万,GDP却直逼9000亿,凭什么?

跟着萱仔去旅游
2026-05-01 02:10:06
悲哀!38.8万彩礼加20万下车礼,新娘进门清礼金,新郎坐旁吃残羹

悲哀!38.8万彩礼加20万下车礼,新娘进门清礼金,新郎坐旁吃残羹

火山詩话
2026-05-02 06:53:57
猪大肠被关注!研究发现:糖尿病患者常吃猪大肠,或有5种变化

猪大肠被关注!研究发现:糖尿病患者常吃猪大肠,或有5种变化

芹姐说生活
2026-05-01 14:34:43
5月8日,国内成品油价格将调整

5月8日,国内成品油价格将调整

海峡网
2026-05-02 10:18:06
美国网友疑惑:美国曾7次帮助中国,为何中国人不感恩?

美国网友疑惑:美国曾7次帮助中国,为何中国人不感恩?

霹雳炮
2026-05-01 22:58:18
季后赛被打废!最失望阵容:从核心到角色,顶薪打飞了!

季后赛被打废!最失望阵容:从核心到角色,顶薪打飞了!

篮球盛世
2026-05-02 01:12:29
美国警告赖清德当局,只要大陆决定武力统一,台湾的结局只有一个

美国警告赖清德当局,只要大陆决定武力统一,台湾的结局只有一个

猫女的小树屋
2026-05-02 10:14:06
打什么电话比12345更管用?这些电话比它管用100倍,建议收藏好

打什么电话比12345更管用?这些电话比它管用100倍,建议收藏好

细说职场
2026-04-28 10:39:02
笑疯了!新加坡媒体尬吹印度,称用手吃饭更香,评论区怼得太狠了

笑疯了!新加坡媒体尬吹印度,称用手吃饭更香,评论区怼得太狠了

谭谈社会
2026-05-01 22:49:20
366.12分夺冠!陈芋汐换搭档拿高分:超联手全红婵奥运夺金成绩

366.12分夺冠!陈芋汐换搭档拿高分:超联手全红婵奥运夺金成绩

李喜林篮球绝杀
2026-05-01 17:38:44
美伊还没打完,第二个伊朗出现!对华使出卸磨杀驴,反向收割中企

美伊还没打完,第二个伊朗出现!对华使出卸磨杀驴,反向收割中企

解锁世界风云
2026-04-30 23:27:26
奥萨苏纳1-2巴萨:缺兵少将打手下败将,不给皇马任何的机会|前瞻

奥萨苏纳1-2巴萨:缺兵少将打手下败将,不给皇马任何的机会|前瞻

体育世界
2026-05-02 13:50:47
美伊谈判,大消息!昨夜今晨,大涨

美伊谈判,大消息!昨夜今晨,大涨

中国基金报
2026-05-02 08:17:51
约基奇时代第2次一轮游 要给这支掘金判死刑了吗?

约基奇时代第2次一轮游 要给这支掘金判死刑了吗?

体坛周报
2026-05-02 18:28:24
“不去后悔,去了更后悔!”五一最堵6大景区曝光,堵到一动不动

“不去后悔,去了更后悔!”五一最堵6大景区曝光,堵到一动不动

阿伧说事
2026-05-02 12:28:11
2026斯诺克世锦赛半决赛:希金斯与墨菲暂时战平,吴宜泽强势开局

2026斯诺克世锦赛半决赛:希金斯与墨菲暂时战平,吴宜泽强势开局

林子说事
2026-05-02 15:08:48
五种废品价格暴涨!提醒老人千万别乱扔,扔了就是白扔钱!

五种废品价格暴涨!提醒老人千万别乱扔,扔了就是白扔钱!

爱下厨的阿酾
2026-05-02 14:11:39
热搜第一!孙杨玩游戏下狠手 扇52岁范明耳光 后者捂脸:你真打啊

热搜第一!孙杨玩游戏下狠手 扇52岁范明耳光 后者捂脸:你真打啊

念洲
2026-05-02 09:23:08
2026-05-02 19:16:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1986文章数 1461关注度
往期回顾 全部

科技要闻

AI热潮耗尽库存,Mac Mini起售调高200美元

头条要闻

单亲妈妈被无辜羁押821天申请国赔遭叫停 最新消息来了

头条要闻

单亲妈妈被无辜羁押821天申请国赔遭叫停 最新消息来了

体育要闻

休赛期总冠军,轮到休斯顿火箭

娱乐要闻

白百何罕晒大儿子 18岁元宝越来越帅

财经要闻

雷军很努力 小米还是跌破了30港元大关

汽车要闻

新纪录!零跑汽车4月交付达71387台

态度原创

本地
健康
教育
公开课
军事航空

本地新闻

用青花瓷的方式,打开西溪湿地

干细胞治烧烫伤面临这些“瓶颈”

教育要闻

五年级几何,很多学生都无从下笔,其实一点也不难

公开课

李玫瑾:为什么性格比能力更重要?

军事要闻

特朗普:对伊战事结束 无限期延长停火

无障碍浏览 进入关怀版