Merged PR 5017: CatImputer
CatImputer Description: This featurizer imputes missing values in an input column with the most frequent one. Design: Underlying implementation of this featurizer is composed of two estimators: 1) HistogramEstimator: This estimator computes the histogram for the input column and creates a HistogramAnnotation. Note that this 'IS A' Annotation Estimator i.e it doesn't have a transformer. 2) HistogramConsumerEstimator: This class retrieves a HistogramAnnotation created by HistogramEstimator and computes the most frequent value from it. This value is then used to impute missing values. Both of these estimators are chained in PipelineExecutionEstimator which is exposed as CatImputer.
This commit is contained in:
Родитель
67d0fd984c
Коммит
a2ab1c8a4e
|
@ -0,0 +1,333 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "../Featurizer.h"
|
||||
#include "../PipelineExecutionEstimator.h"
|
||||
#include "../Archive.h"
|
||||
#include "../Traits.h"
|
||||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class HistogramAnnotation
|
||||
/// \brief This is an annotation class which holds all the values and corresponding
|
||||
/// frequencies for an input column.
|
||||
///
|
||||
template <typename T>
|
||||
class HistogramAnnotation : public Annotation {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using Histogram = std::map<T, std::uint32_t>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Histogram Value;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
HistogramAnnotation(Histogram value) :
|
||||
Annotation(this),
|
||||
Value(std::move(value)) {
|
||||
}
|
||||
|
||||
~HistogramAnnotation(void) override = default;
|
||||
|
||||
FEATURIZER_MOVE_CONSTRUCTOR_ONLY(HistogramAnnotation);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class HistogramEstimator
|
||||
/// \brief This class computes the histogram for an input column
|
||||
/// and creates a HistogramAnnotation.
|
||||
///
|
||||
template <typename InputT,typename TransformedT, size_t ColIndexV>
|
||||
class HistogramEstimator : public AnnotationEstimator<InputT const &> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
HistogramEstimator(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
AnnotationEstimator<InputT const &>("HistogramEstimator", std::move(pAllColumnAnnotations)){
|
||||
}
|
||||
|
||||
FEATURIZER_MOVE_CONSTRUCTOR_ONLY(HistogramEstimator);
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType = AnnotationEstimator<InputT const &>;
|
||||
using Histogram = std::map<TransformedT, std::uint32_t>;
|
||||
|
||||
using TraitsT = Traits<typename std::remove_cv<typename std::remove_reference<InputT>::type>::type>;
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Histogram _histogram;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Estimator::FitResult fit_impl(typename BaseType::FitBufferInputType const *pBuffer, size_t cBuffer) override;
|
||||
|
||||
Estimator::FitResult complete_training_impl(void) override;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class HistogramConsumerEstimator
|
||||
/// \brief This class retrieves a HistogramAnnotation and computes
|
||||
/// the most frequent value from it. This value is used to
|
||||
/// replace null values.
|
||||
///
|
||||
template <typename InputT,typename TransformedT>
|
||||
class HistogramConsumerEstimator : public TransformerEstimator<InputT const &, TransformedT> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType = TransformerEstimator<InputT const &, TransformedT>;
|
||||
using TraitsT = Traits<typename std::remove_cv<typename std::remove_reference<InputT>::type>::type>;
|
||||
|
||||
class Transformer : public BaseType::Transformer {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Transformer(TransformedT mostFreq) : _mostFreq(std::move(mostFreq)) {}
|
||||
Transformer(typename BaseType::Transformer::Archive & ar);
|
||||
|
||||
FEATURIZER_MOVE_CONSTRUCTOR_ONLY(Transformer);
|
||||
|
||||
typename BaseType::TransformedType execute(typename BaseType::InputType input) override;
|
||||
|
||||
void save(typename BaseType::Transformer::Archive & ar) const override;
|
||||
|
||||
TransformedT const & get_most_frequent_value() const;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
TransformedT _mostFreq;
|
||||
};
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
HistogramConsumerEstimator(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType("HistogramConsumerEstimator", std::move(pAllColumnAnnotations), true) {
|
||||
}
|
||||
|
||||
~HistogramConsumerEstimator(void) override = default;
|
||||
|
||||
FEATURIZER_MOVE_CONSTRUCTOR_ONLY(HistogramConsumerEstimator);
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using Histogram = std::map<TransformedT, std::uint32_t>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Estimator::FitResult fit_impl(typename BaseType::BaseType::FitBufferInputType *pBuffer, size_t cBuffer) override;
|
||||
Estimator::FitResult complete_training_impl(void) override;
|
||||
typename BaseType::TransformerUniquePtr create_transformer_impl(void) override;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class CatImputerEstimator
|
||||
/// \brief This class 'chains' HistogramEstimator and HistogramConsumerEstimator.
|
||||
/// HistogramEstimator generates HistogramAnnotation which is consumed by
|
||||
/// HistogramConsumerEstimator to compute most frequent value.
|
||||
///
|
||||
template <typename InputT,typename TransformedT>
|
||||
class CatImputerEstimator : public PipelineExecutionEstimator<HistogramEstimator<InputT,TransformedT,0>,HistogramConsumerEstimator<InputT,TransformedT>> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType = PipelineExecutionEstimator<HistogramEstimator<InputT,TransformedT,0>,HistogramConsumerEstimator<InputT,TransformedT>>;
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
CatImputerEstimator(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType("CatImputerEstimator", std::move(pAllColumnAnnotations)) {
|
||||
}
|
||||
|
||||
FEATURIZER_MOVE_CONSTRUCTOR_ONLY(CatImputerEstimator);
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Implementation
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | HistogramEstimator
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename InputT,typename TransformedT, size_t ColIndexV>
|
||||
Estimator::FitResult HistogramEstimator<InputT,TransformedT,ColIndexV>::fit_impl(typename BaseType::FitBufferInputType const *pBuffer, size_t cBuffer) {
|
||||
|
||||
typename BaseType::FitBufferInputType const * const pEndBuffer(pBuffer + cBuffer);
|
||||
|
||||
while(pBuffer != pEndBuffer) {
|
||||
InputT const & input(*pBuffer++);
|
||||
if(TraitsT::IsNull(input))
|
||||
continue;
|
||||
|
||||
typename Histogram::iterator const iter(
|
||||
[this, &input](void) -> typename Histogram::iterator {
|
||||
auto value = TraitsT::GetValue(input);
|
||||
typename Histogram::iterator const i(_histogram.find(value));
|
||||
|
||||
if(i != _histogram.end())
|
||||
return i;
|
||||
|
||||
std::pair<typename Histogram::iterator, bool> const result(_histogram.insert(std::make_pair(value, 0)));
|
||||
|
||||
return result.first;
|
||||
}()
|
||||
);
|
||||
|
||||
iter->second += 1;
|
||||
}
|
||||
|
||||
return Estimator::FitResult::Continue;
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT, size_t ColIndexV>
|
||||
Estimator::FitResult HistogramEstimator<InputT,TransformedT,ColIndexV>::complete_training_impl(void) {
|
||||
|
||||
BaseType::add_annotation(std::make_shared<HistogramAnnotation<TransformedT>>(std::move(_histogram)), ColIndexV);
|
||||
|
||||
return Estimator::FitResult::Complete;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | HistogramConsumerEstimator
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
HistogramConsumerEstimator<InputT,TransformedT>::Transformer::Transformer(typename BaseType::Transformer::Archive & ar) {
|
||||
_mostFreq = Traits<TransformedT>::deserialize(ar);
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
typename HistogramConsumerEstimator<InputT,TransformedT>::BaseType::TransformedType HistogramConsumerEstimator<InputT,TransformedT>::Transformer::execute(typename BaseType::InputType input) {
|
||||
|
||||
if(TraitsT::IsNull(input))
|
||||
return _mostFreq;
|
||||
|
||||
return TraitsT::GetValue(input);
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
void HistogramConsumerEstimator<InputT,TransformedT>::Transformer::save(typename HistogramConsumerEstimator<InputT,TransformedT>::BaseType::Transformer::Archive & ar) const {
|
||||
Traits<TransformedT>::serialize(ar,_mostFreq);
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
TransformedT const & HistogramConsumerEstimator<InputT,TransformedT>::Transformer::get_most_frequent_value() const {
|
||||
return _mostFreq;
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
Estimator::FitResult HistogramConsumerEstimator<InputT,TransformedT>::fit_impl(typename BaseType::BaseType::FitBufferInputType *, size_t ) {
|
||||
throw std::runtime_error("This should never be called as this class will not be used during training");
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
Estimator::FitResult HistogramConsumerEstimator<InputT,TransformedT>::complete_training_impl(void) {
|
||||
throw std::runtime_error("This should never be called as this class will not be used during training");
|
||||
}
|
||||
|
||||
template <typename InputT,typename TransformedT>
|
||||
typename HistogramConsumerEstimator<InputT,TransformedT>::BaseType::TransformerUniquePtr HistogramConsumerEstimator<InputT,TransformedT>::create_transformer_impl(void) {
|
||||
|
||||
// Retrieve Histogram from Annotation
|
||||
AnnotationMaps const & maps(Estimator::get_column_annotations());
|
||||
// Currently Annnotations are per output column index (0-based)
|
||||
// Since we've only one column as output- hardcoding this to 0 now.
|
||||
// Expect annotation design to be further rationalized in near future
|
||||
// which will address this hard-coding.
|
||||
AnnotationMap const & annotations(maps[0]);
|
||||
AnnotationMap::const_iterator const & iterAnnotations(annotations.find("HistogramEstimator"));
|
||||
|
||||
if(iterAnnotations == annotations.end())
|
||||
throw std::runtime_error("Couldn't retrieve HistogramAnnotation.");
|
||||
|
||||
// An output column can have multiple annotations from same 'kind' of estimator.
|
||||
// However, since we have only one estimator- hence the hard-coded value of 0 for retrieval.
|
||||
// Expect annotation design to be further rationalized in near future
|
||||
// which will address this hard-coding.
|
||||
Annotation const & annotation(*iterAnnotations->second[0]);
|
||||
assert(dynamic_cast<HistogramAnnotation<TransformedT> const *>(&annotation));
|
||||
HistogramAnnotation<TransformedT> const & histogramAnnotation(static_cast<HistogramAnnotation<TransformedT> const &>(annotation));
|
||||
Histogram const & histogram(histogramAnnotation.Value);
|
||||
|
||||
// Compute most frequent value from Histogram
|
||||
typename Histogram::const_iterator iMostCommon(histogram.end());
|
||||
|
||||
for(typename Histogram::const_iterator iter=histogram.begin(); iter != histogram.end(); ++iter) {
|
||||
if(iMostCommon == histogram.end() || iter->second > iMostCommon->second) {
|
||||
iMostCommon = iter;
|
||||
}
|
||||
}
|
||||
if(iMostCommon == histogram.end())
|
||||
throw std::runtime_error("All null values or empty training set.");
|
||||
|
||||
return std::make_unique<Transformer>(iMostCommon->first);
|
||||
}
|
||||
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
|
@ -28,6 +28,7 @@ add_library(libFeaturizers STATIC
|
|||
../SampleAddFeaturizer.h
|
||||
../SampleAddFeaturizer.cpp
|
||||
../StringFeaturizer.h
|
||||
../CatImputerFeaturizer.h
|
||||
)
|
||||
|
||||
enable_testing()
|
||||
|
@ -36,6 +37,7 @@ foreach(_test_name IN ITEMS
|
|||
SampleAddFeaturizer_UnitTest
|
||||
DateTimeFeaturizer_UnitTests
|
||||
StringFeaturizer_UnitTest
|
||||
CatImputerFeaturizer_UnitTests
|
||||
)
|
||||
add_executable(${_test_name} ${_test_name}.cpp)
|
||||
|
||||
|
|
|
@ -0,0 +1,265 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "../../Shared/optional.h"
|
||||
#include "../../Featurizers/CatImputerFeaturizer.h"
|
||||
|
||||
namespace NS = Microsoft::Featurizer;
|
||||
namespace {
|
||||
|
||||
template <typename T, typename ArgT>
|
||||
void make_vector(std::vector<T> & v, ArgT && arg) {
|
||||
v.emplace_back(std::forward<ArgT>(arg));
|
||||
}
|
||||
|
||||
template <typename T, typename ArgT, typename... ArgsT>
|
||||
void make_vector(std::vector<T> &v, ArgT && arg, ArgsT &&...args) {
|
||||
make_vector(v, std::forward<ArgT>(arg));
|
||||
make_vector(v, std::forward<ArgsT>(args)...);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename T, typename... ArgsT>
|
||||
std::vector<T> make_vector(ArgsT &&... args) {
|
||||
std::vector<T> result;
|
||||
|
||||
result.reserve(sizeof...(ArgsT));
|
||||
|
||||
make_vector(result, std::forward<ArgsT>(args)...);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> make_vector(void) {
|
||||
return std::vector<T>();
|
||||
}
|
||||
|
||||
template <typename PipelineT>
|
||||
std::vector<typename PipelineT::TransformedType> Test(
|
||||
PipelineT pipeline,
|
||||
std::vector<std::vector<std::remove_const_t<std::remove_reference_t<typename PipelineT::InputType>>>> const &inputBatches,
|
||||
std::vector<std::remove_const_t<std::remove_reference_t<typename PipelineT::InputType>>> const &data
|
||||
) {
|
||||
using FitResult = typename NS::Estimator::FitResult;
|
||||
using Batches = std::vector<std::vector<std::remove_const_t<std::remove_reference_t<typename PipelineT::InputType>>>>;
|
||||
|
||||
if(inputBatches.empty() == false) {
|
||||
// Train the pipeline
|
||||
typename Batches::const_iterator iter(inputBatches.begin());
|
||||
|
||||
while(true) {
|
||||
FitResult const result(pipeline.fit(iter->data(), iter->size()));
|
||||
|
||||
if(result == FitResult::Complete)
|
||||
break;
|
||||
else if(result == FitResult::ResetAndContinue)
|
||||
iter = inputBatches.begin();
|
||||
else if(result == FitResult::Continue) {
|
||||
++iter;
|
||||
|
||||
if(iter == inputBatches.end()) {
|
||||
if(pipeline.complete_training() == FitResult::Complete)
|
||||
break;
|
||||
|
||||
iter = inputBatches.begin();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(pipeline.is_training_complete());
|
||||
|
||||
typename PipelineT::TransformerUniquePtr pTransformer(pipeline.create_transformer());
|
||||
std::vector<typename PipelineT::TransformedType> output;
|
||||
|
||||
output.reserve(data.size());
|
||||
|
||||
for(auto const &item : data)
|
||||
output.emplace_back(pTransformer->execute(item));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
struct InternalValueTraits {
|
||||
using Type = A;
|
||||
static A GetNullvalue() { return A{}; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct InternalValueTraits<float> {
|
||||
static float GetNullvalue() { return std::numeric_limits<float>::quiet_NaN(); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct InternalValueTraits<double> {
|
||||
static double GetNullvalue() { return std::numeric_limits<double>::quiet_NaN(); }
|
||||
};
|
||||
|
||||
template<typename inputType, typename transformedType, typename castType = uint8_t>
|
||||
void NumericTestWrapper(){
|
||||
inputType null = InternalValueTraits<inputType>::GetNullvalue();
|
||||
|
||||
// Passing int values to make_vector for an optional type gives following error
|
||||
// error: implicit conversion loses integer precision: 'int' to 'nonstd::optional_lite::
|
||||
// Hence explicit cast to uint8_t.
|
||||
auto trainingBatches = make_vector<std::vector<inputType>>(
|
||||
make_vector<inputType>(static_cast<castType>(10),static_cast<castType>(20),null),
|
||||
make_vector<inputType>(static_cast<castType>(10),static_cast<castType>(30),null),
|
||||
make_vector<inputType>(static_cast<castType>(10),static_cast<castType>(10),null),
|
||||
make_vector<inputType>(static_cast<castType>(11),static_cast<castType>(15),null),
|
||||
make_vector<inputType>(static_cast<castType>(18),static_cast<castType>(8),null));
|
||||
|
||||
auto inferencingInput = make_vector<inputType>(static_cast<castType>(5),static_cast<castType>(8),static_cast<castType>(20)
|
||||
,null,null,null,null);
|
||||
|
||||
auto inferencingOutput = make_vector<transformedType>(5,8,20,10,10,10,10);
|
||||
|
||||
NS::AnnotationMapsPtr const pAllColumnAnnotations(NS::CreateTestAnnotationMapsPtr(1));
|
||||
|
||||
CHECK(
|
||||
Test(
|
||||
NS::CatImputerEstimator<inputType,transformedType>(pAllColumnAnnotations),
|
||||
trainingBatches,
|
||||
inferencingInput
|
||||
) == inferencingOutput
|
||||
);
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- int8_t") {
|
||||
using type = nonstd::optional<std::int8_t>;
|
||||
using transformedType = std::int8_t;
|
||||
|
||||
// With default castType of uint8_t we get following error
|
||||
// error: implicit conversion changes signedness: 'unsigned char' to 'nonstd::optional_lite::detail::storage_t<signed char>::value_type' (aka 'signed char')
|
||||
NumericTestWrapper<type,transformedType,std::int8_t>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- uint8_t") {
|
||||
using type = nonstd::optional<std::uint8_t>;
|
||||
using transformedType = std::uint8_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- uint16_t") {
|
||||
using type = nonstd::optional<std::uint16_t>;
|
||||
using transformedType = std::uint16_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- int16_t") {
|
||||
using type = nonstd::optional<std::int16_t>;
|
||||
using transformedType = std::int16_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- uint32_t") {
|
||||
using type = nonstd::optional<std::uint32_t>;
|
||||
using transformedType = std::uint32_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- int32_t") {
|
||||
using type = nonstd::optional<std::int32_t>;
|
||||
using transformedType = std::int32_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- uint64_t") {
|
||||
using type = nonstd::optional<std::uint64_t>;
|
||||
using transformedType = std::uint64_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- int64_t") {
|
||||
using type = nonstd::optional<std::int64_t>;
|
||||
using transformedType = std::int64_t;
|
||||
|
||||
NumericTestWrapper<type,transformedType>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- float") {
|
||||
NumericTestWrapper<std::float_t,std::float_t>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- double") {
|
||||
NumericTestWrapper<std::double_t,std::double_t>();
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- string") {
|
||||
using type = nonstd::optional<std::string>;
|
||||
using transformedType = std::string;
|
||||
|
||||
NS::AnnotationMapsPtr const pAllColumnAnnotations(NS::CreateTestAnnotationMapsPtr(1));
|
||||
|
||||
CHECK(
|
||||
Test(
|
||||
NS::CatImputerEstimator<type,transformedType>(pAllColumnAnnotations),
|
||||
make_vector<std::vector<type>>(
|
||||
make_vector<type>("one", "one", "one",type{},type{},"two", "three")
|
||||
),
|
||||
make_vector<type>("one", "two", "three",type{})
|
||||
) == make_vector<transformedType>("one","two","three","one")
|
||||
);
|
||||
}
|
||||
|
||||
TEST_CASE("CatImputer- All values Null") {
|
||||
using type = nonstd::optional<std::int64_t>;
|
||||
using transformedType = std::int64_t;
|
||||
|
||||
NS::AnnotationMapsPtr const pAllColumnAnnotations(NS::CreateTestAnnotationMapsPtr(1));
|
||||
|
||||
CHECK_THROWS_WITH(Test(
|
||||
NS::CatImputerEstimator<type,transformedType>(pAllColumnAnnotations),
|
||||
make_vector<std::vector<type>>(
|
||||
make_vector<type>(type{},type{},type{},type{},type{},type{})),
|
||||
make_vector<type>(5, 8, 20,type{}))
|
||||
, Catch::Contains("All null values or empty training set."));
|
||||
}
|
||||
|
||||
TEST_CASE("Serialization/Deserialization- Numeric") {
|
||||
using type = nonstd::optional<std::int64_t>;
|
||||
using transformedType = std::int64_t;
|
||||
using transformerType = NS::HistogramConsumerEstimator<type,transformedType>::Transformer;
|
||||
auto model = std::make_shared<transformerType>(10);
|
||||
|
||||
NS::Archive archive;
|
||||
model->save(archive);
|
||||
std::vector<unsigned char> vec = archive.commit();
|
||||
CHECK(vec.size() == 8);
|
||||
|
||||
NS::Archive loader(vec);
|
||||
transformerType modelLoaded(loader);
|
||||
CHECK(modelLoaded.get_most_frequent_value() == 10);
|
||||
}
|
||||
|
||||
TEST_CASE("Serialization/Deserialization- string") {
|
||||
using type = nonstd::optional<std::string>;
|
||||
using transformedType = std::string;
|
||||
using transformerType = NS::HistogramConsumerEstimator<type,transformedType>::Transformer;
|
||||
auto model = std::make_shared<transformerType>("one");
|
||||
|
||||
NS::Archive archive;
|
||||
model->save(archive);
|
||||
std::vector<unsigned char> vec = archive.commit();
|
||||
|
||||
NS::Archive loader(vec);
|
||||
transformerType modelLoaded(loader);
|
||||
CHECK(modelLoaded.get_most_frequent_value() == "one");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ template <typename T>
|
|||
struct TraitsImpl {
|
||||
using nullable_type = nonstd::optional<T>;
|
||||
static bool IsNull(nullable_type const& value) {
|
||||
return !value.is_initialized();
|
||||
return !value.has_value();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -250,11 +250,18 @@ struct Traits<std::uint64_t> : public TraitsImpl<std::uint64_t> {
|
|||
};
|
||||
|
||||
template <>
|
||||
struct Traits<std::float_t> : public TraitsImpl<std::float_t> {
|
||||
struct Traits<std::float_t> {
|
||||
using nullable_type = std::float_t;
|
||||
static bool IsNull(nullable_type const& value) {
|
||||
return std::isnan(value);
|
||||
}
|
||||
|
||||
static std::float_t const & GetValue(nullable_type const& value) {
|
||||
if (IsNull(value))
|
||||
throw std::runtime_error("GetValue attempt on float_t null.");
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static std::string ToString(nullable_type const& value) {
|
||||
if (IsNull(value))
|
||||
|
@ -278,12 +285,19 @@ struct Traits<std::float_t> : public TraitsImpl<std::float_t> {
|
|||
};
|
||||
|
||||
template <>
|
||||
struct Traits<std::double_t> : public TraitsImpl<std::double_t> {
|
||||
struct Traits<std::double_t> {
|
||||
using nullable_type = std::double_t;
|
||||
static bool IsNull(nullable_type const& value) {
|
||||
return std::isnan(value);
|
||||
}
|
||||
|
||||
static std::double_t const & GetValue(nullable_type const& value) {
|
||||
if (IsNull(value))
|
||||
throw std::runtime_error("GetValue attempt on double_t null.");
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static std::string ToString(nullable_type const& value) {
|
||||
if (IsNull(value))
|
||||
{
|
||||
|
@ -480,8 +494,12 @@ struct Traits<std::map<KeyT, T, CompareT, AllocatorT>> : public TraitsImpl<std::
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
struct Traits<nonstd::optional<T>> : public TraitsImpl<nonstd::optional<T>> {
|
||||
struct Traits<nonstd::optional<T>> {
|
||||
using nullable_type = nonstd::optional<T>;
|
||||
|
||||
static bool IsNull(nullable_type const& value) {
|
||||
return !value.has_value();
|
||||
}
|
||||
|
||||
static std::string ToString(nullable_type const& value) {
|
||||
if (value) {
|
||||
|
@ -489,6 +507,14 @@ struct Traits<nonstd::optional<T>> : public TraitsImpl<nonstd::optional<T>> {
|
|||
}
|
||||
return "NULL";
|
||||
}
|
||||
|
||||
static T const & GetValue(nullable_type const& value) {
|
||||
if (value){
|
||||
return value.value();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("GetValue attempt on Optional type null.");
|
||||
}
|
||||
|
||||
template <typename ArchiveT>
|
||||
static ArchiveT & serialize(ArchiveT &ar, nonstd::optional<T> const &value) {
|
||||
|
|
|
@ -37,7 +37,24 @@ static_assert(std::is_same<Traits<std::tuple<int>>::nullable_type, nonstd::optio
|
|||
|
||||
TEST_CASE("Transformer_Nullable") {
|
||||
nonstd::optional<std::int8_t> arg_null;
|
||||
std::float_t arg_f_ini = std::numeric_limits<std::float_t>::quiet_NaN();
|
||||
std::double_t arg_d_ini = std::numeric_limits<std::double_t>::quiet_NaN();
|
||||
|
||||
nonstd::optional<std::int64_t> arg_64(-7799);
|
||||
std::float_t arg_f = 123;
|
||||
std::double_t arg_d = 123.45;
|
||||
|
||||
CHECK(Traits<nonstd::optional<std::int8_t>>::ToString(arg_null) == "NULL");
|
||||
CHECK(Traits<std::float_t>::ToString(Traits<std::float_t>::GetValue(arg_f)) == "123");
|
||||
CHECK(Traits<nonstd::optional<std::int64_t>>::GetValue(arg_64) == -7799);
|
||||
CHECK(Traits<std::double_t>::ToString(Traits<std::double_t>::GetValue(arg_d)) == "123.45");
|
||||
|
||||
CHECK_THROWS_WITH(Traits<nonstd::optional<std::int8_t>>::GetValue(arg_null)
|
||||
, Catch::Contains("GetValue attempt on Optional type null."));
|
||||
CHECK_THROWS_WITH(Traits<float_t>::GetValue(arg_f_ini)
|
||||
, Catch::Contains("GetValue attempt on float_t null."));
|
||||
CHECK_THROWS_WITH(Traits<double_t>::GetValue(arg_d_ini)
|
||||
, Catch::Contains("GetValue attempt on double_t null"));
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Binary") {
|
||||
|
|
Загрузка…
Ссылка в новой задаче