21 #ifndef otbKNearestNeighborsMachineLearningModel_hxx
22 #define otbKNearestNeighborsMachineLearningModel_hxx
35 template <
class TInputValue,
class TTargetValue>
38 m_KNearestModel(cv::ml::KNearest::create()),
40 m_DecisionRule(KNN_VOTING)
47 template <
class TInputValue,
class TTargetValue>
52 otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
56 otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
59 if (this->m_RegressionMode)
61 if (this->m_DecisionRule == KNN_VOTING)
63 this->SetDecisionRule(KNN_MEAN);
68 if (this->m_DecisionRule != KNN_VOTING)
70 this->SetDecisionRule(KNN_VOTING);
74 m_KNearestModel->setDefaultK(m_K);
76 m_KNearestModel->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE);
77 m_KNearestModel->setIsClassifier(!this->m_RegressionMode);
79 m_KNearestModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels));
82 template <
class TInputValue,
class TTargetValue>
91 otb::SampleToMat<InputSampleType>(input, sample);
94 cv::Mat nearest(1, m_K, CV_32FC1);
95 result = m_KNearestModel->findNearest(sample, m_K, cv::noArray(), nearest, cv::noArray());
98 if (quality !=
nullptr)
100 assert(!this->m_RegressionMode);
101 unsigned int accuracy = 0;
102 for (
int k = 0; k < m_K; ++k)
104 if (nearest.at<
float>(0, k) == result)
111 if (proba !=
nullptr && !this->m_ProbaIndex)
112 itkExceptionMacro(
"Probability per class not available for this classifier !");
118 if (this->m_DecisionRule == KNN_MEDIAN)
120 std::multiset<float> values;
121 for (
int k = 0; k < m_K; ++k)
123 values.insert(nearest.at<
float>(0, k));
125 std::multiset<float>::iterator
median = values.begin();
126 int pos = (m_K >> 1);
127 for (
int k = 0; k < pos; ++k, ++
median)
133 target[0] =
static_cast<TTargetValue
>(result);
137 template <
class TInputValue,
class TTargetValue>
140 cv::FileStorage fs(filename, cv::FileStorage::WRITE);
141 fs << (name.empty() ? m_KNearestModel->getDefaultName() : cv::String(name)) <<
"{";
142 m_KNearestModel->write(fs);
143 fs <<
"DecisionRule" << m_DecisionRule;
148 template <
class TInputValue,
class TTargetValue>
151 std::ifstream ifs(filename);
154 itkExceptionMacro(<<
"Could not read file " << filename);
157 bool isKNNv3 =
false;
161 std::getline(ifs, line);
162 if (line.find(m_KNearestModel->getDefaultName()) != std::string::npos)
171 cv::FileStorage fs(filename, cv::FileStorage::READ);
172 m_KNearestModel->read(fs.getFirstTopLevelNode());
173 m_DecisionRule = (int)(fs.getFirstTopLevelNode()[
"DecisionRule"]);
174 m_K = m_KNearestModel->getDefaultK();
182 std::getline(ifs, line);
183 std::istringstream iss(line);
184 if (line.find(
"K") == std::string::npos)
186 itkExceptionMacro(<<
"Could not read file " << filename);
188 std::string::size_type pos = line.find_first_of(
"=", 0);
189 std::string::size_type nextpos = line.find_first_of(
" \n\r", pos + 1);
190 this->SetK(boost::lexical_cast<int>(line.substr(pos + 1, nextpos - pos - 1)));
193 std::getline(ifs, line);
194 if (line.find(
"IsRegression") == std::string::npos)
196 itkExceptionMacro(<<
"Could not read file " << filename);
198 pos = line.find_first_of(
"=", 0);
199 nextpos = line.find_first_of(
" \n\r", pos + 1);
200 this->SetRegressionMode(boost::lexical_cast<bool>(line.substr(pos + 1, nextpos - pos - 1)));
202 if (this->m_RegressionMode)
204 std::getline(ifs, line);
205 pos = line.find_first_of(
"=", 0);
206 nextpos = line.find_first_of(
" \n\r", pos + 1);
207 this->SetDecisionRule(boost::lexical_cast<int>(line.substr(pos + 1, nextpos - pos - 1)));
210 typename InputListSampleType::Pointer samples = InputListSampleType::New();
211 typename TargetListSampleType::Pointer labels = TargetListSampleType::New();
214 unsigned int nbFeatures = 0;
217 std::getline(ifs, line);
221 nbFeatures = std::count(line.begin(), line.end(),
' ');
227 pos = line.find_first_of(
" ", 0);
229 label[0] =
static_cast<TargetValueType>(boost::lexical_cast<unsigned int>(line.substr(0, pos)));
234 nextpos = line.find_first_of(
" ", pos + 1);
235 while (nextpos != std::string::npos)
237 nextpos = line.find_first_of(
" \n\r", pos + 1);
238 std::string subline = line.substr(pos + 1, nextpos - pos - 1);
240 sample[id] = atof(subline.c_str());
244 samples->SetMeasurementVectorSize(itk::NumericTraits<InputSampleType>::GetLength(sample));
245 samples->PushBack(sample);
246 labels->PushBack(label);
251 this->SetInputListSample(samples);
252 this->SetTargetListSample(labels);
256 template <
class TInputValue,
class TTargetValue>
270 template <
class TInputValue,
class TTargetValue>
277 template <
class TInputValue,
class TTargetValue>
281 Superclass::PrintSelf(os, indent);