网易首页 > 网易号 > 正文 申请入驻

PyTorch自定义学习率调度器实现指南

0
分享至


在深度学习训练过程中,学习率调度器扮演着至关重要的角色。这主要是因为在训练的不同阶段,模型的学习动态会发生显著变化。

在训练初期,损失函数通常呈现剧烈波动,梯度值较大且不稳定。此阶段的主要目标是在优化空间中快速接近某个局部最小值。然而,过高的学习率可能导致模型跳过潜在的优质局部最小值,从而限制了模型性能的充分发挥。

尽管PyTorch提供了多种预定义的学习率调度器,但在特定研究场景或需要更精细控制时,这些标准实现可能无法完全满足需求。在这种情况下,实现自定义学习率调度器成为了一个可行的解决方案。

本文将详细介绍如何通过扩展PyTorch的LRScheduler类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。

1、继承LRScheduler类

在PyTorch中实现自定义学习率调度器时,首先需要继承torch.optim.lr_scheduler.LRScheduler类。这个基类提供了管理学习率调度所需的核心功能。

通过继承LRScheduler,我们可以利用以下关键特性:

  1. self.optimizer:对优化器的引用,用于调整其学习率。
  2. self.base_lrs:存储优化器中所有参数组的初始学习率,可在自定义调度器中进行访问和修改。
  3. self.last_epoch:跟踪当前训练轮次,用于根据轮次数调整学习率。
  4. step()方法:在每个训练轮次后调用,用于自动更新学习率。
  5. 参数组处理:LRScheduler设计支持优化器中的多个参数组,允许对模型的不同部分应用不同的学习率调整策略。

以下是继承LRScheduler的基本代码结构:

from torch.optim.lr_scheduler import LRScheduler
class CosineWarmupScheduler(LRScheduler):
pass

通过继承LRScheduler,我们获得了上述所有功能,只需要通过实现get_lr()方法来定义学习率的具体变化逻辑。

2、实现构造函数

在自定义学习率调度器中,构造函数(__init__方法)用于初始化调度器的关键参数。这些参数定义了学习率调整的具体策略,包括预热期的长度、总训练轮次和最小学习率等。

以下是构造函数的实现示例:

class CosineWarmupScheduler(LRScheduler):
def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr=0.0, last_epoch=-1):
self.warmup_epochs = warmup_epochs # 学习率线性增加的预热轮次
self.total_epochs = total_epochs # 总训练轮次
self.min_lr = min_lr # 学习率下限
super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

参数说明:

  • optimizer:PyTorch优化器实例,其学习率将被调整。
  • warmup_epochs:预热阶段的轮次数,在此期间学习率线性增加。
  • total_epochs:训练的总轮次,包括预热阶段和衰减阶段。
  • min_lr:学习率的下限,衰减阶段的最终学习率不会低于此值。
  • last_epoch:上一轮的索引,用于恢复训练。默认为-1,表示从头开始训练。

3、调用父类构造函数

在自定义调度器的构造函数中,通过super()调用父类(LRScheduler)的构造函数是非常重要的。这确保了基类被正确初始化,使我们能够访问诸如self.optimizer、self.base_lrs和self.last_epoch等关键属性。

super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

这行代码不仅初始化了基类,还使得自定义调度器能够继承LRScheduler的其他有用方法,如step()和get_last_lr()。

4、实现get_lr()方法

get_lr()方法是自定义调度器的核心,它定义了学习率如何随训练轮次变化的具体逻辑。在本例中,我们实现了一个包含预热阶段的余弦衰减调度策略:

预热阶段:在前warmup_epochs轮中,学习率从0线性增加到初始学习率。

余弦衰减阶段:预热结束后,学习率按余弦函数从初始值衰减到最小值。

以下是get_lr()方法的实现:

import math
class CosineWarmupScheduler(LRScheduler):
def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr=0.0, last_epoch=-1):
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
self.min_lr = min_lr
super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
epoch = self.last_epoch + 1
if epoch <= self.warmup_epochs:
# 预热阶段:线性增加学习率
return [base_lr * epoch / self.warmup_epochs for base_lr in self.base_lrs]
else:
# 余弦衰减阶段
decay_epochs = self.total_epochs - self.warmup_epochs
cosine_decay = 0.5 * (1 + math.cos(math.pi * (epoch - self.warmup_epochs) / decay_epochs))
return [self.min_lr + (base_lr - self.min_lr) * cosine_decay for base_lr in self.base_lrs]

这个实现确保了学习率在预热阶段平滑增加,然后在剩余的训练过程中逐渐衰减,最终达到指定的最小值。

5、在训练流程中应用自定义调度器

实现自定义学习率调度器后,下一步是将其集成到训练流程中。以下示例展示了如何在PyTorch训练循环中初始化和使用自定义调度器:

import torch
import torch.optim as optim
# 定义模型(此处使用简单的线性模型作为示例)
model = torch.nn.Linear(10, 1)
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.1)
# 初始化自定义学习率调度器
scheduler = CosineWarmupScheduler(optimizer, warmup_epochs=5, total_epochs=50, min_lr=0.001)
# 训练循环
num_epochs = 50
for epoch in range(num_epochs):
model.train()
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 在每个epoch结束时更新学习率
scheduler.step()
# 记录当前学习率(用于监控)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{num_epochs}, Learning Rate: {current_lr:.6f}")

在这个示例中,我们执行以下关键步骤:

  1. 定义模型和优化器。
  2. 使用之前实现的CosineWarmupScheduler初始化学习率调度器。
  3. 在每个训练epoch中:
  • 执行标准的前向传播、损失计算和反向传播步骤。
  • 调用optimizer.step()更新模型参数。
  • 在epoch结束时调用scheduler.step()更新学习率。
  1. 使用scheduler.get_last_lr()获取并记录当前学习率,用于监控训练过程。

关键组件说明

  • scheduler.step():这个方法在每个epoch结束时调用,根据当前epoch更新学习率。它是动态调整学习率的核心机制。
  • scheduler.get_last_lr():返回当前的学习率。在多参数组的情况下,它返回一个列表,每个元素对应一个参数组的学习率。

总结

通过继承PyTorch的LRScheduler类并实现自定义的get_lr()方法,我们可以创建灵活的学习率调度策略,以满足特定的训练需求。本指南展示的带预热的余弦衰减调度器只是众多可能实现的一个例子。

自定义学习率调度器的关键优势在于:

  1. 灵活性:可以实现任何所需的学习率调整策略。
  2. 精确控制:能够根据训练动态和模型特性精细调整学习过程。
  3. 适应性:可以轻松适应不同的模型架构和数据集特性。

在实际应用中,可能需要进行大量实验来确定最适合特定问题的学习率调度策略。通过掌握自定义调度器的实现技巧,研究人员和工程师可以更灵活地优化深度学习模型的训练过程,从而潜在地提高模型性能和训练效率。

https://avoid.overfit.cn/post/aa1e90e02eb24d9f982e2c933bdd97a7

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
真替柳岩感到尴尬

真替柳岩感到尴尬

TVB的四小花
2026-03-15 04:54:12
36岁张含韵近况曝光!春节一周胖6斤,如今和“五竹叔”恋情稳定

36岁张含韵近况曝光!春节一周胖6斤,如今和“五竹叔”恋情稳定

陈意小可爱
2026-03-14 20:39:54
真的难!2026年B级车市场开启“大降价”,最大降幅52%,合资霸榜

真的难!2026年B级车市场开启“大降价”,最大降幅52%,合资霸榜

美食格物
2026-03-15 00:06:22
我遭妻子告发贪污受贿,被澄清后,我立刻便带儿子去做了亲子鉴定

我遭妻子告发贪污受贿,被澄清后,我立刻便带儿子去做了亲子鉴定

千秋文化
2026-03-09 20:55:26
血色24小时!卡车撞穿犹太幼儿园,校园恐袭同日爆发,美国安全防线正崩塌?

血色24小时!卡车撞穿犹太幼儿园,校园恐袭同日爆发,美国安全防线正崩塌?

最英国
2026-03-13 18:41:36
曼城空欢喜!16岁小将刷爆英超纪录,阿森纳2-0埃弗顿10分领跑

曼城空欢喜!16岁小将刷爆英超纪录,阿森纳2-0埃弗顿10分领跑

钉钉陌上花开
2026-03-15 05:05:14
伊朗副外长:允许部分国家船只通过霍尔木兹海峡

伊朗副外长:允许部分国家船只通过霍尔木兹海峡

界面新闻
2026-03-13 06:56:50
上个月去了次河南驻马店,我实话实说:当地人的素质彻底颠覆认

上个月去了次河南驻马店,我实话实说:当地人的素质彻底颠覆认

天气观察站
2026-03-14 10:44:33
韩国犯规大王黄大宪又用阴招!中国短道超新星世锦赛摔倒冲冠梦碎

韩国犯规大王黄大宪又用阴招!中国短道超新星世锦赛摔倒冲冠梦碎

杨华评论
2026-03-15 02:47:52
史上首次!特朗普被警告,美财政部投下金融核弹,美债或将崩盘

史上首次!特朗普被警告,美财政部投下金融核弹,美债或将崩盘

可乐爱微笑
2026-03-14 20:45:52
伊朗称哈尔克岛局势已得到控制

伊朗称哈尔克岛局势已得到控制

界面新闻
2026-03-14 18:35:14
随着曼城爆冷1-1,切尔西0-1,阿森纳2-0,英超最新积分榜出炉

随着曼城爆冷1-1,切尔西0-1,阿森纳2-0,英超最新积分榜出炉

侧身凌空斩
2026-03-15 06:16:02
告诉大家一个坏消息:福州,武汉已出现4大怪象,值得每个人深思

告诉大家一个坏消息:福州,武汉已出现4大怪象,值得每个人深思

美食格物
2026-03-14 22:30:22
哈佛研究实锤:抗老根本不用医美!这6个行为坚持半年,年轻10岁

哈佛研究实锤:抗老根本不用医美!这6个行为坚持半年,年轻10岁

白宸侃片
2026-03-12 19:17:17
85岁核武老人魏世杰:与苦难握手言和|面孔

85岁核武老人魏世杰:与苦难握手言和|面孔

大象新闻
2026-03-14 09:57:13
巴拿马媒体发出警告,中资撤离或重创经济,金融界批巴政府鲁莽

巴拿马媒体发出警告,中资撤离或重创经济,金融界批巴政府鲁莽

纪中百大事
2026-03-14 11:01:33
一觉醒来,外盘又变!美元大涨、黄金重挫,发生了什么?

一觉醒来,外盘又变!美元大涨、黄金重挫,发生了什么?

萌生财经
2026-03-14 12:43:10
23岁女同事住院没人管,我请7天假陪护,出院后董事长却亲自来接

23岁女同事住院没人管,我请7天假陪护,出院后董事长却亲自来接

千秋文化
2026-02-21 19:44:13
俄方:美以“闪电战”失败!伊朗抓“内鬼”:查获多套“星链”终端,“试图影响舆论制造混乱”!伊方嘲讽美国:将阵亡写成“轻微脑震荡”

俄方:美以“闪电战”失败!伊朗抓“内鬼”:查获多套“星链”终端,“试图影响舆论制造混乱”!伊方嘲讽美国:将阵亡写成“轻微脑震荡”

每日经济新闻
2026-03-14 16:12:07
比红薯通便,比芋头养人!中老年多吃它,利尿通便,春天吃正合适

比红薯通便,比芋头养人!中老年多吃它,利尿通便,春天吃正合适

阿龙美食记
2026-03-11 14:15:23
2026-03-15 07:20:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1948文章数 1456关注度
往期回顾 全部

科技要闻

xAI创始伙伴只剩两人!马斯克“痛改前非”

头条要闻

伊朗“命根子”遭到中东史上最大轰炸 特朗普表态

头条要闻

伊朗“命根子”遭到中东史上最大轰炸 特朗普表态

体育要闻

NBA唯一巴西球员,增重20KG顶内线

娱乐要闻

九成美曝田栩宁孕期出轨 AI反转引热议

财经要闻

3·15影子暗访|神秘的“特供酒”

汽车要闻

吉利银河M7技术首秀 实力重构主流电混SUV

态度原创

艺术
手机
游戏
公开课
军事航空

艺术要闻

这是唯一存世的毛主席画作

手机要闻

追觅AURORA:无聊的高端手机市场,迎来了一个无情的“执剑人”

FS社新作终于有新消息!NS2独占 多人在线

公开课

李玫瑾:为什么性格比能力更重要?

军事要闻

特朗普宣布空袭伊石油出口枢纽哈尔克岛

无障碍浏览 进入关怀版