PyTorch 实战:训练 CIFAR-10 图像分类器

张开发
2026/4/8 2:38:40 15 分钟阅读

分享文章

PyTorch 实战:训练 CIFAR-10 图像分类器
一、任务简介图像分类是计算机视觉的入门任务。本文我们将使用 PyTorch 训练一个卷积神经网络CNN对 CIFAR-10 数据集中的 10 类彩色图像进行分类。你将学会下载和预处理 CIFAR-10 数据集定义适用于彩色图像的 CNN 结构训练模型并保存/加载模型在测试集上评估模型性能整体准确率和各类别准确率将训练过程迁移到 GPU 以加速二、CIFAR-10 数据集介绍CIFAR-10 是一个经典的图像分类数据集包含 60000 张 32x32 的彩色图像3 个通道共 10 个类别类别标签0airplane飞机1automobile汽车2bird鸟3cat猫4deer鹿5dog狗6frog青蛙7horse马8ship船9truck卡车训练集 50000 张测试集 10000 张。三、数据下载与预处理使用torchvision可以轻松下载并转换数据。我们将图像转为张量并归一化到 [-1, 1] 范围。importtorchimporttorchvisionimporttorchvision.transformsastransforms# 定义预处理转换为张量 归一化均值0.5标准差0.5transformtransforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])# 下载训练集trainsettorchvision.datasets.CIFAR10(root./data,trainTrue,downloadTrue,transformtransform)trainloadertorch.utils.data.DataLoader(trainset,batch_size4,shuffleTrue,num_workers2)# 下载测试集testsettorchvision.datasets.CIFAR10(root./data,trainFalse,downloadTrue,transformtransform)testloadertorch.utils.data.DataLoader(testset,batch_size4,shuffleFalse,num_workers2)# 类别名称classes(plane,car,bird,cat,deer,dog,frog,horse,ship,truck)Windows 用户提示如果出现BrokenPipeError请将num_workers设为 0。四、可视化部分训练图像为了直观感受数据我们展示一个 batch 的图片importmatplotlib.pyplotaspltimportnumpyasnpdefimshow(img):imgimg/20.5# 反归一化到 [0,1]npimgimg.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))plt.show()# 获取一个 batchdataiteriter(trainloader)images,labelsnext(dataiter)# 显示图片网格imshow(torchvision.utils.make_grid(images))# 打印标签print( .join(f{classes[labels[j]]:5s}forjinrange(4)))五、定义卷积神经网络由于 CIFAR-10 是彩色 3 通道图像我们将第一层卷积的输入通道改为 3。网络结构两个卷积层5x5 卷积核 三个全连接层。importtorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.Module):def__init__(self):super(Net,self).__init__()self.conv1nn.Conv2d(3,6,5)# 输入3通道输出6通道self.poolnn.MaxPool2d(2,2)# 池化窗口2x2self.conv2nn.Conv2d(6,16,5)# 输入6输出16# 经过两次卷积池化后特征图尺寸为 16 * 5 * 5计算过程见下文self.fc1nn.Linear(16*5*5,120)self.fc2nn.Linear(120,84)self.fc3nn.Linear(84,10)defforward(self,x):xself.pool(F.relu(self.conv1(x)))xself.pool(F.relu(self.conv2(x)))xx.view(-1,16*5*5)xF.relu(self.fc1(x))xF.relu(self.fc2(x))xself.fc3(x)returnx netNet()print(net)尺寸计算输入 32x32第一个卷积5x5无 padding输出 28x28池化后 14x14第二个卷积输出 10x10池化后 5x5。因此展平维度为 1655400。六、定义损失函数和优化器分类任务使用交叉熵损失CrossEntropyLoss优化器选择带动量的 SGD。importtorch.optimasoptim criterionnn.CrossEntropyLoss()optimizeroptim.SGD(net.parameters(),lr0.001,momentum0.9)七、训练模型训练 2 个 epoch可根据需要增加每个 epoch 遍历整个训练集。forepochinrange(2):running_loss0.0fori,datainenumerate(trainloader,0):inputs,labelsdata# 梯度清零optimizer.zero_grad()# 前向传播outputsnet(inputs)# 计算损失losscriterion(outputs,labels)# 反向传播loss.backward()# 更新参数optimizer.step()running_lossloss.item()if(i1)%20000:print(f[Epoch{epoch1}, Batch{i1:5d}] loss:{running_loss/2000:.3f})running_loss0.0print(训练完成)输出示例[Epoch 1, Batch 2000] loss: 2.227 [Epoch 1, Batch 4000] loss: 1.884 ... [Epoch 2, Batch 12000] loss: 1.291 训练完成八、保存模型训练完成后通常只保存模型的参数state_dict而不是整个对象。PATH./cifar_net.pthtorch.save(net.state_dict(),PATH)九、测试模型1. 在单 batch 上测试# 加载模型netNet()net.load_state_dict(torch.load(PATH))# 获取测试集的一个 batchdataiteriter(testloader)images,labelsnext(dataiter)# 预测outputsnet(images)_,predictedtorch.max(outputs,1)# 显示真实标签和预测标签print(GroundTruth: , .join(f{classes[labels[j]]:5s}forjinrange(4)))print(Predicted: , .join(f{classes[predicted[j]]:5s}forjinrange(4)))2. 在整个测试集上评估准确率correct0total0withtorch.no_grad():# 不计算梯度节省内存fordataintestloader:images,labelsdata outputsnet(images)_,predictedtorch.max(outputs,1)totallabels.size(0)correct(predictedlabels).sum().item()print(f在 10000 张测试集上的准确率:{100*correct/total:.1f}%)通常可以达到 53% 左右的准确率随机猜测仅 10%说明模型确实学到了特征。3. 分类别统计准确率class_correct[0.0]*10class_total[0.0]*10withtorch.no_grad():fordataintestloader:images,labelsdata outputsnet(images)_,predictedtorch.max(outputs,1)c(predictedlabels).squeeze()foriinrange(4):# batch_size4labellabels[i]class_correct[label]c[i].item()class_total[label]1foriinrange(10):print(f{classes[i]:10s}准确率:{100*class_correct[i]/class_total[i]:.1f}%)输出示例plane 准确率: 62.0% car 准确率: 62.0% bird 准确率: 45.0% cat 准确率: 36.0% deer 准确率: 52.0% dog 准确率: 25.0% frog 准确率: 69.0% horse 准确率: 60.0% ship 准确率: 70.0% truck 准确率: 48.0%可以看到模型对猫、狗的分类效果较差对船、青蛙的分类效果较好。十、在 GPU 上训练加速如果你的电脑有 NVIDIA 显卡并已安装 CUDA可以轻松将训练迁移到 GPU。devicetorch.device(cuda:0iftorch.cuda.is_available()elsecpu)print(使用设备:,device)# 将网络转移到 GPUnet.to(device)# 在训练循环中将输入和标签也转移到 GPUforepochinrange(2):fori,datainenumerate(trainloader,0):inputs,labelsdata inputs,labelsinputs.to(device),labels.to(device)# 其余代码不变 ...只需添加.to(device)PyTorch 就会自动在 GPU 上执行运算训练速度显著提升。十一、总结使用torchvision下载并预处理 CIFAR-10 数据集为彩色图像设计 CNN 结构完成完整的训练、保存、加载和测试流程分析模型在不同类别上的表现将模型迁移到 GPU 加速训练。

更多文章