从零到一:DETR模型在自定义数据集上的实战部署与调优

张开发
2026/4/5 10:11:44 15 分钟阅读

分享文章

从零到一:DETR模型在自定义数据集上的实战部署与调优
1. 环境准备与源码配置第一次接触DETR模型时我被它简洁的端到端设计惊艳到了。相比传统目标检测模型那些复杂的anchor设置和NMS后处理DETR直接用Transformer搞定一切。不过在实际部署时环境配置这块确实踩过不少坑。建议新手直接用Linux系统操作Windows下的兼容性问题会让你怀疑人生。1.1 创建Python虚拟环境我习惯用conda管理环境这里用Python 3.7创建一个新环境conda create -n detr python3.7 -y conda activate detr为什么选3.7因为PyTorch 1.7对这个版本支持最稳定。有次用3.8遇到奇怪的dll加载错误折腾半天才发现是Python版本问题。1.2 安装PyTorch与CUDA安装PyTorch前务必先确认CUDA版本nvidia-smi看到CUDA Version是11.0后去PyTorch官网找对应版本。这里有个坑CUDA 11.0其实兼容cudatoolkit 11.0及以下版本。我测试过用cudatoolkit10.2也能正常运行conda install pytorch1.7.0 torchvision0.8.0 torchaudio0.7.0 cudatoolkit11.0 -c pytorch安装完一定要验证import torch print(torch.cuda.is_available()) # 输出True才算成功1.3 克隆DETR源码官方源码在Facebook Research的GitHub仓库git clone https://github.com/facebookresearch/detr.git cd detr建议把源码放在固态硬盘上训练时数据加载速度能快不少。我试过机械硬盘数据加载经常成为训练瓶颈。1.4 安装依赖库最麻烦的是pycocotools的安装。直接pip安装经常报错推荐用源码编译conda install cython scipy pip install -U githttps://github.com/cocodataset/cocoapi.git#subdirectoryPythonAPI如果遇到unicode未定义错误可能是Python 2/3兼容性问题。这时需要手动修改cocoapi源码把unicode替换为str。2. 数据准备与格式转换2.1 COCO数据格式详解DETR默认使用COCO格式目录结构应该是这样dataset/ ├── train2017/ │ ├── 000001.jpg │ └── ... ├── val2017/ │ ├── 000002.jpg │ └── ... └── annotations/ ├── instances_train2017.json └── instances_val2017.json关键在annotations里的JSON文件。做过几个项目后发现COCO格式最核心的是这几个字段{ images: [{id: 1, file_name: 000001.jpg, width: 640, height: 480}], annotations: [{ id: 1, image_id: 1, category_id: 1, bbox: [x,y,width,height], area: width*height, iscrowd: 0 }], categories: [{id: 1, name: defect}] }2.2 数据转换实战技巧对于工业缺陷检测这种场景标注工具输出的通常是VOC格式或YOLO格式。我写过一个转换脚本核心逻辑是解析原始标注文件获取bbox坐标将像素坐标转换为COCO的相对坐标格式计算每个bbox的面积area字段处理类别ID映射关系特别提醒iscrowd字段千万别忽略有次训练结果异常排查半天发现是这个字段全设为1了导致模型把所有目标都当成密集人群处理。2.3 数据增强策略DETR对数据增强比较敏感推荐使用torchvision自带的变换from torchvision.transforms import Compose, RandomResizedCrop, ToTensor, Normalize transform Compose([ RandomResizedCrop(size800, scale(0.8, 1.2)), ToTensor(), Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])注意保持宽高比随机裁剪时scale别设太大否则小目标容易被裁掉。在PCB缺陷检测项目中我把scale调到(0.9,1.1)效果最好。3. 模型配置与迁移学习3.1 修改类别数DETR原始模型是为COCO的80类设计的。修改detr.py中的类别数时要注意两处# detr/models/detr.py 第313行 self.num_classes num_classes 1 # 记得1是给背景类 # 如果是全景分割还要改317行 self.num_classes_panoptic num_classes_panoptic我建议直接在代码里加个assert检查assert num_classes 2, 别忘了修改预训练权重3.2 预训练权重处理下载ResNet50预训练模型后需要用这段代码调整分类头import torch pretrained_weights torch.load(./detr-r50-e632da11.pth) num_classes 2 # 你的类别数1 pretrained_weights[model][class_embed.weight].resize_(num_classes1, 256) pretrained_weights[model][class_embed.bias].resize_(num_classes1) torch.save(pretrained_weights, detr_r50_%d.pth%num_classes)有个隐藏坑如果类别数变化很大比如从80类改成2类最好重新初始化最后一层。我试过在遥感图像场景下这样改能提升3-5个点的AP。4. 训练技巧与参数调优4.1 单卡训练配置基础训练命令长这样python main.py \ --batch_size 4 \ --lr_drop 10 \ --output_dir ./output \ --coco_path ./dataset \ --resume detr_r50_2.pth \ --epochs 200几个关键参数经验值batch_size显存够的话尽量大但超过8可能不稳定lr_drop小数据集建议设小点如10epochs工业缺陷检测一般100-200轮足够4.2 学习率策略DETR对学习率特别敏感。默认配置可能不适合小数据集我常用的调整策略# 在main.py里修改optimizer配置 param_dicts [ {params: [p for n, p in model.named_parameters() if backbone not in n and p.requires_grad]}, { params: [p for n, p in model.named_parameters() if backbone in n and p.requires_grad], lr: args.lr * 0.1, # backbone学习率降低 }, ]对于小数据集把初始lr从1e-4降到5e-5往往更稳定。配合warmup效果更好--lr 5e-5 --warmup_epochs 54.3 分布式训练多卡训练用这个命令模板python -m torch.distributed.launch --nproc_per_node4 \ --use_env main.py \ --output_dir ./output \ --coco_path ./dataset \ --resume detr_r50_2.pth \ --epochs 200注意batch_size是每卡的batch总batch_sizenproc_per_node×batch_size。同步BN能提升多卡训练稳定性model torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)5. 评估与问题排查5.1 验证指标解读训练完会输出这些指标Average Precision (AP) [ IoU0.50:0.95 | area all | maxDets100 ] 0.412 Average Precision (AP) [ IoU0.50 | area all | maxDets100 ] 0.687 Average Precision (AP) [ IoU0.75 | area all | maxDets100 ] 0.421重点关注AP50:95和AP50。工业场景中如果AP50很高但AP75低说明bbox定位不够准可能需要增加数据增强调整匈牙利匹配的cost权重增加训练epoch5.2 常见问题解决问题1Loss震荡严重调小学习率增加batch_size检查数据标注质量问题2验证集AP远低于训练集尝试更强的数据增强减少模型容量如用ResNet18早停early stopping问题3预测框偏离目标调整匈牙利匹配的bbox_loss权重检查标注框是否都是xywh格式增加GIoU损失的权重系数在PCB缺陷检测项目中通过调整matcher的cost权重将分类权重从1降到0.5bbox权重从5升到8AP提升了7个点。

更多文章