中文文本分类项目过程及问题解决记录
本项目基于 GitHub 仓库 Chinese-Text-Classification-Pytorch,实现了多种经典文本分类模型,包括 TextCNN、FastText、BiLSTM+Attention、DPCNN 和 Transformer 等,使用 THUCNews 子集(20 万条新闻标题,10 个类别)进行实验。该仓库最后更新于 2020 年,代码结构清晰,但存在环境兼容性和少量实现细节问题。以下记录实际运行中的主要步骤、遇到的问题及解决方案。
1. 环境准备与代码调整
- Python 与 PyTorch 兼容性:原项目使用 Python 3.7 和 PyTorch 1.1,当前环境建议升级至 Python 3.10 和 PyTorch 2.x。推荐新建虚拟环境:
conda create -n textclass python=3.10
随后安装依赖:pip install torch tqdm scikit-learn tensorboard。 - 模型保存路径:权重文件默认保存至
./saved_dict/xxx.ckpt,而非 checkpoints 目录。运行结束后需检查 saved_dict 文件夹。 - 函数参数问题:utils.py 中 build_dataset 函数的第二个参数为
ues_word(原作者拼写错误,应为 use_word),调用时需显式传入False(字符级输入)。 - 配置依赖处理:代码中多处直接访问 config 的属性(如 vocab_path、pad_size),但 config 并非标准对象。预测脚本建议硬编码关键参数:pad_size=32、embedding_dim=300、device=‘cpu’。
- 目录创建:建议在 run.py 开头添加
os.makedirs('saved_dict', exist_ok=True),防止保存时因目录缺失报错。
上述调整可显著减少后续运行中的兼容性问题。
2. 模型训练
- 推荐模型:优先选择 FastText,训练命令为
python run.py --model FastText --embedding random(随机初始化嵌入)。
数据规模较小,CPU 环境下约 10–20 分钟完成。 - 早停机制:验证集准确率在 91.3%–91.9% 区间波动后触发早停,最终测试集准确率 92.06%。若需更高性能,可去除
--embedding random,使用默认搜狗预训练向量,准确率可达 92.2% 以上。 - 日志监控:每 100 次迭代输出一次训练/验证指标,验证准确率稳定在 91% 以上即可视为收敛,无需强制完成 20 个 epoch。
- 权重保存:最优模型保存为 saved_dict/FastText.ckpt,文件大小约几十 MB,包含嵌入层及分类头参数。
3. 预测实现与测试
- 项目未提供预测脚本,需自行实现。
- 核心流程:通过 build_dataset(dataset, False) 获取词汇表和标签映射;加载 ckpt 文件;输入文本按字符切分,转为 ID 序列;填充或截断至 pad_size=32;模型前向传播;softmax 后取 argmax,并映射至类别名称。
- 交互式测试:可实现循环输入模式,实时输出预测类别及置信度。建议准备 8–10 条覆盖不同类别的测试样本,用于观察模型在财经与股票等易混类别上的表现。
- 优化建议:首次运行后,可将词汇表、标签映射及 pad_size 通过 pickle 保存为文件,后续预测直接加载,避免重复调用 build_dataset。
4. 网页接口部署
实现方式:可使用 FastAPI 结合 uvicorn 构建简单接口,接收文本输入,返回预测结果。
常见问题:在服务器执行
uvicorn app:app --host 0.0.0.0 --port 8000时提示命令未找到。该命令属于 Python 包而非系统包,需先执行pip install "uvicorn[standard]"。本地开发:使用
--reload参数支持热重载,访问 localhost:8000/docs 通过 Swagger 测试接口。