这是一篇面向自我复习/长期查阅的笔记式教程,覆盖 RNN 基础原理、BPTT、梯度问题、LSTM/GRU 结构与 PyTorch 代码示例,可直接用于 Hexo 博客(已含 YAML Front‑Matter、Mermaid 图与公式)。

目录


为什么需要 RNN

许多真实数据是序列:文本、语音、时间序列(股价/气温/用电负荷)等。普通前馈网络缺少“记忆”,难以利用上下文依赖循环神经网络(RNN)通过在时间维度引入隐藏状态并复用参数,让模型具备“把过去带到现在”的能力,从而建模时间依赖。

RNN 工作原理

给定长度为 T 的输入序列 (x_1,\dots,x_T),基本 RNN 的递推为:

ht=ϕ(W<!swig7>h<!swig8>+W<!swig9>xt+bh),yt=W<!swig10>ht+by,h_t = \phi(W_ h_NaN + W_ x_t + b_h),\qquad y_t = W_ h_t + b_y\,,

其中 (\phi) 常用 (\tanh) 或 ReLU;(h_t) 作为到 t 为止的“记忆摘要”。展开为时间链后,参数在每个时间步共享

RNN 时间展开示意(Mermaid)

flowchart LR
  x1[X1] --> r1((RNN))
  h0((h0)) --> r1
  r1 --> h1((h1))
  r1 --> y1[Y1]

  x2[X2] --> r2((RNN))
  h1 --> r2
  r2 --> h2((h2))
  r2 --> y2[Y2]

  x3[...]
  h2 --> r3((RNN))
  x3 --> r3
  r3 --> h3((h3))
  r3 --> y3[...]

最小 PyTorch 示例(nn.RNN

1
2
3
4
5
6
7
8
import torch, torch.nn as nn

rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=1, batch_first=False)
x = torch.randn(5, 3, 10) # (seq_len=5, batch=3, input_size=10)
h0 = torch.zeros(1, 3, 20) # (num_layers, batch, hidden_size)

out, hn = rnn(x, h0)
print(out.shape, hn.shape) # torch.Size([5, 3, 20]) torch.Size([1, 3, 20])

前向传播与 BPTT

对每个时间步的损失 (\ell_t)(如交叉熵),总损失 (L=\sum_t \ell_t / T)。训练时对时间展开后的计算图反向传播(BPTT),梯度会沿着 (h_t\to h_{t-1}\to\cdots) 回传,出现连乘项,从而带来:

  • 梯度消失:连乘 < 1,远端依赖难学;
  • 梯度爆炸:连乘 > 1,训练不稳。

截断 BPTT:长序列按块(如 50 步)反传,降低不稳定。

梯度消失/爆炸与常见对策

  • 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
  • 良好初始化与归一化:如 Xavier/Kaiming;适度 LayerNorm;
  • 合适激活:ReLU/LeakyReLU 在部分任务上更稳;
  • 门控结构:LSTM/GRU 通过“直通路径”缓解长依赖;
  • 缩短有效深度:截断 BPTT、分层/分块建模。

LSTM:长短期记忆网络

LSTM 在 RNN 基础上引入细胞状态 (C_t) 与三类门(遗忘/输入/输出),核心:

\begin f_t &= \sigma(W_f [h_NaN, x_t] + b_f),\\ i_t &= \sigma(W_i [h_NaN, x_t] + b_i),\quad \tilde C_t = \tanh(W_C [h_NaN, x_t] + b_C),\\ C_t &= f_t \odot C_NaN + i_t \odot \tilde C_t,\\ o_t &= \sigma(W_o [h_NaN, x_t] + b_o),\quad h_t = o_t \odot \tanh(C_t). \end

LSTM 单元(Mermaid)

flowchart LR
  subgraph LSTM Cell at time t
    ht_1[h_{t-1}] --> concat
    xt[x_t] --> concat
    concat[[concat]] --> Wf[W_f] --> ft[f_t = σ(.)]
    concat --> Wi[W_i] --> it[i_t = σ(.)]
    concat --> Wc[W_C] --> Ct_tilde[~C_t = tanh(.)]
    Ctm1[C_{t-1}] --> mul1
    ft --> mul1[[⊙]] --> add1
    it --> mul2[[⊙]]
    Ct_tilde --> mul2 --> add1[[+]] --> Ct[C_t]
    concat --> Wo[W_o] --> ot[o_t = σ(.)]
    Ct --> tanhCt[tanh(C_t)] --> mul3[[⊙]] --> ht[h_t]
    ot --> mul3
  end

PyTorch:nn.LSTM

1
2
3
4
5
6
7
8
9
import torch, torch.nn as nn

lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2, batch_first=False)
x = torch.randn(5, 3, 10)
h0 = torch.zeros(2, 3, 20)
c0 = torch.zeros(2, 3, 20)

out, (hn, cn) = lstm(x, (h0, c0))
print(out.shape, hn.shape, cn.shape) # [5,3,20], [2,3,20], [2,3,20]

GRU:门控循环单元

GRU 将 LSTM 的 (C_t) 与 (h_t) 合并,并用更新门 (z_t) 与 重置门 (r_t) 控制信息:

\begin z_t &= \sigma(W_z [h_NaN, x_t]+b_z),\quad r_t = \sigma(W_r [h_NaN, x_t]+b_r),\\ \tilde h_t &= \tanh(W_h [r_t \odot h_NaN, x_t]+b_h),\\ h_t &= z_t \odot h_NaN + (1-z_t) \odot \tilde h_t\,. \end

GRU 单元(Mermaid)

flowchart LR
  subgraph GRU Cell at time t
    ht_1[h_{t-1}] --> concat1
    xt[x_t] --> concat1[[concat]] --> Wz[W_z] --> zt[z_t=σ(.)]
    ht_1 --> concat2
    xt --> concat2[[concat]] --> Wr[W_r] --> rt[r_t=σ(.)]
    ht_1 --> mul1[[⊙]]
    rt --> mul1
    mul1 --> concat3[[concat with x_t]]
    xt --> concat3 --> Wh[W_h] --> h_tilde[~h_t=tanh(.)]
    zt --> mul2[[⊙]]
    ht_1 --> mul2
    zt_bar[1 - z_t] --> mul3[[⊙]]
    h_tilde --> mul3
    mul2 --> add[[+]] --> ht[h_t]
    mul3 --> add
  end

PyTorch:nn.GRU

1
2
3
4
5
6
7
8
import torch, torch.nn as nn

gru = nn.GRU(input_size=10, hidden_size=20, num_layers=2, batch_first=False)
x = torch.randn(5, 3, 10)
h0 = torch.zeros(2, 3, 20)

out, hn = gru(x, h0)
print(out.shape, hn.shape) # [5,3,20], [2,3,20]

RNN / LSTM / GRU 对比

模型 记忆路径 门控数量 参数量 长依赖能力 训练/推理速度
RNN 仅 (h_t) 0 最少 最快
LSTM (C_t) 直通 + (h_t) 3 较多
GRU (h_t) 直通 2 适中 较强 较快

实践经验:若参数预算较紧、数据量中等,GRU 常是高性价比选择;追求稳健长依赖建模时,LSTM 更常用。

双向 RNN 与 Seq2Seq 扩展

  • 双向 RNN(BiRNN):正向与反向两个 RNN 拼接,利用过去与未来上下文,常用于序列标注(NER/词性标注)。
  • Seq2Seq(编码器-解码器):编码器压缩源序列到状态,解码器按步生成目标序列;可配注意力以缓解信息瓶颈。

简要 Seq2Seq(Mermaid)

flowchart LR
  subgraph Encoder
    x1 --> e1((RNN))
    e1 --> h1
    x2 --> e2((RNN))
    h1 --> e2 --> h2
    xT --> eT((RNN))
    h_{T-1} --> eT --> hT
  end
  hT --> d0[Init Decoder]
  subgraph Decoder
    y0[] --> d1((RNN))
    d0 --> d1 --> y1
    y1 --> d2((RNN)) --> y2
    y_{t-1} --> dt((RNN)) --> y_t
  end

PyTorch 实战代码清单

1) 序列分类(变长样本 + Pack)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch, torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class RNNClassifier(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden=256, num_classes=2):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.rnn = nn.GRU(emb_dim, hidden, num_layers=1, batch_first=True, bidirectional=True)
self.out = nn.Linear(hidden*2, num_classes)
def forward(self, x, lengths):
# x: (B, L)
emb = self.emb(x) # (B, L, E)
packed = pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
out_packed, h = self.rnn(packed) # h: (2, B, H)
out, _ = pad_packed_sequence(out_packed, batch_first=True)
# 取双向最后隐状态拼接
h_cat = torch.cat([h[-2], h[-1]], dim=-1) # (B, 2H)
logits = self.out(h_cat)
return logits

2) 字符级语言模型(下一个字符预测)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch, torch.nn as nn

class CharRNNLM(nn.Module):
def __init__(self, vocab_size, emb=128, hidden=256, layers=2):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb)
self.rnn = nn.LSTM(emb, hidden, num_layers=layers, batch_first=True)
self.fc = nn.Linear(hidden, vocab_size)
def forward(self, x, h=None):
x = self.emb(x)
y, h = self.rnn(x, h)
logits = self.fc(y)
return logits, h

# 训练时:交叉熵,teacher forcing,将输入序列右移一位作为标签

3) 简易时间序列预测(单步回归)

1
2
3
4
5
6
7
8
9
10
11
import torch, torch.nn as nn

class TSRegressor(nn.Module):
def __init__(self, in_dim=1, hidden=64):
super().__init__()
self.rnn = nn.GRU(in_dim, hidden, batch_first=True)
self.fc = nn.Linear(hidden, 1)
def forward(self, x):
y, h = self.rnn(x) # x: (B, T, 1)
yhat = self.fc(y[:, -1]) # 取最后一步
return yhat

训练骨架(通用)

1
2
3
4
5
6
7
8
9
10
11
def train_one_epoch(model, loader, optimizer, clip=1.0, device="cuda"):
model.train()
total = 0.0
for batch in loader:
loss = compute_loss(model, batch, device)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
total += loss.item()
return total / max(1, len(loader))

训练与调参清单

  • 优化器:AdamW(1e-3~3e-4) 起步;配合余弦退火/StepLR;
  • 正则化:Dropout(0.1~0.5),权重衰减(1e-4~1e-2);
  • 批量与截断:按显存选择 batch,大序列用截断 BPTT;
  • Padding/Mask:NLP 常用 pad_sequence + pack_padded_sequence
  • 梯度稳定:开启裁剪、检查 NaN,必要时降低学习率;
  • 监控:分类看准确率/F1,语言建模看困惑度(PPL),回归看 MSE/MAE;
  • 选择结构:RNN(快/简)、GRU(性价比)、LSTM(长依赖稳健)。

总结

  • RNN 通过循环隐藏状态建模序列;
  • 长依赖难点在 BPTT 的梯度消失/爆炸;
  • LSTM/GRU 以门控和“直通路径”显著缓解;
  • PyTorch 提供 nn.RNN/nn.LSTM/nn.GRU 快速上手;
  • 结合 Pack、裁剪与正则,可在情感分析、语言建模、时间序列等任务稳定落地。

备注:文中 Mermaid 需在 Hexo 中启用相应插件(如 hexo-filter-mermaid-diagrams 或主题内置支持);数学公式需启用 hexo-math/katex/mathjax 等渲染插件。