RNN 全面详解与 PyTorch 实战(含 LSTM/GRU)
这是一篇面向自我复习/长期查阅的笔记式教程,覆盖 RNN 基础原理、BPTT、梯度问题、LSTM/GRU 结构与 PyTorch 代码示例,可直接用于 Hexo 博客(已含 YAML Front‑Matter、Mermaid 图与公式)。
目录
- 为什么需要 RNN
- RNN 工作原理
- 前向传播与 BPTT
- 梯度消失/爆炸与常见对策
- LSTM:长短期记忆网络
- GRU:门控循环单元
- RNN / LSTM / GRU 对比
- 双向 RNN 与 Seq2Seq 扩展
- PyTorch 实战代码清单
- 训练与调参清单
- 总结
为什么需要 RNN
许多真实数据是序列:文本、语音、时间序列(股价/气温/用电负荷)等。普通前馈网络缺少“记忆”,难以利用上下文依赖。循环神经网络(RNN)通过在时间维度引入隐藏状态并复用参数,让模型具备“把过去带到现在”的能力,从而建模时间依赖。
RNN 工作原理
给定长度为 T 的输入序列 (x_1,\dots,x_T),基本 RNN 的递推为:
其中 (\phi) 常用 (\tanh) 或 ReLU;(h_t) 作为到 t 为止的“记忆摘要”。展开为时间链后,参数在每个时间步共享。
RNN 时间展开示意(Mermaid):
flowchart LR x1[X1] --> r1((RNN)) h0((h0)) --> r1 r1 --> h1((h1)) r1 --> y1[Y1] x2[X2] --> r2((RNN)) h1 --> r2 r2 --> h2((h2)) r2 --> y2[Y2] x3[...] h2 --> r3((RNN)) x3 --> r3 r3 --> h3((h3)) r3 --> y3[...]
最小 PyTorch 示例(nn.RNN):
1 | import torch, torch.nn as nn |
前向传播与 BPTT
对每个时间步的损失 (\ell_t)(如交叉熵),总损失 (L=\sum_t \ell_t / T)。训练时对时间展开后的计算图反向传播(BPTT),梯度会沿着 (h_t\to h_{t-1}\to\cdots) 回传,出现连乘项,从而带来:
- 梯度消失:连乘 < 1,远端依赖难学;
- 梯度爆炸:连乘 > 1,训练不稳。
截断 BPTT:长序列按块(如 50 步)反传,降低不稳定。
梯度消失/爆炸与常见对策
- 梯度裁剪:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - 良好初始化与归一化:如 Xavier/Kaiming;适度 LayerNorm;
- 合适激活:ReLU/LeakyReLU 在部分任务上更稳;
- 门控结构:LSTM/GRU 通过“直通路径”缓解长依赖;
- 缩短有效深度:截断 BPTT、分层/分块建模。
LSTM:长短期记忆网络
LSTM 在 RNN 基础上引入细胞状态 (C_t) 与三类门(遗忘/输入/输出),核心:
\begin f_t &= \sigma(W_f [h_NaN, x_t] + b_f),\\ i_t &= \sigma(W_i [h_NaN, x_t] + b_i),\quad \tilde C_t = \tanh(W_C [h_NaN, x_t] + b_C),\\ C_t &= f_t \odot C_NaN + i_t \odot \tilde C_t,\\ o_t &= \sigma(W_o [h_NaN, x_t] + b_o),\quad h_t = o_t \odot \tanh(C_t). \end
LSTM 单元(Mermaid):
flowchart LR
subgraph LSTM Cell at time t
ht_1[h_{t-1}] --> concat
xt[x_t] --> concat
concat[[concat]] --> Wf[W_f] --> ft[f_t = σ(.)]
concat --> Wi[W_i] --> it[i_t = σ(.)]
concat --> Wc[W_C] --> Ct_tilde[~C_t = tanh(.)]
Ctm1[C_{t-1}] --> mul1
ft --> mul1[[⊙]] --> add1
it --> mul2[[⊙]]
Ct_tilde --> mul2 --> add1[[+]] --> Ct[C_t]
concat --> Wo[W_o] --> ot[o_t = σ(.)]
Ct --> tanhCt[tanh(C_t)] --> mul3[[⊙]] --> ht[h_t]
ot --> mul3
end
PyTorch:nn.LSTM
1 | import torch, torch.nn as nn |
GRU:门控循环单元
GRU 将 LSTM 的 (C_t) 与 (h_t) 合并,并用更新门 (z_t) 与 重置门 (r_t) 控制信息:
\begin z_t &= \sigma(W_z [h_NaN, x_t]+b_z),\quad r_t = \sigma(W_r [h_NaN, x_t]+b_r),\\ \tilde h_t &= \tanh(W_h [r_t \odot h_NaN, x_t]+b_h),\\ h_t &= z_t \odot h_NaN + (1-z_t) \odot \tilde h_t\,. \end
GRU 单元(Mermaid):
flowchart LR
subgraph GRU Cell at time t
ht_1[h_{t-1}] --> concat1
xt[x_t] --> concat1[[concat]] --> Wz[W_z] --> zt[z_t=σ(.)]
ht_1 --> concat2
xt --> concat2[[concat]] --> Wr[W_r] --> rt[r_t=σ(.)]
ht_1 --> mul1[[⊙]]
rt --> mul1
mul1 --> concat3[[concat with x_t]]
xt --> concat3 --> Wh[W_h] --> h_tilde[~h_t=tanh(.)]
zt --> mul2[[⊙]]
ht_1 --> mul2
zt_bar[1 - z_t] --> mul3[[⊙]]
h_tilde --> mul3
mul2 --> add[[+]] --> ht[h_t]
mul3 --> add
end
PyTorch:nn.GRU
1 | import torch, torch.nn as nn |
RNN / LSTM / GRU 对比
| 模型 | 记忆路径 | 门控数量 | 参数量 | 长依赖能力 | 训练/推理速度 |
|---|---|---|---|---|---|
| RNN | 仅 (h_t) | 0 | 最少 | 弱 | 最快 |
| LSTM | (C_t) 直通 + (h_t) | 3 | 较多 | 强 | 慢 |
| GRU | (h_t) 直通 | 2 | 适中 | 较强 | 较快 |
实践经验:若参数预算较紧、数据量中等,GRU 常是高性价比选择;追求稳健长依赖建模时,LSTM 更常用。
双向 RNN 与 Seq2Seq 扩展
- 双向 RNN(BiRNN):正向与反向两个 RNN 拼接,利用过去与未来上下文,常用于序列标注(NER/词性标注)。
- Seq2Seq(编码器-解码器):编码器压缩源序列到状态,解码器按步生成目标序列;可配注意力以缓解信息瓶颈。
简要 Seq2Seq(Mermaid):
flowchart LR
subgraph Encoder
x1 --> e1((RNN))
e1 --> h1
x2 --> e2((RNN))
h1 --> e2 --> h2
xT --> eT((RNN))
h_{T-1} --> eT --> hT
end
hT --> d0[Init Decoder]
subgraph Decoder
y0[] --> d1((RNN))
d0 --> d1 --> y1
y1 --> d2((RNN)) --> y2
y_{t-1} --> dt((RNN)) --> y_t
end
PyTorch 实战代码清单
1) 序列分类(变长样本 + Pack)
1 | import torch, torch.nn as nn |
2) 字符级语言模型(下一个字符预测)
1 | import torch, torch.nn as nn |
3) 简易时间序列预测(单步回归)
1 | import torch, torch.nn as nn |
训练骨架(通用):
1 | def train_one_epoch(model, loader, optimizer, clip=1.0, device="cuda"): |
训练与调参清单
- 优化器:AdamW(1e-3~3e-4) 起步;配合余弦退火/StepLR;
- 正则化:Dropout(0.1~0.5),权重衰减(1e-4~1e-2);
- 批量与截断:按显存选择 batch,大序列用截断 BPTT;
- Padding/Mask:NLP 常用
pad_sequence+pack_padded_sequence; - 梯度稳定:开启裁剪、检查 NaN,必要时降低学习率;
- 监控:分类看准确率/F1,语言建模看困惑度(PPL),回归看 MSE/MAE;
- 选择结构:RNN(快/简)、GRU(性价比)、LSTM(长依赖稳健)。
总结
- RNN 通过循环隐藏状态建模序列;
- 长依赖难点在 BPTT 的梯度消失/爆炸;
- LSTM/GRU 以门控和“直通路径”显著缓解;
- PyTorch 提供
nn.RNN/nn.LSTM/nn.GRU快速上手; - 结合 Pack、裁剪与正则,可在情感分析、语言建模、时间序列等任务稳定落地。
备注:文中 Mermaid 需在 Hexo 中启用相应插件(如
hexo-filter-mermaid-diagrams或主题内置支持);数学公式需启用hexo-math/katex/mathjax等渲染插件。

