第21节——手搓一个 LM (语言模型-文本续写)训练和推理器
·
21. 简易小说(散文)续写 Transformer 语言模型
- 把任务从「复制序列」换成「小说续写」——也就是语言模型(LM)任务
- 做一个 Decoder-only Transformer
- 写一套训练 + 续写代码骨架
- 前面的 Transformer 是 Encoder-Decoder 结构,主要用于“翻译式”任务(输入一句话 → 输出另一种格式)。
- 而“小说续写”本质是:给定前文 token 序列 x_1,x_2,…,x_T,去预测下一个 token x_T+1
- 所以训练目标是:语言模型(Language Model, LM) / 自回归模型
- 常见的做法就是 GPT 那一类:只有 Decoder,没有 Encoder,每一层都是:Masked Self-Attention + FFN。
21.1 数据准备:把小说文本变成 token 序列
- 先用字符级别做,不引入额外 tokenizer,方便专心玩模型。
- 创建一个文本文件:./data/novel.txt,里面就是一堆散文/小说内容(中文/英文都行,越多越好)
- 我已经准备好了,用下面的分析工具可以看一下文本文件的信息
import os
import sys
import re
def analyze_txt_file(file_path, encoding='utf-8'):
# 验证文件后缀(可选,增强容错性)
if not file_path.lower().endswith('.txt'):
print(f"⚠️ 警告:文件后缀不是.txt,可能不是文本文件")
try:
# 4. 读取文件内容(高效处理大文件)
with open(file_path, 'r', encoding=encoding, errors='ignore') as f:
# 读取全部内容(小文件)或处理大文件(优化内存)
file_content = f.read()
# 5. 计算统计信息
stats = {}
# 基本信息
stats['文件路径'] = file_path
stats['文件大小(字节)'] = os.path.getsize(file_path)
stats['总字符数'] = len(file_content)
stats['总行数'] = file_content.count('\n') + 1 # 最后一行可能没有换行符
# 字符类型统计
stats['中文字符数'] = len(re.findall(r'[\u4e00-\u9fa5]', file_content))
stats['英文字母数'] = len(re.findall(r'[a-zA-Z]', file_content))
stats['数字字符数'] = len(re.findall(r'[0-9]', file_content))
stats['空格字符数'] = len(re.findall(r'\s', file_content)) # 包含空格、制表符、换行符等
stats['标点符号数'] = len(re.findall(r'[^\w\s\u4e00-\u9fa5]', file_content)) # 非字母数字中文空格
# 单词数(以空格/标点分隔,简单统计)
words = re.findall(r'[a-zA-Z0-9\u4e00-\u9fa5]+', file_content)
stats['单词/词组数'] = len(words)
# 6. 提取前后各100字符
content_len = len(file_content)
prefix = file_content[:100] if content_len >= 100 else file_content
suffix = file_content[-100:] if content_len >= 100 else ""
# 7. 格式化输出结果
print("=" * 60)
print("📊 TXT文件统计信息")
print("=" * 60)
for key, value in stats.items():
print(f"{key:12}: {value}")
print("\n" + "-" * 60)
print("📝 文件内容预览(前后各100字符)")
print("-" * 60)
print("\n【前100字符】:")
if prefix:
print(prefix)
else:
print("(文件为空)")
if content_len > 100:
print(f"\n...(中间省略 {content_len - 200} 个字符)...")
print("\n【后100字符】:")
print(suffix)
else:
print("\n(文件长度不足200字符,无尾部预览)")
print("\n" + "=" * 60)
except Exception as e:
print(f"❌ 处理文件时出错:{str(e)}")
print("💡 建议尝试其他编码格式,例如:gbk、gb2312、latin-1")
if __name__ == "__main__":
file_path = "./data/novel.txt"
encoding = 'utf-8'
analyze_txt_file(file_path, encoding)
============================================================
📊 TXT文件统计信息
============================================================
文件路径 : ./data/novel.txt
文件大小(字节) : 11073
总字符数 : 3711
总行数 : 4
中文字符数 : 3197
英文字母数 : 3
数字字符数 : 8
空格字符数 : 20
标点符号数 : 482
单词/词组数 : 460
------------------------------------------------------------
📝 文件内容预览(前后各100字符)
------------------------------------------------------------
【前100字符】:
盼望着,盼望着,东风来了,春天的脚步近了。 一切都像刚睡醒的样子,欣欣然张开了眼。山朗润起来了,水涨起来了,太阳的脸红起来了。小草偷偷地从土里钻出来,嫩嫩的,绿绿的。园子里,田野里,瞧去,一大片一大片
...(中间省略 3511 个字符)...
【后100字符】:
的绿呀!我若能裁你以为带,我将赠给那轻盈的舞女;她必能临风飘举了。我若能挹你以为眼,我将赠给那善歌的盲妹;她必明眸善睐了。我舍不得你;我怎舍得你呢%3F我用手拍着你,抚摩着你,如同一个十二三岁的小姑娘
============================================================
读取文本并构建字符词表
import io
from collections import Counter
import torch
# 路径你自己改
NOVEL_PATH = "./data/novel.txt"
with io.open(NOVEL_PATH, "r", encoding="utf-8") as f:
text = f.read()
print("Corpus length (characters):", len(text))
print("前 500 个字符预览:")
print(text[:500])
Corpus length (characters): 3711
前 500 个字符预览:
盼望着,盼望着,东风来了,春天的脚步近了。 一切都像刚睡醒的样子,欣欣然张开了眼。山朗润起来了,水涨起来了,太阳的脸红起来了。小草偷偷地从土里钻出来,嫩嫩的,绿绿的。园子里,田野里,瞧去,一大片一大片满是的。坐着,躺着,打两个滚,踢几脚球,赛几趟跑,捉几回迷藏。风轻悄悄的,草软绵绵的。桃树、杏树、梨树,你不让我,我不让你,都开满了花赶趟儿。红的像火,粉的像霞,白的像雪。花里带着甜味儿;闭了眼,树上仿佛已经满是桃儿、杏儿、梨儿。花下成千成百的蜜蜂嗡嗡地闹着,大小的蝴蝶飞来飞去。野花遍地是:杂样儿,有名字的,没名字的,散在草丛里,像眼睛,像星星,还眨呀眨的。 “吹面不寒杨柳风”,不错的,像母亲的手抚摸着你。风里带来些新翻的泥土的气息,混着青草味儿,还有各种花的香,都在微微润湿的空气里酝酿。鸟儿将巢安在繁花嫩叶当中,高兴起来了,呼朋引伴地卖弄清脆的喉咙,唱出宛转的曲子,跟轻风流水应和着。牛背上牧童的短笛,这时候也成天嘹亮地响着。 雨是最寻常的,一下就是三两天。可别恼。看,像牛毛,像花针,像细丝,密密地斜织着,人家屋顶上全笼着一层薄烟。树叶儿却绿得发亮,小草儿也青得逼你的眼。傍晚时候,上灯了,
然后构建「字符 → id」映射(char-level vocab):
# 特殊符号
PAD = 0
BOS = 1
EOS = 2 # 可以不用,先预留
SPECIAL_TOKENS = ["<PAD>", "<BOS>", "<EOS>"]
# 统计所有字符
counter = Counter(text)
chars = sorted(counter.keys())
print("Unique chars:", len(chars))
# 为每个字符分配 id,从 3 开始,避开PAD/BOS/EOS
id2char = SPECIAL_TOKENS + chars
char2id = {ch: idx for idx, ch in enumerate(id2char)}
vocab_size = len(id2char)
print("vocab_size:", vocab_size)
# 把整个文本编码成 id 序列
ids = [char2id[ch] for ch in text]
ids = torch.tensor(ids, dtype=torch.long)
print("ids shape:", ids.shape)
Unique chars: 829
vocab_size: 832
ids shape: torch.Size([3711])
- 注意:这里的 vocab_size 就是我们之后 Transformer 的 tgt_vocab_size
用滑动窗口构造训练样本(语言模型)
- 我们要把长序列打成很多小样本,每个样本用来做「预测下一个 token」:
- 给定长度为 seq_len 的片段:[x0, x1, …, x_{L-1}]
- 输入是:x_input = [BOS, x0, x1, …, x_{L-2}]
- 标签是:x_target = [x0, x1, …, x_{L-1}]
- 实现一个简单 Dataset:
from torch.utils.data import Dataset, DataLoader
import random
class CharLMDataset(Dataset):
def __init__(self, ids, seq_len):
"""
ids: 整个语料的 token id 序列 (T,)
seq_len: 每个样本的目标长度(不含 BOS)
"""
self.ids = ids
self.seq_len = seq_len
# 能切多少个窗口(简单起见,忽略最后不够的部分)
self.num_samples = (len(ids) - 1) // seq_len
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
"""
返回:
input_ids: (seq_len + 1,) = [BOS, x0, x1, ..., x_{L-1}? or x_{L-2}]
target_ids: (seq_len + 1,) = [x0, x1, ..., x_{L-1}, x_{L}]
"""
start = idx * self.seq_len
end = start + self.seq_len + 1 # 多取一个当 label 尾巴
chunk = self.ids[start:end] # 长度 seq_len+1
# 如果不足就 pad(一般不会,保险)
if len(chunk) < self.seq_len + 1:
pad_len = self.seq_len + 1 - len(chunk)
chunk = torch.cat([chunk, torch.full((pad_len,), PAD, dtype=torch.long)])
# 模型输入:在前面加 BOS
# 例如 chunk = [x0, x1, x2, x3, x4]
# input_ids = [BOS, x0, x1, x2, x3]
# target_ids = [x0, x1, x2, x3, x4]
input_ids = torch.empty(self.seq_len + 1, dtype=torch.long)
input_ids[0] = BOS
input_ids[1:] = chunk[:-1]
target_ids = chunk # 长度 seq_len+1
return input_ids, target_ids
SEQ_LEN = 64
dataset = CharLMDataset(ids, seq_len=SEQ_LEN)
print("num_samples:", len(dataset))
input_ids, target_ids = dataset[0]
print("input_ids:", input_ids[:20])
print("target_ids:", target_ids[:20])
num_samples: 57
input_ids: tensor([ 1, 551, 420, 557, 828, 551, 420, 557, 828, 29, 809, 429, 46, 828,
403, 219, 543, 639, 455, 744])
target_ids: tensor([551, 420, 557, 828, 551, 420, 557, 828, 29, 809, 429, 46, 828, 403,
219, 543, 639, 455, 744, 46])
- DataLoader:
BATCH_SIZE = 32
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
21.2 定义 Decoder-only Transformer 语言模型
- 复用之前的组件:MultiHeadSelfAttention、PositionwiseFeedForward、PositionalEncoding、TokenEmbedding
from MyTransformer import MultiHeadSelfAttention, PositionwiseFeedForward, PositionalEncoding, TokenEmbedding
- 现在写一个「只有 Self-Attn + FFN」的层:
import torch
from torch import nn
class DecoderOnlyLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadSelfAttention(d_model, num_heads)
self.ffn = PositionwiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, pad_mask=None):
"""
x: (B, L, d_model)
pad_mask: (B, L) —— 1 表示 PAD
"""
B, L, _ = x.shape
device = x.device
# 生成因果 mask(不能看未来)
subsequent_mask = torch.triu(
torch.ones(L, L, device=device), diagonal=1
) # (L, L)
subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
# Self-Attn
_attn_out, self_attn_map = self.self_attn(
x,
pad_mask=pad_mask, # 屏蔽 PAD
attn_mask=subsequent_mask # 屏蔽未来
)
x = x + self.dropout1(_attn_out)
x = self.norm1(x)
# FFN
_ffn_out = self.ffn(x)
x = x + self.dropout2(_ffn_out)
x = self.norm2(x)
return x, self_attn_map
- 然后定义 LM 模型:
class TransformerLM(nn.Module):
def __init__(self,
vocab_size,
d_model=256,
num_heads=4,
d_ff=512,
num_layers=4,
max_len=2048,
pad_id=PAD,
dropout=0.1):
super().__init__()
self.d_model = d_model
self.pad_id = pad_id
self.tok_embed = TokenEmbedding(vocab_size, d_model, pad_id=pad_id)
self.pos_encoding = PositionalEncoding(d_model, max_len=max_len)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList([
DecoderOnlyLayer(d_model, num_heads, d_ff, dropout=dropout)
for _ in range(num_layers)
])
self.output_proj = nn.Linear(d_model, vocab_size)
def make_pad_mask(self, ids):
return (ids == self.pad_id).int() # (B, L)
def forward(self, input_ids):
"""
input_ids: (B, L) —— 已有上下文(含 BOS)
返回:
logits: (B, L, vocab_size)
"""
B, L = input_ids.shape
pad_mask = self.make_pad_mask(input_ids)
x = self.tok_embed(input_ids) # (B,L,d_model)
x = x * math.sqrt(self.d_model)
pos = self.pos_encoding(x) # (B,L,d_model)
x = x + pos
x = self.dropout(x)
attn_maps = []
for layer in self.layers:
x, attn = layer(x, pad_mask=pad_mask)
attn_maps.append(attn)
logits = self.output_proj(x)
return logits, attn_maps
21.3 训练循环:小说语言模型
- 和前面的 copy 任务几乎一样,只是这里:
- 输入:input_ids(已经是 [BOS, x0, x1,…])
- 标签:target_ids(对应 [x0, x1, …])
import math
import matplotlib.pyplot as plt
from IPython.display import clear_output
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerLM(
vocab_size=vocab_size,
d_model=256,
num_heads=4,
d_ff=512,
num_layers=4,
max_len=SEQ_LEN + 1,
pad_id=PAD,
dropout=0.1,
).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)
EPOCHS = 5000
train_losses = [] # 用来记录每个 epoch 的 avg_loss
for epoch in range(1, EPOCHS + 1):
model.train()
total_loss = 0.0
for batch_idx, (input_ids, target_ids) in enumerate(loader):
input_ids = input_ids.to(DEVICE) # (B, L)
target_ids = target_ids.to(DEVICE) # (B, L)
logits, _ = model(input_ids) # (B, L, vocab_size)
B, L, V = logits.shape
logits_flat = logits.view(B * L, V)
target_flat = target_ids.view(B * L)
loss = criterion(logits_flat, target_flat)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / (batch_idx + 1)
train_losses.append(avg_loss)
# 每 N 个 epoch 更新一次图像(比如每 10 轮)
if epoch % 10 == 0 or epoch == 1:
clear_output(wait=True) # 清空输出区域,达到“动态刷新”效果
plt.figure(figsize=(8, 4))
plt.plot(range(1, len(train_losses) + 1), train_losses, marker="o")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.title("Training Loss Curve")
plt.grid(True)
plt.tight_layout()
plt.show()
print(f"Epoch {epoch}/{EPOCHS}, avg_loss = {avg_loss:.4f}")

Epoch 5000/5000, avg_loss = 0.0660
- 作者这里设置 EPOCHS = 5000,你可以自己更改,loss能收敛就行
- 真正训练小说 LM,EPOCHS 可能更多,而且最好上 GPU。
21.4 续写函数:给一段开头,生成后续文本
- 现在,我们实现一个自回归生成的续写函数:
- 输入:prompt 文本(字符串)
- 过程:
- 1.把 prompt 编码成 id 序列
- 2.前面加 BOS,得到初始 input_ids
- 3.循环:
- 喂给模型,拿最后一个位置的 logits
- 取 argmax,得到下一个 token id
- 拼接到序列后面
- 4.最后把 id 转回字符,输出文本
import sys
import time
def generate_text_stream(model, prompt, max_new_tokens=100, delay=0.02):
"""
delay:每个字符之间暂停多少秒,0 表示立即输出
"""
model.eval()
with torch.no_grad():
# ===== 1) 编码 prompt =====
prompt_ids = encode_text(prompt)
if len(prompt_ids) == 0:
prompt_ids = [char2id.get("。", 3)]
# ===== 2) initial input_ids =====
input_ids = [BOS] + prompt_ids
input_ids = torch.tensor(input_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) # (1, L)
# 先把 prompt 输出出来
print(prompt, end="", flush=True)
# ===== 3) 自回归生成 =====
for _ in range(max_new_tokens):
# 如果太长,截断
if input_ids.size(1) > SEQ_LEN:
input_chunk = input_ids[:, -SEQ_LEN:]
else:
input_chunk = input_ids
logits, _ = model(input_chunk)
next_token_logits = logits[0, -1]
next_id = int(next_token_logits.argmax(dim=-1).item())
# 拼接到输入序列中
next_token_tensor = torch.tensor([[next_id]], dtype=torch.long, device=DEVICE)
input_ids = torch.cat([input_ids, next_token_tensor], dim=1)
# ===== 4) 输出字符 =====
if next_id < len(id2char):
next_char = id2char[next_id]
else:
next_char = "" # 出现未知 id 时避免报错
print(next_char, end="", flush=True)
if delay > 0:
time.sleep(delay)
print() # 最后换行
# 调用
generate_text_stream(model, "他静静地站在窗前,", max_new_tokens=1000, delay=0.02)
他静静地站在窗前,慢爬下眼。到这就是惦记着我走到车上。他去搀他。他和我儿,慢探身下午上车上车上车上车上一股脑儿放在我那年纪的迂;他和我北京什么一个胖的儿子,他和他们罢:“吹面,这平的紫毛大衣铺着,告诉去。我们罢:“吹面,他们罢:“进去不复返呢?
我们自己逃走了一个我们的时我们的我们的时,我们的时候,我们的时候,我与他 们的时候,他去。他终于他终于他相见,他戴着旋转。我,蹒跚13F我,聪明了一个橘子往回走到铁道:“进去不复返,我们去不好!”+我。他写了一个瀑布潭。我们去吧,他相见我们去吧。我说:“我们去吧,他相见我们去吧,他终于不送我们去不送我,自己逃走了靠车北京去不好座位。”+我,自己插嘴不送我,只是惦记着照看见他们去。我,怕茶房不送我,怕茶房不送我最低已二十岁,像火钱办吧。”他的儿将他写了一个橘子往往回家中光中光中光中光中光中光中光中,他终于决定了。其实我做的紫毛大年纪的皮大衣铺着,惦记着;我北京什么偏要回北京什么偏要紧,终于忘却我们。到南京到南京到南京已经从北京时,他们先到那醉人的绿呀!”+我走了。到南京什么痕迹呢?又藏在何处呢?是他们自己逃走了罢了罢了罢了罢:“爸爸,如此处呢?是他们自己逃走了罢?是自己逃去吧,他们自己逃走了我将我掩着照我买几个我最低。到这时候;我掩着照看见他已二十岁,说:“我掩着我掩的紫毛大衣上车北京去吧。他们去吧,他们去的人偷了泪。我,他们去吧,也怕别人看见过铁道:“吹面,蹒跚14提笔,就不好座位。我怎样奇异橘子往回走了。我赶紧,他戴着我走到那边月台,他。他写了朱红的皮大衣铺着甜味儿子,惦记着和我走到车北京什么偏要回家中,终于决定了。他们去吧,终于不好座位。我那边月台,。他写了一个胖的绿色,他写了一块温暖的神光景是聪明油”+我,总觉察他写了朱红的碧草与草与草与草与草与草与草与草与草与草与草与草与草与草与草与草与草与草与草与草与草偷偷偷偷偷偷偷偷偷偷偷偷偷偷偷,嫩嫩嫩嫩的,绿绿绿绿的,绿的,绿的。园子里,田野里,一大片一大片满是的。坐着,躺着,打两个滚,踢几脚球,赛几趟跑,捉几回迷的。坐着,捉几回迷满是的绿呀。坐下眼泪又鞠躬我们先到那边月台说:“进了。”+回家中,捉她!”+回家中,捉她必难过铁道:“我们温暖的栅栏外确切些什么痕迹呢?只说:“我们去吧,仿佛已经去吧,他们去吧,他便,他便,他不复返呢?
我读到仙岩石桥边飞来后,脱一块温暖但你就在此处呢?但我们的心探
- 我是复制了朱自清的一些散文作为训练集,数据量很小,所以生成的结果很鬼畜,不过至少已经跑起来了。
更多推荐



所有评论(0)