OTB  9.0.0
Orfeo Toolbox
otbTrainNeuralNetwork.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainNeuralNetwork_hxx
22 #define otbTrainNeuralNetwork_hxx
23 #include <boost/lexical_cast.hpp>
26 
27 namespace otb
28 {
29 namespace Wrapper
30 {
31 
32 template <class TInputValue, class TOutputValue>
33 void LearningApplicationBase<TInputValue, TOutputValue>::InitNeuralNetworkParams()
34 {
35  AddChoice("classifier.ann", "Artificial Neural Network classifier");
36  SetParameterDescription("classifier.ann", "http://docs.opencv.org/modules/ml/doc/neural_networks.html");
37 
38  // TrainMethod
39  AddParameter(ParameterType_Choice, "classifier.ann.t", "Train Method Type");
40  AddChoice("classifier.ann.t.back", "Back-propagation algorithm");
41  SetParameterDescription("classifier.ann.t.back",
42  "Method to compute the gradient of the loss function and adjust weights "
43  "in the network to optimize the result.");
44  AddChoice("classifier.ann.t.reg", "Resilient Back-propagation algorithm");
45  SetParameterDescription("classifier.ann.t.reg",
46  "Almost the same as the Back-prop algorithm except that it does not "
47  "take into account the magnitude of the partial derivative (coordinate "
48  "of the gradient) but only its sign.");
49 
50  SetParameterString("classifier.ann.t", "reg");
51  SetParameterDescription("classifier.ann.t", "Type of training method for the multilayer perceptron (MLP) neural network.");
52 
53  // LayerSizes
54  // There is no ParameterType_IntList, so i use a ParameterType_StringList and convert it.
55  /*std::vector<std::string> layerSizes;
56  layerSizes.push_back("100");
57  layerSizes.push_back("100"); */
58  AddParameter(ParameterType_StringList, "classifier.ann.sizes", "Number of neurons in each intermediate layer");
59  // SetParameterStringList("classifier.ann.sizes", layerSizes);
60  SetParameterDescription("classifier.ann.sizes", "The number of neurons in each intermediate layer (excluding input and output layers).");
61 
62  // ActivateFunction
63  AddParameter(ParameterType_Choice, "classifier.ann.f", "Neuron activation function type");
64  AddChoice("classifier.ann.f.ident", "Identity function");
65  AddChoice("classifier.ann.f.sig", "Symmetrical Sigmoid function");
66  AddChoice("classifier.ann.f.gau", "Gaussian function (Not completely supported)");
67  SetParameterString("classifier.ann.f", "sig");
68  SetParameterDescription("classifier.ann.f",
69  "This function determine whether the output of the node is positive or not "
70  "depending on the output of the transfer function.");
71 
72  // Alpha
73  AddParameter(ParameterType_Float, "classifier.ann.a", "Alpha parameter of the activation function");
74  SetParameterFloat("classifier.ann.a", 1.);
75  SetParameterDescription("classifier.ann.a", "Alpha parameter of the activation function (used only with sigmoid and gaussian functions).");
76 
77  // Beta
78  AddParameter(ParameterType_Float, "classifier.ann.b", "Beta parameter of the activation function");
79  SetParameterFloat("classifier.ann.b", 1.);
80  SetParameterDescription("classifier.ann.b", "Beta parameter of the activation function (used only with sigmoid and gaussian functions).");
81 
82  // BackPropDWScale
83  AddParameter(ParameterType_Float, "classifier.ann.bpdw", "Strength of the weight gradient term in the BACKPROP method");
84  SetParameterFloat("classifier.ann.bpdw", 0.1);
85  SetParameterDescription("classifier.ann.bpdw",
86  "Strength of the weight gradient term in the BACKPROP method. The "
87  "recommended value is about 0.1.");
88 
89  // BackPropMomentScale
90  AddParameter(ParameterType_Float, "classifier.ann.bpms", "Strength of the momentum term (the difference between weights on the 2 previous iterations)");
91  SetParameterFloat("classifier.ann.bpms", 0.1);
92  SetParameterDescription("classifier.ann.bpms",
93  "Strength of the momentum term (the difference between weights on the 2 previous "
94  "iterations). This parameter provides some inertia to smooth the random "
95  "fluctuations of the weights. It can vary from 0 (the feature is disabled) "
96  "to 1 and beyond. The value 0.1 or so is good enough.");
97 
98  // RegPropDW0
99  AddParameter(ParameterType_Float, "classifier.ann.rdw", "Initial value Delta_0 of update-values Delta_{ij} in RPROP method");
100  SetParameterFloat("classifier.ann.rdw", 0.1);
101  SetParameterDescription("classifier.ann.rdw", "Initial value Delta_0 of update-values Delta_{ij} in RPROP method (default = 0.1).");
102 
103  // RegPropDWMin
104  AddParameter(ParameterType_Float, "classifier.ann.rdwm", "Update-values lower limit Delta_{min} in RPROP method");
105  SetParameterFloat("classifier.ann.rdwm", 1e-7);
106  SetParameterDescription("classifier.ann.rdwm",
107  "Update-values lower limit Delta_{min} in RPROP method. It must be positive "
108  "(default = 1e-7).");
109 
110  // TermCriteriaType
111  AddParameter(ParameterType_Choice, "classifier.ann.term", "Termination criteria");
112  AddChoice("classifier.ann.term.iter", "Maximum number of iterations");
113  SetParameterDescription("classifier.ann.term.iter",
114  "Set the number of iterations allowed to the network for its "
115  "training. Training will stop regardless of the result when this "
116  "number is reached");
117  AddChoice("classifier.ann.term.eps", "Epsilon");
118  SetParameterDescription("classifier.ann.term.eps",
119  "Training will focus on result and will stop once the precision is"
120  "at most epsilon");
121  AddChoice("classifier.ann.term.all", "Max. iterations + Epsilon");
122  SetParameterDescription("classifier.ann.term.all", "Both termination criteria are used. Training stop at the first reached");
123  SetParameterString("classifier.ann.term", "all");
124  SetParameterDescription("classifier.ann.term", "Termination criteria.");
125 
126  // Epsilon
127  AddParameter(ParameterType_Float, "classifier.ann.eps", "Epsilon value used in the Termination criteria");
128  SetParameterFloat("classifier.ann.eps", 0.01);
129  SetParameterDescription("classifier.ann.eps", "Epsilon value used in the Termination criteria.");
130 
131  // MaxIter
132  AddParameter(ParameterType_Int, "classifier.ann.iter", "Maximum number of iterations used in the Termination criteria");
133  SetParameterInt("classifier.ann.iter", 1000);
134  SetParameterDescription("classifier.ann.iter", "Maximum number of iterations used in the Termination criteria.");
135 }
136 
137 template <class TInputValue, class TOutputValue>
138 void LearningApplicationBase<TInputValue, TOutputValue>::TrainNeuralNetwork(typename ListSampleType::Pointer trainingListSample,
139  typename TargetListSampleType::Pointer trainingLabeledListSample,
140  std::string modelPath)
141 {
143  typename NeuralNetworkType::Pointer classifier = NeuralNetworkType::New();
144  classifier->SetRegressionMode(this->m_RegressionFlag);
145  classifier->SetInputListSample(trainingListSample);
146  classifier->SetTargetListSample(trainingLabeledListSample);
147 
148  switch (GetParameterInt("classifier.ann.t"))
149  {
150  case 0: // BACKPROP
151  classifier->SetTrainMethod(CvANN_MLP_TrainParams::BACKPROP);
152  break;
153  case 1: // RPROP
154  classifier->SetTrainMethod(CvANN_MLP_TrainParams::RPROP);
155  break;
156  default: // DEFAULT = RPROP
157  classifier->SetTrainMethod(CvANN_MLP_TrainParams::RPROP);
158  break;
159  }
160 
161  std::vector<unsigned int> layerSizes;
162  std::vector<std::string> sizes = GetParameterStringList("classifier.ann.sizes");
163 
164 
165  unsigned int nbImageBands = trainingListSample->GetMeasurementVectorSize();
166  layerSizes.push_back(nbImageBands);
167  for (unsigned int i = 0; i < sizes.size(); i++)
168  {
169  unsigned int nbNeurons = boost::lexical_cast<unsigned int>(sizes[i]);
170  layerSizes.push_back(nbNeurons);
171  }
172 
173 
174  unsigned int nbClasses = 0;
175  if (this->m_RegressionFlag)
176  {
177  layerSizes.push_back(1);
178  }
179  else
180  {
181  std::set<TargetValueType> labelSet;
182  TargetSampleType currentLabel;
183  for (unsigned int itLab = 0; itLab < trainingLabeledListSample->Size(); ++itLab)
184  {
185  currentLabel = trainingLabeledListSample->GetMeasurementVector(itLab);
186  labelSet.insert(currentLabel[0]);
187  }
188  nbClasses = labelSet.size();
189  layerSizes.push_back(nbClasses);
190  }
191 
192  classifier->SetLayerSizes(layerSizes);
193 
194  switch (GetParameterInt("classifier.ann.f"))
195  {
196  case 0: // ident
197  classifier->SetActivateFunction(CvANN_MLP::IDENTITY);
198  break;
199  case 1: // sig
200  classifier->SetActivateFunction(CvANN_MLP::SIGMOID_SYM);
201  break;
202  case 2: // gaussian
203  classifier->SetActivateFunction(CvANN_MLP::GAUSSIAN);
204  break;
205  default: // DEFAULT = RPROP
206  classifier->SetActivateFunction(CvANN_MLP::SIGMOID_SYM);
207  break;
208  }
209 
210  classifier->SetAlpha(GetParameterFloat("classifier.ann.a"));
211  classifier->SetBeta(GetParameterFloat("classifier.ann.b"));
212  classifier->SetBackPropDWScale(GetParameterFloat("classifier.ann.bpdw"));
213  classifier->SetBackPropMomentScale(GetParameterFloat("classifier.ann.bpms"));
214  classifier->SetRegPropDW0(GetParameterFloat("classifier.ann.rdw"));
215  classifier->SetRegPropDWMin(GetParameterFloat("classifier.ann.rdwm"));
216 
217  switch (GetParameterInt("classifier.ann.term"))
218  {
219  case 0: // CV_TERMCRIT_ITER
220  classifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
221  break;
222  case 1: // CV_TERMCRIT_EPS
223  classifier->SetTermCriteriaType(CV_TERMCRIT_EPS);
224  break;
225  case 2: // CV_TERMCRIT_ITER + CV_TERMCRIT_EPS
226  classifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
227  break;
228  default: // DEFAULT = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS
229  classifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
230  break;
231  }
232  classifier->SetEpsilon(GetParameterFloat("classifier.ann.eps"));
233  classifier->SetMaxIter(GetParameterInt("classifier.ann.iter"));
234  classifier->Train();
235  classifier->Save(modelPath);
236 }
237 
238 } // end namespace wrapper
239 } // end namespace otb
240 
241 #endif
otb::Wrapper::ParameterType_Choice
@ ParameterType_Choice
Definition: otbWrapperTypes.h:47
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otbNeuralNetworkMachineLearningModel.h
otb::NeuralNetworkMachineLearningModel
Definition: otbNeuralNetworkMachineLearningModel.h:34
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::Wrapper::ParameterType_Float
@ ParameterType_Float
Definition: otbWrapperTypes.h:39
otb::Wrapper::ParameterType_StringList
@ ParameterType_StringList
Definition: otbWrapperTypes.h:42