AI 学习之利用指数加权平均值来进行动量梯度下降
动量梯度下降的核心是指数加权平均值的计算。
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
import math
import sklearn
import sklearn.datasets
from opt_utils import load_params_and_grads, initialize_parameters, forward_propagation, backward_propagation
from opt_utils import compute_cost, predict, predict_dec, plot_decision_boundary, load_dataset
from testCases import *
%matplotlib inline
plt.rcParams['figure.figsize'] = (7.0, 4.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
/Users/kaiyiwang/Code/AI/captainAI/8 动量梯度下降/opt_utils.py:76: SyntaxWarning: assertion is always true, perhaps remove parentheses?
assert(parameters['W' + str(l)].shape == layer_dims[l], layer_dims[l-1])
/Users/kaiyiwang/Code/AI/captainAI/8 动量梯度下降/opt_utils.py:77: SyntaxWarning: assertion is always true, perhaps remove parentheses?
assert(parameters['W' + str(l)].shape == layer_dims[l], 1)
1 - 动量梯度下降
因为mini-batch梯度下降每次只向一个子训练集集进行学习,学习的对象比较少,所以学习的方向会偏离得更加严重,学习路径就很曲折。而使用动量梯度下降技术会使学习路径更加平滑。
动量梯度下降会对之前的梯度值进行指数加权平均运算来得到更加平滑的学习路径。下图中红色的箭头就是使用了动量梯度下降后的学习路径,蓝色的虚线是原始的路径。可以看出新路径比老路径要平滑。
# 初始化指数加权平均值字典
def initialize_velocity(parameters):
L = len(parameters) // 2 # 获取神经网络的层数
v = {}
# 循环每一层
for l in range(L):
# 因为l是从0开始的,所以下面要在l后面加上1
# zeros_like会返回一个与输入参数维度相同的数组,而且将这个数组全部设置为0
# 指数加权平均值字典的维度应该是与梯度字典一样的,而梯度字典是与参数字典一样的,所以zeros_like的输入参数是参数字典
v["dW" + str(l + 1)] = np.zeros_like(parameters["W" + str(l+1)])
v["db" + str(l + 1)] = np.zeros_like(parameters["b" + str(l+1)])
return v
parameters = initialize_velocity_test_case()
v = initialize_velocity(parameters)
print("v[\"dW1\"] = " + str(v["dW1"]))
print("v[\"db1\"] = " + str(v["db1"]))
print("v[\"dW2\"] = " + str(v["dW2"]))
print("v[\"db2\"] = " + str(v["db2"]))
v["dW1"] = [[0. 0. 0.]
[0. 0. 0.]]
v["db1"] = [[0.]
[0.]]
v["dW2"] = [[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
v["db2"] = [[0.]
[0.]
[0.]]
# 使用动量梯度下降算法来更新参数
def update_parameters_with_momentum(parameters, grads, v, beta, learning_rate):
L = len(parameters) // 2
# 遍历每一层
for l in range(L):
# 算出指数加权平均值。
# 下面的beta就相当于我们文章中的k。
# 看这段代码时应该回想一下我们文章中学到的“一行代码搞定指数加权平均值”的知识点
# vt = k*v(t-1) + (1-k)wt 指数加权平均算法
v["dW" + str(l + 1)] = beta * v["dW" + str(l + 1)] + (1 - beta) * grads['dW' + str(l + 1)]
v["db" + str(l + 1)] = beta * v["db" + str(l + 1)] + (1 - beta) * grads['db' + str(l + 1)]
# 用指数加权平均值来更新参数
parameters["W" + str(l + 1)] = parameters["W" + str(l + 1)] - learning_rate * v["dW" + str(l + 1)]
parameters["b" + str(l + 1)] = parameters["b" + str(l + 1)] - learning_rate * v["db" + str(l + 1)]
return parameters, v
parameters, grads, v = update_parameters_with_momentum_test_case()
parameters, v = update_parameters_with_momentum(parameters, grads, v, beta = 0.9, learning_rate = 0.01)
print("W1 = " + str(parameters["W1"]))
print("b1 = " + str(parameters["b1"]))
print("W2 = " + str(parameters["W2"]))
print("b2 = " + str(parameters["b2"]))
print("v[\"dW1\"] = " + str(v["dW1"]))
print("v[\"db1\"] = " + str(v["db1"]))
print("v[\"dW2\"] = " + str(v["dW2"]))
print("v[\"db2\"] = " + str(v["db2"]))
W1 = [[ 1.62544598 -0.61290114 -0.52907334]
[-1.07347112 0.86450677 -2.30085497]]
b1 = [[ 1.74493465]
[-0.76027113]]
W2 = [[ 0.31930698 -0.24990073 1.4627996 ]
[-2.05974396 -0.32173003 -0.38320915]
[ 1.13444069 -1.0998786 -0.1713109 ]]
b2 = [[-0.87809283]
[ 0.04055394]
[ 0.58207317]]
v["dW1"] = [[-0.11006192 0.11447237 0.09015907]
[ 0.05024943 0.09008559 -0.06837279]]
v["db1"] = [[-0.01228902]
[-0.09357694]]
v["dW2"] = [[-0.02678881 0.05303555 -0.06916608]
[-0.03967535 -0.06871727 -0.08452056]
[-0.06712461 -0.00126646 -0.11173103]]
v["db2"] = [[0.02344157]
[0.16598022]
[0.07420442]]
注意 :
- 这里的指数加权平均值是没有添加修正算法的。所以在前面一小段的梯度下降中,趋势平均值是不准确的。
- 如果$\beta = 0$,那么上面的就成了一个普通的标准梯度下降算法了。
如何选择$\beta$?
- $\beta$越大,那么学习路径就越平滑,因为与指数加权平均值关系紧密的梯度值就越多。但是,如果$\beta$太大了,那么它就不能准确地实时反应出梯度的真实情况了.
- 一般来说,$\beta$的取值范围是0.8到0.999。$\beta = 0.9$是最常用的默认值。
- 当然,你可以尝试0.9之外的值,也许能找到一个更合适的值。建议大家尝试尝试。
大家需要记住下面几点:
- 动量梯度下降算法通过之前的梯度值而算出指数加权平均值,而使学习路径更加平滑。这个算法可以运用在batch梯度下降中,也可以运用在mini-batch梯度下降和随机梯度下降中。
- 如果使用这个算法,那么就又多了一个超参数$\beta$了.
为者常成,行者常至
自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)