手把手教你用PyTorch复现ConvNeXt-Tiny:从零搭建到图像分类实战(附完整代码)

张开发
2026/4/20 19:52:53 15 分钟阅读

分享文章

手把手教你用PyTorch复现ConvNeXt-Tiny:从零搭建到图像分类实战(附完整代码)
手把手教你用PyTorch复现ConvNeXt-Tiny从零搭建到图像分类实战ConvNeXt作为2022年横空出世的卷积神经网络新星在ImageNet分类任务上以纯卷积结构超越了当时最先进的Vision Transformer。本文将带您从零开始实现ConvNeXt-Tiny模型并完成完整的图像分类实战。不同于简单的API调用我们会深入每个模块的设计原理让您真正掌握这个现代化卷积网络的精华。1. 环境准备与数据加载工欲善其事必先利其器。在开始构建模型前我们需要配置合适的开发环境。推荐使用Python 3.8和PyTorch 1.12版本这些组合经过充分验证能避免大多数兼容性问题。conda create -n convnext python3.8 conda activate convnext pip install torch torchvision torchaudio pip install matplotlib tqdm对于数据集我们选择CIFAR-10作为演示它包含10类共60000张32x32彩色图像。虽然原始ConvNeXt设计用于224x224输入但我们会展示如何适配不同尺寸import torch from torchvision import datasets, transforms # 数据增强策略 train_transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding4), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) test_transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 加载数据集 train_set datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtrain_transform) test_set datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtest_transform) train_loader torch.utils.data.DataLoader(train_set, batch_size128, shuffleTrue) test_loader torch.utils.data.DataLoader(test_set, batch_size128, shuffleFalse)提示如果您的显存有限可以适当减小batch_size。ConvNeXt-Tiny在RTX 3060上使用batch_size128大约需要4GB显存。2. ConvNeXt核心模块实现ConvNeXt的成功源于对Transformer设计理念的巧妙借鉴。让我们拆解它的核心组件您会发现这些看似简单的改动如何带来质的飞跃。2.1 深度可分离卷积与大型核传统CNN使用3x3卷积核而ConvNeXt大胆采用7x7深度可分离卷积class DepthwiseConv(nn.Module): def __init__(self, dim): super().__init__() self.dwconv nn.Conv2d(dim, dim, kernel_size7, padding3, groupsdim) def forward(self, x): return self.dwconv(x)这种设计有两重优势大感受野7x7卷积核能捕获更广的上下文信息参数效率深度可分离卷积将参数量减少为普通卷积的1/72.2 反向瓶颈结构与层归一化ConvNeXt采用了类似Transformer的宽中间层设计class InvertedBottleneck(nn.Module): def __init__(self, dim, expansion4): super().__init__() self.pwconv1 nn.Linear(dim, dim * expansion) # 扩展 self.act nn.GELU() self.pwconv2 nn.Linear(dim * expansion, dim) # 压缩 def forward(self, x): return self.pwconv2(self.act(self.pwconv1(x)))配合LayerNorm使用形成了完整的特征处理流程class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.norm nn.LayerNorm(dim) def forward(self, x): x x.permute(0, 2, 3, 1) # [B,C,H,W] - [B,H,W,C] x self.norm(x) return x.permute(0, 3, 1, 2) # 恢复原维度2.3 完整ConvNeXt Block集成将上述组件组合起来就构成了ConvNeXt的基础模块class ConvNeXtBlock(nn.Module): def __init__(self, dim, drop_path_rate0.): super().__init__() self.dwconv DepthwiseConv(dim) self.norm LayerNorm(dim) self.pwconv InvertedBottleneck(dim) self.gamma nn.Parameter(torch.ones(dim)) self.drop_path DropPath(drop_path_rate) if drop_path_rate 0 else nn.Identity() def forward(self, x): shortcut x x self.dwconv(x) x self.norm(x) x self.pwconv(x) x self.gamma.view(1, -1, 1, 1) * x return shortcut self.drop_path(x)注意DropPath是随机深度正则化技术训练时随机丢弃部分路径测试时完整使用能有效防止过拟合。3. 构建ConvNeXt-Tiny网络现在我们可以组装完整的ConvNeXt-Tiny架构。该模型包含4个阶段(stage)每个阶段有不同的特征维度和块数量Stage特征维度块数量下采样方式19634x4 conv219232x2 conv338492x2 conv476832x2 convclass ConvNeXt(nn.Module): def __init__(self, in_chans3, num_classes1000, depths[3, 3, 9, 3], dims[96, 192, 384, 768]): super().__init__() # 下采样层 self.downsample_layers nn.ModuleList() stem nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size4, stride4), LayerNorm(dims[0]) ) self.downsample_layers.append(stem) for i in range(3): downsample_layer nn.Sequential( LayerNorm(dims[i]), nn.Conv2d(dims[i], dims[i1], kernel_size2, stride2) ) self.downsample_layers.append(downsample_layer) # 阶段块 self.stages nn.ModuleList() dp_rates [x.item() for x in torch.linspace(0, 0.1, sum(depths))] cur 0 for i in range(4): stage nn.Sequential( *[ConvNeXtBlock(dimdims[i], drop_path_ratedp_rates[curj]) for j in range(depths[i])] ) self.stages.append(stage) cur depths[i] # 分类头 self.norm nn.LayerNorm(dims[-1]) self.head nn.Linear(dims[-1], num_classes) def forward(self, x): for i in range(4): x self.downsample_layers[i](x) x self.stages[i](x) x x.mean([-2, -1]) # 全局平均池化 x self.norm(x) return self.head(x)4. 模型训练与微调技巧有了完整的模型架构接下来就是训练过程。这里分享几个关键技巧能显著提升训练效果。4.1 优化器配置ConvNeXt作者推荐使用AdamW优化器配合余弦退火学习率调度def create_optimizer(model, lr4e-3, weight_decay0.05): decay_params [] no_decay_params [] for name, param in model.named_parameters(): if not param.requires_grad: continue if name.endswith(.bias) or norm in name: no_decay_params.append(param) else: decay_params.append(param) optim_groups [ {params: decay_params, weight_decay: weight_decay}, {params: no_decay_params, weight_decay: 0.0} ] return torch.optim.AdamW(optim_groups, lrlr) optimizer create_optimizer(model) scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max200)4.2 训练循环实现完整的训练循环需要考虑混合精度训练和梯度裁剪scaler torch.cuda.amp.GradScaler() for epoch in range(200): model.train() for images, labels in train_loader: images, labels images.cuda(), labels.cuda() optimizer.zero_grad() with torch.cuda.amp.autocast(): outputs model(images) loss F.cross_entropy(outputs, labels) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() # 验证集评估 model.eval() correct 0 total 0 with torch.no_grad(): for images, labels in test_loader: images, labels images.cuda(), labels.cuda() outputs model(images) _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() print(fEpoch {epoch1}, Accuracy: {100 * correct / total:.2f}%)4.3 关键调参技巧学习率预热前5个epoch线性增加学习率避免初期不稳定标签平滑使用0.1的标签平滑正则化减轻过拟合随机深度逐步增加drop path rate最高到0.1EMA模型参数指数移动平均提升测试时稳定性# 标签平滑示例 criterion nn.CrossEntropyLoss(label_smoothing0.1) # EMA实现 class ModelEMA: def __init__(self, model, decay0.9999): self.ema deepcopy(model).eval() self.decay decay def update(self, model): with torch.no_grad(): for ema_param, model_param in zip(self.ema.parameters(), model.parameters()): ema_param.mul_(self.decay).add_(model_param, alpha1-self.decay)5. 自定义数据集适配当我们需要在自己的数据集上应用ConvNeXt时有几个关键点需要注意5.1 输入尺寸调整原始ConvNeXt设计用于224x224输入但可以通过修改stem层适配不同尺寸def adapt_convnext_for_size(model, input_size32): # 调整初始下采样层 if input_size 32: model.downsample_layers[0][0] nn.Conv2d(3, 96, kernel_size3, stride1, padding1) elif input_size 64: model.downsample_layers[0][0] nn.Conv2d(3, 96, kernel_size4, stride2, padding1) return model5.2 类别数修改对于不同类别数的分类任务只需替换最后的分类头num_classes 10 # CIFAR-10的类别数 model ConvNeXt(num_classesnum_classes)5.3 数据增强策略根据数据集特点调整增强策略例如对于医学图像medical_transform transforms.Compose([ transforms.RandomAffine(degrees10, translate(0.1, 0.1)), transforms.RandomResizedCrop(224, scale(0.8, 1.0)), transforms.ColorJitter(brightness0.1, contrast0.1), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])6. 模型部署与性能优化训练好的模型需要优化才能高效部署。以下是几个实用技巧6.1 模型量化PyTorch支持动态和静态量化显著减少模型大小和推理时间# 动态量化 quantized_model torch.quantization.quantize_dynamic( model, {nn.Linear, nn.Conv2d}, dtypetorch.qint8 ) # 静态量化需要准备校准数据 model.qconfig torch.quantization.get_default_qconfig(fbgemm) torch.quantization.prepare(model, inplaceTrue) # 运行校准数据... torch.quantization.convert(model, inplaceTrue)6.2 ONNX导出导出为ONNX格式可实现跨平台部署dummy_input torch.randn(1, 3, 224, 224).cuda() torch.onnx.export( model, dummy_input, convnext_tiny.onnx, input_names[input], output_names[output], dynamic_axes{input: {0: batch}, output: {0: batch}} )6.3 TensorRT加速对于生产环境可以使用TensorRT进一步优化# 使用torch2trt转换 from torch2trt import torch2trt model_trt torch2trt(model, [dummy_input], fp16_modeTrue) # 推理时直接调用 output model_trt(dummy_input)在实际项目中ConvNeXt-Tiny量化后在NVIDIA Jetson Xavier NX上能达到约50 FPS的推理速度完全满足实时性要求。

更多文章