{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Modified from https://scikit-learn.org/stable/auto_examples/gaussian_process/plot_gpr_noisy_targets.html\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from matplotlib import cm\n", "\n", "from sklearn.gaussian_process import GaussianProcessRegressor\n", "from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a, c = (6, 1)\n", "def f(x):\n", " phi = np.array([np.sin(a*x[0]), np.sin(a*c*x[1]), 2*x[0], 2*c*x[1]])\n", " return phi.T@phi" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# visualize the real function, f\n", "%matplotlib notebook\n", "%matplotlib notebook\n", "\n", "nx1, nx2 = (50, 50)\n", "x1 = np.linspace(-1, 1, nx1)\n", "x2 = np.linspace(-1, 1, nx2)\n", "x1v, x2v = np.meshgrid(x1, x2)\n", "\n", "fx_s = np.zeros_like(x1v)\n", "for i in range(nx1):\n", " for j in range(nx2):\n", " x = np.array([x1[i], x2[j]])\n", " fx_s[i,j] = f(x)\n", "\n", "fig = plt.figure(figsize=(8,4))\n", "ax1 = fig.add_subplot(121, projection=\"3d\")\n", "surf = ax1.plot_surface(x1v,x2v,fx_s,cmap=cm.coolwarm)\n", "fig.colorbar(surf)\n", "ax1.set_xlabel('x1')\n", "ax1.set_ylabel('x2')\n", "ax1.set_zlabel('f')\n", "ax2 = fig.add_subplot(122)\n", "surf2 = plt.contour(x1v,x2v,fx_s,levels=20,cmap=cm.coolwarm)\n", "fig.colorbar(surf2)\n", "ax2.set_xlabel('x1')\n", "ax2.set_ylabel('x2')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Instantiate a Gaussian Process model\n", "l = .3\n", "kernel = RBF(l, (1e-2, 1e2))\n", "gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)\n", "\n", "X = np.expand_dims(np.random.rand(2)*2-1, axis=0) # uniformly sample from [-1,1]^2\n", "y = np.expand_dims(f(X[0]), axis=0)\n", "gp.fit(X, y)\n", "y_opt_old = np.inf\n", "X_candidate = np.stack([x1v.reshape(-1), x2v.reshape(-1)], axis=1)\n", "for i in range(300):\n", " y_pred, sigma = gp.predict(X_candidate, return_std=True)\n", " UCB = -y_pred+sigma\n", " max_ind = np.argmax(UCB)\n", " x_opt = X_candidate[max_ind,:]\n", " y_opt = f(x_opt)\n", " \n", " X = np.concatenate([X, np.expand_dims(x_opt, axis=0)], axis = 0)\n", " y = np.append(y, y_opt)\n", " # Fit to data using Maximum Likelihood Estimation of the parameters\n", " gp.fit(X, y)\n", " \n", " if (y_opt_old- y_opt)**2 < 1e-7 and i > 10:\n", " break\n", " \n", " y_opt_old = y_opt\n", " \n", " if (i+1)%10 == 0:\n", " fig = plt.figure(figsize=(8,4))\n", " ax1 = fig.add_subplot(121)\n", " surf1 = plt.contourf(x1v,x2v,y_pred.reshape(nx1,nx2),levels=20,cmap=cm.coolwarm)\n", " fig.colorbar(surf1)\n", "\n", " ax2 = fig.add_subplot(122)\n", " surf2 = plt.contourf(x1v,x2v,sigma.reshape(nx1,nx2)**2,levels=20,cmap=cm.coolwarm)\n", " plt.clim([0,1])\n", " fig.colorbar(surf2)\n", "\n", " ax1.plot(X[:,0], X[:,1], 'ro', ms=5)\n", " ax1.set_xlabel('x1')\n", " ax1.set_ylabel('x2')\n", " ax1.plot(x_opt[0], x_opt[1], 'y*', ms=10)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "fig = plt.figure(figsize=(8,4))\n", "ax1 = fig.add_subplot(121)\n", "surf1 = plt.contourf(x1v,x2v,y_pred.reshape(nx1,nx2),levels=20,cmap=cm.coolwarm)\n", "fig.colorbar(surf1)\n", "\n", "ax2 = fig.add_subplot(122)\n", "surf2 = plt.contourf(x1v,x2v,sigma.reshape(nx1,nx2)**2,levels=20,cmap=cm.coolwarm)\n", "plt.clim([0,1])\n", "fig.colorbar(surf2)\n", "\n", "ax1.plot(X[:,0], X[:,1], 'ro', ms=5)\n", "ax1.set_xlabel('x1')\n", "ax1.set_ylabel('x2')\n", "ax1.plot(x_opt[0], x_opt[1], 'k*', ms=10)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }