21 #ifndef otbTrainNeuralNetwork_hxx
22 #define otbTrainNeuralNetwork_hxx
23 #include <boost/lexical_cast.hpp>
32 template <
class TInputValue,
class TOutputValue>
33 void LearningApplicationBase<TInputValue, TOutputValue>::InitNeuralNetworkParams()
35 AddChoice(
"classifier.ann",
"Artificial Neural Network classifier");
36 SetParameterDescription(
"classifier.ann",
"http://docs.opencv.org/modules/ml/doc/neural_networks.html");
40 AddChoice(
"classifier.ann.t.back",
"Back-propagation algorithm");
41 SetParameterDescription(
"classifier.ann.t.back",
42 "Method to compute the gradient of the loss function and adjust weights "
43 "in the network to optimize the result.");
44 AddChoice(
"classifier.ann.t.reg",
"Resilient Back-propagation algorithm");
45 SetParameterDescription(
"classifier.ann.t.reg",
46 "Almost the same as the Back-prop algorithm except that it does not "
47 "take into account the magnitude of the partial derivative (coordinate "
48 "of the gradient) but only its sign.");
50 SetParameterString(
"classifier.ann.t",
"reg");
51 SetParameterDescription(
"classifier.ann.t",
"Type of training method for the multilayer perceptron (MLP) neural network.");
60 SetParameterDescription(
"classifier.ann.sizes",
"The number of neurons in each intermediate layer (excluding input and output layers).");
64 AddChoice(
"classifier.ann.f.ident",
"Identity function");
65 AddChoice(
"classifier.ann.f.sig",
"Symmetrical Sigmoid function");
66 AddChoice(
"classifier.ann.f.gau",
"Gaussian function (Not completely supported)");
67 SetParameterString(
"classifier.ann.f",
"sig");
68 SetParameterDescription(
"classifier.ann.f",
69 "This function determine whether the output of the node is positive or not "
70 "depending on the output of the transfer function.");
73 AddParameter(
ParameterType_Float,
"classifier.ann.a",
"Alpha parameter of the activation function");
74 SetParameterFloat(
"classifier.ann.a", 1.);
75 SetParameterDescription(
"classifier.ann.a",
"Alpha parameter of the activation function (used only with sigmoid and gaussian functions).");
78 AddParameter(
ParameterType_Float,
"classifier.ann.b",
"Beta parameter of the activation function");
79 SetParameterFloat(
"classifier.ann.b", 1.);
80 SetParameterDescription(
"classifier.ann.b",
"Beta parameter of the activation function (used only with sigmoid and gaussian functions).");
83 AddParameter(
ParameterType_Float,
"classifier.ann.bpdw",
"Strength of the weight gradient term in the BACKPROP method");
84 SetParameterFloat(
"classifier.ann.bpdw", 0.1);
85 SetParameterDescription(
"classifier.ann.bpdw",
86 "Strength of the weight gradient term in the BACKPROP method. The "
87 "recommended value is about 0.1.");
90 AddParameter(
ParameterType_Float,
"classifier.ann.bpms",
"Strength of the momentum term (the difference between weights on the 2 previous iterations)");
91 SetParameterFloat(
"classifier.ann.bpms", 0.1);
92 SetParameterDescription(
"classifier.ann.bpms",
93 "Strength of the momentum term (the difference between weights on the 2 previous "
94 "iterations). This parameter provides some inertia to smooth the random "
95 "fluctuations of the weights. It can vary from 0 (the feature is disabled) "
96 "to 1 and beyond. The value 0.1 or so is good enough.");
99 AddParameter(
ParameterType_Float,
"classifier.ann.rdw",
"Initial value Delta_0 of update-values Delta_{ij} in RPROP method");
100 SetParameterFloat(
"classifier.ann.rdw", 0.1);
101 SetParameterDescription(
"classifier.ann.rdw",
"Initial value Delta_0 of update-values Delta_{ij} in RPROP method (default = 0.1).");
104 AddParameter(
ParameterType_Float,
"classifier.ann.rdwm",
"Update-values lower limit Delta_{min} in RPROP method");
105 SetParameterFloat(
"classifier.ann.rdwm", 1e-7);
106 SetParameterDescription(
"classifier.ann.rdwm",
107 "Update-values lower limit Delta_{min} in RPROP method. It must be positive "
108 "(default = 1e-7).");
112 AddChoice(
"classifier.ann.term.iter",
"Maximum number of iterations");
113 SetParameterDescription(
"classifier.ann.term.iter",
114 "Set the number of iterations allowed to the network for its "
115 "training. Training will stop regardless of the result when this "
116 "number is reached");
117 AddChoice(
"classifier.ann.term.eps",
"Epsilon");
118 SetParameterDescription(
"classifier.ann.term.eps",
119 "Training will focus on result and will stop once the precision is"
121 AddChoice(
"classifier.ann.term.all",
"Max. iterations + Epsilon");
122 SetParameterDescription(
"classifier.ann.term.all",
"Both termination criteria are used. Training stop at the first reached");
123 SetParameterString(
"classifier.ann.term",
"all");
124 SetParameterDescription(
"classifier.ann.term",
"Termination criteria.");
127 AddParameter(
ParameterType_Float,
"classifier.ann.eps",
"Epsilon value used in the Termination criteria");
128 SetParameterFloat(
"classifier.ann.eps", 0.01);
129 SetParameterDescription(
"classifier.ann.eps",
"Epsilon value used in the Termination criteria.");
132 AddParameter(
ParameterType_Int,
"classifier.ann.iter",
"Maximum number of iterations used in the Termination criteria");
133 SetParameterInt(
"classifier.ann.iter", 1000);
134 SetParameterDescription(
"classifier.ann.iter",
"Maximum number of iterations used in the Termination criteria.");
137 template <
class TInputValue,
class TOutputValue>
138 void LearningApplicationBase<TInputValue, TOutputValue>::TrainNeuralNetwork(
typename ListSampleType::Pointer trainingListSample,
139 typename TargetListSampleType::Pointer trainingLabeledListSample,
140 std::string modelPath)
143 typename NeuralNetworkType::Pointer classifier = NeuralNetworkType::New();
144 classifier->SetRegressionMode(this->m_RegressionFlag);
145 classifier->SetInputListSample(trainingListSample);
146 classifier->SetTargetListSample(trainingLabeledListSample);
148 switch (GetParameterInt(
"classifier.ann.t"))
151 classifier->SetTrainMethod(CvANN_MLP_TrainParams::BACKPROP);
154 classifier->SetTrainMethod(CvANN_MLP_TrainParams::RPROP);
157 classifier->SetTrainMethod(CvANN_MLP_TrainParams::RPROP);
161 std::vector<unsigned int> layerSizes;
162 std::vector<std::string> sizes = GetParameterStringList(
"classifier.ann.sizes");
165 unsigned int nbImageBands = trainingListSample->GetMeasurementVectorSize();
166 layerSizes.push_back(nbImageBands);
167 for (
unsigned int i = 0; i < sizes.size(); i++)
169 unsigned int nbNeurons = boost::lexical_cast<unsigned int>(sizes[i]);
170 layerSizes.push_back(nbNeurons);
174 unsigned int nbClasses = 0;
175 if (this->m_RegressionFlag)
177 layerSizes.push_back(1);
181 std::set<TargetValueType> labelSet;
182 TargetSampleType currentLabel;
183 for (
unsigned int itLab = 0; itLab < trainingLabeledListSample->Size(); ++itLab)
185 currentLabel = trainingLabeledListSample->GetMeasurementVector(itLab);
186 labelSet.insert(currentLabel[0]);
188 nbClasses = labelSet.size();
189 layerSizes.push_back(nbClasses);
192 classifier->SetLayerSizes(layerSizes);
194 switch (GetParameterInt(
"classifier.ann.f"))
197 classifier->SetActivateFunction(CvANN_MLP::IDENTITY);
200 classifier->SetActivateFunction(CvANN_MLP::SIGMOID_SYM);
203 classifier->SetActivateFunction(CvANN_MLP::GAUSSIAN);
206 classifier->SetActivateFunction(CvANN_MLP::SIGMOID_SYM);
210 classifier->SetAlpha(GetParameterFloat(
"classifier.ann.a"));
211 classifier->SetBeta(GetParameterFloat(
"classifier.ann.b"));
212 classifier->SetBackPropDWScale(GetParameterFloat(
"classifier.ann.bpdw"));
213 classifier->SetBackPropMomentScale(GetParameterFloat(
"classifier.ann.bpms"));
214 classifier->SetRegPropDW0(GetParameterFloat(
"classifier.ann.rdw"));
215 classifier->SetRegPropDWMin(GetParameterFloat(
"classifier.ann.rdwm"));
217 switch (GetParameterInt(
"classifier.ann.term"))
220 classifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
223 classifier->SetTermCriteriaType(CV_TERMCRIT_EPS);
226 classifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
229 classifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
232 classifier->SetEpsilon(GetParameterFloat(
"classifier.ann.eps"));
233 classifier->SetMaxIter(GetParameterInt(
"classifier.ann.iter"));
235 classifier->Save(modelPath);