ENH: Add traits for MachineLearningModel to accomodate containers as template parameters
authorJulien Michel <julien.michel@cnes.fr>
Tue, 13 Jun 2017 09:25:29 +0000 (11:25 +0200)
committerJulien Michel <julien.michel@cnes.fr>
Tue, 13 Jun 2017 09:25:29 +0000 (11:25 +0200)
Modules/Learning/LearningBase/include/otbMachineLearningModel.h
Modules/Learning/LearningBase/include/otbMachineLearningModelTraits.h [new file with mode: 0644]
Modules/Learning/LearningBase/test/CMakeLists.txt
Modules/Learning/LearningBase/test/otbMachineLearningModelTemplates.cxx [new file with mode: 0644]

index c5fb59f..552880c 100644 (file)
@@ -22,8 +22,8 @@
 #define otbMachineLearningModel_h
 
 #include "itkObject.h"
-#include "itkVariableLengthVector.h"
 #include "itkListSample.h"
+#include "otbMachineLearningModelTraits.h"
 
 namespace otb
 {
@@ -66,6 +66,7 @@ namespace otb
  *
  * \ingroup OTBLearningBase
  */
+
 template <class TInputValue, class TTargetValue, class TConfidenceValue = double >
 class ITK_EXPORT MachineLearningModel
   : public itk::Object
@@ -81,22 +82,22 @@ public:
 
   /**\name Input related typedefs */
   //@{
-  typedef TInputValue                                   InputValueType;
-  typedef itk::VariableLengthVector<InputValueType>     InputSampleType;
-  typedef itk::Statistics::ListSample<InputSampleType>  InputListSampleType;
+  typedef typename MLMSampleTraits<TInputValue>::ValueType  InputValueType;
+  typedef typename MLMSampleTraits<TInputValue>::SampleType InputSampleType;
+  typedef itk::Statistics::ListSample<InputSampleType>      InputListSampleType;
   //@}
 
   /**\name Target related typedefs */
   //@{
-  typedef TTargetValue                                  TargetValueType;
-  typedef itk::FixedArray<TargetValueType,1>            TargetSampleType;
-  typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
+  typedef typename MLMTargetTraits<TTargetValue>::ValueType  TargetValueType;
+  typedef typename MLMTargetTraits<TTargetValue>::SampleType TargetSampleType;
+  typedef itk::Statistics::ListSample<TargetSampleType>      TargetListSampleType;
   //@}
 
   /**\name Confidence value typedef */
-  typedef TConfidenceValue                                  ConfidenceValueType;
-  typedef itk::FixedArray<ConfidenceValueType,1>            ConfidenceSampleType;
-  typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType;
+  typedef typename MLMTargetTraits<TConfidenceValue>::ValueType  ConfidenceValueType;
+  typedef typename MLMTargetTraits<TConfidenceValue>::SampleType ConfidenceSampleType;
+  typedef itk::Statistics::ListSample<ConfidenceSampleType>      ConfidenceListSampleType;
 
   /**\name Standard macros */
   //@{
diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModelTraits.h b/Modules/Learning/LearningBase/include/otbMachineLearningModelTraits.h
new file mode 100644 (file)
index 0000000..e9bc4ce
--- /dev/null
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
+ *
+ * This file is part of Orfeo Toolbox
+ *
+ *     https://www.orfeo-toolbox.org/
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#ifndef otbMachineLearningModelTraits_h
+#define otbMachineLearningModelTraits_h
+
+
+#include "itkVariableLengthVector.h"
+#include "itkFixedArray.h"
+#include "itkIsNumber.h"
+#include "itkMetaProgrammingLibrary.h"
+
+namespace otb
+{
+
+/**
+ * This is the struct defining the sample implementation for
+ * MachineLearningModel. It offers two type definitions: SampleType
+ * and ValueType.
+ *
+ * \tparam TInput : input sample type (can be either a scalar type or
+ * a VariableLenghtVector
+ * \tparam isNumber either TrueType or FalseType for partial
+ * specialization
+  
+ */
+template <typename TInput, typename isNumber> struct MLMSampleTraitsImpl;
+
+
+/// \cond SPECIALIZATION_IMPLEMENTATION
+// For Numbers
+template <typename TInput> struct MLMSampleTraitsImpl<TInput, itk::mpl::TrueType> {
+  typedef TInput                             ValueType;
+   typedef itk::VariableLengthVector<TInput> SampleType;
+};
+
+// For Vectors
+template <typename TInput> struct MLMSampleTraitsImpl<TInput, itk::mpl::FalseType> {
+  typedef typename TInput::ValueType        ValueType;
+  typedef TInput                            SampleType;
+};
+/// \endcond
+
+/**
+ * Simplified implementation of SampleTraits using MLMSampleTraitsImpl
+ */
+template <typename TInput> using MLMSampleTraits = MLMSampleTraitsImpl< TInput, typename itk::mpl::IsNumber<TInput>::Type >;
+
+
+/**
+ * This is the struct defining the sample implementation for
+ * MachineLearningModel. It offers two type definitions: TargetType
+ * and ValueType.
+ *
+ * \tparam TInput : input sample type (can be either a scalar type or
+ * a VariableLenghtVector or a FixedArray
+ * \tparam isNumber either TrueType or FalseType for partial
+ * specialization
+  
+ */
+template <typename TInput, typename isNumber> struct MLMTargetTraitsImpl;
+
+
+/// \cond SPECIALIZATION_IMPLEMENTATION
+// For Numbers
+template <typename TInput> struct MLMTargetTraitsImpl<TInput, itk::mpl::TrueType> {
+  typedef TInput                             ValueType;
+  typedef itk::FixedArray<TInput,1>          SampleType;
+};
+
+// For Vectors
+template <typename TInput> struct MLMTargetTraitsImpl<TInput, itk::mpl::FalseType> {
+  typedef typename TInput::ValueType        ValueType;
+  typedef TInput                            SampleType;
+};
+/// \endcond
+
+/**
+ * Simplified implementation of TargetTraits using MLMTargetTraitsImpl
+ */
+template <typename TInput> using MLMTargetTraits = MLMTargetTraitsImpl< TInput, typename itk::mpl::IsNumber<TInput>::Type >;
+
+
+} // End namespace otb
+
+#endif
index 74a67d4..d1d16c3 100644 (file)
@@ -29,6 +29,7 @@ otbDecisionTreeWithRealValues.cxx
 otbSEMClassifierNew.cxx
 otbDecisionTreeNew.cxx
 otbKMeansImageClassificationFilterNew.cxx
+otbMachineLearningModelTemplates.cxx
 )
 
 add_executable(otbLearningBaseTestDriver ${OTBLearningBaseTests})
diff --git a/Modules/Learning/LearningBase/test/otbMachineLearningModelTemplates.cxx b/Modules/Learning/LearningBase/test/otbMachineLearningModelTemplates.cxx
new file mode 100644 (file)
index 0000000..4acaf2c
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
+ *
+ * This file is part of Orfeo Toolbox
+ *
+ *     https://www.orfeo-toolbox.org/
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <otbMachineLearningModel.h>
+
+typedef otb::MachineLearningModel<float,short>          MachineLearningModelType1;
+typedef MachineLearningModelType1::InputValueType       InputValueType1;
+typedef MachineLearningModelType1::InputSampleType      InputSampleType1;
+typedef MachineLearningModelType1::InputListSampleType  InputListSampleType1;
+typedef MachineLearningModelType1::TargetValueType      TargetValueType1;
+typedef MachineLearningModelType1::TargetSampleType     TargetSampleType1;
+typedef MachineLearningModelType1::TargetListSampleType TargetListSampleType1;
+
+typedef otb::MachineLearningModel<float,itk::VariableLengthVector<double>> MachineLearningModelType2;
+typedef MachineLearningModelType2::InputValueType       InputValueType2;
+typedef MachineLearningModelType2::InputSampleType      InputSampleType2;
+typedef MachineLearningModelType2::InputListSampleType  InputListSampleType2;
+typedef MachineLearningModelType2::TargetValueType      TargetValueType2;
+typedef MachineLearningModelType2::TargetSampleType     TargetSampleType2;
+typedef MachineLearningModelType2::TargetListSampleType TargetListSampleType2;
+
+
+