00001
00008 #ifndef ARTMAP_H
00009 #define ARTMAP_H
00010
00011 #include <iostream>
00012 #include <vector>
00013 #include <fstream>
00014 #include <valarray>
00015
00016 #include "Logger.h"
00017 #include "MsgException.h"
00018 #include "util.h"
00019
00020 using std::ifstream;
00021 using std::vector;
00022 using std::string;
00023 using std::istringstream;
00024
00025 #ifndef ARTMAP_DLL
00026 #define ARTMAP_DECLSPEC
00027 #else
00028 #ifdef ARTMAP_IMPORT
00029 #define ARTMAP_DECLSPEC __declspec(dllimport)
00030 #else
00031 #define ARTMAP_DECLSPEC __declspec(dllexport)
00032 #endif
00033 #endif
00034
00045 class ARTMAP_DECLSPEC artmap {
00046 public:
00048 typedef enum RunModeType { FUZZY, DEFAULT, IC, DISTRIB };
00049
00050 private:
00051 RunModeType NetworkType;
00052
00053 int M;
00054 int L;
00055
00056 float RhoBar;
00057 float RhoBarTest;
00058 float Alpha;
00059 float Beta;
00060 float Eps;
00061 float P;
00062
00063 int C;
00064
00065 int J;
00066 int K;
00067
00068 float rho;
00069
00071 float * A;
00072 float * x;
00073 float * y;
00074 float * Y;
00075 float * T;
00076 float * S;
00077 float * H;
00078 float * c;
00079 bool * lambda;
00080 float * sigma_i;
00081 float * sigma_k;
00082 int * kap;
00083 float * dKap;
00084 float * tIj;
00085 float * tJi;
00086
00087 bool dMapWeights;
00088
00089
00090 float Tu;
00091 float sum_x;
00092 int _2M;
00093 int N;
00094 int i, j, k;
00095
00096 void complementCode(float *a);
00097 int F0_to_F2_signal();
00098 void newNode();
00099 void CAM_distrib();
00100 void CAM_WTA();
00101 void F1signal_WTA();
00102 void F1signal_distrib();
00103 bool passesVigilance();
00104 int prediction_distrib();
00105 int prediction_WTA();
00106 void matchTracking();
00107 void creditAssignment();
00108 void resonance_distrib();
00109 void resonance_WTA();
00110 void growF2 (float factor);
00111
00121 float cost(float x) { return ((2-Alpha)*M - x); }
00122 ofstream *ostCategoryActivations;
00123 void toStr();
00124 void toStr_dimensions();
00125 void toStr_A();
00126 void toStr_nodeJTSH(int j);
00127 void toStr_nodeJdetails(int j);
00128 void toStr_nodeJtauIj(int j);
00129 void toStr_nodeJtauJi(int j);
00130 void toStr_x();
00131 void toStr_sigma_i();
00132 void toStr_sigma_k();
00133
00134 public:
00135 artmap (int M, int L);
00136 ~artmap();
00137 void train (float *a, int K);
00138 void test (float *a);
00144 float getOutput (int k) { return sigma_k[k]; }
00150 int getMaxOutputIndex () {
00151 std::valarray<float> outs = std::valarray<float> (sigma_k, L);
00152 return getIndexOfMaxElt (outs);
00153 }
00154
00155 void fwrite (ofstream &ofs);
00156 void fread (ifstream &ifs, string &specialRequest);
00157
00158 void setParam (const string &name, const string &value);
00160 int getC() { return C; }
00162 int getNodeClass (int j) { if ((j < 0) || (j > C) || dMapWeights) { return -1; } else { return kap[j]; } }
00164 int getLtmRequired () { return C * M * 2 * sizeof (float ); }
00165 float &tauIj (int i, int j);
00166 float &tauJi (int i, int j);
00167 int getOutputType (const string &name);
00168 int getInt (const string &name);
00169 float getFloat (const string &name);
00170 string &getString (const string &name);
00171
00172 void requestOutput (const string &name, ofstream *ost);
00173 void closeStreams ();
00174
00175 void setNetworkType (RunModeType v) { NetworkType = v; }
00176 void setM (int v) { M = v; }
00177 void setL (int v) { L = v; }
00178 void setRhoBar (float v) { RhoBar = v; }
00179 void setRhoBarTest (float v) { RhoBarTest = v; }
00180 void setAlpha (float v) { Alpha = v; }
00181 void setBeta (float v) { Beta = v; }
00182 void setEps (float v) { Eps = v; }
00183 void setP (float v) { P = v; }
00184
00185 RunModeType getNetworkType() { return NetworkType; }
00186 int getM() { return M; }
00187 int getL() { return L; }
00188 float getRhoBar() { return RhoBar; }
00189 float getRhoBarTest() { return RhoBarTest; }
00190 float getAlpha() { return Alpha; }
00191 float getBeta() { return Beta; }
00192 float getEps() { return Eps; }
00193 float getP() { return P; }
00194 };
00195
00196
00197 #define foreach_i for (i = 0; i < _2M; i++)
00198 #define foreach_j for (j = 0; j < C; j++)
00199 #define foreach_k for (k = 0; k < L; k++)
00200
00201 #define forall_j for (j = 0; j < N; j++)
00202
00203 #endif
00204
00205