00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #include<lwpr.h>
00017
00018 void testLWPR(){
00019 uint i,k;
00020
00021
00022 doubleA X,Y;
00023 #if 1 //1D data
00024 grid(X,1,-1.,1.,1000);
00025 Y.resize(X.d0,1);
00026 for(i=0;i<X.d0;i++) Y(i,0)=::sin(5.*X(i,0));
00027 MT::save(X,"grid");
00028 #else //2D data
00029 grid(X,2,-1.,1.,30);
00030 Y.resize(X.d0,1);
00031 for(i=0;i<X.d0;i++) Y(i,0)=::sin(5.*X(i,0)) * ::sin(7.*X(i,1));
00032 #endif
00033
00034
00035 X*=2.;
00036 X+=1.;
00037
00038
00039 MT::IOraw=true;
00040 std::ofstream da("data");
00041 for(i=0;i<X.d0;i++) da <<X[i] <<Y[i] <<std::endl;
00042 if(X.d1==2){ Y.reshape(31,31); MT::save(Y,"dataY"); Y.reshape(X.d0,1); }
00043
00044
00045 Lwpr l;
00046 l.learnNormalization(X);
00047 for(k=0;k<3;k++){
00048 l.learn(X,Y);
00049 std::cout <<"iteration " <<k <<" #rfs=" <<l.rfsno <<std::endl;
00050 }
00051
00052
00053 l.report(std::cout);
00054
00055
00056
00057
00058
00059
00060 doubleA Z1;
00061 l.map(X,Z1);
00062
00063
00064 std::cout <<"MSE = " <<sumOfSqr(Z1-Y)/(double)Y.d0 <<std::endl;
00065
00066
00067 std::ofstream os("out");
00068 for(i=0;i<X.d0;i++) os <<X[i] <<Z1[i] <<' ' <<Z1[i]+l.confidence(X[i])<<' ' <<Z1[i]-l.confidence(X[i]) <<std::endl;
00069 if(X.d1==2){ Z1.reshape(31,31); MT::save(Z1,"outY"); Z1.reshape(X.d0,1); }
00070
00071
00072 l.save("model");
00073
00074
00075 Lwpr ll;
00076 ll.useNormalization(l.normMean,l.normTrans);
00077 ll.load("model");
00078 doubleA Z2;
00079 ll.map(X,Z2);
00080 std::cout <<"maximum difference in output between original and reloaded = "
00081 <<absMax(Z1-Z2) <<std::endl;
00082
00083
00084 std::cout <<"trying to display output..." <<std::endl;
00085 if(X.d1==1) gnuplot("plot 'data' title 'true','out' us 1:2 title 'learned mean','out' us 1:3 title 'upper bound','out' us 1:4 title 'lower bound'");
00086 else gnuplot("splot 'dataY' matrix title 'true','outY' matrix title 'learned'");
00087 }
00088
00089 int main(int argn,char** argv){
00090
00091 MT::init(argn,argv);
00092
00093 testLWPR();
00094
00095 return 0;
00096 }