rbf.h

説明を見る。
00001 // -*-c++-*-
00002 
00008 /*
00009  *Copyright:
00010 
00011  Copyright (C) Hidehisa AKIYAMA
00012 
00013  This code is free software; you can redistribute it and/or
00014  modify it under the terms of the GNU Lesser General Public
00015  License as published by the Free Software Foundation; either
00016  version 2.1 of the License, or (at your option) any later version.
00017 
00018  This library is distributed in the hope that it will be useful,
00019  but WITHOUT ANY WARRANTY; without even the implied warranty of
00020  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00021  Lesser General Public License for more details.
00022 
00023  You should have received a copy of the GNU Lesser General Public
00024  License along with this library; if not, write to the Free Software
00025  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
00026 
00027  *EndCopyright:
00028  */
00029 
00031 
00032 #ifndef RCSC_ANN_RBF_H
00033 #define RCSC_ANN_RBF_H
00034 
00035 #include <boost/array.hpp>
00036 
00037 #include <vector>
00038 #include <iostream>
00039 #include <cmath>
00040 
00041 namespace rcsc {
00042 
00044 
00049 class RBFNetwork {
00050 public:
00051 
00053     typedef std::vector< double > input_vector;
00055     typedef std::vector< double > output_vector;
00056 
00061     struct Unit {
00062         input_vector center_;
00063 
00064         output_vector weights_; 
00065         output_vector delta_weights_;
00066 
00067         double sigma_; 
00068         double delta_sigma_;
00069 
00070     private:
00071         // not used
00072         Unit();
00073 
00074     public:
00080         Unit( std::size_t input_dim,
00081               std::size_t output_dim );
00082 
00083         void randomize( const double & min_weight,
00084                         const double & max_weight,
00085                         const double & initial_sigma );
00086 
00087         double dist2( const input_vector & input ) const
00088           {
00089               const std::size_t INPUT = input.size();
00090               if ( INPUT != center_.size() )
00091               {
00092                   return 0.0;
00093               }
00094 
00095               double d2 = 0.0;
00096               for ( std::size_t i = 0; i < INPUT; ++i )
00097               {
00098                   d2 += std::pow( center_[i] - input[i], 2.0 );
00099               }
00100               return d2;
00101           }
00102 
00103         double dist( const input_vector & input ) const
00104           {
00105               return std::sqrt( dist2( input ) );
00106           }
00107 
00108         double calc( const input_vector & input ) const
00109           {
00110               return std::exp( - dist2( input ) / ( 2.0 * sigma_ * sigma_ ) );
00111           }
00112     };
00113 
00114 private:
00115 
00116     const std::size_t M_input_dim; 
00117     const std::size_t M_output_dim; 
00118 
00119     double M_eta;
00120     double M_alpha;
00121     double M_min_weight;
00122     double M_max_weight;
00123     double M_initial_sigma;
00124 
00125     std::vector< Unit > M_units;
00126 
00127     // not used
00128     RBFNetwork();
00129 
00130 public:
00131 
00137     RBFNetwork( const std::size_t input_dim,
00138                 const std::size_t output_dim );
00139 
00140     void setLearningRate( const double & eta,
00141                           const double & alpha )
00142       {
00143           M_eta = eta;
00144           M_alpha = alpha;
00145       }
00146 
00147     void setWeightRange( const double & min_weight,
00148                          const double & max_weight )
00149       {
00150           M_min_weight = min_weight;
00151           M_max_weight = max_weight;
00152       }
00153 
00154     void setInitialSigma( const double & initial_sigma )
00155       {
00156           M_initial_sigma = initial_sigma;
00157       }
00158 
00159     const
00160     std::vector< Unit > & units() const
00161       {
00162           return M_units;
00163       }
00164 
00165 
00166     void addCenter( const input_vector & center );
00167 
00168     void propagate( const input_vector & input,
00169                     output_vector & output ) const;
00170 
00171     double train( const input_vector & input,
00172                   const output_vector & teacher );
00173 
00174     bool read( std::istream & is );
00175 
00176     std::ostream & print( std::ostream & os ) const;
00177 
00178     std::ostream & printUnits( std::ostream & os ) const;
00179 
00180 };
00181 
00182 }
00183 
00184 #endif

librcscに対してThu May 1 15:41:20 2008に生成されました。  doxygen 1.5.0