00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #ifndef GP_LWPR_H
00017 #define GP_LWPR_H
00018
00019 #include "array.h"
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00032 class Lmodel {
00033 public:
00034 double initD;
00035 double init_lambda;
00036 double tau_lambda;
00037 double final_lambda;
00038 double add_threshold;
00039 bool meta;
00040 double meta_rate;
00041
00042 doubleA center;
00043 doubleA mean_x;
00044 double mean_y;
00045 doubleA W;
00046 double MyMSE_R;
00047 doubleA lambda;
00048 doubleA M;
00049 doubleA n_data;
00050
00051
00052 MT::Array<doubleA> uprojections;
00053 MT::Array<doubleA> pprojections;
00054 MT::Array<doubleA> sXresYres;
00055 doubleA betas;
00056 doubleA MSE;
00057
00058 doubleA azz;
00059 doubleA azres;
00060 MT::Array<doubleA> axz;
00061
00062
00063 double sum_e2;
00064 doubleA sum_ecv2;
00065 doubleA aH;
00066 doubleA aG;
00067 double aE;
00068
00069
00070 double gamma;
00071
00072 doubleA alpha;
00073
00074 doubleA b;
00075 doubleA h;
00076
00077 double apk;
00078 doubleA dist;
00079
00080 public:
00081 int R;
00082 bool trustworthy;
00083 bool degenerated;
00084
00085 public:
00086 Lmodel(){};
00087 Lmodel(const doubleA& x, double y);
00088 Lmodel(const Lmodel& model, const doubleA& x, double y);
00089 double activation(const doubleA& x);
00090 void updateMean(const doubleA& x, double y);
00091 void updateStat(double w);
00092 void updateError(const doubleA& x, double y);
00093 void updateError_matlab(double, double);
00094 void update(const doubleA& x, double y);
00095 double predict(const doubleA& x ) const;
00096 void update_dist(const doubleA& x, double y, double e_cv, double e);
00097 double gete(const doubleA& x, double y);
00098 double gete_cv(const doubleA& x, double y);
00099 bool check_add_projection(const doubleA& x, double y);
00100 void add_projection(const doubleA& x, double y);
00101 void printProj();
00102 friend std::ostream& operator<<(std::ostream& outs, const Lmodel& lm);
00103
00104 void save(std::ostream&);
00105 void load(std::istream&);
00106
00107 double s2pred(const doubleA& x, double w);
00108 void updateApk(const doubleA& x,double w);
00109 const doubleA& getCenter() const{ return center; }
00110
00111 private :
00112 void computeZs(doubleA& zs, const doubleA& x);
00113 void loadParameters(uint N);
00114 void initialise(int N);
00115
00116 doubleA convert_vec(const doubleA& x);
00117 void getdJdM(doubleA& dJdM,const doubleA& D,const doubleA& M,
00118 const doubleA& x, const doubleA& center,
00119 const doubleA& z, const doubleA& azz, double w,
00120 double W, double e_cv, double e, double gamma,
00121 const doubleA& derivative_ok, double& sum_dJ1dw,
00122 doubleA& dwdM);
00123
00124 double getsum_dJ1dw(double e_cv, double e, double w, double W,
00125 const doubleA& z_vec, const doubleA& azz_vec,
00126 const doubleA& derivative_ok);
00127 void update_dist_stat(double e_cv, double w, const doubleA& z_vec,
00128 const doubleA& azz_vec, double transient_multiplier,
00129 const doubleA& derivative_ok);
00130 void getdJ2dM(doubleA& dJ2dM,const doubleA& D, const doubleA& M, double gamma);
00131 void getdwdM(doubleA& dwdM,const doubleA& M, const doubleA& x,
00132 const doubleA& center, double w);
00133 void getdDdMkl(doubleA& dDdMkl,const doubleA& M, int k, int l);
00134
00135
00136 void meta_update(double e_cv, double W, double w, double e, const doubleA& aH,
00137 double aE, const doubleA& z_vec,
00138 const doubleA& azz_vec, const doubleA& derivative_ok,
00139 const doubleA& x, const doubleA& center, double dJ1dw,
00140 const doubleA& dwdM, double transient_multiplier, const doubleA& dJdM);
00141 void getdJ2dJ2dMdM(doubleA& dJ2dJ2dMdM,const doubleA& D, const doubleA& M,
00142 double gamma);
00143 void getdwdwdMdM(doubleA& dwdwdMdM,const doubleA& M, const doubleA& x,
00144 const doubleA& center, double w);
00145
00146 void cholesky(doubleA& L,const doubleA& A);
00147 intA find(const doubleA& A, double value, bool gt);
00148 double absmax(const doubleA& A);
00149 };
00150
00161 class Lwpr {
00162 public:
00164 Lwpr();
00165
00167 void learn(const doubleA& X,const doubleA& Y);
00168
00170 void map(const doubleA& X,doubleA& Y);
00171
00173 void updaterfs(const doubleA& X,const doubleA& Y){ learn(X,Y); }
00174
00176 doubleA predict(const doubleA& x){ doubleA y; map(x,y); return y; }
00177
00181 void useNormalization(const doubleA& mean,const doubleA& trans){
00182 normMean=mean; normTrans=trans;
00183 }
00184
00188 void learnNormalization(doubleA& X);
00189
00191 void save(char *);
00192
00194 void load(char *,bool alsoParameters=true);
00195
00197 double confidence(const doubleA &x){
00198 doubleA nx;
00199 normalizeInput(nx,x);
00200 return ::sqrt(s2pred(nx));
00201
00202
00203 }
00204
00206 void report(std::ostream& os){
00207 uint i,j;
00208 os
00209 <<"LWPR report:"
00210 <<"\n inDim="<<in_dim <<" outDim="<<out_dim <<std::endl;
00211 for(i=0;i<models.N;i++){
00212 os <<i <<"th dimension: #models=" <<models(i).N <<std::endl;
00213 for(j=0;j<models(i).N;j++){
00214 os
00215 <<' ' <<j <<"th model: center="
00216 <<((models(i)(j)).center)
00217 <<" mean="
00218 <<((models(i)(j)).mean_x)
00219 <<std::endl;
00220 }
00221 }
00222 }
00223
00225 int get_rfs_no(int out_dim);
00226
00228 double get_proj_average(int out_dim);
00229
00230 uint rfsno;
00231
00232
00233 doubleA normMean,normTrans;
00234
00235 private:
00236 void initialize(uint inDim,uint outDim);
00237 double s2pred(const doubleA& x);
00238 double predict(const MT::Array<Lmodel>& rfs, const doubleA& x);
00239 void updaterfs(MT::Array<Lmodel>& rfs, const doubleA& x, double y);
00240 void normalizeInput(doubleA& normIn,const doubleA& unNormIn);
00241
00242
00243 bool isInitialized;
00244 uint in_dim;
00245 uint out_dim;
00246 bool add_proj;
00247 bool updateD;
00248 double cut_off;
00249 double w_prune;
00250 double w_gen;
00251
00252 MT::Array<MT::Array<Lmodel> > models;
00253 };
00254
00255 #ifdef MT_IMPLEMENTATION
00256 #include "lwpr.cpp"
00257 #endif
00258
00259 #endif