{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def myRelu(z):\n", " x = np.clip(z,0,np.inf)\n", " return x\n", "def dMyRelu(z):\n", " tmp = (z>0).astype(float)\n", " J = np.diag(tmp)\n", " return J" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "z = [ 0.04424291 -0.66054354 1.58899743 -0.30882183 -0.06852131],\n", "x = [0.04424291 0. 1.58899743 0. 0. ],\n", "J = [[1. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 1. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]]\n" ] } ], "source": [ "z = np.random.normal(0,1,5)\n", "x = myRelu(z)\n", "J = dMyRelu(z)\n", "print(\"z = {},\\nx = {},\\nJ = {}\".format(z, x, J))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-144.7493557]\n" ] } ], "source": [ "h = [10, 256, 512, 1]\n", "depth = len(h)-1\n", "W = [np.random.normal(0,1,(h[d+1],h[d])) for d in range(depth)]\n", "# print(W)\n", "\n", "def f(x):\n", " z_hist = []; x_hist = []\n", " for d in range(depth-1):\n", " z = W[d]@x\n", " x = myRelu(z)\n", " z_hist.append(z)\n", " x_hist.append(x)\n", " return W[-1]@x, z_hist, x_hist\n", "\n", "x0 = np.random.normal(0,1,h[0])\n", "print(f(x0)[0])\n", "# def f(x0):\n", "# z1 = W[0]@x0\n", "# x1 = myRelu(z1)\n", "# z2 = W[1]@x1\n", "# x2 = myRelu(z2)\n", "# f = W[2]@x2\n", "# return f" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forward, J.shape = (256, 10)\n", "forward, J.shape = (512, 10)\n", "forward, J.shape = (1, 10)\n", "[[-258.36554866 -97.6021517 -204.43489232 -57.40188988 57.19934187\n", " 41.10554691 142.12464324 -144.20916934 -62.9182526 246.171591 ]]\n", "backward, J.shape = (1, 512)\n", "backward, J.shape = (1, 256)\n", "backward, J.shape = (1, 10)\n", "[[-258.36554866 -97.6021517 -204.43489232 -57.40188988 57.19934187\n", " 41.10554691 142.12464324 -144.20916934 -62.9182526 246.171591 ]]\n" ] } ], "source": [ "def df_forward(x): \n", " _, z_hist, x_hist = f(x)\n", " J = W[0]\n", " print(\"forward, J.shape = {}\".format(J.shape))\n", " for d in range(depth-1):\n", " J = dMyRelu(z_hist[d])@J\n", " J = W[d+1]@J\n", " print(\"forward, J.shape = {}\".format(J.shape))\n", " return J\n", "\n", "def df_backward(x):\n", " _, z_hist, x_hist = f(x)\n", " J = W[-1]\n", " print(\"backward, J.shape = {}\".format(J.shape))\n", " for d in reversed(range(depth-1)):\n", " J = J@dMyRelu(z_hist[d])\n", " J = J@W[d]\n", " print(\"backward, J.shape = {}\".format(J.shape))\n", " return J\n", "\n", "\n", "print(df_forward(x0))\n", "print(df_backward(x0))\n", "# print(np.linalg.norm(df_forward(x0)-df_backward(x0), np.inf))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "forward, J.shape = (256, 10)\n", "forward, J.shape = (512, 10)\n", "forward, J.shape = (1, 10)\n", "Result of gradient checking: True\n", "backward, J.shape = (1, 512)\n", "backward, J.shape = (1, 256)\n", "backward, J.shape = (1, 10)\n", "Result of gradient checking: True\n" ] } ], "source": [ "def gradientChecking(x, f, df):\n", " n = x.shape[0]\n", " d = f(x)[0].shape[0]\n", " J, eps = np.zeros((d,n)), 1e-6\n", " for i in range(n):\n", " ei = np.eye(n)[:,i]\n", " J[:,i] = (f(x+eps*ei)[0]-f(x-eps*ei)[0])/(2*eps)\n", " return np.linalg.norm(J-df(x), np.inf) < 1e-4\n", "\n", "\n", "x0 = np.random.normal(0,1,h[0])\n", "print(\"Result of gradient checking:\", gradientChecking(x0, f, df_forward))\n", "print(\"Result of gradient checking:\", gradientChecking(x0, f, df_backward))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }