{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def gradientChecking(x, f, df):\n", " n = x.shape[0]\n", " d = f(x).shape[0]\n", " J, eps = np.zeros((d,n)), 1e-6\n", " for i in range(n):\n", " ei = np.eye(n)[:,i]\n", " J[:,i] = (f(x+eps*ei)-f(x-eps*ei))/(2*eps)\n", " return np.linalg.norm(J-df(x), np.inf) < 1e-4" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Result of gradient checking: True\n", "\n", " x = [-0.05552067 0.67402907 0.50368608 0.34202009 -0.17492368 0.07465995\n", " 0.27421375 -0.89329098 -1.49934746 0.53344681 0.45551155 1.41957592\n", " -1.18824559 -0.12788941 -0.12112304 -0.50782059 -0.26612117 0.58705633\n", " -0.39995979 0.18904155 -0.75398675 0.27202379 -0.56623774 -1.53722522\n", " 0.3019853 -2.23411454 -0.2298651 0.1080393 -1.55767315 1.88821005\n", " -0.91237782 -0.72130771 -0.88242758 0.09169899 0.65991994 0.22093076\n", " -0.6348826 -0.41352092 -2.07294686 0.49106272 -2.22132309 0.20741678\n", " -0.03614922 -1.72870479 -0.67439975 -0.53813268 -0.76132184 -1.68098301\n", " -0.46476026 -1.37816936 0.44037233 -0.76425912 1.05805956 0.31334026\n", " -1.08393679 -2.44001119 -0.69327004 -0.56234646 0.69830479 -0.21314466\n", " -0.08647727 -1.58922865 -2.56622169 0.87126578 -0.07122054 1.05183164\n", " 0.04007953 0.33856557 -0.77994659 1.11047719 -1.18590547 -0.3426766\n", " -0.26779626 -2.03887823 0.15348207 -0.06604659 1.30965932 -0.51515658\n", " -0.09195049 0.41267759 1.53913321 0.34747779 1.44723936 -0.91308325\n", " 1.57984889 0.29403069 -0.12787797 0.29036943 0.15274965 1.81815799\n", " -1.20616718 -2.3639856 0.93213294 0.58914528 -0.88670594 1.89095063\n", " -0.42616178 0.58913204 0.22754284 1.1174765 0.39806539 2.13148983\n", " -1.73988912 1.86166589 0.38533626 -0.70185231 0.62967356 -0.38993674\n", " -0.16558697 -0.22131086 -0.05769881 0.5899888 1.52225214 0.07744338\n", " 0.34200911 0.39776661 1.5443236 -0.01875822 -0.45153624 -1.07168829\n", " -0.57609912 -3.34045257 1.74671959 0.93751468 -1.26404285 -0.41917274\n", " -0.62373011 -2.53453508 -0.95015278 1.14708753 -1.72452924 -0.37169184\n", " 0.57286086 -0.22513731 0.74029086 -1.20075405 -1.47347247 -1.68421691\n", " -1.40008734 -1.18040522 1.50272066 -0.6967215 2.0343898 -1.45055896\n", " 1.3832795 -1.10554338 -0.28642787 0.38168838 -1.47609647 0.13120172\n", " 0.92432388 1.69139622 0.6561509 0.97099279 -0.53069475 0.51673201\n", " 1.22440477 -0.85065683 -0.46514794 -1.01227975 0.16332421 -0.38627518\n", " 0.12840765 -1.62909114 -0.49245356 -1.0721438 -0.13946779 0.27650121\n", " -0.39656159 0.88611064 -2.25677651 -0.20983128 0.21832808 0.27413769\n", " 1.20821516 0.65826114 1.26632251 -0.67664373 -0.56375033 -0.36308099\n", " 1.20209996 0.22466048 -1.18500913 -0.03523694 -0.60324785 -0.01465519\n", " 0.31197191 1.16987568 0.79612668 -0.11316887 -1.1910813 -1.57164011\n", " -1.74664369 -0.79895321 -1.35873374 0.41072751 -0.31272729 0.83270448\n", " 0.11284564 -1.00307481],\n", " f = [ 10.24481871 6.70351793 19.2002201 -5.70864037 -2.70403052\n", " -14.67261945 28.70543061 10.68066818 9.97375689 -12.78820011\n", " -2.54724471 17.81641489 -23.27731001 -23.9155024 6.55855495\n", " -9.30748488 7.50791194 -36.73043672 -8.52517481 0.30184681\n", " 0.89607502 30.97130052 -1.18941493 18.98992385 -13.51580593\n", " -2.60774774 -4.47703142 20.08646973 11.18212292 -4.27775256\n", " 31.54627774 29.96644123 -7.02425659 -1.90643492 0.39699102\n", " 10.49048787 6.46982536 23.69904473 -12.48788537 -5.61111319\n", " -10.54789752 -16.11567033 9.9885964 32.22934685 8.66189225\n", " -6.60336043 28.27017398 13.91786913 0.48355928 3.20881219\n", " 9.29336702 -13.5495338 10.82773294 -2.3766683 -13.90454403\n", " 14.61577829 -1.38965197 -29.85012428 -8.24912761 -11.29189546\n", " 1.98339976 23.95272313 -8.19055587 -3.18435714 10.63111417\n", " 7.3834312 -6.23434374 -14.99850984 -17.78726101 -3.87169721\n", " -4.38332429 16.33428892 2.37522157 21.37888782 20.97592971\n", " -13.31077939 -22.47243223 17.52403201 13.12776887 21.65389046\n", " -7.48645507 -19.60477423 -38.94529913 -0.16341566 0.69969398\n", " 5.66710627 -12.14173202 0.23859072 11.31489966 -2.12131111\n", " 9.27616179 5.9495566 6.43664934 29.5957738 -0.6264163\n", " -19.34498338 17.73368491 5.98538416 14.66707208 51.39664093\n", " 5.94129911 -8.79745733 2.4122948 4.38775849 29.39446042\n", " 5.85732312 0.70647965 12.22930766 -2.7429887 13.15783304\n", " 32.22284892 10.77215193 -15.52020925 22.70529709 38.03868108\n", " 11.87197901 0.58553963 -16.33183978 3.15136868 -9.53250862\n", " -4.12880728 5.47332525 -12.53317245 -10.74881476 9.96261829\n", " 37.74471223 -8.55603272 10.79746545 7.78730151 -1.22450702\n", " -5.80137338 -9.9386949 11.54981987 -1.45976532 5.82253474\n", " -21.14419795 -3.46213982 34.06262374 15.96642132 -14.6278741\n", " 13.77709641 7.76580021 21.8305501 4.21483937 12.41093651\n", " -15.75126333 13.9911712 -11.05086304 -12.67054749 10.08739237\n", " 9.41185059 -17.6764974 -20.9219123 -3.59004252 -2.8411478\n", " -13.00857412 -5.56320956 7.12109161 5.78644988 17.96119868\n", " 26.71918865 6.42260649 7.17037968 -2.56065586 -22.374526\n", " -6.76422102 6.25729384 -15.45862898 -17.93532238 -9.24173027\n", " 17.65395473 -16.74008162 -11.20079769 -9.17099142 -7.58394017\n", " 9.70837015 -31.29710048 27.69571325 15.87545474 -14.93375024\n", " -5.39623567 21.53265651 -0.94135743 23.19256533 -21.95340624\n", " -8.50724853 17.45666499 24.44390904 7.67596803 15.76590943\n", " 13.84952277 0.22386026 -0.93808568 -5.73234375 11.17469647\n", " 3.60563719 1.88851345 3.1077788 12.94349423 -30.64355924\n", " -9.86622367 21.13536421 -4.60442932 7.22207331 -16.0865235\n", " -18.5727915 -20.31596728 -18.25342214 -6.77299132 33.49843442\n", " 22.70315172 -1.84184725 -10.24948694 -0.62157667 8.27985331\n", " 2.75358529 -3.88039698 -21.80793681 -9.75035949 11.57788482\n", " -23.67870784 17.79748144 -1.99002044 14.27806701 -11.34349002\n", " 4.5351635 8.76709766 -13.88468626 8.50189036 -12.68223645\n", " -12.69021527 -18.34040281 4.1201505 15.05315406 21.87816612\n", " 9.44168025 15.57884057 7.69147552 7.04836901 -6.98888763\n", " -20.97495436 -20.61312186 -19.66670249 6.40077102 -12.1730631\n", " 13.37101266 4.88981661 -12.22856561 -15.30910062 2.20038606\n", " -20.12444224 22.82841218 2.53624742 -21.5601334 -8.79103065\n", " -4.96806355 -14.32072922 -26.14990986 0.30542982 -6.92292299\n", " -11.83791031 -11.53637137 15.13648385 15.71770294 -4.25016216\n", " -13.45743755 -11.45626606 21.53415703 -26.0944011 -18.88423291\n", " -10.97861244 16.86007944 23.4364529 10.16461252 14.42642108\n", " -20.06724187 -16.02881954 6.61688375 14.35037663 -9.4631823\n", " -4.2568375 10.23456126 25.59317772 -15.04056625 -3.60254612\n", " 21.4447476 -31.66560442 -12.39590346 13.03115698 -4.86557388\n", " -17.66689595 0.92708649 -7.78715597 8.6626386 17.31254686\n", " 4.30901574 -22.17211774 -5.45158838 2.65772613 -15.36353039],\n", " df = [[-1.29680007e+00 -5.24213767e-01 9.48074946e-01 ... 7.23324055e-01\n", " -8.62981516e-02 -6.96272776e-01]\n", " [-3.09393108e-01 6.87811416e-01 1.42446002e-02 ... 9.51513879e-01\n", " 4.00767769e-01 -1.51265750e+00]\n", " [-1.06619206e-01 -5.82045513e-01 7.97363540e-01 ... 1.93772772e+00\n", " 5.81187856e-01 -7.29988665e-01]\n", " ...\n", " [-1.29399616e+00 6.04368043e-01 1.13191710e+00 ... -4.29574844e-01\n", " -1.17671525e+00 -1.47999948e-01]\n", " [ 8.22082364e-01 1.35832052e+00 5.65504837e-01 ... -3.47498868e-01\n", " -4.22178957e-01 6.05669253e-01]\n", " [ 9.32203533e-01 9.71323500e-01 3.05272457e-01 ... 1.03043192e-01\n", " -1.37506258e-03 7.24442799e-01]]\n" ] } ], "source": [ "n1 = 200\n", "m1 = 300\n", "\n", "x1 = np.random.normal(0,1,n1)\n", "A = np.random.normal(0,1,(m1,n1))\n", "\n", "# f1 = lambda x:A@x \n", "# df1 = lambda x:A\n", "def f1(x):\n", " return A@x\n", "def df1(x):\n", " return A\n", "\n", "print(\"Result of gradient checking:\", gradientChecking(x1, f1, df1))\n", "print(\"\\n x = {},\\n f = {},\\n df = {}\".format(x1,f1(x1),df1(x1)))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Result of gradient checking: True\n", "\n", " x = [ 0.8206599 -0.03344141 0.38553765 -0.81092066 0.58261421 -1.03857561\n", " 0.94632213 -1.60559213 0.42891458],\n", " f = [6.5563306],\n", " df = [ 1.64131981 -0.06688283 0.77107529 -1.62184132 1.16522842 -2.07715122\n", " 1.89264427 -3.21118426 0.85782916]\n" ] } ], "source": [ "n2 = 9\n", "x2 = np.random.normal(0,1,n2)\n", "\n", "# f2 = lambda x:np.array([x.T@x])\n", "# df2 = lambda x:2*x.T\n", "def f2(x):\n", " return np.array([x.T@x])\n", "def df2(x):\n", " return 2*x.T\n", "\n", "print(\"Result of gradient checking:\", gradientChecking(x2, f2, df2))\n", "print(\"\\n x = {},\\n f = {},\\n df = {}\".format(x2,f2(x2),df2(x2)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }