1、include "RadialBasisNetwork.h" #include "Exception.h" #include "Matrix.h" #include "File.h" using namespace std; namespace annie { /** Creates a Radial basis function network. All the outputs will have a bias. * @param inputs Number of inputs taken in by the network * @param cent
2、ers Number of centers the network has. Each center will be * an inputs-dimensional point * @param outputs The number of outputs given by the neuron. All of them will have * a bias */ RadialBasisNetwork::RadialBasisNetwork(int inputs, int centers, int outputs, real (*CenterArray)
3、[1024]) : Network(inputs,outputs) { int i,j; //extern real CenterArray[WORDNUM][inputs]; centroid = new real[inputs]; /// Layer of input. Each member is an InputNeuron //InputLayer *_inputLayer; _inputLayer = new InputLayer(0,inputs); /** Number of centers in the network.
4、 If you plan to extend this class, then the onus of keeping this value
* consistent lies on you
*/
_nCenters = centers;
/// Layer of centers, each member is a CenterNEuron
_centerLayer = new Layer(1);
for (i=0;i<_nCenters;i++)
{
for(j=0;j 5、enterArray[i][j];
}
//CenterNeuron *c = new CenterNeuron(Layer::MAX_LAYER_SIZE*1+i,inputs);
CenterNeuron *c = new CenterNeuron(Layer::MAX_LAYER_SIZE*1+i,inputs, centroid);
for (j=0;j 6、}
/// Layer of output, each member if a SimpleNeuron
_outputLayer = new Layer(2);
for (i=0;i 7、tNeuron(j));
_outputLayer->addNeuron(n);
}
}
/// Copy constructor, NOT YET IMPLEMENTED
RadialBasisNetwork::RadialBasisNetwork(RadialBasisNetwork &src) : Network(src)
{
int i,j,lbl;
int inputs = src._inputLayer->getSize();
int centers = src._centerLayer->getSize();
int outputs = 8、src._outputLayer->getSize();
_inputLayer = new InputLayer(src._inputLayer->getLabel(),src._inputLayer->getSize());
_nCenters = src._nCenters;
lbl = src._centerLayer->getLabel();
_centerLayer = new Layer(lbl);
for (i=0;i 9、YER_SIZE*lbl+i,inputs);
CenterNeuron &cSrc = (CenterNeuron&)src._centerLayer->getNeuron(i);
c->setCenter(cSrc.getCenter());
for (j=0;j 10、yer(lbl);
for (i=0;i 11、 Exception("RadialBasisNetwork::RadialBasisNetwork() - Copy constructor not fully implemented");
}
_outputLayer->addNeuron(n);
}
}
/** Loads a network from a text file
* @see save
* @param filename Name of the file from which to load network structure
* @throws Exception On a 12、ny error
*/
RadialBasisNetwork::RadialBasisNetwork(const char *filename) : Network(0,0)
{
File file;
int i,j;
try
{
file.open(filename);
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::" + getClassName() + "() - " + e.what();
throw Exception(err 13、or);
}
string s;
s=file.readWord();
if (pare(getClassName())!=0)
{
string error(getClassName());
error = error + "::" + getClassName() + "() - File supplied is not about this type of network.";
throw Exception(error);
}
int maxLayerSize = Layer::MAX_LAYER_SIZE;
while (!fil 14、e.eof())
{
s=file.readWord();
if (!pare("INPUTS"))
{
_nInputs=file.readInt();
_inputLayer = new InputLayer(0,_nInputs);
}
else if (!pare("OUTPUTS"))
{
_nOutputs=file.readInt();
_outputLayer = new Layer(2);
for (i=0;i 15、euron *n = new SimpleNeuron(maxLayerSize*2+i,true);
n->setActivationFunction(identity,didentity);
_outputLayer->addNeuron(n);
}
}
else if (!pare("CENTERS"))
{
_nCenters = file.readInt();
_centerLayer = new Layer(1);
for (i=0;i 16、rNeuron *n = new CenterNeuron(maxLayerSize*1+i,getInputCount());
_centerLayer->addNeuron(n);
}
}
else if (!pare("CENTER_POINTS"))
{
for (i=0;i 17、putCount();j++)
center.push_back(file.readDouble());
n.setCenter(center);
}
}
else if (!pare("MAX_LAYER_SIZE"))
maxLayerSize=file.readInt();
else if (!pare("Biases"))
{
for (i=0;i 18、SimpleNeuron&)_outputLayer->getNeuron(i);
if (file.readChar()=='t')
o.setBias(file.readDouble());
else
o.removeBias();
}
}
else if (!pare("BEGIN_META_DATA"))
{
static const basic_string 19、d("END_META_DATA");
string metaData;
s = file.readLine();
while (s.find(end,0)==npos)
{
metaData = metaData + s + "\n";
s = file.readLine();
}
if (metaData.length()>0)
metaData.erase(metaData.length()-1);
setMetaData(metaData);
}
else if (!pare("Conne 20、ctions"))
{
//Connect inputs to centers
for (i=0;i 21、OutputCount();i++)
{
SimpleNeuron &o = (SimpleNeuron&)_outputLayer->getNeuron(i);
for (j=0;j 22、ring.\n";
} // while (!file.eof())
file.close();
}
RadialBasisNetwork::~RadialBasisNetwork()
{
delete _inputLayer;
delete _centerLayer;
delete _outputLayer;
delete []centroid;
}
/** Returns the point corresponding to the ith center.
* @param i The center whose point is wante 23、d
* @return The getInputCount() dimensional point corresponding to the
* ith center
*/
VECTOR
RadialBasisNetwork::getCenter(int i)
{
VECTOR answer;
try
{
answer = ((CenterNeuron&)_centerLayer->getNeuron(i)).getCenter();
}
catch (Exception &e)
{
string error(getClassNa 24、me());
error = error + "::getCenter() - " + e.what();
throw Exception(error);
}
return answer;
}
//CenterNeuron&
//RadialBasisNetwork::getCenterNeuron(int i)
//{
// try
// {
// return (CenterNeuron&)(_centerLayer->getNeuron(i));
// }
// catch (Exception &e)
// {
// string er 25、ror(getClassName());
// error = error + "::getCenterNeuron() - " + e.what();
// throw Exception(error);
// }
//}
/** Returns the output of the network for the given input.
* @param input A vector of getDimension() reals
* @return The corresponding output of the network
*/
VECTOR
26、RadialBasisNetwork::getOutput(VECTOR &input)
{
try
{
_inputLayer->setInput(input);
return _outputLayer->getOutput();
}
catch(Exception e)
{
string error(getClassName());
error = error + "::getOutput() - "+e.what();
throw Exception(error);
}
}
/** Sets the ith center poi 27、nt to the given point.
* @param i The center that is to be changed
* @param center The getInputCount() dimensional point
*/
void
RadialBasisNetwork::setCenter(int i, VECTOR ¢er)
{
try
{
CenterNeuron &c = (CenterNeuron&)_centerLayer->getNeuron(i);
c.setCenter(center);
}
28、 catch (Exception &e)
{
string error(getClassName());
error = error + "::setCenter() - " + e.what();
throw Exception(e);
}
}
/** Sets the ith center point to the given point.
* @param i The center that is to be changed
* @param center The getInputCount() dimensional point
*/
29、
void
RadialBasisNetwork::setCenter(int i, real *center)
{
try
{
CenterNeuron &c = (CenterNeuron&)_centerLayer->getNeuron(i);
c.setCenter(center);
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::setCenter() - " + e.what();
throw Exception(e);
}
30、
}
/** Sets the weight between the given center and output
* @param center Index of the center (0<=center 31、ception if any of the parameters given is invalid
*/
void
RadialBasisNetwork::setWeight(int center, int output, real weight)
{
try
{
CenterNeuron &c = (CenterNeuron&)_centerLayer->getNeuron(center);
SimpleNeuron &o = (SimpleNeuron&)_outputLayer->getNeuron(output);
o.connect(&c,weigh 32、t);
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::setWeight() - " + e.what();
throw Exception(error);
}
}
/** Returns the weight of the link between the given center and output
* @param center Index of the center (0<=center 33、 @param output Index of the output (0<=output 34、CenterNeuron&)_centerLayer->getNeuron(center);
return o.getWeight(&c);
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::getWeight() - " + e.what();
throw Exception(error);
}
}
/** Sets the bias of the ith output.
* @param i The index of the output 35、0<=i 36、as);
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::setBias() - " + e.what();
throw Exception(e);
}
}
/** Returns the bias of the ith output
* @param i The index of the output (0<=i 37、is no bias, it returns 0.0
*/
real
RadialBasisNetwork::getBias(int i)
{
try
{
SimpleNeuron &n = (SimpleNeuron&)(_outputLayer->getNeuron(i));
return n.getBias();
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::setBias() - " + e.what();
throw Exce 38、ption(e);
}
}
/** Wrapper function to allow getOutput() to work for an array
* of real as input as well.
* Does exactly the same thing as Network::getOutput(real*).
*/
VECTOR
RadialBasisNetwork::getOutput(real *input)
{ return Network::getOutput(input); }
const char *
RadialB 39、asisNetwork::getClassName()
{ return "RadialBasisNetwork"; }
/// The number of centers in the network
int
RadialBasisNetwork::getCenterCount()
{ return _nCenters; }
/** Sets the activation function of the center neurons.
* (The activation function is gaussian by default)
* @param f 40、 The activation function to be used.
* @param df The derivation of the activation function, used in gradient descent training
*/
void
RadialBasisNetwork::setCenterActivationFunction(ActivationFunction1 f)
//RadialBasisNetwork::setCenterActivationFunction(ActivationFunction f,ActivationFuncti 41、on df)
{
int i;
for (i=0;i 42、putCount()).
* @throws Exception if the index given is invalid
*/
void
RadialBasisNetwork::removeBias(int i)
{
try
{
SimpleNeuron &o = (SimpleNeuron&)_outputLayer->getNeuron(i);
o.removeBias();
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::re 43、moveBias() - " + e.what();
throw Exception(error);
}
}
/** Trains the weights of the network, centers are kept fixed.
* @param T The TrainingSet from which input/desired-output pairs will be obtained
*/
void
RadialBasisNetwork::trainWeights(TrainingSet &T)
{
if (T.getInputSize() ! 44、 getInputCount())//getInputCount()继承于Network
{
string error(getClassName());
error = error + "::trainWeights() - Invalid TrainingSet provided.";
throw Exception(error);
}
int output;
int i,j;
int p = T.getSize(); //number of training patters
int h = getCenterCount(); //numbe 45、r of centers
VECTOR in,y;
//do for each output
for (output=0; output 46、s = outNrn.hasBias())
effectiveH++;
//setup matrices
Matrix *Y = new Matrix(p, 1);//存放p个模式下输出层各个节点的输出值
Matrix *W = NULL;//存放权值
Matrix *V = new Matrix(p, effectiveH);//存放p个模式各自隐节点输出值
Matrix *VT = NULL;
if (p!=effectiveH)
VT = new Matrix(effectiveH ,p);
extern int Hi 47、dden_num;
for (i=0;i elementAt(i,j) = _centerLayer->getNeuron(j).getOutput();//为第i,j个元素赋值
if (VT)
VT->elementAt(j,i) = V->elementAt(i,j);
}
if (hasBias)
{
48、 V->elementAt(i,j) = 1.0;
if (VT)
VT->elementAt(j,i) = 1.0;
}
Y->elementAt(i,0) = y[output];
} // for i=[0..p)
if (VT)
{
Matrix *VTVinv, *VTY;
try
{
Matrix *VTV;
VTV = VT->multiply(V);
VTVinv = VTV->inverse();
delete VTV;
}
catc 49、h (Exception &e)
{
string error(getClassName());
error = error + "::trainWeights() - " + e.what();
throw Exception(error);
}
VTY = VT->multiply(Y);
W = VTVinv->multiply(VTY);
delete VTVinv;
delete VTY;
} // if VT
else
{
Matrix *Vinv;
try
{
Vinv = V->inverse();
}
catch (Exception &e)
{
string error(getClassName());
error = error + "::trainWeights() - " + e.what();
throw Exception(error);
}
W = Vinv->multiply(Y);
delete Vinv;
}
//set the






