中文文本分类项目过程与问题解决
中文文本分类项目过程及问题解决记录 本项目基于 GitHub 仓库 Chinese-Text-Classification-Pytorch,实现了多种经典文本分类模型,包括 TextCNN、FastText、BiLSTM+Attention、DPCNN 和 Transformer 等,使用 THUCNews 子集(20 万条新闻标题,10 个类别)进行实验。该仓库最后更新于 2020 年,代码结构清晰,但存在环境兼容性和少量实现细节问题。以下记录实际运行中的主要步骤、遇到的问题及解决方案。 1. 环境准备与代码调整 Python 与 PyTorch 兼容性:原项目使用 Python 3.7 和 PyTorch 1.1,当前环境建议升级至 Python 3.10 和 PyTorch 2.x。推荐新建虚拟环境: conda create -n textclass python=3.10 随后安装依赖:pip install torch tqdm scikit-learn tensorboard。 模型保存路径:权重文件默认保存至 ./saved_dict/xxx.ckpt,而非 checkpoints 目录。运行结束后需检查 saved_dict 文件夹。 函数参数问题:utils.py 中 build_dataset 函数的第二个参数为 ues_word(原作者拼写错误,应为 use_word),调用时需显式传入 False(字符级输入)。 配置依赖处理:代码中多处直接访问 config 的属性(如 vocab_path、pad_size),但 config 并非标准对象。预测脚本建议硬编码关键参数:pad_size=32、embedding_dim=300、device=‘cpu’。 目录创建:建议在 run.py 开头添加 os.makedirs('saved_dict', exist_ok=True),防止保存时因目录缺失报错。 上述调整可显著减少后续运行中的兼容性问题。 2. 模型训练 推荐模型:优先选择 FastText,训练命令为 python run.py --model FastText --embedding random(随机初始化嵌入)。 数据规模较小,CPU 环境下约 10–20 分钟完成。 早停机制:验证集准确率在 91.3%–91.9% 区间波动后触发早停,最终测试集准确率 92.06%。若需更高性能,可去除 --embedding random,使用默认搜狗预训练向量,准确率可达 92.2% 以上。 日志监控:每 100 次迭代输出一次训练/验证指标,验证准确率稳定在 91% 以上即可视为收敛,无需强制完成 20 个 epoch。 权重保存:最优模型保存为 saved_dict/FastText.ckpt,文件大小约几十 MB,包含嵌入层及分类头参数。 3. 预测实现与测试 项目未提供预测脚本,需自行实现。 核心流程:通过 build_dataset(dataset, False) 获取词汇表和标签映射;加载 ckpt 文件;输入文本按字符切分,转为 ID 序列;填充或截断至 pad_size=32;模型前向传播;softmax 后取 argmax,并映射至类别名称。 交互式测试:可实现循环输入模式,实时输出预测类别及置信度。建议准备 8–10 条覆盖不同类别的测试样本,用于观察模型在财经与股票等易混类别上的表现。 优化建议:首次运行后,可将词汇表、标签映射及 pad_size 通过 pickle 保存为文件,后续预测直接加载,避免重复调用 build_dataset。 4. 网页接口部署 实现方式:可使用 FastAPI 结合 uvicorn 构建简单接口,接收文本输入,返回预测结果。 ...