20 #ifndef otbTrainVectorBase_hxx
21 #define otbTrainVectorBase_hxx
30 template <
class TInputValue,
class TOutputValue>
35 this->SetParameterDescription(
"io",
"This group of parameters allows setting input and output data.");
38 this->SetParameterDescription(
"io.vd",
"Input geometries used for training (note: all geometries from the layer will be used)");
41 this->MandatoryOff(
"io.stats");
42 this->SetParameterDescription(
"io.stats",
"XML file containing mean and variance of each feature.");
45 this->SetParameterDescription(
"io.out",
"Output file containing the model estimated (.txt format).");
48 this->SetParameterDescription(
"layer",
"Index of the layer to use in the input vector file.");
49 this->MandatoryOff(
"layer");
50 this->SetDefaultParameterInt(
"layer", 0);
53 this->SetParameterDescription(
"feat",
"List of field names in the input vector data to be used as features for training.");
54 this->SetVectorData(
"feat",
"io.vd");
55 this->SetTypeFilter(
"feat", { OFTInteger, OFTInteger64, OFTReal });
59 this->SetParameterDescription(
"valid",
"This group of parameters defines validation data.");
62 this->SetParameterDescription(
"valid.vd",
63 "Geometries used for validation "
64 "(must contain the same fields used for training, all geometries from the layer will be used)");
65 this->MandatoryOff(
"valid.vd");
68 this->SetParameterDescription(
"valid.layer",
"Index of the layer to use in the validation vector file.");
69 this->MandatoryOff(
"valid.layer");
70 this->SetDefaultParameterInt(
"valid.layer", 0);
73 this->AddParameter(
ParameterType_Field,
"cfield",
"Field containing the class integer label for supervision");
74 this->SetParameterDescription(
"cfield",
75 "Field containing the class id for supervision. "
76 "The values in this field shall be cast into integers. "
77 "Only geometries with this field available will be taken into account.");
78 this->SetVectorData(
"cfield",
"io.vd");
79 this->SetTypeFilter(
"cfield", { OFTString, OFTInteger, OFTInteger64, OFTReal });
80 this->SetListViewSingleSelectionMode(
"cfield",
true);
83 this->SetParameterDescription(
"v",
"Verbose mode, display the contingency table result.");
84 this->SetParameterInt(
"v", 1);
87 this->SetDocExampleParameterValue(
"io.vd",
"vectorData.shp");
88 this->SetDocExampleParameterValue(
"io.stats",
"meanVar.xml");
89 this->SetDocExampleParameterValue(
"io.out",
"svmModel.svm");
90 this->SetDocExampleParameterValue(
"feat",
"perimeter area width");
91 this->SetDocExampleParameterValue(
"cfield",
"predicted");
97 this->AddRANDParameter();
100 template <
class TInputValue,
class TOutputValue>
104 if (this->HasValue(
"io.vd"))
106 std::vector<std::string> vectorFileList = this->GetParameterStringList(
"io.vd");
108 ogr::Layer layer = ogrDS->GetLayer(
static_cast<size_t>(this->GetParameterInt(
"layer")));
111 this->ClearChoices(
"feat");
112 this->ClearChoices(
"cfield");
116 for (
int iField = 0; iField < feature.
ogr().GetFieldCount(); iField++)
118 std::string key, item = feature.
ogr().GetFieldDefnRef(iField)->GetNameRef();
120 std::string::iterator end = std::remove_if(key.begin(), key.end(), [](
char c) { return !std::isalnum(c); });
121 std::transform(key.begin(), end, key.begin(), tolower);
123 OGRFieldType fieldType = feature.
ogr().GetFieldDefnRef(iField)->GetType();
125 if (featTypeFilter.empty() ||
std::find(featTypeFilter.begin(), featTypeFilter.end(), fieldType) != std::end(featTypeFilter))
127 std::string tmpKey =
"feat." + key.substr(0,
static_cast<unsigned long>(end - key.begin()));
128 this->AddChoice(tmpKey, item);
130 if (cfieldTypeFilter.empty() ||
std::find(cfieldTypeFilter.begin(), cfieldTypeFilter.end(), fieldType) != std::end(cfieldTypeFilter))
132 std::string tmpKey =
"cfield." + key.substr(0,
static_cast<unsigned long>(end - key.begin()));
133 this->AddChoice(tmpKey, item);
139 template <
class TInputValue,
class TOutputValue>
142 m_FeaturesInfo.SetFieldNames(this->GetChoiceNames(
"feat"), this->GetSelectedItems(
"feat"));
145 if (m_FeaturesInfo.m_SelectedIdx.empty())
147 otbAppLogFATAL(<<
"No features have been selected to train the classifier on!");
151 ExtractAllSamples(measurement);
153 this->Train(m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString(
"io.out"));
155 m_PredictedList = this->Classify(m_ClassificationSamplesWithLabel.listSample, this->GetParameterString(
"io.out"));
158 template <
class TInputValue,
class TOutputValue>
161 m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
162 m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
165 template <
class TInputValue,
class TOutputValue>
169 return ExtractSamplesWithLabel(
"io.vd",
"layer", measurement);
172 template <
class TInputValue,
class TOutputValue>
176 if (this->GetClassifierCategory() == Superclass::Supervised)
179 SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel(
"valid.vd",
"valid.layer", measurement);
188 otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case.");
189 tmpSamplesWithLabel.
listSample = m_TrainingSamplesWithLabel.listSample;
190 tmpSamplesWithLabel.
labeledListSample = m_TrainingSamplesWithLabel.labeledListSample;
193 return tmpSamplesWithLabel;
197 return m_TrainingSamplesWithLabel;
201 template <
class TInputValue,
class TOutputValue>
205 if (this->HasValue(
"io.stats") && this->IsParameterEnabled(
"io.stats"))
208 std::string XMLfile = this->GetParameterString(
"io.stats");
209 statisticsReader->SetFileName(XMLfile);
223 template <
class TInputValue,
class TOutputValue>
229 if (this->HasValue(parameterName) && this->IsParameterEnabled(parameterName))
231 typename ListSampleType::Pointer input = ListSampleType::New();
232 typename TargetListSampleType::Pointer target = TargetListSampleType::New();
233 input->SetMeasurementVectorSize(m_FeaturesInfo.m_NbFeatures);
235 std::vector<std::string> fileList = this->GetParameterStringList(parameterName);
236 for (
unsigned int k = 0; k < fileList.size(); k++)
238 otbAppLogINFO(
"Reading vector file " << k + 1 <<
"/" << fileList.size());
240 ogr::Layer layer = source->GetLayer(
static_cast<size_t>(this->GetParameterInt(parameterLayer)));
242 bool goesOn = feature.
addr() != 0;
245 otbAppLogWARNING(
"The layer " << this->GetParameterInt(parameterLayer) <<
" of " << fileList[k] <<
" is empty, input is skipped.");
251 int cFieldIndex = feature.
ogr().GetFieldIndex(m_FeaturesInfo.m_SelectedCFieldName.c_str());
252 if (cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty())
254 otbAppLogFATAL(
"The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName <<
") has not been found in the vector file " << fileList[k]);
258 std::vector<int> featureFieldIndex(m_FeaturesInfo.m_NbFeatures, -1);
259 for (
unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++)
261 featureFieldIndex[i] = feature.
ogr().GetFieldIndex(m_FeaturesInfo.m_SelectedNames[i].c_str());
262 if (featureFieldIndex[i] < 0)
263 otbAppLogFATAL(
"The field name for feature " << m_FeaturesInfo.m_SelectedNames[i] <<
" has not been found in the vector file " << fileList[k]);
271 mv.SetSize(m_FeaturesInfo.m_NbFeatures);
272 for (
unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx)
274 switch (feature[featureFieldIndex[idx]].
GetType())
277 mv[idx] =
static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<
int>());
280 mv[idx] =
static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<
int>());
283 mv[idx] =
static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<
double>());
286 itkExceptionMacro(<<
"incorrect field type: " << feature[featureFieldIndex[idx]].
GetType() <<
".");
292 if (cFieldIndex >= 0 &&
ogr::Field(feature, cFieldIndex).HasBeenSet())
294 switch (feature[cFieldIndex].
GetType())
297 target->PushBack(
static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
300 target->PushBack(
static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
303 target->PushBack(
static_cast<ValueType>(feature[cFieldIndex].GetValue<double>()));
306 target->PushBack(
static_cast<ValueType>(std::stod(feature[cFieldIndex].GetValue<std::string>())));
309 itkExceptionMacro(<<
"incorrect field type: " << feature[featureFieldIndex[cFieldIndex]].
GetType() <<
".");
313 target->PushBack(0.);
315 feature = layer.
ogr().GetNextFeature();
316 goesOn = feature.
addr() != 0;
322 shiftScaleFilter->SetInput(input);
325 shiftScaleFilter->Update();
327 samplesWithLabel.
listSample = shiftScaleFilter->GetOutput();
329 samplesWithLabel.
listSample->DisconnectPipeline();
332 return samplesWithLabel;