introduction  —  namespaces  —  modules  —  classes  —  files  —  globals  —  members  —  examples  —  Marc Toussaint

lwpr.h

00001 /*  Copyright (C) 2000, 2006  Marc Toussaint (mtoussai@inf.ed.ac.uk)
00002     under the terms of the GNU LGPL (http://www.gnu.org/copyleft/lesser.html)
00003     see the `std.h' file for a full copyright statement  */
00004 
00005 /*  LWPR is a regresion technique mainly developed by Sethu
00006     Vijayakumar. This file tests an implementation of this algorithm,
00007     coded by Narayanan Edakunni and Giorgos Petkos, cleaned up and
00008     made more efficient by myself.
00009 
00010     When using this code, please cite
00011 
00012     Sethu Vijayakumar, Aaron D'Souza and Stefan Schaal, "Incremental Online 
00013     Learning in High Dimensions", Neural Computation, 17:12, pp. 2602-2632 
00014     (2005) */
00015 
00016 #ifndef GP_LWPR_H
00017 #define GP_LWPR_H
00018 
00019 #include "array.h"
00020 
00021 /*const double initD=1.0;
00022 const double init_lambda = 0.999;
00023 const double tau_lambda = 0.99999;
00024 const double final_lambda = 0.9999;
00025 const double add_threshold = 0.5;
00026 const bool meta = false;
00027 const double meta_rate = 25;
00028 */
00029 
00032 class Lmodel {
00033 public:
00034   double initD;//=1.0;
00035   double init_lambda;// = 0.999;
00036   double tau_lambda;// = 0.99999;
00037   double final_lambda;// = 0.9999;
00038   double add_threshold;// = 0.5;
00039   bool meta;// = false;
00040   double meta_rate;// = 25;
00041   
00042   doubleA center;  
00043   doubleA mean_x;  
00044   double mean_y;   
00045   doubleA  W;    
00046   double MyMSE_R;  
00047   doubleA  lambda; 
00048   doubleA  M;    
00049   doubleA  n_data; 
00050 
00051   
00052   MT::Array<doubleA> uprojections;   
00053   MT::Array<doubleA> pprojections;   
00054   MT::Array<doubleA> sXresYres;
00055   doubleA betas;   
00056   doubleA MSE;     
00057 
00058   doubleA azz;
00059   doubleA azres;
00060   MT::Array<doubleA> axz;
00061   
00062   //distance update statistics
00063   double sum_e2;
00064   doubleA  sum_ecv2;
00065   doubleA  aH;
00066   doubleA  aG;
00067   double aE;
00068   
00069   //penalty
00070   double gamma;
00071   //learning param
00072   doubleA  alpha;
00073   
00074   doubleA  b;
00075   doubleA  h;
00076   
00077   double apk;
00078   doubleA  dist;
00079 
00080 public:
00081   int R;   
00082   bool trustworthy;
00083   bool degenerated; 
00084   
00085 public:
00086   Lmodel(){};
00087   Lmodel(const doubleA& x, double y);
00088   Lmodel(const Lmodel& model, const doubleA& x, double y);
00089   double activation(const doubleA& x);
00090   void updateMean(const doubleA& x, double y);
00091   void updateStat(double w);
00092   void updateError(const doubleA& x, double y);
00093   void updateError_matlab(double, double); // emulates the error measure that's used in Matlab's LWPR implementation
00094   void update(const doubleA& x, double y);
00095   double predict(const doubleA& x ) const;
00096   void  update_dist(const doubleA& x, double y, double e_cv, double e);
00097   double gete(const doubleA& x, double y);
00098   double gete_cv(const doubleA& x, double y);
00099   bool check_add_projection(const doubleA& x, double y);
00100   void add_projection(const doubleA& x, double y);
00101   void printProj();
00102   friend std::ostream& operator<<(std::ostream& outs, const Lmodel& lm);
00103 
00104   void save(std::ostream&);
00105   void load(std::istream&);
00106   
00107   double s2pred(const doubleA& x, double w);
00108   void updateApk(const doubleA& x,double w);
00109   const doubleA& getCenter() const{ return center; }
00110   
00111 private :
00112   void computeZs(doubleA& zs, const doubleA& x);
00113   void loadParameters(uint N);
00114   void initialise(int N);
00115   //void reinitialize();
00116   doubleA convert_vec(const doubleA& x);
00117   void getdJdM(doubleA& dJdM,const doubleA& D,const doubleA& M,
00118     const doubleA& x, const doubleA& center,
00119     const doubleA& z, const doubleA& azz, double w,
00120     double W, double e_cv, double e, double gamma,
00121     const doubleA& derivative_ok, double& sum_dJ1dw,
00122     doubleA& dwdM);
00123   
00124   double getsum_dJ1dw(double e_cv, double e, double w, double W,
00125     const doubleA& z_vec, const doubleA& azz_vec,
00126     const doubleA& derivative_ok);
00127   void update_dist_stat(double e_cv, double w, const doubleA& z_vec,
00128     const doubleA& azz_vec, double transient_multiplier,
00129     const doubleA& derivative_ok);
00130   void getdJ2dM(doubleA& dJ2dM,const doubleA& D, const doubleA& M, double gamma);
00131   void getdwdM(doubleA& dwdM,const doubleA& M, const doubleA& x, 
00132     const doubleA& center, double w);
00133   void getdDdMkl(doubleA& dDdMkl,const doubleA& M, int k, int l);
00134   
00135   //meta methods
00136   void meta_update(double e_cv, double W, double w, double e, const doubleA& aH,
00137     double aE, const doubleA& z_vec, 
00138     const doubleA& azz_vec, const doubleA& derivative_ok,
00139     const doubleA& x, const doubleA& center, double dJ1dw,
00140     const doubleA& dwdM, double transient_multiplier, const doubleA& dJdM);
00141   void getdJ2dJ2dMdM(doubleA& dJ2dJ2dMdM,const doubleA& D, const doubleA& M, 
00142     double gamma);
00143   void getdwdwdMdM(doubleA& dwdwdMdM,const doubleA& M, const doubleA& x, 
00144     const doubleA& center, double w);
00145   
00146   void cholesky(doubleA& L,const doubleA& A);
00147   intA find(const doubleA& A, double value, bool gt);
00148   double absmax(const doubleA& A);
00149 };
00150 
00161 class Lwpr {
00162 public:
00164   Lwpr();
00165 
00167   void learn(const doubleA& X,const doubleA& Y);
00168 
00170   void map(const doubleA& X,doubleA& Y);
00171   
00173   void updaterfs(const doubleA& X,const doubleA& Y){ learn(X,Y); }
00174 
00176   doubleA predict(const doubleA& x){ doubleA y; map(x,y); return y; }
00177 
00181   void useNormalization(const doubleA& mean,const doubleA& trans){
00182     normMean=mean; normTrans=trans;
00183   }
00184 
00188   void learnNormalization(doubleA& X);
00189 
00191   void save(char *);
00192 
00194   void load(char *,bool alsoParameters=true);
00195 
00197   double confidence(const doubleA &x){
00198     doubleA nx;
00199     normalizeInput(nx,x);
00200     return ::sqrt(s2pred(nx));
00201     //@todo allows for negative prediction confidences to exist (pathological cases)
00202     //return ( s2 >= 0 ? s2 : -s2 ); // this would swap sign in pathological cases
00203   }
00204   
00206   void report(std::ostream& os){
00207     uint i,j;
00208     os
00209       <<"LWPR report:"
00210       <<"\n inDim="<<in_dim <<" outDim="<<out_dim <<std::endl;
00211     for(i=0;i<models.N;i++){
00212       os <<i <<"th dimension: #models=" <<models(i).N <<std::endl;
00213       for(j=0;j<models(i).N;j++){
00214   os
00215     <<' ' <<j <<"th model: center="
00216     <<((models(i)(j)).center)
00217     <<" mean="
00218     <<((models(i)(j)).mean_x)
00219     <<std::endl;
00220       }
00221     }
00222   }
00223   
00225   int get_rfs_no(int out_dim);
00226 
00228   double get_proj_average(int out_dim);
00229 
00230   uint rfsno; 
00231 
00232   //normalizations
00233   doubleA normMean,normTrans;
00234 
00235 private:
00236   void initialize(uint inDim,uint outDim);
00237   double s2pred(const doubleA& x);
00238   double predict(const MT::Array<Lmodel>& rfs, const doubleA& x);
00239   void updaterfs(MT::Array<Lmodel>& rfs, const doubleA& x, double y);
00240   void normalizeInput(doubleA& normIn,const doubleA& unNormIn);
00241 
00242 
00243   bool isInitialized;
00244   uint in_dim;
00245   uint out_dim;   
00246   bool add_proj;
00247   bool updateD;
00248   double cut_off;
00249   double w_prune;
00250   double w_gen;
00251 
00252   MT::Array<MT::Array<Lmodel> > models;  //main data structure!
00253 };
00254 
00255 #ifdef MT_IMPLEMENTATION
00256 #include "lwpr.cpp"
00257 #endif
00258 
00259 #endif
[]