网易首页 > 网易号 > 正文 申请入驻

自然语言生成任务中的5种采样方法介绍和Pytorch代码实现

0
分享至


在自然语言生成任务(NLG)中,采样方法是指从生成模型中获取文本输出的一种技术。本文将介绍常用的5中方法并用Pytorch进行实现。

1、Greedy Decoding

Greedy Decoding在每个时间步选择当前条件概率最高的词语作为输出,直到生成结束。在贪婪解码中,生成模型根据输入序列,逐个时间步地预测输出序列中的每个词语。在每个时间步,模型根据当前的隐藏状态和已生成的部分序列计算每个词语的条件概率分布,模型选择具有最高条件概率的词语作为当前时间步的输出。这个词语成为下一个时间步的输入,生成过程持续直到满足某种终止条件,比如生成了指定长度的序列或者生成了特殊的结束标记。

这种方法简单高效,每个时间步只需计算当前条件概率最高的词语,因此计算速度较快。但是由于每个时间步只考虑当前条件概率最高的词语,贪婪解码可能会陷入局部最优解,而无法获得全局最优解。这可能导致生成的文本缺乏多样性或不准确。

尽管贪婪解码存在一些局限性,但它仍然是许多序列生成任务中常用的一种方法,特别是在对速度要求较高或者任务较为简单的情况下。

def greedy_decoding(input_ids, max_tokens=300):
with torch.inference_mode():
for _ in range(max_tokens):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
if next_token == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, rearrange(next_token, 'c -> 1 c')], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
return generated_text

2、Beam Search

束搜索(Beam Search)是贪婪解码的一种扩展,通过在每个时间步保留多个候选序列来克服贪婪解码的局部最优问题。

在每个时间步保留概率最高的前几个候选词语,然后在下一个时间步基于这些候选词语继续扩展,直到生成结束。束搜索通过考虑多个候选词语路径,可以在一定程度上增加生成文本的多样性。

在束搜索中,模型在每个时间步会生成多个候选序列,而不是仅选择一个最优序列。模型会根据当前已生成的部分序列和隐藏状态,预测下一个时间步可能的词语,并计算每个词语的条件概率分布。

上图的每一步中,只保留两条最可能的路径(根据beam =2),而所有其他都被丢弃。此过程将继续进行,直到满足停止条件,该停止条件可以是生成序列结束令牌或达到最大序列长度的模型。最终输出将是最后一组路径中具有最高总体概率的序列。

from einops import rearrange
import torch.nn.functional as F
def beam_search(input_ids, max_tokens=100, beam_size=2):
beam_scores = torch.zeros(beam_size).to(device)
beam_sequences = input_ids.clone()
active_beams = torch.ones(beam_size, dtype=torch.bool)
for step in range(max_tokens):
outputs = model(beam_sequences)
logits = outputs.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
top_scores, top_indices = torch.topk(probs.flatten(), k=beam_size, sorted=False)
beam_indices = top_indices // probs.shape[-1]
token_indices = top_indices % probs.shape[-1]
beam_sequences = torch.cat([
beam_sequences[beam_indices],
token_indices.unsqueeze(-1)
], dim=-1)
beam_scores = top_scores
active_beams = ~(token_indices == tokenizer.eos_token_id)
if not active_beams.any():
print("no active beams")
break
best_beam = beam_scores.argmax()
best_sequence = beam_sequences[best_beam]
generated_text = tokenizer.decode(best_sequence)
return generated_text

3、Temperature Sampling

温度参数采样(Temperature Sampling)常用于基于概率的生成模型,如语言模型。它通过引入一个称为“温度”(Temperature)的参数来调整模型输出的概率分布,从而控制生成文本的多样性。

在温度参数采样中,模型在每个时间步生成词语时,会计算出词语的条件概率分布。然后模型将这个条件概率分布中的每个词语的概率值除以温度参数,对结果进行归一化处理,获得新的归一化概率分布。较高的温度值会使概率分布更平滑,从而增加生成文本的多样性。低概率的词语也有较高的可能性被选择;而较低的温度值则会使概率分布更集中,更倾向于选择高概率的词语,因此生成的文本更加确定性。最后模型根据这个新的归一化概率分布进行随机采样,选择生成的词语。

import torch
import torch.nn.functional as F
def temperature_sampling(logits, temperature=1.0):
logits = logits / temperature
probabilities = F.softmax(logits, dim=-1)
sampled_token = torch.multinomial(probabilities, 1)
return sampled_token.item()

4、Top-K Sampling

Top-K 采样(在每个时间步选择条件概率排名前 K 的词语,然后在这 K 个词语中进行随机采样。这种方法既能保持一定的生成质量,又能增加文本的多样性,并且可以通过限制候选词语的数量来控制生成文本的多样性。

这个过程使得生成的文本在保持一定的生成质量的同时,也具有一定的多样性,因为在候选词语中仍然存在一定的竞争性。

参数 K 控制了在每个时间步中保留的候选词语的数量。较小的 K 值会导致更加贪婪的行为,因为只有少数几个词语参与随机采样,而较大的 K 值会增加生成文本的多样性,但也会增加计算开销。

def top_k_sampling(input_ids, max_tokens=100, top_k=50, temperature=1.0):
for _ in range(max_tokens):
with torch.inference_mode():
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
top_k_probs = F.softmax(top_k_logits / temperature, dim=-1)
next_token_index = torch.multinomial(top_k_probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token_index)
input_ids = torch.cat([input_ids, next_token], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
return generated_text

5、Top-P (Nucleus) Sampling:

Nucleus Sampling(核采样),也被称为Top-p Sampling旨在在保持生成文本质量的同时增加多样性。这种方法可以视作是Top-K Sampling的一种变体,它在每个时间步根据模型输出的概率分布选择概率累积超过给定阈值p的词语集合,然后在这个词语集合中进行随机采样。这种方法会动态调整候选词语的数量,以保持一定的文本多样性。

在Nucleus Sampling中,模型在每个时间步生成词语时,首先按照概率从高到低对词汇表中的所有词语进行排序,然后模型计算累积概率,并找到累积概率超过给定阈值p的最小词语子集,这个子集就是所谓的“核”(nucleus)。模型在这个核中进行随机采样,根据词语的概率分布来选择最终输出的词语。这样做可以保证所选词语的总概率超过了阈值p,同时也保持了一定的多样性。

参数p是Nucleus Sampling中的重要参数,它决定了所选词语的概率总和。p的值会被设置在(0,1]之间,表示词语总概率的一个下界。

Nucleus Sampling 能够保持一定的生成质量,因为它在一定程度上考虑了概率分布。通过选择概率总和超过给定阈值p的词语子集进行随机采样,Nucleus Sampling 能够增加生成文本的多样性。

def top_p_sampling(input_ids, max_tokens=100, top_p=0.95):
with torch.inference_mode():
for _ in range(max_tokens):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probabilities, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits.scatter_(-1, indices_to_remove[None, :], float('-inf'))
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
generated_text = tokenizer.decode(input_ids[0])
return generated_text

总结

自然语言生成任务中,采样方法是非常重要的。选择合适的采样方法可以在一定程度上影响生成文本的质量、多样性和效率。上面介绍的几种采样方法各有特点,适用于不同的应用场景和需求。

贪婪解码是一种简单直接的方法,适用于速度要求较高的情况,但可能导致生成文本缺乏多样性。束搜索通过保留多个候选序列来克服贪婪解码的局部最优问题,生成的文本质量更高,但计算开销较大。Top-K 采样和核采样可以控制生成文本的多样性,适用于需要平衡质量和多样性的场景。温度参数采样则可以根据温度参数灵活调节生成文本的多样性,适用于需要平衡多样性和质量的任务。

https://avoid.overfit.cn/post/42c2631bc56347849d538768d84d47c2

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
曹杰,培植个人势力,违规为亲属谋求特殊待遇

曹杰,培植个人势力,违规为亲属谋求特殊待遇

新京报
2025-09-15 21:50:37
首尔爆发千人反美集会 高喊"特朗普滚出地球"

首尔爆发千人反美集会 高喊"特朗普滚出地球"

看看新闻Knews
2025-09-16 00:00:04
河北一高中收2700元教辅费?教体局通报

河北一高中收2700元教辅费?教体局通报

界面新闻
2025-09-15 19:22:15
“治港败类”曾荫权:治理香港7年,为何却在卸任后,获刑20个月

“治港败类”曾荫权:治理香港7年,为何却在卸任后,获刑20个月

卷史
2025-09-15 11:50:59
炸裂!200多名包括旅长在内的俄军官兵,被自己战友勾结乌军屠杀

炸裂!200多名包括旅长在内的俄军官兵,被自己战友勾结乌军屠杀

雪中风车
2025-09-15 11:18:12
豪门玩物,几年被折磨成牙齿全脱落,不足80斤,似骷髅,惨不忍睹

豪门玩物,几年被折磨成牙齿全脱落,不足80斤,似骷髅,惨不忍睹

观察鉴娱
2025-08-17 09:54:54
易会满出事后,浙江这家企业被查出多项问题

易会满出事后,浙江这家企业被查出多项问题

温百君
2025-09-15 22:09:14
日军为何要在帽子上挂两块布?专家:此举可以少死10万人!​

日军为何要在帽子上挂两块布?专家:此举可以少死10万人!​

顾史
2025-09-15 20:03:30
非必要,不做CT!JAMA子刊:每多做一次CT,癌症风险或增加43%

非必要,不做CT!JAMA子刊:每多做一次CT,癌症风险或增加43%

医诺维
2025-09-14 14:48:47
贾国龙还是嫌西贝死得不够透

贾国龙还是嫌西贝死得不够透

亮见
2025-09-15 12:50:18
我们错怪向余望了?要怪就怪安东尼奥把对的人放在了错误的位置!

我们错怪向余望了?要怪就怪安东尼奥把对的人放在了错误的位置!

田先生篮球
2025-09-14 23:01:45
央行史无前例大放水

央行史无前例大放水

边际财经实验室
2025-09-15 17:20:35
中华人民共和国政府与波兰共和国政府间合作委员会第四次全体会议共同文件

中华人民共和国政府与波兰共和国政府间合作委员会第四次全体会议共同文件

新京报
2025-09-15 21:06:21
2岁宝宝梅毒阳性,婆婆大骂儿媳,不料全家血检结果让人难以置信

2岁宝宝梅毒阳性,婆婆大骂儿媳,不料全家血检结果让人难以置信

二十一号故事铺
2024-08-15 01:10:02
世界第2遭爆冷!英格兰公开赛32决出8名国手5人!丁俊晖对手确定

世界第2遭爆冷!英格兰公开赛32决出8名国手5人!丁俊晖对手确定

越岭寻踪
2025-09-16 00:14:30
雷军:我们要认真向特斯拉学习,除了特斯拉,比我们好的没有!网友:“天下造车,唯马与雷尔!”

雷军:我们要认真向特斯拉学习,除了特斯拉,比我们好的没有!网友:“天下造车,唯马与雷尔!”

大白聊IT
2025-09-14 10:43:28
顿巴斯油荒爆发!俄军士兵自掏腰包加油,莫斯科12万人游行

顿巴斯油荒爆发!俄军士兵自掏腰包加油,莫斯科12万人游行

知兵
2025-09-11 14:13:07
哇塞!计划加盟日本男篮!未来的NBA首轮秀

哇塞!计划加盟日本男篮!未来的NBA首轮秀

篮球实战宝典
2025-09-15 22:22:52
7条中欧班列,全都要经过波兰?19架来历不明的无人机闯大祸了

7条中欧班列,全都要经过波兰?19架来历不明的无人机闯大祸了

历史摆渡
2025-09-15 12:50:03
涉嫌严重违纪违法!四川省古蔺县政协原党组副书记、副主席朱明辉被查(附简历)

涉嫌严重违纪违法!四川省古蔺县政协原党组副书记、副主席朱明辉被查(附简历)

鲁中晨报
2025-09-15 15:59:12
2025-09-16 01:27:00
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1769文章数 1427关注度
往期回顾 全部

科技要闻

官方:英伟达违反反垄断法 将施进一步调查

头条要闻

中美就TikTok等经贸问题在西班牙马德里举行会谈

头条要闻

中美就TikTok等经贸问题在西班牙马德里举行会谈

体育要闻

诺维茨基退役十年后,德国篮球走向巅峰

娱乐要闻

60岁张曼玉定居法国:瘦成皮包骨?

财经要闻

华与华秒怂 罗永浩称已接到对方道歉

汽车要闻

后轮转向和5C 2026款梦想家把想到的都给了

态度原创

旅游
家居
数码
公开课
军事航空

旅游要闻

热闻|清明假期将至,热门目的地有哪些?

家居要闻

典雅大气 舒适中带童趣

数码要闻

CASETiFY推出iPhone 17系列手机壳:晶釉手机壳亮相

公开课

李玫瑾:为什么性格比能力更重要?

军事要闻

三人伪装"外卖员""钓鱼佬"窃取军事秘密 详情公布

无障碍浏览 进入关怀版