Fixed a calculation error in the backpropagation process (#607)

This commit is contained in:
DaGang 2021-01-25 12:15:36 +08:00 коммит произвёл GitHub
Родитель b69d6f42d5
Коммит f7aefbed59
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 11 добавлений и 11 удалений

Просмотреть файл

@ -157,7 +157,7 @@ $$
\begin{aligned}
\frac{dJ}{dW2}&=\frac{dJ}{dZ2}\frac{dZ2}{dW2}+\frac{dJ}{dW2}
\\
&=(Z2-Y)\cdot A1^T+\lambda \odot W2
&=A1^T\cdot (Z2-Y)+\lambda \odot W2
\end{aligned}
\tag{9}
$$
@ -170,9 +170,9 @@ $$dB2=dZ2 \tag{10}$$
再继续反向传播到第一层网络:
$$dZ1 = W2^T \times dZ2 \odot A1 \odot (1-A1) \tag{11}$$
$$dZ1 = dZ2 \cdot W2^T \odot A1 \odot (1-A1) \tag{11}$$
$$dW1= dZ1 \cdot X^T + \lambda \odot W1 \tag{12}$$
$$dW1= X^T \cdot dZ1+ \lambda \odot W1 \tag{12}$$
$$dB1= dZ1 \tag{13}$$
@ -183,13 +183,13 @@ $$dB1= dZ1 \tag{13}$$
dZ = delta_in
m = self.x.shape[1]
if self.regular == RegularMethod.L2:
self.weights.dW = (np.dot(dZ, self.x.T) + self.lambd * self.weights.W) / m
self.weights.dW = (np.dot(self.x.T, dZ) + self.lambd * self.weights.W) / m
else:
self.weights.dW = np.dot(dZ, self.x.T) / m
self.weights.dW = np.dot(self.x.T, dZ) / m
# end if
self.weights.dB = np.sum(dZ, axis=1, keepdims=True) / m
delta_out = np.dot(self.weights.W.T, dZ)
delta_out = np.dot(dZ, self.weights.W.T)
if len(self.input_shape) > 2:
return delta_out.reshape(self.input_shape)

Просмотреть файл

@ -147,10 +147,10 @@ $$J(w,b) = J_0 + \lambda (\lvert W1 \rvert+\lvert W2 \rvert)$$
$$
\begin{aligned}
dW2&=\frac{dJ}{dW2}=\frac{dJ}{dZ2}\frac{dZ2}{dW2}+\frac{dJ}{dW2} \\\\
&=dZ2 \cdot A1^T+\lambda \odot sign(W2)
&=A1^T \cdot dZ2+\lambda \odot sign(W2)
\end{aligned}
$$
$$dW1= dZ1 \cdot X^T + \lambda \odot sign(W1) $$
$$dW1= X^T \cdot dZ1 + \lambda \odot sign(W1) $$
从上面的公式中可以看到正则项在方向传播过程中唯一影响的就是求W的梯度时要增加一个$\lambda \odot sign(W)$sign是符号函数返回该值的符号即1或-1。所以我们可以修改`FullConnectionLayer.py`中的反向传播函数如下:
@ -159,11 +159,11 @@ def backward(self, delta_in, idx):
dZ = delta_in
m = self.x.shape[1]
if self.regular == RegularMethod.L2:
self.weights.dW = (np.dot(dZ, self.x.T) + self.lambd * self.weights.W) / m
self.weights.dW = (np.dot(self.x.T, dZ) + self.lambd * self.weights.W) / m
elif self.regular == RegularMethod.L1:
self.weights.dW = (np.dot(dZ, self.x.T) + self.lambd * np.sign(self.weights.W)) / m
self.weights.dW = (np.dot(self.x.T, dZ) + self.lambd * np.sign(self.weights.W)) / m
else:
self.weights.dW = np.dot(dZ, self.x.T) / m
self.weights.dW = np.dot(self.x.T, dZ) / m
# end if
self.weights.dB = np.sum(dZ, axis=1, keepdims=True) / m
......