PyTorch 系列 之 nn.Module:所有模型的骨架
7 / 7 章
开篇:Module 不是“普通父类” 写 PyTorch 模型,最终都会回到 nn.Module。 它不是一个空壳父类。它是模型的总管:管参数、管子模块、管状态、管调用、管保存加载。 你写的 forward 只是计算逻辑。真正让模型“像模型
开篇:Module 不是“普通父类”
写 PyTorch 模型,最终都会回到 nn.Module。
它不是一个空壳父类。它是模型的总管:管参数、管子模块、管状态、管调用、管保存加载。
你写的 forward 只是计算逻辑。真正让模型“像模型一样工作”的,是 nn.Module 背后的注册和调用机制。
把 Module 放到 PyTorch 体系里看
1. 官方定位:Module 是神经网络模块的基类
PyTorch 官方文档说得很直接:torch.nn.Module 是所有神经网络模块的基类,自己的模型也应该继承它。
更关键的是:Module 可以包含其他 Module。也就是说,模型天然是一棵树。
这棵树的价值很大。只要树建好了,PyTorch 就能递归找参数、迁移设备、切换模式、保存状态。
模型是 Module 组件树
2. 写 Module,只记住两个函数
__init__:放结构,放状态,放需要被 PyTorch 管理的对象。
forward:放计算,描述输入如何一步步变成输出。
简单说:__init__ 管“有什么”,forward 管“怎么走”。
__init__ 与 forward 的分工
import torch
import torch.nn as nn
class TinyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 32)
self.act = nn.ReLU()
self.fc2 = nn.Linear(32, 2)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
return self.fc2(x)
model = TinyNet()
y = model(torch.randn(4, 10))这段代码很短,但已经触发了 Module 的核心能力:fc1、act、fc2 都被注册成子模块;Linear 里的 weight 和 bias 会被递归找到。
3. 自动注册:不是魔法,是 __setattr__
为什么 self.fc = nn.Linear(...) 之后,model.parameters() 能找到里面的权重?
因为 nn.Module 重写了 __setattr__。你给属性赋值时,它会看 value 的类型。
如果是 nn.Parameter,放进 _parameters;如果是 nn.Module,放进 _modules;如果是 Buffer 或 register_buffer,放进 _buffers;其他对象才只是普通属性。
__setattr__ 如何把对象分流到不同注册表
4. 四类成员:命运完全不同
很多 PyTorch Bug,本质是对象放错了地方。
想让优化器更新,必须是 Parameter。想随模型保存但不被优化,应该是 Buffer。想嵌套层,必须是 Module 或容器 Module。
Parameter、Buffer、Module、普通属性对比
5. 源码入口一:Module.__init__ 先准备内部字典
源码里,Module.__init__ 会先创建几个关键容器:_parameters、_buffers、_modules,以及 hooks 相关字典。
这就是为什么自定义模型第一行通常要写 super().__init__()。没有这一步,后续注册没有地方可放。
# torch/nn/modules/module.py 的核心思路
super().__setattr__("training", True)
super().__setattr__("_parameters", {})
super().__setattr__("_buffers", {})
super().__setattr__("_modules", {})注意:这里源码用的是 super().__setattr__,不是 self.xxx = xxx。这样可以绕开 Module.__setattr__ 的注册逻辑,避免初始化时自己拦截自己。
6. 源码入口二:model(x) 会走 __call__
官方文档提醒:虽然计算配方写在 forward 里,但应该调用 Module 实例本身,而不是直接调用 forward。
原因很简单:model(x) 会进入 __call__ 和 _call_impl,那里会处理 forward hooks、pre hooks、编译调用等逻辑。直接 model.forward(x) 会绕开这些能力。
model(x) 的真实调用路径
这条链路也解释了很多高级能力:特征提取可以靠 hook,调试可以靠 hook,torch.compile 也可以挂在 Module 调用入口上。
7. parameters():优化器为什么能拿到所有权重
parameters() 返回的是 Module 参数迭代器,通常直接交给优化器。
它默认 recurse=True,会递归进入所有子模块。所以你只需要写 optimizer = torch.optim.Adam(model.parameters()),不需要手动收集每一层权重。
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for name, p in model.named_parameters():
print(name, p.shape)关键点:被注册的参数才会出现。普通 Python list 里的层不会自动注册。需要用 nn.ModuleList 或 nn.Sequential。
8. Buffer:不学习,但要跟着模型走
Buffer 是模型状态,但不是可学习参数。
典型例子是 BatchNorm 的 running_mean。它会影响计算,也需要保存和加载,但不应该交给优化器更新。
class RunningStat(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(10))
def forward(self, x):
return x - self.running_meanpersistent=True 的 Buffer 会进 state_dict。persistent=False 的 Buffer 会跟着 to()/cuda() 迁移,但不会被保存。
9. state_dict:模型保存的核心
PyTorch 推荐保存 state_dict,而不是把整个模型对象直接打包。
state_dict 里包含参数和持久 Buffer。key 是名字,value 是 Tensor。名字来自 Module 树的路径。
state_dict 如何收集模型状态
torch.save(model.state_dict(), "model.pt")
new_model = TinyNet()
new_model.load_state_dict(torch.load("model.pt"))理解 state_dict,后面学断点续训、迁移学习、分布式保存、部署导出都会轻松很多。
10. train() 和 eval():只切模式,不跑训练
model.train() 不会自动开始训练。model.eval() 也不会自动开始推理。
它们只是递归设置每个 Module 的 training 标志。某些层会读这个标志,比如 Dropout 和 BatchNorm。
train/eval 的真实作用
一个常见误解:eval() 不等于 no_grad()。eval() 只影响层行为;no_grad() 才影响 Autograd 是否记录梯度。
11. to():为什么整棵模型能搬到 GPU
model.to(device) 能生效,是因为 Module 会递归处理自己的参数、Buffer 和子模块。
普通属性不会自动被迁移。普通 Tensor 如果没注册成 Buffer,也不会被 Module 管。
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyNet().to(device)
x = torch.randn(4, 10, device=device)
y = model(x)记住:需要跟着模型走的 Tensor,要么是 Parameter,要么注册成 Buffer。
12. 源码级主线:六个动作串起来
读 Module 源码,不要从头到尾硬啃。按动作读:初始化、注册、调用、遍历、状态、模式/迁移。
最重要的三个问题:
对象赋值后,去了 _parameters、_buffers、_modules,还是普通 __dict__?
模型调用时,是走 model(x),还是绕过了 __call__?
保存加载时,这个对象会不会出现在 state_dict 里?
13. 常见坑:先看有没有被注册
当参数没有更新、模型保存后丢东西、to(cuda) 后设备不一致,先不要怀疑玄学。
先检查注册。
print(model)
print(dict(model.named_parameters()).keys())
print(dict(model.named_buffers()).keys())
print(model.state_dict().keys())总结
nn.Module 是 PyTorch 模型的骨架。
它把散落的层、参数、状态组织成一棵可训练、可迁移、可保存、可调试的树。
掌握 Module,后面讲 Linear、Conv、Transformer、训练循环、模型保存、分布式训练,都会有同一套底层语言。
相关推荐
LangChain 系列 之 Short-term Memory:当前会话内如何保存状态?
1. 短期记忆到底是什么? Short-term Memory = 当前 thread 内的状态保存。它让 Agent 在同一段会话里记住前面发生过什么。 LangChain 官方把短期记忆叫做 thread-level persisten

烧钱不止,AI大模型厂商陷入“订阅困局”
长期以来,OpenAI与Anthropic等大模型厂商依靠固定月费的订阅模式,迅速完成了在用户群体中的普及。然而,行业研究机构SemiAnalysis的一项深度测评显示,这种看似双赢的模式,正在让厂商面临日益尖锐的成本危机。 测评机构通过对比两家公司多档订阅计划发现,由于重度用户在编程和“智能体”交互中产生的高频Token消耗,平价的订阅费往往难以覆盖其背后

蚂蚁集团正秘密测试“AI版支付宝”,智能体助手或成未来核心
支付宝的界面或许即将迎来一场彻底的“智能化重构”。近期有消息传出,蚂蚁集团正在内部秘密测试一款全新的AI版本支付宝,计划通过全方位的改版,将人工智能深度融入支付与生活服务生态。 据悉,这款新版本最直观的变化在于其交互逻辑的升级。用户将能够通过一键操作直接切入原生AI界面,以文字或语音的方式,直接向AI助手“阿宝”下达指令。无论是复杂的资金管理,还是日常的生活

AI音乐视频创作新风向:立刻MV 1. 1 版本实现“一键成片”跨越
音乐视频(MV)的制作门槛正在被新技术拉低。近日,一站式AI音乐视频创作平台“立刻MV”正式发布了1. 1 版本,此次更新在网页端与iPhone端同步落地,标志着该平台在AI辅助影像创作领域迈出了关键一步。 此次版本迭代的核心亮点,在于大幅提升了视频生成的表现力与灵活性。区别于以往常见的“图片幻灯片式”处理方式,1. 1 版本引入了全新的AI视频生成模块,能
LangChain 系列之Memory:多轮对话为什么不能只靠历史消息硬塞?
记忆不是“多带几轮聊天记录”。记忆是一个可读、可写、可压缩、可持久化的上下文系统。 1. 记忆不是历史消息 很多人第一次做多轮对话,会把历史消息全部拼进 Prompt。刚开始没问题。聊十轮、二十轮以后,问题就来了:上下文越来越长,成本越来越

蚂蚁阿福试水"AI+医生"模式:AI回答可由医生把关 15%用户会选择
6月15日,健康AI应用"蚂蚁阿福"宣布“拍皮肤”功能升级:可识别皮肤病种类从50种增至100多种,覆盖99%的线上就医常见皮肤问题。同时,阿福还上线了“医生把关”这一新服务:用户获得阿福的解答后,可选择邀请三甲医院的医生对阿福的分析结果进行复核并补充意见。这也是国内首个落地"AI问答+医生把关"协作模式的AI应用,为AI与医生的合作提供了可行路径,打开了“
阅读补充
一句话看懂
开篇:Module 不是“普通父类” 写 PyTorch 模型,最终都会回到 nn.Module。 它不是一个空壳父类。
事件背景
这篇内容围绕“PyTorch”展开,热闻岛基于公开信息整理事件背景、主要进展与可继续关注的方向。
事件时间线
发布
相关信息进入公开传播
更新
热闻岛对内容进行整理与补充。
看点
- · PyTorch的最新进展是什么
- · 相关信息对用户或行业会带来哪些影响
- · 后续是否会有新的回应或处理结果
后续关注
- · 后续官方回应或权威通报
- · 相关主体的进一步说明
- · 事件对普通用户和平台传播的持续影响
免责声明:本文仅代表作者观点,不构成投资建议、法律建议、医疗建议。财经类内容尤其需要注意风险;爆料类信息请以权威通报为准。
评论 (0)
登录后即可发表评论
去登录