网易首页 > 网易号 > 正文 申请入驻

PyTorch那些事儿(七):深入解析nn.modules

0
分享至

正文共:11632字 预计阅读时间:30分钟

1 引言

如果要从PyTorch中选出一个最为核心的类,那无疑非nn.Module莫属了。nn.Module为神经网络层及整个模型提供了基本结构,塔不仅包含了神经网络中各种层的实现,还提供了易于扩展的基类,让用户可以方便地自定义层或组合已有层来构建复杂的模型。在PyTorch中,我们所熟知的网络结构,如卷积层、全连接层都是通过继承nn.Module实现的。这使得nn.Module成为PyTorch中最重要的组件之一。

使用nn.Module进行网络搭建时,还有以下优点:

  • 参数管理:nn.Module可以自动地跟踪模型中的所有参数(如权重和偏置),这样用户无需手动管理这些参数。此外,nn.Module还提供了方便的方法来遍历、更新和保存模型参数。

  • 设备无关性:nn.Module使得模型可以轻松在不同的设备(如CPU、GPU或TPU)之间切换,只需调用一个简单的方法即可。

  • 可组合性:用户可以将多个nn.Module子类组合成一个更大的网络结构,而不需要关心底层的实现细节。这使得神经网络的设计变得非常灵活和模块化。

  • 易于扩展:用户可以方便地创建自定义层或模型,只需继承nn.Module并实现特定的方法即可。这使得PyTorch适应各种不同的应用场景和研究方向。

总之,nn.Module为构建神经网络提供了一个灵活且功能强大的框架。通过继承nn.Module并实现自定义的前向传播方法,用户可以快速地搭建各种复杂的模型结构。在接下来的部分中,我们将详细介绍nn.Module的基础知识、高级功能以及如何使用它构建实际的神经网络模型。

2 自定义网络结构

网络结构是神经网络模型的基本框架,它描述了网络中各层的类型、顺序、连接方式以及各层之间的参数。无论何种网络结构,简单的单一网络层也好,错综复杂的网络模块也罢,在定义时,都必须继承nn.Module,并实现init()和forward()方法。这两个方法分别负责初始化网络结构的参数和定义网络结构的计算逻辑。

  • __init__()方法

在自定义网络结构时,我们需要在init()方法中定义网络中需要使用的各种资源,可能包括参数亦或者是其他网络结构。这些资源作为网络的属性存储在类实例中,供forward()方法使用。

在init()方法中,我们首先需要调用父类nn.Module的init()方法,以确保父类的初始化工作能够正常进行。

  • forward()方法

forward()方法是自定义网络结构中的核心部分,它负责定义网络的计算逻辑。在forward()方法中,我们需要根据网络结构的设计,依次调用__init__()方法中定义的各个层和模块,并传递输入数据。最终,forward()方法返回网络的输出结果。

forward()方法的参数通常包括网络的输入数据。在实际应用中,我们可能需要根据任务的不同需求来处理一个或多个输入。

如下代码所示,我们定义了一种最简单的网络结构——网络层:

import torch
from torch import nn

class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.randn(out_features))

def forward(self, input):
return (input @ self.weight) + self.bias
m = MyLinear(4, 3)
sample_input = torch.randn(4)
sample_input

Out:

tensor([-0.1949, -0.5323, -0.2206, 0.1164])

网络层之间可以进行相互组合以及嵌套,形成更加复杂的网络模块,这种特性使得模型的网络结构更加丰富,实现更加精细化对模型进行控制,网络模型本身就是一个网络模块,只不过功能完善。这里,为了区分单一的简单网络层,我把多层网络的组合或者嵌套之后的网络结构称为网络模块。话要说回来,网络层与网络模块的概念本就没有明确的区分。

import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.l0 = MyLinear(4, 3)
self.l1 = MyLinear(4, 3)
def forward(self, x0, x1):
x0 = self.l0(x0)
x1 = self.l1(x1)
x = x0 + x1
x = F.relu(x)
return x
net2 = Net()
x0 = torch.randn(4)
x1 = torch.randn(4)
net2(x0, x1)

Out:

tensor([2.3498, 0.0000, 2.7529], grad_fn=)

在上面代码中,我们通过继承nn.Module创建了一个网络模块,这个网络模块由两个MyLinear组成,并定义了前向传播的逻辑:两个MyLinear分别对输入的两输入进行运算,然后对输入进行求和,最后对求和结果进行ReLu激活。通过继承nn.Module创建了网络模块是最常用、最基本、最重要的一种方法,其他的所有方法都是在这种方法基础上进行了进一步的封装而已。

接下来,我们介绍几个在PyTorch中提供的一些在nn.Module基础上进一步封装好的其他创建网络模块的方法。

  • Sequential

Sequential是构建网络模块最简单的方式。Sequential本身就是nn.Module的子类,它会将我们输入的多个网络层进行连接,输入张量可以自动地在网络模块内容多个网络层间按顺序进行前向传播,例如下方代码中,输入张量首先会传递给第一个MyLinear进行运算,其结果将传递给ReLU层,然后ReLU的结果作为输入张量进入到最后一个MyLinear层中参与前向传播。

net1 = nn.Sequential(
MyLinear(4, 3),
nn.ReLU(),
MyLinear(3, 1)

Sequential构建网络模块快捷、方便,但是只能简单地将已有的网络结构按照既定顺序进行堆叠,但是PyTorch官方更加推荐通过自行继承nn.Module的方法去定义网络模块,这种方法相较于Sequential方法有以下优势:

  • 灵活性:继承nn.Module允许更灵活地定义网络结构和前向传播方法。您可以在forward()方法中包含任意的计算逻辑,例如条件判断、循环等。这使得自定义复杂的网络结构变得更加简单。

  • 多输入输出:继承nn.Module方法允许在forward()方法中接受多个输入和输出。这对于处理多模态数据或实现多任务学习等场景非常有用。

  • 易于调试:通过继承nn.Module方法创建的网络结构更容易进行调试,因为可以在forward()方法中添加断点或打印语句以检查中间结果。

对于复杂的网络结构,建议使用继承nn.Module方法;对于简单的线性网络结构,可以考虑使用Sequential方法,当然,也可以混合使用:

class DeepNet(nn.Module):
def __init__(self):
super().__init__()
self.mdl0 = nn.Sequential( # 使用Sequential组合出一个网络模块
MyLinear(4, 3),
nn.ReLU(),
MyLinear(3, 4)
self.mdl1 = nn.Sequential( # 使用Sequential组合出一个网络模块
MyLinear(4, 6),
nn.ReLU(),
MyLinear(6, 4)
self.mdl2 = Net() # 上述定义的Net网络模块

def forward(self, x):
x1 = self.mdl0(x)
x2 = self.mdl1(x)
x = self.mdl2(x1, x2)
return x
net2 = DeepNet()
x = torch.randn(4)
net2(x)

Out:

tensor([0.0000, 0.0000, 1.8571], grad_fn=)

  • ModuleList和ModuleDict

有时候,我们需要动态地设定网络模块内部的结构,这时候就需要使用ModuleList和ModuleDict两个类了。ModuleList和ModuleDict可以通过Python原生的list和dict来创建网络模块。

class DynamicNet(nn.Module):
def __init__(self, num_layers):
super().__init__()
# self.linears = [MyLinear(4, 4) for _ in range(num_layers)] # 直接将所有网络层放在Python原生的list中
self.linears = nn.ModuleList(
[MyLinear(4, 4) for _ in range(num_layers)])

# self.activations = { # 直接使用Python原生dict存放
# 'relu': nn.ReLU(),
# 'lrelu': nn.LeakyReLU()
# }
self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'lrelu': nn.LeakyReLU()
})
self.final = MyLinear(4, 1)
def forward(self, x, act):
for linear in self.linears:
x = linear(x)
x = self.activations[act](x)
x = self.final(x)
return x

dynamic_net = DynamicNet(3)
sample_input = torch.randn(4)
output = dynamic_net(sample_input, 'relu')
output

Out:

tensor([-1.3024], grad_fn=)

注意上述代码中,有两部分代码注释,注释部分代码使用的是Python原生的list和dict来存放网络层,并没有使用ModuleList和ModuleDict两个类,虽然这种方法在某些情况下仍然可以正常工作,但它不会自动跟踪和管理子模块。这意味着调用parameters()方法时,不会返回所有的参数,可能导致在优化过程中参数无法更新。为了解决这个问题,需要手动将这些子模块注册到父模块,更加麻烦,所以,还是建议通过ModuleList和ModuleDict两个类更加方便。

3 网络结构管理

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
华晨宇哭了损失大了,在云南投资上亿拿下地皮,如今紧急叫停

华晨宇哭了损失大了,在云南投资上亿拿下地皮,如今紧急叫停

以茶带书
2026-04-25 16:22:06
A股:2.5亿股民,今晚可能要兴奋得睡不着觉了,你知道为什么吗?

A股:2.5亿股民,今晚可能要兴奋得睡不着觉了,你知道为什么吗?

夜深爱杂谈
2026-04-25 20:37:19
欧洲用20年严格监管美国科技巨头,却把自己管成美国的“经济农奴”

欧洲用20年严格监管美国科技巨头,却把自己管成美国的“经济农奴”

风向观察
2026-04-25 14:32:55
特朗普:已取消威特科夫和库什纳前往巴基斯坦的行程

特朗普:已取消威特科夫和库什纳前往巴基斯坦的行程

新华社
2026-04-25 23:53:04
孩子脱臼复位只收100元,家长举报乱收费!卫健委:应收110元,你还少给了!家长拒缴费后离开!

孩子脱臼复位只收100元,家长举报乱收费!卫健委:应收110元,你还少给了!家长拒缴费后离开!

医脉圈
2026-04-25 20:04:06
失联十余日!两届奥运冠军,羽协主席张军被查!后果与影响分析

失联十余日!两届奥运冠军,羽协主席张军被查!后果与影响分析

史海流年号
2026-04-25 08:22:32
被指涉嫌在美强奸27岁女性?陆宏达辞任智度股份、国光电器董事长不到一周再起风波

被指涉嫌在美强奸27岁女性?陆宏达辞任智度股份、国光电器董事长不到一周再起风波

蓝鲸新闻
2026-04-25 15:54:05
澳洲首创! 全新癌症治疗法重磅突破: 不化疗, 不手术, 能治愈80%癌症, 已开始人体实验!

澳洲首创! 全新癌症治疗法重磅突破: 不化疗, 不手术, 能治愈80%癌症, 已开始人体实验!

澳微Daily
2026-04-25 15:43:13
斯凯奇,在三亚交上「专业」答卷 !

斯凯奇,在三亚交上「专业」答卷 !

跑步侠
2026-02-05 16:46:15
“全新以赴”不再是口号 大众在北京车展交出“兑现”答卷

“全新以赴”不再是口号 大众在北京车展交出“兑现”答卷

网上车市
2026-04-25 18:05:49
美国发出宣战书!美军集结到位,21国要求本国公民立即从伊朗撤离

美国发出宣战书!美军集结到位,21国要求本国公民立即从伊朗撤离

史政先锋
2026-04-25 14:47:45
世锦赛战报:再爆大冷,世界第2惨败出局,8强决出2席!罗伯逊4-1

世锦赛战报:再爆大冷,世界第2惨败出局,8强决出2席!罗伯逊4-1

求球不落谛
2026-04-25 19:43:42
48小时内,美日都打算派高层访华,特朗普说:我不生中国气

48小时内,美日都打算派高层访华,特朗普说:我不生中国气

一口娱乐
2026-04-25 17:37:52
山姆“爆雷”,3亿中产炸了!

山姆“爆雷”,3亿中产炸了!

新零售参考Pro
2026-04-23 16:31:50
5月1日起全国严查!以前的“小事”现在可能坐牢,抓紧了解一下!

5月1日起全国严查!以前的“小事”现在可能坐牢,抓紧了解一下!

细说职场
2026-04-25 17:42:02
10国签反华协议!沉默一天后,中方出手,不得未经允许接受美资

10国签反华协议!沉默一天后,中方出手,不得未经允许接受美资

清欢百味
2026-04-25 16:25:29
4200万人断缴社保,年轻人和灵活就业群体断缴率最高,均超30%!

4200万人断缴社保,年轻人和灵活就业群体断缴率最高,均超30%!

灯锦年
2026-04-25 15:52:24
悲催!浙江一女子出轨,丈夫直言婚姻本就是一场赌注,放手去爱吧

悲催!浙江一女子出轨,丈夫直言婚姻本就是一场赌注,放手去爱吧

火山詩话
2026-04-25 16:19:12
战与和的拉扯:美国无限反转在消磨什么?日本扩军狂飙想干什么?

战与和的拉扯:美国无限反转在消磨什么?日本扩军狂飙想干什么?

上观新闻
2026-04-25 18:49:05
2025年中国私人对乌克兰捐款位列全球第四

2025年中国私人对乌克兰捐款位列全球第四

刘耘博士
2026-04-25 10:25:41
2026-04-26 01:44:49
Ai学习的老章 incentive-icons
Ai学习的老章
Ai学习的老章
3351文章数 11140关注度
往期回顾 全部

科技要闻

DeepSeek V4发布!黄仁勋预言的"灾难"降临

头条要闻

媒体:美军在中东罕见高密度集结 伊朗开始调整战术

头条要闻

媒体:美军在中东罕见高密度集结 伊朗开始调整战术

体育要闻

那一刻开始,两支球队的命运悄然改变了

娱乐要闻

《我们的爸爸2》第一季完美爸爸翻车了

财经要闻

90%订单消失,中东旺季没了

汽车要闻

2026款乐道L90亮相北京车展 乐道L80正式官宣

态度原创

本地
时尚
旅游
数码
公开课

本地新闻

云游中国|逛世界风筝都 留学生探秘中国传统文化

这些穿搭适合春天!外套彩色内搭白色、裤子穿基础款,舒适大方

旅游要闻

美猴王VS水蜜桃,连云港、无锡文旅“双向奔赴”

数码要闻

联发科亮相2026北京车展:主动式智能体座舱解决方案

公开课

李玫瑾:为什么性格比能力更重要?

无障碍浏览 进入关怀版