21 #ifndef otbVectorPrediction_hxx
22 #define otbVectorPrediction_hxx
31 template <
bool RegressionMode>
34 DoInitSpecialization();
37 assert(GetParameterByKey(
"in") !=
nullptr);
38 assert(GetParameterByKey(
"instat") !=
nullptr);
39 assert(GetParameterByKey(
"model") !=
nullptr);
40 assert(GetParameterByKey(
"cfield") !=
nullptr);
41 assert(GetParameterByKey(
"feat") !=
nullptr);
42 assert(GetParameterByKey(
"out") !=
nullptr);
45 template <
bool RegressionMode>
50 auto shapefileName = GetParameterString(
"in");
53 auto layer = ogrDS->GetLayer(0);
54 OGRFeatureDefn& layerDefn = layer.GetLayerDefn();
59 for (
int iField = 0; iField < layerDefn.GetFieldCount(); iField++)
61 auto fieldDefn = layerDefn.GetFieldDefn(iField);
62 std::string item = fieldDefn->GetNameRef();
63 std::string key(item);
64 key.erase(std::remove_if(key.begin(), key.end(), [](
char c) { return !std::isalnum(c); }), key.end());
65 std::transform(key.begin(), key.end(), key.begin(), tolower);
66 auto fieldType = fieldDefn->GetType();
68 if (typeFilter.empty() ||
std::find(typeFilter.begin(), typeFilter.end(), fieldType) != std::end(typeFilter))
70 std::string tmpKey =
"feat." + key;
71 AddChoice(tmpKey, item);
77 template <
bool RegressionMode>
82 typename ListSampleType::Pointer input = ListSampleType::New();
84 const auto nbFeatures = GetSelectedItems(
"feat").size();
85 input->SetMeasurementVectorSize(nbFeatures);
86 std::vector<int> featureFieldIndex(nbFeatures, -1);
89 for (
unsigned int i = 0; i < nbFeatures; i++)
93 featureFieldIndex[i] = (*it_feat).GetFieldIndex(GetChoiceNames(
"feat")[GetSelectedItems(
"feat")[i]]);
97 otbAppLogFATAL(
"The field name for feature " << GetChoiceNames(
"feat")[GetSelectedItems(
"feat")[i]] <<
" has not been found" << std::endl);
101 for (
auto const& feature : layer)
104 for (
unsigned int idx = 0; idx < nbFeatures; ++idx)
106 auto field = feature[featureFieldIndex[idx]];
107 switch (field.GetType())
111 mv[idx] =
static_cast<ValueType>(field.template GetValue<int>());
114 mv[idx] =
static_cast<ValueType>(field.template GetValue<double>());
117 itkExceptionMacro(<<
"incorrect field type: " << field.GetType() <<
".");
126 template <
bool RegressionMode>
129 const int nbFeatures = GetSelectedItems(
"feat").size();
134 if (HasValue(
"instat") && IsParameterEnabled(
"instat"))
137 std::string XMLfile = GetParameterString(
"instat");
138 statisticsReader->SetFileName(XMLfile);
139 meanMeasurementVector = statisticsReader->GetStatisticVectorByName(
"mean");
140 stddevMeasurementVector = statisticsReader->GetStatisticVectorByName(
"stddev");
144 meanMeasurementVector.SetSize(nbFeatures);
145 meanMeasurementVector.Fill(0.);
146 stddevMeasurementVector.SetSize(nbFeatures);
147 stddevMeasurementVector.Fill(1.);
151 trainingShiftScaleFilter->SetInput(input);
152 trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
153 trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
154 trainingShiftScaleFilter->Update();
156 otbAppLogINFO(
"standard deviation used: " << stddevMeasurementVector);
160 return trainingShiftScaleFilter->GetOutput();
164 template <
bool RegressionMode>
173 layer = buffer->CopyLayer(inputLayer, std::string(
"Buffer"));
181 template <
bool RegressionMode>
190 for (
int k = 0; k < inLayerDefn.GetFieldCount(); k++)
192 OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
198 template <
bool RegressionMode>
203 const OGRFieldType labelType = RegressionMode ? OFTReal : OFTInteger;
205 int idx = layerDefn.GetFieldIndex(GetParameterString(
"cfield").c_str());
208 if (layerDefn.GetFieldDefn(idx)->GetType() != labelType)
209 itkExceptionMacro(
"Field name " << GetParameterString(
"cfield") <<
" already exists with a different type!");
213 OGRFieldDefn predictedField(GetParameterString(
"cfield").c_str(), labelType);
219 if (computeConfidenceMap)
221 idx = layerDefn.GetFieldIndex(confFieldName.c_str());
224 if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
225 itkExceptionMacro(
"Field name " << confFieldName <<
" already exists with a different type!");
229 OGRFieldDefn confidenceField(confFieldName.c_str(), OFTReal);
230 confidenceField.SetWidth(confidenceField.GetWidth());
231 confidenceField.SetPrecision(confidenceField.GetPrecision());
238 template <
bool RegressionMode>
240 typename ConfidenceListSampleType::Pointer quality,
bool updateMode,
bool computeConfidenceMap)
242 unsigned int count = 0;
243 std::string classfieldname = GetParameterString(
"cfield");
244 for (
auto const& feature : layer)
247 dstFeature.
SetFrom(feature, TRUE);
248 dstFeature.
SetFID(feature.GetFID());
249 auto field = dstFeature[classfieldname];
250 switch (field.GetType())
254 field.template SetValue<int>(target->GetMeasurementVector(count)[0]);
257 field.template SetValue<double>(target->GetMeasurementVector(count)[0]);
260 field.template SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
263 itkExceptionMacro(<<
"incorrect field type: " << field.GetType() <<
".");
265 if (computeConfidenceMap)
266 dstFeature[confFieldName].template SetValue<double>(quality->GetMeasurementVector(count)[0]);
279 template <
bool RegressionMode>
282 m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString(
"model"), MachineLearningModelFactoryType::ReadMode);
284 if (m_Model.IsNull())
286 otbAppLogFATAL(<<
"Error when loading model " << GetParameterString(
"model") <<
" : unsupported model type");
289 m_Model->SetRegressionMode(RegressionMode);
291 m_Model->Load(GetParameterString(
"model"));
294 auto shapefileName = GetParameterString(
"in");
297 auto layer = source->GetLayer(0);
298 auto input = ReadInputListSample(layer);
300 ListSampleType::Pointer listSample = NormalizeListSample(input);
301 typename LabelListSampleType::Pointer target;
305 const bool computeConfidenceMap = shouldComputeConfidenceMap();
306 typename ConfidenceListSampleType::Pointer quality;
308 if (computeConfidenceMap)
310 quality = ConfidenceListSampleType::New();
311 target = m_Model->PredictBatch(listSample, quality);
315 target = m_Model->PredictBatch(listSample);
318 const bool updateMode = !(IsParameterEnabled(
"out") && HasValue(
"out"));
328 output = ReopenDataSourceInUpdateMode(source, layer, buffer);
332 output = CreateOutputDataSource(layer);
337 OGRErr errStart = outLayer.
ogr().StartTransaction();
338 if (errStart != OGRERR_NONE)
340 itkExceptionMacro(<<
"Unable to start transaction for OGR layer " << outLayer.
ogr().GetName() <<
".");
343 AddPredictionField(outLayer, layer, computeConfidenceMap);
344 FillOutputLayer(outLayer, layer, target, quality, updateMode, computeConfidenceMap);
346 if (outLayer.
ogr().TestCapability(
"Transactions"))
348 const OGRErr errCommitX = outLayer.
ogr().CommitTransaction();
349 if (errCommitX != OGRERR_NONE)
351 itkExceptionMacro(<<
"Unable to commit transaction for OGR layer " << outLayer.
ogr().GetName() <<
".");
355 output->SyncToDisk();