pytorch 深度学习目标检测算法yolov5训练电动车闯红灯检测数据集 建立基于深度学习Yolov5电动车闯红灯检测识别

张开发
2026/5/22 23:12:17 15 分钟阅读
pytorch 深度学习目标检测算法yolov5训练电动车闯红灯检测数据集 建立基于深度学习Yolov5电动车闯红灯检测识别
pytorch 深度学习目标检测算法yolov5训练电动车闯红灯检测数据集 建立基于深度学习Yolov5电动车闯红灯检测识别文章目录1. 数据准备数据集结构数据标注数据划分2. 环境搭建安装依赖3. 数据配置4. 模型训练使用YOLOv5进行训练5. 配置超参数6. 模型推理单张图片推理7. 批量推理8. 性能评估mAP计算自定义评估脚本yolov5电动车闯红灯检测 可检测 红灯 绿灯 电动车三类1基于YOLOv5的电动车闯红灯检测系统涉及多个步骤包括数据准备、环境搭建、模型训练、超参数配置、模型推理、批量推理以及性能评估。以下是详细的实现步骤和代码示例。1. 数据准备数据集结构确保你的数据集结构如下traffic_dataset/ ├── images/ │ ├── train/ │ └── val/ └── labels/ ├── train/ └── val/数据标注每个图像文件对应一个.txt文件格式为YOLO适用的标注格式每行包含类别ID和边界框信息归一化后的中心点坐标和宽高。例如0 0.5 0.5 0.2 0.2 1 0.7 0.3 0.1 0.1数据划分将数据集按比例划分为训练集和验证集例如9:1。importosimportrandomfromshutilimportcopyfile# 定义数据集路径dataset_pathtraffic_datasetimages_pathos.path.join(dataset_path,images)labels_pathos.path.join(dataset_path,labels)# 创建目录os.makedirs(images_path,exist_okTrue)os.makedirs(labels_path,exist_okTrue)# 获取所有图像文件image_files[fforfinos.listdir(dataset_path)iff.endswith(.jpg)orf.endswith(.png)]random.shuffle(image_files)# 按比例划分训练集和验证集train_ratio0.9train_sizeint(len(image_files)*train_ratio)train_imagesimage_files[:train_size]val_imagesimage_files[train_size:]# 复制文件到对应的目录forimg_fileintrain_images:label_fileimg_file.replace(.jpg,.txt).replace(.png,.txt)copyfile(os.path.join(dataset_path,img_file),os.path.join(images_path,train,img_file))copyfile(os.path.join(dataset_path,label_file),os.path.join(labels_path,train,label_file))forimg_fileinval_images:label_fileimg_file.replace(.jpg,.txt).replace(.png,.txt)copyfile(os.path.join(dataset_path,img_file),os.path.join(images_path,val,img_file))copyfile(os.path.join(dataset_path,label_file),os.path.join(labels_path,val,label_file))2. 环境搭建安装依赖# 创建并激活虚拟环境conda create-ntraffic_detectionpython3.8conda activate traffic_detection# 安装YOLOv5和相关库pipinstalltorch torchvisiongitclone https://github.com/ultralytics/yolov5.gitcdyolov5 pipinstall-rrequirements.txt3. 数据配置在traffic_dataset/目录下创建一个名为data.yaml的数据配置文件内容如下train:./images/train/val:./images/val/nc:3names:[red_light,green_light,non_motor_vehicle]4. 模型训练使用YOLOv5进行训练cdyolov5 python train.py--img640--batch16--epochs100--data../traffic_dataset/data.yaml--weightsyolov5s.pt5. 配置超参数在训练过程中可以通过修改yolov5/data/hyps/hyp.scratch.yaml文件来设置超参数。以下是一些常见的超参数及其说明lr0: 初始学习率默认值通常为0.01。lrf: 最终学习率因子默认值通常为0.01。momentum: 动量默认值通常为0.937。weight_decay: 权重衰减默认值通常为0.0005。6. 模型推理单张图片推理fromyolov5.models.experimentalimportattempt_loadfromyolov5.utils.datasetsimportLoadImagesfromyolov5.utils.generalimportnon_max_suppression,scale_coordsfromyolov5.utils.torch_utilsimportselect_deviceimportcv2importtorchdefinfer_image(weights,img_path):deviceselect_device()modelattempt_load(weights,map_locationdevice)datasetLoadImages(img_path)forpath,img,im0s,vid_capindataset:imgtorch.from_numpy(img).to(device)imgimg.float()/255.0ifimg.ndimension()3:imgimg.unsqueeze(0)predmodel(img,augmentFalse)[0]prednon_max_suppression(pred,conf_thres0.25,iou_thres0.45)fordetinpred:iflen(det):det[:,:4]scale_coords(img.shape[2:],det[:,:4],im0s.shape).round()for*xyxy,conf,clsinreversed(det):labelf{model.names[int(cls)]}{conf:.2f}plot_one_box(xyxy,im0s,labellabel,color(0,255,0),line_thickness3)cv2.imshow(Traffic Detection,im0s)cv2.waitKey(0)cv2.destroyAllWindows()if__name____main__:weightsruns/train/exp/weights/best.ptimg_pathpath/to/test_image.jpginfer_image(weights,img_path)7. 批量推理importosdefbatch_infer_images(weights,directory):forfilenameinos.listdir(directory):iffilename.endswith(.jpg)orfilename.endswith(.png):image_pathos.path.join(directory,filename)annotated_frameinfer_image(weights,image_path)cv2.imwrite(foutput_{filename},annotated_frame)if__name____main__:weightsruns/train/exp/weights/best.ptdirectorypath/to/imagesbatch_infer_images(weights,directory)8. 性能评估mAP计算YOLOv5自带评估功能可以在验证集上计算mAP。cdyolov5 python val.py--img640--batch16--data../traffic_dataset/data.yaml--weightsruns/train/exp/weights/best.pt自定义评估脚本如果你想要更详细的评估指标比如针对特定类别的准确率和召回率可以编写自定义脚本。fromsklearn.metricsimportprecision_recall_fscore_supportdefevaluate_model(model,dataset):all_preds[]all_labels[]forimages,labelsindataset:predsmodel(images)pred_classes[int(pred.cls[0])forpredinpreds]true_classeslabels.numpy().flatten()all_preds.extend(pred_classes)all_labels.extend(true_classes)precision,recall,f1,_precision_recall_fscore_support(all_labels,all_preds,averageweighted)print(fPrecision:{precision:.2f}, Recall:{recall:.2f}, F1 Score:{f1:.2f})# 示例评估模型性能evaluate_model(model,val_dataset)以上就是关于如何在YOLOv5基础上构建电动车闯红灯检测系统的完整指南包括数据准备、环境搭建、模型训练、超参数配置、模型推理、批量推理以及性能评估代码。根据实际需求调整上述代码段关键代码示例仅供参考。

更多文章