跳转至

9.8 模型训练、评估与上线

模型训练是将数据集、算法、算力转化为可实际使用的智能系统的核心过程;评估确保模型在真实场景下表现可信;上线则完成从实验室到生产环境的最后一公里。


一、训练流程总览

数据集准备(9.7)
    ↓ 模型搭建(选骨干 + 任务头)
    ↓ 训练配置(学习率调度 / 损失函数 / 优化器)
    ↓ 训练循环(前向 → 损失 → 反向 → 更新)
    ↓ 验证集评估(每 N epoch)→ 保存最优 checkpoint
    ↓ 测试集最终评估
    ↓ 模型导出(ONNX / TorchScript / TensorRT)
    ↓ 线上部署

二、训练策略

2.1 迁移学习(Transfer Learning)

利用在大数据集(ImageNet、COCO、OpenImages)上预训练的权重作为起点,在目标数据集上微调:

策略 适用场景 做法
全量微调(Fine-Tune) 数据量较大(>5k) 解冻所有层,小学习率全局更新
冻结骨干微调 数据量较小(<1k) 冻结骨干,只训练任务头
线性探测(Linear Probe) 快速验证特征质量 骨干全冻,只训练一个 FC 层

2.2 学习率调度

Warmup(线性增长,前 N 步)→ 余弦退火衰减(训练全程)→ 可选二次 Warmup(多阶段)
  • Warmup 防止大学习率初期梯度爆炸
  • 余弦退火(Cosine Annealing)比阶梯衰减更平滑,最终性能通常更好

2.3 混合精度训练(AMP)

使用 FP16 前向/反向,FP32 参数更新,显存节省约 50%,速度提升 1.5–3×,PyTorch 用 torch.cuda.amp.autocast() 一键开启。

2.4 分布式训练

策略 原理 适用场景
DataParallel(DP) 单进程,多 GPU 复制模型 单机多卡,简单但低效
DistributedDataParallel(DDP) 多进程,梯度 AllReduce 单机/多机首选
ZeRO(DeepSpeed) 优化器状态 / 梯度 / 参数分片 超大模型(LLM)

三、评估方法

3.1 评估原则

  • 评估集严格隔离:test 集禁止用于任何决策(调参、选模型),只做一次性最终报告
  • 多指标联合报告:不能只看单一指标(如只看准确率会掩盖类别不均衡问题)
  • 置信区间:小测试集(<1000 样本)需报告标准差或置信区间

3.2 常见评估场景

任务 主评指标 辅助指标
分类 Top-1 Accuracy F1、AUC-ROC
目标检测 mAP@0.5:0.95 AR@100、推理延迟
分割 mIoU Dice、边界 F1
去噪/超分辨率 PSNR + SSIM LPIPS

3.3 消融实验(Ablation Study)

逐步增删模型组件,量化每个设计决策的贡献。典型格式:

配置 mAP 说明
基线 72.3 ResNet50 + FPN
+ CBAM 73.8 +1.5,注意力增益
+ 数据增广 75.1 +1.3,Mosaic + CutMix
+ 大分辨率输入 76.4 +1.3,640→1280

四、模型导出与优化

4.1 导出格式

格式 场景 工具
ONNX 跨框架部署,中间格式 torch.onnx.export
TorchScript PyTorch 原生,服务端 torch.jit.trace/script
TensorRT NVIDIA GPU 最快推理 TensorRT(TRT)优化
RKNN 瑞芯微 NPU(RK3588 等) RKNN-Toolkit
CoreML Apple 芯片 coremltools

4.2 量化(Quantization)

将浮点权重转为 INT8 / INT4,减少显存和延迟:

  • 训练后量化(PTQ):用 100–1000 张校准图,速度快,精度损失 1–3%
  • 量化感知训练(QAT):训练阶段模拟量化误差,精度损失 <1%,但耗时

4.3 剪枝(Pruning)

删除不重要的权重/通道:

  • 结构化剪枝(整通道删除)→ 直接得到小模型,推理加速明显
  • 非结构化剪枝(稀疏权重)→ 需稀疏计算库支持才能加速

五、上线部署

5.1 部署形态

形态 特点 代表场景
云端 API 弹性伸缩,GPU 充足 SaaS、在线识别服务
边缘服务器 低延迟,本地化 工业相机、安防视频流
嵌入式端(NPU/MCU) 极低功耗 机载、手持设备

5.2 推理服务化

模型文件(ONNX/TRT)
    ↓ 推理引擎(Triton / TorchServe / FastAPI)
REST / gRPC 接口
    ↓ 负载均衡(Nginx / k8s)
客户端调用

5.3 上线检查清单

  • [ ] 测试集最终指标已记录(与预期基线对比)
  • [ ] 推理延迟、吞吐已压测(P99 满足 SLA)
  • [ ] 模型版本与数据集版本绑定记录
  • [ ] 灰度发布或 A/B 测试策略就绪
  • [ ] 监控告警(精度漂移检测)已配置

六、线上监控与再训练

模型上线后需持续监控:

  • 数据漂移(Data Drift):输入数据分布偏离训练分布,可通过特征统计量检测
  • 概念漂移(Concept Drift):真实标签分布变化(如新产品外观变更)
  • 定期再训练:收集线上难样本 → 回流标注 → 增量或全量再训练 → 灰度发布

参考资料

  • PyTorch 文档:https://pytorch.org/docs/stable/
  • NVIDIA TensorRT 文档:https://docs.nvidia.com/deeplearning/tensorrt/
  • Sculley et al., \"Hidden Technical Debt in Machine Learning Systems\", NeurIPS, 2015

更新时间

2026-03-03