21 #ifndef otbLearningApplicationBase_hxx
22 #define otbLearningApplicationBase_hxx
33 template <
class TInputValue,
class TOutputValue>
39 template <
class TInputValue,
class TOutputValue>
42 ModelFactoryType::CleanFactories();
45 template <
class TInputValue,
class TOutputValue>
52 SetParameterDescription(
"classifier",
"Choice of the classifier to use for the training.");
54 InitSupervisedClassifierParams();
55 m_SupervisedClassifier = GetChoiceKeys(
"classifier");
57 InitUnsupervisedClassifierParams();
58 std::vector<std::string> allClassifier = GetChoiceKeys(
"classifier");
60 if (allClassifier.size() > m_UnsupervisedClassifier.size())
61 m_UnsupervisedClassifier.assign(allClassifier.begin() + m_SupervisedClassifier.size(), allClassifier.end());
64 template <
class TInputValue,
class TOutputValue>
67 if (m_UnsupervisedClassifier.empty())
73 bool foundUnsupervised =
74 std::find(m_UnsupervisedClassifier.begin(), m_UnsupervisedClassifier.end(), GetParameterString(
"classifier")) != m_UnsupervisedClassifier.
end();
75 return foundUnsupervised ? Unsupervised : Supervised;
79 template <
class TInputValue,
class TOutputValue>
92 if (!m_RegressionFlag)
96 InitDecisionTreeParams();
97 InitNeuralNetworkParams();
98 if (!m_RegressionFlag)
100 InitNormalBayesParams();
102 InitRandomForestsParams();
107 InitSharkRandomForestsParams();
111 template <
class TInputValue,
class TOutputValue>
115 if (!m_RegressionFlag)
117 InitSharkKMeansParams();
122 template <
class TInputValue,
class TOutputValue>
128 dummyFilter->SetProgress(0.0f);
129 this->AddProcess(dummyFilter,
"Validation...");
130 dummyFilter->InvokeEvent(itk::StartEvent());
133 ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath, ModelFactoryType::ReadMode);
140 model->Load(modelPath);
141 model->SetRegressionMode(this->m_RegressionFlag);
143 typename TargetListSampleType::Pointer predictedList = model->PredictBatch(validationListSample, NULL);
146 dummyFilter->UpdateProgress(1.0f);
147 dummyFilter->InvokeEvent(itk::EndEvent());
149 return predictedList;
152 template <
class TInputValue,
class TOutputValue>
154 typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
159 dummyFilter->SetProgress(0.0f);
160 this->AddProcess(dummyFilter,
"Training model...");
161 dummyFilter->InvokeEvent(itk::StartEvent());
164 const std::string modelName = GetParameterString(
"classifier");
166 if (modelName ==
"libsvm")
168 #ifdef OTB_USE_LIBSVM
169 TrainLibSVM(trainingListSample, trainingLabeledListSample, modelPath);
171 otbAppLogFATAL(
"Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration.");
174 if (modelName ==
"sharkrf")
177 TrainSharkRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
179 otbAppLogFATAL(
"Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
182 else if (modelName ==
"sharkkm")
185 TrainSharkKMeans(trainingListSample, trainingLabeledListSample, modelPath);
187 otbAppLogFATAL(
"Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
190 else if (modelName ==
"svm")
192 #ifdef OTB_USE_OPENCV
193 TrainSVM(trainingListSample, trainingLabeledListSample, modelPath);
195 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
198 else if (modelName ==
"boost")
200 #ifdef OTB_USE_OPENCV
201 TrainBoost(trainingListSample, trainingLabeledListSample, modelPath);
203 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
206 else if (modelName ==
"dt")
208 #ifdef OTB_USE_OPENCV
209 TrainDecisionTree(trainingListSample, trainingLabeledListSample, modelPath);
211 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
214 else if (modelName ==
"ann")
216 #ifdef OTB_USE_OPENCV
217 TrainNeuralNetwork(trainingListSample, trainingLabeledListSample, modelPath);
219 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
222 else if (modelName ==
"bayes")
224 #ifdef OTB_USE_OPENCV
225 TrainNormalBayes(trainingListSample, trainingLabeledListSample, modelPath);
227 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
230 else if (modelName ==
"rf")
232 #ifdef OTB_USE_OPENCV
233 TrainRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
235 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
238 else if (modelName ==
"knn")
240 #ifdef OTB_USE_OPENCV
241 TrainKNN(trainingListSample, trainingLabeledListSample, modelPath);
243 otbAppLogFATAL(
"Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
248 dummyFilter->UpdateProgress(1.0f);
249 dummyFilter->InvokeEvent(itk::EndEvent());