通俗易懂讲透批量梯度下降法(BGD)

张开发
2026/4/10 21:46:03 15 分钟阅读

分享文章

通俗易懂讲透批量梯度下降法(BGD)
通俗易懂讲透批量梯度下降法BGD本科生/研究生都能看懂本文用大白话下山比喻公式拆解完整代码可视化把批量梯度下降Batch Gradient Descent从原理、流程、优缺点到适用场景讲得明明白白适合机器学习入门、面试复习、课程笔记。一、先搞懂什么是批量梯度下降BGD一句话定义批量梯度下降 每一步都用「全部训练数据」算梯度再统一更新一次参数。超级形象比喻你在下山找最低点最小损失BGD每走一步都把整座山的地形看一遍再决定往哪走、走多远优点方向准、不跑偏缺点走一步超级慢二、核心思想超简单看所有样本的误差算平均梯度沿梯度反方向更新参数重复直到损失不再下降三、为什么要用 BGD方向最准确用全体数据梯度无噪声稳定不震荡更新路径平滑凸问题一定收敛到全局最优实现最简单、理论最扎实四、数学公式一步步看懂1. 损失函数以均方误差 MSE 为例J(θ)1m∑i1m(hθ(x(i))−y(i))2 J(\theta)\frac{1}{m}\sum_{i1}^m\left(h_\theta(x^{(i)})-y^{(i)}\right)^2J(θ)m1​i1∑m​(hθ​(x(i))−y(i))2m样本总数hθ(x)模型预测值2. 梯度对每个参数求偏导∂J∂θj1m∑i1m(hθ(x(i))−y(i))xj(i) \frac{\partial J}{\partial\theta_j}\frac{1}{m}\sum_{i1}^m\left(h_\theta(x^{(i)})-y^{(i)}\right)x_j^{(i)}∂θj​∂J​m1​i1∑m​(hθ​(x(i))−y(i))xj(i)​3. 参数更新核心公式θjθj−η⋅∂J∂θj \theta_j \theta_j - \eta \cdot \frac{\partial J}{\partial\theta_j}θj​θj​−η⋅∂θj​∂J​η学习率步长方向沿梯度反方向下降五、BGD 完整算法流程4步背会初始化参数θ随机用全部数据计算梯度按学习率更新参数重复直到损失收敛六、代码实战批量梯度下降训练线性回归房价预测直接复制可运行包含数据生成标准化BGD 训练损失曲线 预测对比图importnumpyasnpimportpandasaspdimportmatplotlib.pyplotaspltimportseabornassns# 1. 生成模拟房价数据 np.random.seed(42)defgenerate_data(num_samples):areanp.random.normal(1500,500,num_samples)agenp.random.normal(20,10,num_samples)roomsnp.random.randint(1,6,num_samples)pricearea*300age*(-1500)rooms*10000np.random.normal(0,50000,num_samples)returnpd.DataFrame({area:area,age:age,rooms:rooms,price:price})datagenerate_data(100000)# 标准化data(data-data.mean())/data.std()# 可视化价格分布 相关性plt.figure(figsize(14,5))plt.subplot(121)sns.histplot(data[price],kdeTrue)plt.title(房价分布)plt.subplot(122)sns.heatmap(data.corr(),annotTrue,cmapcoolwarm)plt.title(特征相关性)plt.show()# 2. 构建 X、y Xdata[[area,age,rooms]].values ydata[price].values.reshape(-1,1)mlen(X)# 添加偏置项截距Xnp.hstack((np.ones((m,1)),X))# 3. 批量梯度下降 BGD thetanp.random.randn(4,1)lr0.01iters1000cost_history[]defcompute_cost(X,y,theta):predX.dot(theta)returnnp.mean((pred-y)**2)/2foriinrange(iters):# 计算梯度全部样本gradX.T.dot(X.dot(theta)-y)/m# 更新参数theta-lr*grad# 记录损失costcompute_cost(X,y,theta)cost_history.append(cost)ifi%1000:print(fIter{i:4d}| Cost:{cost:.6f})# 4. 损失曲线 plt.figure(figsize(12,5))plt.plot(cost_history,b-,linewidth2)plt.title(批量梯度下降 损失收敛曲线)plt.xlabel(迭代次数)plt.ylabel(损失)plt.grid()plt.show()# 5. 预测 vs 真实 predX.dot(theta)plt.figure(figsize(12,5))plt.scatter(y,pred,alpha0.2)plt.plot([y.min(),y.max()],[y.min(),y.max()],r-,linewidth2)plt.title(真实价格 vs 预测价格)plt.xlabel(真实)plt.ylabel(预测)plt.grid()plt.show()七、批量梯度下降BGD优点梯度最准确无噪声方向最稳收敛可靠凸问题必到全局最优更新平滑不震荡、不抖动理论保证强最容易分析收敛性八、BGD 缺点非常关键速度极慢每步都要遍历全量数据占内存大必须把数据全部放进内存无法在线学习不能流式处理数据大数据不适用千万级别样本基本跑不动九、BGD vs SGD vs Mini-batch GD速记表算法每次用多少数据速度梯度噪声稳定性适用场景BGD全部最慢无最稳小数据集、凸优化SGD1个最快大震荡大数据、深度学习Mini-batch一小批中中较稳深度学习通用十、BGD 适用场景直接照抄✅适合用 BGD数据集较小万级以内线性回归、逻辑回归等凸优化希望训练过程最稳定教学、实验、推导❌不适合大数据集深度学习CNN、Transformer内存有限设备十一、一句话总结批量梯度下降BGD是最原始、最稳定、最准确的梯度下降版本但因为每一步都要跑完所有数据速度极慢只适合小数据集与教学演示。

更多文章