00001
00002
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00031
00032 #ifndef RCSC_ANN_NGNET_H
00033 #define RCSC_ANN_NGNET_H
00034
00035 #include <boost/array.hpp>
00036
00037 #include <vector>
00038 #include <iostream>
00039 #include <cmath>
00040
00041 namespace rcsc {
00042
00044
00049 class NGNet {
00050 public:
00051
00052 enum {
00053 INPUT = 2,
00054 };
00055
00056 enum {
00057 OUTPUT = 2,
00058 };
00059
00061 typedef boost::array< double, INPUT > input_vector;
00063 typedef boost::array< double, OUTPUT > output_vector;
00064
00069 struct Unit {
00070 input_vector center_;
00071
00072 output_vector weights_;
00073 output_vector delta_weights_;
00074
00075 double sigma_;
00076 double delta_sigma_;
00077
00081 Unit();
00082
00089 void randomize( const double & min_weight,
00090 const double & max_weight,
00091 const double & initial_sigma );
00092
00098 double dist2( const input_vector & input ) const
00099 {
00100 double dist2 = 0.0;
00101 for ( std::size_t i = 0; i < INPUT; ++i )
00102 {
00103 dist2 += std::pow( center_[i] - input[i], 2.0 );
00104 }
00105 return dist2;
00106 }
00107
00113 double calc( const input_vector & input ) const
00114 {
00115 return std::exp( - dist2( input ) / ( 2.0 * sigma_ * sigma_ ) );
00116 }
00117 };
00118
00119 private:
00120
00121 double M_eta;
00122 double M_alpha;
00123 double M_min_weight;
00124 double M_max_weight;
00125 double M_initial_sigma;
00126
00127 std::vector< Unit > M_units;
00128
00129 public:
00130
00134 NGNet();
00135
00141 void setLearningRate( const double & eta,
00142 const double & alpha )
00143 {
00144 M_eta = eta;
00145 M_alpha = alpha;
00146 }
00147
00153 void setWeightRange( const double & min_weight,
00154 const double & max_weight )
00155 {
00156 M_min_weight = min_weight;
00157 M_max_weight = max_weight;
00158 }
00159
00164 void setInitialSigma( const double & initial_sigma )
00165 {
00166 M_initial_sigma = initial_sigma;
00167 }
00168
00173 const
00174 std::vector< Unit > & units() const
00175 {
00176 return M_units;
00177 }
00178
00183 void addCenter( const input_vector & center );
00184
00190 void propagate( const input_vector & input,
00191 output_vector & output ) const;
00192
00199 double train( const input_vector & input,
00200 const output_vector & teacher );
00201
00207 bool read( std::istream & is );
00208
00214 std::ostream & print( std::ostream & os ) const;
00215
00221 std::ostream & printUnits( std::ostream & os ) const;
00222
00223 };
00224
00225 }
00226
00227 #endif