introduction  —  namespaces  —  modules  —  classes  —  files  —  globals  —  members  —  examples  —  Marc Toussaint

lwpr.h

00001 /*  Copyright (C) 2000, 2006  Marc Toussaint (mtoussai@inf.ed.ac.uk)
00002     under the terms of the GNU LGPL (http://www.gnu.org/copyleft/lesser.html)
00003     see the `std.h' file for a full copyright statement  */
00004 
00005 /*  This implements the online LWPR regression algorithm as described in:
00006     Sethu Vijayakumar, Aaron D'Souza and Stefan Schaal, "Incremental Online
00007     Learning in High Dimensions", Neural Computation, 17:12, pp. 2602-2632  (2005)
00008     (Please cite when using code) */
00009 
00010 #ifndef GP_LWPR_H
00011 #define GP_LWPR_H
00012 
00013 #include "array.h"
00014 
00015 
00016 //forward declaration
00017 class Lmodel;
00018 
00023 void getStdDeviations(doubleA& norm,const doubleA& X);
00024 
00025 
00026 
00029 class Lwpr {
00030 public:
00032   Lwpr();
00033 
00035   void learn(const doubleA& X,const doubleA& Y);
00036 
00038   void predict(const doubleA& X,doubleA& Y);
00039   
00041   double confidence(const doubleA &x){
00042     doubleA nx;
00043     normalizeInput(nx,x);
00044     return ::sqrt(s2pred(nx));
00045   }
00046   
00048   void save(char *);
00049 
00051   void load(char *,bool alsoParameters=true);
00052 
00054   void report(std::ostream& os);
00055   
00057   int get_rfs_no(int out_dim);
00058 
00060   double get_proj_average(int out_dim);
00061 
00063   int verbosity; 
00064 
00066   uint rfsno;    
00067 
00080   doubleA norm;
00081 
00085   bool add_proj;   
00086 
00088   bool updateD; 
00089 
00091   bool useNorm; 
00092 
00100   double cut_off; 
00105   double w_prune;
00110   double w_gen; 
00113   double initD; 
00116   double init_lambda;
00119   double tau_lambda; 
00122   double final_lambda; 
00123 
00130   double add_threshold; 
00131 
00133   bool meta; 
00134 
00136   double meta_rate; 
00137 
00141   int initR; 
00142 
00151   double alpha; 
00152 
00155   double gamma; 
00156 
00159   bool blend; 
00160 
00161 private:
00162   void initialize(uint inDim,uint outDim);
00163   double s2pred(const doubleA& x);
00164   double predict(const MT::Array<Lmodel>& rfs, const doubleA& x);
00165   void updaterfs(MT::Array<Lmodel>& rfs, const doubleA& x, double y);
00166   void normalizeInput(doubleA& normIn,const doubleA& unNormIn);
00167 
00168   bool isInitialized;
00169   uint in_dim;
00170   uint out_dim;   
00171 
00172   MT::Array<MT::Array<Lmodel> > models;  //main data structure!
00173 };
00174 
00175 //================================================================================
00176 
00177 #ifndef MT_doxy
00178 /* This class defines the local models used by Lwpr - for experts
00179    only. Please refer to the Lwpr class first. */
00180 class Lmodel {
00181 public:
00182   Lwpr *lwpr;
00183   
00184   doubleA center; // center of the model
00185   doubleA mean_x; // x00
00186   double  mean_y; // beta00
00187   doubleA W;    // running sum of weights
00188   double  MyMSE_R;// don't know
00189   doubleA lambda; // forgetting parameter
00190   doubleA M;    // cholesky decomposition of dist
00191   doubleA n_data; // number of data seen
00192 
00193   
00194   MT::Array<doubleA> uprojections;   //u projections
00195   MT::Array<doubleA> pprojections;   //p projections
00196   MT::Array<doubleA> sXresYres;
00197   doubleA betas;   // betas
00198   doubleA MSE;     // squared errors
00199 
00200   doubleA azz;
00201   doubleA azres;
00202   MT::Array<doubleA> axz;
00203   
00204   //distance update statistics
00205   double sum_e2;
00206   doubleA  sum_ecv2;
00207   doubleA  aH;
00208   doubleA  aG;
00209   double aE;
00210   
00211   //penalty
00212   double gamma;
00213   //learning param
00214   doubleA  alpha;
00215   
00216   doubleA  b;
00217   doubleA  h;
00218   
00219   double apk;
00220   doubleA  dist;
00221 
00222 public:
00223   int R;   // number of projections
00224   bool trustworthy;
00225   bool degenerated; // used for pruning of numerically degenerated local models
00226   
00227 public:
00228   Lmodel(){};
00229   Lmodel(Lwpr* _lwpr,const doubleA& x, double y);
00230   Lmodel(Lwpr* _lwpr,const Lmodel& model, const doubleA& x, double y);
00231   double activation(const doubleA& x);
00232   void updateMean(const doubleA& x, double y);
00233   void updateStat(double w);
00234   void updateError(const doubleA& x, double y);
00235   void updateError_matlab(double, double); // emulates the error measure that's used in Matlab's LWPR implementation
00236   void update(const doubleA& x, double y);
00237   double predict(const doubleA& x ) const;
00238   void  update_dist(const doubleA& x, double y, double e_cv, double e);
00239   double gete(const doubleA& x, double y);
00240   double gete_cv(const doubleA& x, double y);
00241   bool check_add_projection(const doubleA& x, double y);
00242   void add_projection(const doubleA& x, double y);
00243   void printProj();
00244   friend std::ostream& operator<<(std::ostream& outs, const Lmodel& lm);
00245 
00246   void save(std::ostream&);
00247   void load(Lwpr* _lwpr, std::istream&);
00248   
00249   double s2pred(const doubleA& x, double w);
00250   void updateApk(const doubleA& x,double w);
00251   const doubleA& getCenter() const{ return center; }
00252   
00253 private :
00254   void computeZs(doubleA& zs, const doubleA& x);
00255   void loadParameters(uint N);
00256   void initialise(int N);
00257   //void reinitialize();
00258   doubleA convert_vec(const doubleA& x);
00259   void getdJdM(doubleA& dJdM,const doubleA& D,const doubleA& M,
00260     const doubleA& x, const doubleA& center,
00261     const doubleA& z, const doubleA& azz, double w,
00262     double W, double e_cv, double e, double gamma,
00263     const doubleA& derivative_ok, double& sum_dJ1dw,
00264     doubleA& dwdM);
00265   
00266   double getsum_dJ1dw(double e_cv, double e, double w, double W,
00267     const doubleA& z_vec, const doubleA& azz_vec,
00268     const doubleA& derivative_ok);
00269   void update_dist_stat(double e_cv, double w, const doubleA& z_vec,
00270     const doubleA& azz_vec, double transient_multiplier,
00271     const doubleA& derivative_ok);
00272   void getdJ2dM(doubleA& dJ2dM,const doubleA& D, const doubleA& M, double gamma);
00273   void getdwdM(doubleA& dwdM,const doubleA& M, const doubleA& x, 
00274     const doubleA& center, double w);
00275   void getdDdMkl(doubleA& dDdMkl,const doubleA& M, int k, int l);
00276   
00277   //meta methods
00278   void meta_update(double e_cv, double W, double w, double e, const doubleA& aH,
00279     double aE, const doubleA& z_vec, 
00280     const doubleA& azz_vec, const doubleA& derivative_ok,
00281     const doubleA& x, const doubleA& center, double dJ1dw,
00282     const doubleA& dwdM, double transient_multiplier, const doubleA& dJdM);
00283   void getdJ2dJ2dMdM(doubleA& dJ2dJ2dMdM,const doubleA& D, const doubleA& M, 
00284     double gamma);
00285   void getdwdwdMdM(doubleA& dwdwdMdM,const doubleA& M, const doubleA& x, 
00286     const doubleA& center, double w);
00287   
00288   void cholesky(doubleA& L,const doubleA& A);
00289   intA find(const doubleA& A, double value, bool gt);
00290   double absmax(const doubleA& A);
00291 };
00292 #endif
00293 
00294 #ifdef MT_IMPLEMENTATION
00295 #include "lwpr.cpp"
00296 #endif
00297 
00298 #endif
[]