Merged PR 4950: Updates to Featurizers to support additional scenarios
Updates to Featurizers to support additional scenarios
This commit is contained in:
Родитель
623496d798
Коммит
2e3d623850
|
@ -5,75 +5,144 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/serialization/access.hpp>
|
||||
#include <boost/serialization/base_object.hpp>
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/optional.hpp>
|
||||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class Transformer
|
||||
/// \brief Transforms a single "value" and output the result.
|
||||
/// A value can be anything from an integer to a collection
|
||||
/// of integers.
|
||||
/// \class Annotation
|
||||
/// \brief Base class for an individual datum associated with a column that is produced
|
||||
/// by an `Estimator`. Once an `Annotation` is created and associated with a column,
|
||||
/// any downstream `Estimator` can query for the `Annotation` and retrieve its
|
||||
/// associated values. With this system in place, we ensure that data associated with
|
||||
/// a column isn't calculated repeatedly.
|
||||
///
|
||||
template <typename ReturnT, typename ArgT>
|
||||
class Transformer {
|
||||
/// Examples of possible derived classes:
|
||||
/// - Mean of all values in a column, as calculated by a Mean `AnnotationEstimator`
|
||||
/// - Most common value in a column, as calculated by a histogram-like `AnnotationEstimator`
|
||||
/// - etc.
|
||||
///
|
||||
/// This base class doesn't contain the data produced by an `Estimator`, but
|
||||
/// `Estimators` should implement functionality to easily retrieve annotation
|
||||
/// data that they have created for a specific column. Note the virtual destructor
|
||||
/// to ensure proper cleanup when these values are destroyed.
|
||||
///
|
||||
class Annotation {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
using return_type = ReturnT;
|
||||
using arg_type = ArgT;
|
||||
using transformer_type = Transformer<ReturnT, ArgT>;
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \typedef EstimatorUniqueId
|
||||
/// \brief Different estimators of the same name produce annotations that must be uniquely
|
||||
/// identifiable. For example, the results produce by a mean `AnnotationEstimator`
|
||||
/// based on the first 200 values must be distinguishable from the results produced
|
||||
/// by a mean `AnnotationEstimator` based on the first 500 values.
|
||||
///
|
||||
/// The estimator's address is memory is used to determine uniqueness.
|
||||
///
|
||||
/// THIS IMPLEMENTATION WILL BREAK IF WE EVER SUPPORT SERIALIZATION, as the memory
|
||||
/// address of a deserialized object will be different from the memory address used
|
||||
/// when originally creating the `Annotation`. If we want to support serialization
|
||||
/// during training, we will need to introduce a mechanisms to uniquely identify
|
||||
/// object instances in a way that persists after objects are serialized/deserialized.
|
||||
/// `Annotations` are only used during training, so this will not be an issue when
|
||||
/// serializing `Transformers`.
|
||||
///
|
||||
using EstimatorUniqueId = void const *;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
EstimatorUniqueId const CreatorId;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
Transformer(void) = default;
|
||||
virtual ~Transformer(void) = default;
|
||||
|
||||
Transformer(Transformer const &) = delete;
|
||||
Transformer & operator =(Transformer const &) = delete;
|
||||
|
||||
Transformer(Transformer &&) = default;
|
||||
Transformer & operator =(Transformer &&) = delete;
|
||||
|
||||
virtual return_type transform(arg_type const &arg) const = 0;
|
||||
|
||||
private:
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
virtual ~Annotation(void) = default;
|
||||
|
||||
Annotation(Annotation const &) = delete;
|
||||
Annotation & operator =(Annotation const &) = delete;
|
||||
|
||||
Annotation(Annotation &&) = default;
|
||||
Annotation & operator =(Annotation &&) = delete;
|
||||
|
||||
protected:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &, unsigned int const /*version*/);
|
||||
// |
|
||||
// | Protected Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Annotation(EstimatorUniqueId creator_id);
|
||||
};
|
||||
|
||||
using AnnotationPtr = std::shared_ptr<Annotation>;
|
||||
|
||||
// `Estimators` with the same name may generate different Annotations based on
|
||||
// the settings provided when it was constructed...
|
||||
using AnnotationPtrs = std::vector<AnnotationPtr>;
|
||||
// TODO: Updating the vector should be thread safe when executing the DAGs in parallel.
|
||||
|
||||
// A single column supports `Annotations` from different `Estimators`...
|
||||
using AnnotationMap = std::map<std::string, AnnotationPtrs>;
|
||||
// TODO: Updating the map should be thread safe when executing the DAGs in parallel.
|
||||
|
||||
// An `Estimator` may support multiple columns...
|
||||
using AnnotationMaps = std::vector<AnnotationMap>;
|
||||
|
||||
// All `Estimators` within a DAG should use the same collection of column `Annotations`.
|
||||
using AnnotationMapsPtr = std::shared_ptr<AnnotationMaps>;
|
||||
|
||||
// TODO: Expect more classes with regards to Annotation as we use the functionality more.
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function CreateTestAnnotationMapsPtr
|
||||
/// \brief An `Estimator` requires an `AnnotationMapsPtr` upon
|
||||
/// construction. This method can be used to quickly create
|
||||
/// one of these objects during testing.
|
||||
///
|
||||
/// THIS FUNCTION SHOULD NOT BE USED IN ANY PRODUCTION CODE!
|
||||
///
|
||||
AnnotationMapsPtr CreateTestAnnotationMapsPtr(size_t num_cols);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class Estimator
|
||||
/// \brief Collects state over a collection of data, then produces
|
||||
/// a `Transformer` that is able to operate on that collected
|
||||
/// state.
|
||||
/// \brief An `Estimator` collects data during the training process and ultimately generates
|
||||
/// state data that is used during training or inferencing. Note that
|
||||
/// `Estimators` are only used during the training process.
|
||||
///
|
||||
template <typename ReturnT, typename ArgT>
|
||||
class Estimator {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
using transformer_type = Transformer<ReturnT, ArgT>;
|
||||
using TransformerUniquePtr = std::unique_ptr<transformer_type>;
|
||||
|
||||
using estimator_type = Estimator<ReturnT, ArgT>;
|
||||
|
||||
using apache_arrow = unsigned long; // TODO: Temp type as we figure out what will eventually be here
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using AnnotationMapsPtr = AnnotationMapsPtr;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
std::string const Name;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
Estimator(void) = default;
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
virtual ~Estimator(void) = default;
|
||||
|
||||
Estimator(Estimator const &) = delete;
|
||||
|
@ -82,34 +151,297 @@ public:
|
|||
Estimator(Estimator &&) = default;
|
||||
Estimator & operator =(Estimator &&) = delete;
|
||||
|
||||
// This method can be called repeatedly in the support of streaming scenarios
|
||||
Estimator & fit(apache_arrow const &data);
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function get_column_annotations
|
||||
/// \brief Returns the column annotations for all columns. Note
|
||||
/// that this information is shared across all `Estimators`
|
||||
/// with the DAG.
|
||||
///
|
||||
/// Note that is most cases, callers should prefer to retrieve
|
||||
/// `Annotation` values from a derived class method rather than
|
||||
/// this method. This method is here so that the final
|
||||
/// `Annotation` states for all columns can be used in
|
||||
/// framework-specific code.
|
||||
///
|
||||
AnnotationMaps const & get_column_annotations(void) const;
|
||||
|
||||
// Calls to `commit` are destructive - all previously generated state should
|
||||
// be reset. `Estimator` objects that want to share state prior to calls to commit
|
||||
// should implement a `copy` method.
|
||||
TransformerUniquePtr commit(void);
|
||||
protected:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Protected Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Estimator(std::string name, AnnotationMapsPtr pAllColumnAnnotations);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function add_annotation
|
||||
/// \brief Adds an `Annotation` to the specified column.
|
||||
///
|
||||
void add_annotation(AnnotationPtr pAnnotation, size_t col_index) const;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function get_annotation_impl
|
||||
/// \brief Helper method that can be used by derived class when implementation functionality
|
||||
/// to retrieve `Annotation` data created by the derived class itself.
|
||||
///
|
||||
template <typename DerivedAnnotationT>
|
||||
boost::optional<DerivedAnnotationT &> get_annotation_impl(size_t col_index) const;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Data
|
||||
bool _committed = false;
|
||||
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &, unsigned int const /*version*/);
|
||||
|
||||
virtual Estimator & fit_impl(apache_arrow const &data) = 0;
|
||||
virtual TransformerUniquePtr commit_impl(void) = 0;
|
||||
AnnotationMapsPtr const _all_column_annotations;
|
||||
};
|
||||
|
||||
template <typename EstimatorT, typename... EstimatorConstructorArgsT>
|
||||
typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::apache_arrow const &data, EstimatorConstructorArgsT &&...args);
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class FitEstimatorImpl
|
||||
/// \brief Common base class for an `Estimator` that supports fit functionality. Derived
|
||||
/// classes can produce `Annotations` used by other `Estimators` during the training
|
||||
/// process and/or state that is returned to the caller during runtime as a part
|
||||
/// of training and inferencing activities.
|
||||
///
|
||||
template <typename InputT>
|
||||
class FitEstimatorImpl : public Estimator {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using InputType = InputT;
|
||||
using ThisType = FitEstimatorImpl<InputType>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \enum FitResult
|
||||
/// \brief Result returned by the `fit` method.
|
||||
///
|
||||
enum class FitResult {
|
||||
Complete, /// Fitting is complete and there is no need to call `fit` on this `Estimator` any more.
|
||||
Continue /// Continue providing data to `fit` (if such data is available).
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
FitEstimatorImpl(std::string name, AnnotationMapsPtr pAllColumnAnnotations, bool is_training_complete=false);
|
||||
~FitEstimatorImpl(void) override = default;
|
||||
|
||||
FitEstimatorImpl(FitEstimatorImpl const &) = delete;
|
||||
FitEstimatorImpl & operator =(FitEstimatorImpl const &) = delete;
|
||||
|
||||
FitEstimatorImpl(FitEstimatorImpl &&) = default;
|
||||
FitEstimatorImpl & operator =(FitEstimatorImpl &&) = delete;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function is_training_complete
|
||||
/// \brief Returns true if the `complete_training` method has been called
|
||||
/// for this `Estimator`. `fit` should not be invoked on
|
||||
/// an `Estimator` where training has been completed.
|
||||
///
|
||||
bool is_training_complete(void) const;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function fit
|
||||
/// \brief Method invoked during training. This method will be invoked until it returns `FitResult::Complete`
|
||||
/// or no additional data is available. Derived classes should use this columnar data to create
|
||||
/// state (either in the form of `Annotations`) used during the training process or state data that
|
||||
/// is used in future calls to `transform`. This method should not be invoked on an object that
|
||||
/// has already been completed.
|
||||
///
|
||||
FitResult fit(InputType value);
|
||||
FitResult fit(InputType const *pInputBuffer, size_t cInputBuffer, boost::optional<std::uint64_t> const &optionalNumTrailingNulls=boost::optional<std::uint64_t>());
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function complete_training
|
||||
/// \brief Completes the training process. Derived classes should use this method to produce any final state
|
||||
/// that is used in calls to `transform` or to add `Annotations` for a column. This method should not be
|
||||
/// invoked on an object that has already been completed.
|
||||
///
|
||||
FitEstimatorImpl & complete_training(void);
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
bool _is_training_complete;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function fit_impl
|
||||
/// \brief `fit` performs common object state and parameter validation before invoking
|
||||
/// this abstract method.
|
||||
///
|
||||
virtual FitResult fit_impl(InputType const *pBuffer, size_t cBuffer, boost::optional<std::uint64_t> const &optionalNumTrailingNulls) = 0;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function complete_training_impl
|
||||
/// \brief `complete_training` performs common object state validation before invoking this
|
||||
/// abstract method.
|
||||
///
|
||||
virtual void complete_training_impl(void) = 0;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class AnnotationEstimator
|
||||
/// \brief An `Estimator` that generates `Annotations` when completed. It is no longer
|
||||
/// invoked once its training is complete (i.e. it doesn't have a `transform`
|
||||
/// method).
|
||||
///
|
||||
template <typename InputT>
|
||||
class AnnotationEstimator : public FitEstimatorImpl<InputT> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using InputType = InputT;
|
||||
using ThisType = AnnotationEstimator<InputType>;
|
||||
using BaseType = FitEstimatorImpl<InputT>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType::BaseType;
|
||||
~AnnotationEstimator(void) override = default;
|
||||
|
||||
AnnotationEstimator(AnnotationEstimator const &) = delete;
|
||||
AnnotationEstimator & operator =(AnnotationEstimator const &) = delete;
|
||||
|
||||
AnnotationEstimator(AnnotationEstimator &&) = default;
|
||||
AnnotationEstimator & operator =(AnnotationEstimator &&) = delete;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class TransformerEstimator
|
||||
/// \brief An `Estimator` that performs some activity during calls
|
||||
/// to `transform`. Derived classes may or may not generate
|
||||
/// training state (in the form of `Annotations`) or transforming
|
||||
/// state.
|
||||
///
|
||||
template <typename InputT, typename TransformedT>
|
||||
class TransformerEstimator : public FitEstimatorImpl<InputT> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using InputType = InputT;
|
||||
using TransformedType = TransformedT;
|
||||
using ThisType = TransformerEstimator<InputType, TransformedType>;
|
||||
using BaseType = FitEstimatorImpl<InputT>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class Transformer
|
||||
/// \brief Object that uses state to produce a result during
|
||||
/// inferencing activities.
|
||||
///
|
||||
class Transformer {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Types
|
||||
using InputType = InputType;
|
||||
using TransformedType = TransformedType;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
|
||||
// TODO: Add a method that can be used to create a derived transformer from
|
||||
// persistent state.
|
||||
//
|
||||
// template <typename DerivedTransformerT>
|
||||
// static std::shared_ptr<DerivedTransformerT> Create(??? persisted_state);
|
||||
|
||||
Transformer(void) = default;
|
||||
virtual ~Transformer(void) = default;
|
||||
|
||||
Transformer(Transformer const &) = delete;
|
||||
Transformer & operator =(Transformer const &) = delete;
|
||||
|
||||
Transformer(Transformer &&) = default;
|
||||
Transformer & operator =(Transformer &&) = delete;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function execute
|
||||
/// \brief Produces a result for a given input.
|
||||
///
|
||||
virtual TransformedType execute(InputType input) = 0;
|
||||
|
||||
// TODO: Add a method that can be used to save state
|
||||
//
|
||||
// virtual ??? PersistState(???) const = 0;
|
||||
};
|
||||
|
||||
using TransformerPtr = std::shared_ptr<Transformer>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType::BaseType;
|
||||
~TransformerEstimator(void) override = default;
|
||||
|
||||
TransformerEstimator(TransformerEstimator const &) = delete;
|
||||
TransformerEstimator & operator =(TransformerEstimator const &) = delete;
|
||||
|
||||
TransformerEstimator(TransformerEstimator &&) = default;
|
||||
TransformerEstimator & operator =(TransformerEstimator &&) = delete;
|
||||
|
||||
ThisType & complete_training(void);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function has_created_transformer
|
||||
/// \brief Returns true if this object has been used to create
|
||||
/// a `Transformer`. No methods should be called on the object
|
||||
/// once it has been used to create a transformer.
|
||||
///
|
||||
bool has_created_transformer(void) const;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function create_transformer
|
||||
/// \brief Creates a `Transformer` using the trained state of the
|
||||
/// object. No methods should be called on the object once
|
||||
/// it has been used to create a transformer.
|
||||
///
|
||||
TransformerPtr create_transformer(void);
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Data
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
bool _created_transformer = false;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \function create_transformer_impl
|
||||
/// \brief `create_transformer` performs common object state validation before
|
||||
/// calling this method.
|
||||
///
|
||||
virtual TransformerPtr create_transformer_impl(void) = 0;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
|
@ -123,12 +455,30 @@ typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::ap
|
|||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Transformer
|
||||
// | Annotation
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename ReturnT, typename ArgT>
|
||||
template <typename ArchiveT>
|
||||
void Transformer<ReturnT, ArgT>::serialize(ArchiveT & /*ar*/, unsigned int const /*version*/) {
|
||||
inline Annotation::Annotation(EstimatorUniqueId creator_id) :
|
||||
CreatorId(
|
||||
[&creator_id]() {
|
||||
if(creator_id == nullptr)
|
||||
throw std::runtime_error("Invalid id");
|
||||
|
||||
return creator_id;
|
||||
}()
|
||||
) {
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
inline AnnotationMapsPtr CreateTestAnnotationMapsPtr(size_t num_cols) {
|
||||
AnnotationMaps maps;
|
||||
|
||||
if(num_cols)
|
||||
maps.resize(num_cols);
|
||||
|
||||
return std::make_shared<decltype(maps)>(std::move(maps));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
|
@ -136,40 +486,183 @@ void Transformer<ReturnT, ArgT>::serialize(ArchiveT & /*ar*/, unsigned int const
|
|||
// | Estimator
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename ReturnT, typename ArgT>
|
||||
Estimator<ReturnT, ArgT> & Estimator<ReturnT, ArgT>::fit(apache_arrow const &data) {
|
||||
if(_committed)
|
||||
throw std::runtime_error("This instance has already been committed");
|
||||
|
||||
return fit_impl(data);
|
||||
inline AnnotationMaps const & Estimator::get_column_annotations(void) const {
|
||||
return *_all_column_annotations;
|
||||
}
|
||||
|
||||
template <typename ReturnT, typename ArgT>
|
||||
typename Estimator<ReturnT, ArgT>::TransformerUniquePtr Estimator<ReturnT, ArgT>::commit(void) {
|
||||
if(_committed)
|
||||
throw std::runtime_error("This instance has already been committed");
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
inline Estimator::Estimator(std::string name, AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
Name(
|
||||
std::move(
|
||||
[&name](void) -> std::string & {
|
||||
if(name.empty())
|
||||
throw std::runtime_error("Invalid name");
|
||||
|
||||
TransformerUniquePtr result(commit_impl());
|
||||
return name;
|
||||
}()
|
||||
)
|
||||
),
|
||||
_all_column_annotations(
|
||||
std::move(
|
||||
[&pAllColumnAnnotations](void) -> AnnotationMapsPtr & {
|
||||
if(!pAllColumnAnnotations || pAllColumnAnnotations->empty())
|
||||
throw std::runtime_error("Empty annotations");
|
||||
|
||||
return pAllColumnAnnotations;
|
||||
}()
|
||||
)
|
||||
) {
|
||||
}
|
||||
|
||||
inline void Estimator::add_annotation(AnnotationPtr pAnnotation, size_t col_index) const {
|
||||
if(!pAnnotation)
|
||||
throw std::runtime_error("Invalid annotation");
|
||||
|
||||
AnnotationMaps & all_annotations(*_all_column_annotations);
|
||||
|
||||
if(col_index >= all_annotations.size())
|
||||
throw std::runtime_error("Invalid annotation index");
|
||||
|
||||
AnnotationPtrs & annotations(
|
||||
[&all_annotations, &col_index, this](void) -> AnnotationPtrs & {
|
||||
AnnotationMap & column_annotations(all_annotations[col_index]);
|
||||
|
||||
// TODO: Acquire read map lock
|
||||
AnnotationMap::iterator const iter(column_annotations.find(Name));
|
||||
|
||||
if(iter != column_annotations.end())
|
||||
return iter->second;
|
||||
|
||||
// TODO: Promote read lock to write lock
|
||||
std::pair<AnnotationMap::iterator, bool> const result(column_annotations.emplace(std::make_pair(Name, AnnotationPtrs())));
|
||||
|
||||
if(result.first == column_annotations.end() || result.second == false)
|
||||
throw std::runtime_error("Invalid insertion");
|
||||
|
||||
return result.first->second;
|
||||
}()
|
||||
);
|
||||
|
||||
// TODO: Acquire vector read lock
|
||||
annotations.emplace_back(std::move(pAnnotation));
|
||||
}
|
||||
|
||||
template <typename DerivedAnnotationT>
|
||||
boost::optional<DerivedAnnotationT &> Estimator::get_annotation_impl(size_t col_index) const {
|
||||
AnnotationMaps const & all_annotations(*_all_column_annotations);
|
||||
|
||||
if(col_index >= all_annotations.size())
|
||||
throw std::runtime_error("Invalid annotation index");
|
||||
|
||||
AnnotationMap const & column_annotations(all_annotations[col_index]);
|
||||
|
||||
// TODO: Acquire read map lock
|
||||
AnnotationMap::const_iterator const column_annotations_iter(column_annotations.find(Name));
|
||||
|
||||
if(column_annotations_iter != column_annotations.end()) {
|
||||
AnnotationPtrs const & annotations(column_annotations_iter->second);
|
||||
|
||||
// TODO: Acquire vector read lock
|
||||
for(auto const & annotation : annotations) {
|
||||
if(annotation->CreatorId == this) {
|
||||
return static_cast<DerivedAnnotationT &>(*annotation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return boost::optional<DerivedAnnotationT &>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | FitEstimatorImpl
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename InputT>
|
||||
FitEstimatorImpl<InputT>::FitEstimatorImpl(std::string name, AnnotationMapsPtr pAllColumnAnnotations, bool is_training_complete /*=false*/) :
|
||||
Estimator(std::move(name), std::move(pAllColumnAnnotations)),
|
||||
_is_training_complete(std::move(is_training_complete)) {
|
||||
}
|
||||
|
||||
template <typename InputT>
|
||||
bool FitEstimatorImpl<InputT>::is_training_complete(void) const {
|
||||
return _is_training_complete;
|
||||
}
|
||||
|
||||
template <typename InputT>
|
||||
typename FitEstimatorImpl<InputT>::FitResult FitEstimatorImpl<InputT>::fit(InputType value) {
|
||||
return fit(&value, 1);
|
||||
}
|
||||
|
||||
template <typename InputT>
|
||||
typename FitEstimatorImpl<InputT>::FitResult FitEstimatorImpl<InputT>::fit(InputType const *pInputBuffer, size_t cInputBuffer, boost::optional<std::uint64_t> const &optionalNumTrailingNulls) {
|
||||
if(_is_training_complete)
|
||||
throw std::runtime_error("`fit` should not be invoked on an estimator that is already complete");
|
||||
|
||||
if(pInputBuffer && cInputBuffer == 0)
|
||||
throw std::runtime_error("Invalid buffer");
|
||||
|
||||
if(pInputBuffer == nullptr && cInputBuffer != 0)
|
||||
throw std::runtime_error("Invalid buffer");
|
||||
|
||||
if(pInputBuffer == nullptr && cInputBuffer == 0 && !optionalNumTrailingNulls)
|
||||
throw std::runtime_error("Invalid invocation");
|
||||
|
||||
if(optionalNumTrailingNulls && *optionalNumTrailingNulls == 0)
|
||||
throw std::runtime_error("Invalid number of nulls");
|
||||
|
||||
FitResult result(fit_impl(pInputBuffer, cInputBuffer, optionalNumTrailingNulls));
|
||||
|
||||
if(result == FitResult::Complete)
|
||||
complete_training();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename InputT>
|
||||
FitEstimatorImpl<InputT> & FitEstimatorImpl<InputT>::complete_training(void) {
|
||||
if(_is_training_complete)
|
||||
throw std::runtime_error("`complete_training` should not be invoked on an estimator that is already complete");
|
||||
|
||||
complete_training_impl();
|
||||
_is_training_complete = true;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | TransformerEstimator
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename InputT, typename TransformedT>
|
||||
TransformerEstimator<InputT, TransformedT> & TransformerEstimator<InputT, TransformedT>::complete_training(void) {
|
||||
return static_cast<ThisType &>(BaseType::complete_training());
|
||||
}
|
||||
|
||||
template <typename InputT, typename TransformedT>
|
||||
bool TransformerEstimator<InputT, TransformedT>::has_created_transformer(void) const {
|
||||
return _created_transformer;
|
||||
}
|
||||
|
||||
template <typename InputT, typename TransformedT>
|
||||
typename TransformerEstimator<InputT, TransformedT>::TransformerPtr TransformerEstimator<InputT, TransformedT>::create_transformer(void) {
|
||||
if(!BaseType::is_training_complete())
|
||||
throw std::runtime_error("`create_transformer` should not be invoked on an estimator that is not yet complete");
|
||||
|
||||
if(_created_transformer)
|
||||
throw std::runtime_error("`create_transformer` should not be invoked on an estimator that has been used to create a `Transformer`");
|
||||
|
||||
TransformerPtr result(create_transformer_impl());
|
||||
|
||||
if(!result)
|
||||
throw std::runtime_error("Invalid result");
|
||||
|
||||
_committed = true;
|
||||
_created_transformer = true;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename ReturnT, typename ArgT>
|
||||
template <typename ArchiveT>
|
||||
void Estimator<ReturnT, ArgT>::serialize(ArchiveT & /*ar*/, unsigned int const /*version*/) {
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename EstimatorT, typename... EstimatorConstructorArgsT>
|
||||
typename EstimatorT::TransformerUniquePtr fit_and_commit(typename EstimatorT::apache_arrow const &data, EstimatorConstructorArgsT &&...args) {
|
||||
return EstimatorT(std::forward<EstimatorConstructorArgsT>(args)...).fit(data).commit();
|
||||
}
|
||||
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
||||
|
|
|
@ -14,43 +14,58 @@ inline struct tm *gmtime_r(time_t const* const timer, struct tm* const result)
|
|||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
namespace DateTimeFeaturizer {
|
||||
|
||||
TimePoint::TimePoint(const std::chrono::system_clock::time_point& sysTime) {
|
||||
// Get to a tm to get what we need.
|
||||
// Eventually C++202x will have expanded chrono support that might
|
||||
// have what we need, but not yet!
|
||||
std::tm tmt;
|
||||
time_t tt = std::chrono::system_clock::to_time_t(sysTime);
|
||||
std::tm* res = gmtime_r(&tt, &tmt);
|
||||
if (res) {
|
||||
year = static_cast<std::int32_t>(tmt.tm_year) + 1900;
|
||||
month = static_cast<std::uint8_t>(tmt.tm_mon) + 1;
|
||||
day = static_cast<std::uint8_t>(tmt.tm_mday);
|
||||
hour = static_cast<std::uint8_t>(tmt.tm_hour);
|
||||
minute = static_cast<std::uint8_t>(tmt.tm_min);
|
||||
second = static_cast<std::uint8_t>(tmt.tm_sec);
|
||||
dayOfWeek = static_cast<std::uint8_t>(tmt.tm_wday);
|
||||
dayOfYear = static_cast<std::uint16_t>(tmt.tm_yday);
|
||||
quarterOfYear = (month + 2) / 3;
|
||||
weekOfMonth = (day - 1) / 7;
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | TimePoint
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
TimePoint::TimePoint(const std::chrono::system_clock::time_point& sysTime) {
|
||||
// Get to a tm to get what we need.
|
||||
// Eventually C++202x will have expanded chrono support that might
|
||||
// have what we need, but not yet!
|
||||
std::tm tmt;
|
||||
time_t tt = std::chrono::system_clock::to_time_t(sysTime);
|
||||
std::tm* res = gmtime_r(&tt, &tmt);
|
||||
if (res) {
|
||||
year = static_cast<std::int32_t>(tmt.tm_year) + 1900;
|
||||
month = static_cast<std::uint8_t>(tmt.tm_mon) + 1;
|
||||
day = static_cast<std::uint8_t>(tmt.tm_mday);
|
||||
hour = static_cast<std::uint8_t>(tmt.tm_hour);
|
||||
minute = static_cast<std::uint8_t>(tmt.tm_min);
|
||||
second = static_cast<std::uint8_t>(tmt.tm_sec);
|
||||
dayOfWeek = static_cast<std::uint8_t>(tmt.tm_wday);
|
||||
dayOfYear = static_cast<std::uint16_t>(tmt.tm_yday);
|
||||
quarterOfYear = (month + 2) / 3;
|
||||
weekOfMonth = (day - 1) / 7;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tt < 0) {
|
||||
throw std::invalid_argument("Dates prior to 1970 are not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tt < 0) {
|
||||
throw std::invalid_argument("Dates prior to 1970 are not supported.");
|
||||
}
|
||||
else {
|
||||
throw std::invalid_argument("Unknown error converting input date.");
|
||||
}
|
||||
else {
|
||||
throw std::invalid_argument("Unknown error converting input date.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Transformer::return_type Transformer::transform(arg_type const &arg) const /*override*/ {
|
||||
return Microsoft::Featurizer::DateTimeFeaturizer::TimePoint(arg);
|
||||
}
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | DateTimeTransformer
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
DateTimeTransformer::TransformedType DateTimeTransformer::execute(InputType input) /*override*/ {
|
||||
return TimePoint(input);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | DateTimeFeaturizer
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
DateTimeFeaturizer::DateTimeFeaturizer(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType("DateTimeFeaturizer", std::move(pAllColumnAnnotations)) {
|
||||
}
|
||||
|
||||
} // namespace DateTimeFeaturizer
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
||||
|
|
|
@ -4,85 +4,103 @@
|
|||
// ----------------------------------------------------------------------
|
||||
#pragma once
|
||||
|
||||
#include "../Featurizer.h"
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../InferenceOnlyFeaturizerImpl.h"
|
||||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \namespace DateTimeTransformer
|
||||
/// \brief A Transformer that takes a chrono::system_clock::time_point and
|
||||
/// returns a struct with all the data split out.
|
||||
/// \struct TimePoint
|
||||
/// \brief Struct to hold various components of DateTime information
|
||||
///
|
||||
namespace DateTimeFeaturizer {
|
||||
struct TimePoint {
|
||||
std::int32_t year = 0;
|
||||
std::uint8_t month = 0; /* 1-12 */
|
||||
std::uint8_t day = 0; /* 1-31 */
|
||||
std::uint8_t hour = 0; /* 0-23 */
|
||||
std::uint8_t minute = 0; /* 0-59 */
|
||||
std::uint8_t second = 0; /* 0-59 */
|
||||
std::uint8_t dayOfWeek = 0; /* 0-6 */
|
||||
std::uint16_t dayOfYear = 0; /* 0-365 */
|
||||
std::uint8_t quarterOfYear = 0; /* 1-4 */
|
||||
std::uint8_t weekOfMonth = 0; /* 0-4 */
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \struct TimePoint
|
||||
/// \brief Struct to hold various components of DateTime information
|
||||
///
|
||||
struct TimePoint {
|
||||
std::int32_t year = 0;
|
||||
std::uint8_t month = 0; /* 1-12 */
|
||||
std::uint8_t day = 0; /* 1-31 */
|
||||
std::uint8_t hour = 0; /* 0-23 */
|
||||
std::uint8_t minute = 0; /* 0-59 */
|
||||
std::uint8_t second = 0; /* 0-59 */
|
||||
std::uint8_t dayOfWeek = 0; /* 0-6 */
|
||||
std::uint16_t dayOfYear = 0; /* 0-365 */
|
||||
std::uint8_t quarterOfYear = 0; /* 1-4 */
|
||||
std::uint8_t weekOfMonth = 0; /* 0-4 */
|
||||
TimePoint(const std::chrono::system_clock::time_point& sysTime);
|
||||
|
||||
TimePoint(const std::chrono::system_clock::time_point& sysTime);
|
||||
TimePoint(TimePoint const &) = delete;
|
||||
TimePoint & operator =(TimePoint const &) = delete;
|
||||
|
||||
TimePoint(TimePoint&&) = default;
|
||||
TimePoint(const TimePoint&) = delete;
|
||||
TimePoint& operator=(const TimePoint&) = delete;
|
||||
TimePoint(TimePoint &&) = default;
|
||||
TimePoint & operator =(TimePoint &&) = delete;
|
||||
|
||||
enum {
|
||||
JANUARY = 1, FEBRUARY, MARCH, APRIL, MAY, JUNE,
|
||||
JULY, AUGUST, SEPTEMBER, OCTOBER, NOVEMBER, DECEMBER
|
||||
};
|
||||
enum {
|
||||
SUNDAY = 0, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY
|
||||
};
|
||||
enum {
|
||||
JANUARY = 1, FEBRUARY, MARCH, APRIL, MAY, JUNE,
|
||||
JULY, AUGUST, SEPTEMBER, OCTOBER, NOVEMBER, DECEMBER
|
||||
};
|
||||
|
||||
inline TimePoint SystemToDPTimePoint(const std::chrono::system_clock::time_point& sysTime) {
|
||||
return TimePoint (sysTime);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class DateTimeTransformer
|
||||
/// \brief Transformer
|
||||
///
|
||||
class Transformer : public Microsoft::Featurizer::Transformer<Microsoft::Featurizer::DateTimeFeaturizer::TimePoint, std::chrono::system_clock::time_point> {
|
||||
public:
|
||||
Transformer(void) = default;
|
||||
~Transformer(void) override = default;
|
||||
|
||||
Transformer(Transformer const &) = delete;
|
||||
Transformer & operator =(Transformer const &) = delete;
|
||||
|
||||
Transformer(Transformer &&) = default;
|
||||
Transformer & operator =(Transformer &&) = delete;
|
||||
|
||||
return_type transform(arg_type const &arg) const override;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &ar, unsigned int const version);
|
||||
enum {
|
||||
SUNDAY = 0, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY
|
||||
};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class DateTimeTransformer
|
||||
/// \brief A Transformer that takes a chrono::system_clock::time_point and
|
||||
/// returns a struct with all the data split out.
|
||||
///
|
||||
class DateTimeTransformer : public TransformerEstimator<std::chrono::system_clock::time_point, TimePoint>::Transformer {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType = TransformerEstimator<std::chrono::system_clock::time_point, TimePoint>::Transformer;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
DateTimeTransformer(void) = default;
|
||||
~DateTimeTransformer(void) override = default;
|
||||
|
||||
DateTimeTransformer(DateTimeTransformer const &) = delete;
|
||||
DateTimeTransformer & operator =(DateTimeTransformer const &) = delete;
|
||||
|
||||
DateTimeTransformer(DateTimeTransformer &&) = default;
|
||||
DateTimeTransformer & operator =(DateTimeTransformer &&) = delete;
|
||||
|
||||
TransformedType execute(InputType input) override;
|
||||
};
|
||||
|
||||
class DateTimeFeaturizer : public InferenceOnlyFeaturizerImpl<DateTimeTransformer> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType = InferenceOnlyFeaturizerImpl<DateTimeTransformer>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
DateTimeFeaturizer(AnnotationMapsPtr pAllColumnAnnotations);
|
||||
~DateTimeFeaturizer(void) override = default;
|
||||
|
||||
DateTimeFeaturizer(DateTimeFeaturizer const &) = delete;
|
||||
DateTimeFeaturizer & operator =(DateTimeFeaturizer const &) = delete;
|
||||
|
||||
DateTimeFeaturizer(DateTimeFeaturizer &&) = default;
|
||||
DateTimeFeaturizer & operator =(DateTimeFeaturizer &&) = delete;
|
||||
};
|
||||
|
||||
} // Namespace DateTimeFeaturizer
|
||||
} // Namespace Featurizer
|
||||
} // Namespace Microsoft
|
||||
|
|
|
@ -6,35 +6,63 @@
|
|||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
namespace SampleAdd {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class SampleAddTransformer
|
||||
/// \brief Adds a delta to the provided value.
|
||||
///
|
||||
class SampleAddTransformer : public SampleAddFeaturizer::BaseType::Transformer {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Data
|
||||
std::uint32_t const Delta;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
SampleAddTransformer(std::uint32_t delta) :
|
||||
Delta(delta) {
|
||||
}
|
||||
|
||||
~SampleAddTransformer(void) override = default;
|
||||
|
||||
SampleAddTransformer(SampleAddTransformer const &) = delete;
|
||||
SampleAddTransformer & operator =(SampleAddTransformer const &) = delete;
|
||||
|
||||
SampleAddTransformer(SampleAddTransformer &&) = default;
|
||||
SampleAddTransformer & operator =(SampleAddTransformer &&) = delete;
|
||||
|
||||
TransformedType execute(InputType input) override {
|
||||
return input + Delta;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Transformer
|
||||
// | SampleAddFeaturizer
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Transformer::Transformer(std::uint16_t delta) :
|
||||
_delta(delta) {
|
||||
SampleAddFeaturizer::SampleAddFeaturizer(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType("SampleAddFeaturizer", std::move(pAllColumnAnnotations)) {
|
||||
}
|
||||
|
||||
Transformer::return_type Transformer::transform(arg_type const &arg) const /*override*/ {
|
||||
return _delta + arg;
|
||||
SampleAddFeaturizer::FitResult SampleAddFeaturizer::fit_impl(InputType const *pBuffer, size_t cBuffer, boost::optional<std::uint64_t> const &) /*override*/ {
|
||||
InputType const * const pEndBuffer(pBuffer + cBuffer);
|
||||
|
||||
while(pBuffer != pEndBuffer) {
|
||||
_accumulated_delta += *pBuffer;
|
||||
++pBuffer;
|
||||
}
|
||||
|
||||
return FitResult::Continue;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Estimator
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
Estimator & Estimator::fit_impl(apache_arrow const &data) /*override*/ {
|
||||
_accumulated_delta += static_cast<std::uint16_t>(data);
|
||||
return *this;
|
||||
void SampleAddFeaturizer::complete_training_impl(void) /*override*/ {
|
||||
// Nothing to do here
|
||||
}
|
||||
|
||||
Estimator::TransformerUniquePtr Estimator::commit_impl(void) /*override*/ {
|
||||
return std::make_unique<SampleAdd::Transformer>(_accumulated_delta);
|
||||
SampleAddFeaturizer::TransformerPtr SampleAddFeaturizer::create_transformer_impl(void) /*override*/ {
|
||||
return std::make_shared<SampleAddTransformer>(_accumulated_delta);
|
||||
}
|
||||
|
||||
} // namespace SampleAdd
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
||||
|
|
|
@ -10,117 +10,51 @@ namespace Microsoft {
|
|||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \namespace SampleAdd
|
||||
/// \class SampleAddFeaturizer
|
||||
/// \brief A Transformer and Estimator that add values. This is a
|
||||
/// sample intended to demonstrate patterns within the
|
||||
/// implementation of these types.
|
||||
///
|
||||
namespace SampleAdd {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class Transformer
|
||||
/// \brief Transformer that adds an integer value to a saved delta
|
||||
/// and returns the result.
|
||||
///
|
||||
class Transformer : public Microsoft::Featurizer::Transformer<std::uint32_t, std::uint16_t> {
|
||||
class SampleAddFeaturizer : public TransformerEstimator<std::uint16_t, std::uint32_t> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using BaseType = TransformerEstimator<std::uint16_t, std::uint32_t>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
Transformer(std::uint16_t delta=0);
|
||||
~Transformer(void) override = default;
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
SampleAddFeaturizer(AnnotationMapsPtr pAllColumnAnnotations);
|
||||
~SampleAddFeaturizer(void) override = default;
|
||||
|
||||
Transformer(Transformer const &) = delete;
|
||||
Transformer & operator =(Transformer const &) = delete;
|
||||
SampleAddFeaturizer(SampleAddFeaturizer const &) = delete;
|
||||
SampleAddFeaturizer & operator =(SampleAddFeaturizer const &) = delete;
|
||||
|
||||
Transformer(Transformer &&) = default;
|
||||
Transformer & operator =(Transformer &&) = delete;
|
||||
|
||||
return_type transform(arg_type const &arg) const override;
|
||||
SampleAddFeaturizer(SampleAddFeaturizer &&) = default;
|
||||
SampleAddFeaturizer & operator =(SampleAddFeaturizer &&) = delete;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Data
|
||||
std::uint32_t const _delta;
|
||||
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &ar, unsigned int const version);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class Estimator
|
||||
/// \brief Estimator that accumulates a delta value and then
|
||||
/// creates a Transformer with than value when requested.
|
||||
///
|
||||
class Estimator : public Microsoft::Featurizer::Estimator<std::uint32_t, std::uint16_t> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
Estimator(void) = default;
|
||||
~Estimator(void) override = default;
|
||||
|
||||
Estimator(Estimator const &) = delete;
|
||||
Estimator & operator =(Estimator const &) = delete;
|
||||
|
||||
Estimator(Estimator &&) = default;
|
||||
Estimator & operator =(Estimator &&) = delete;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Data
|
||||
std::uint32_t _accumulated_delta = 0;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &ar, unsigned int const version);
|
||||
|
||||
Estimator & fit_impl(apache_arrow const &data) override;
|
||||
TransformerUniquePtr commit_impl(void) override;
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
FitResult fit_impl(InputType const *pBuffer, size_t cBuffer, boost::optional<std::uint64_t> const &optionalNumTrailingNulls) override;
|
||||
void complete_training_impl(void) override;
|
||||
TransformerPtr create_transformer_impl(void) override;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Implementation
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Transformer
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename ArchiveT>
|
||||
void Transformer::serialize(ArchiveT &ar, unsigned int const version) {
|
||||
ar & boost::serialization::base_object<Microsoft::Featurizer::Transformer>(*this);
|
||||
ar & boost::serialization::make_nvp("delta", _delta);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Estimator
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename ArchiveT>
|
||||
void Estimator::serialize(ArchiveT &ar, unsigned int const version) {
|
||||
ar & boost::serialization::base_object<Microsoft::Featurizer::Estimator>(*this);
|
||||
ar & boost::serialization::make_nvp("accumulated_delta", _accumulated_delta);
|
||||
}
|
||||
|
||||
} // namespace SampleAdd
|
||||
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
#pragma once
|
||||
|
||||
#include "../InferenceOnlyFeaturizerImpl.h"
|
||||
#include "../Traits.h"
|
||||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class StringTransformer
|
||||
/// \brief Converts input into strings.
|
||||
///
|
||||
template <typename T>
|
||||
class StringTransformer : public TransformerEstimator<T, std::string>::Transformer {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using Type = T;
|
||||
using ThisType = StringTransformer<Type>;
|
||||
using BaseType = typename TransformerEstimator<Type, std::string>::Transformer;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
StringTransformer(void) = default;
|
||||
~StringTransformer(void) override = default;
|
||||
|
||||
StringTransformer(StringTransformer const &) = delete;
|
||||
StringTransformer & operator =(StringTransformer const &) = delete;
|
||||
|
||||
StringTransformer(StringTransformer && other) = default;
|
||||
StringTransformer & operator =(StringTransformer &&) = delete;
|
||||
|
||||
typename BaseType::TransformedType execute(typename BaseType::InputType input) override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class StringFeaturizer : public InferenceOnlyFeaturizerImpl<StringTransformer<T>> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using Type = T;
|
||||
using ThisType = StringFeaturizer<Type>;
|
||||
using BaseType = InferenceOnlyFeaturizerImpl<StringTransformer<Type>>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
StringFeaturizer(AnnotationMapsPtr pAllCoumnAnnotations);
|
||||
~StringFeaturizer(void) override = default;
|
||||
|
||||
StringFeaturizer(StringFeaturizer const &) = delete;
|
||||
StringFeaturizer & operator =(StringFeaturizer const &) = delete;
|
||||
|
||||
StringFeaturizer(StringFeaturizer &&) = default;
|
||||
StringFeaturizer & operator =(StringFeaturizer &&) = delete;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Implementation
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | StringTransformer
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename T>
|
||||
typename StringTransformer<T>::BaseType::TransformedType StringTransformer<T>::execute(typename BaseType::InputType input) /*override*/ {
|
||||
return Traits::Traits<T>::ToString(input);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | StringFeaturizer
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename T>
|
||||
StringFeaturizer<T>::StringFeaturizer(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType("StringFeaturizer", std::move(pAllColumnAnnotations)) {
|
||||
}
|
||||
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
|
@ -1,58 +0,0 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
#pragma once
|
||||
|
||||
#include "../Featurizer.h"
|
||||
#include "../Traits.h"
|
||||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \namespace SampleAdd
|
||||
/// \brief A Transformer and Estimator that add values. This is a
|
||||
/// sample intended to demonstrate patterns within the
|
||||
/// implementation of these types.
|
||||
///
|
||||
namespace StringTransformer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class Transformer
|
||||
/// \brief Transformer that adds an integer value to a saved delta
|
||||
/// and returns the result.
|
||||
///
|
||||
template <typename argT>
|
||||
class Transformer : public Microsoft::Featurizer::Transformer<std::string, argT> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
Transformer() = default;
|
||||
~Transformer() override = default;
|
||||
|
||||
Transformer(Transformer const &) = delete;
|
||||
Transformer & operator =(Transformer const &) = delete;
|
||||
|
||||
Transformer(Transformer &&) = default;
|
||||
Transformer & operator =(Transformer &&) = delete;
|
||||
|
||||
std::string transform(argT const &arg) const override
|
||||
{
|
||||
return Traits::Traits<argT>::ToString(arg);
|
||||
}
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
//friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &ar, unsigned int const version);
|
||||
};
|
||||
|
||||
} // namespace StringTransformer
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
|
@ -29,7 +29,7 @@ add_library(libFeaturizers STATIC
|
|||
../SampleAdd.cpp
|
||||
../DateTimeFeaturizer.h
|
||||
../DateTimeFeaturizer.cpp
|
||||
../StringTransformer.h
|
||||
../StringFeaturizer.h
|
||||
)
|
||||
|
||||
enable_testing()
|
||||
|
@ -37,7 +37,7 @@ enable_testing()
|
|||
foreach(_test_name IN ITEMS
|
||||
SampleAdd_UnitTest
|
||||
DateTimeFeaturizer_UnitTests
|
||||
StringTransformer_UnitTest
|
||||
StringFeaturizer_UnitTest
|
||||
)
|
||||
add_executable(${_test_name} ${_test_name}.cpp)
|
||||
|
||||
|
|
|
@ -10,14 +10,19 @@
|
|||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
namespace DateTimeFeaturizer {
|
||||
|
||||
using SysClock = std::chrono::system_clock;
|
||||
|
||||
TEST_CASE("DateTimeFeaturizer") {
|
||||
CHECK(DateTimeFeaturizer(CreateTestAnnotationMapsPtr(2)).Name == "DateTimeFeaturizer");
|
||||
CHECK(DateTimeFeaturizer(CreateTestAnnotationMapsPtr(2)).is_training_complete());
|
||||
CHECK(std::dynamic_pointer_cast<DateTimeTransformer>(DateTimeFeaturizer(CreateTestAnnotationMapsPtr(2)).create_transformer()));
|
||||
}
|
||||
|
||||
TEST_CASE("Past - 1976 Nov 17, 12:27:04", "[DateTimeFeaturizer][DateTime]") {
|
||||
const time_t date = 217081624;
|
||||
SysClock::time_point stp = SysClock::from_time_t(date);
|
||||
|
||||
|
||||
// Constructor
|
||||
TimePoint tp(stp);
|
||||
CHECK(tp.year == 1976);
|
||||
|
@ -38,18 +43,18 @@ TEST_CASE("Past - 1976 Nov 17, 12:27:04", "[DateTimeFeaturizer][DateTime]") {
|
|||
CHECK(tp1.day == 17);
|
||||
|
||||
// function
|
||||
TimePoint tp2 = SystemToDPTimePoint(stp);
|
||||
TimePoint tp2 = TimePoint(stp);
|
||||
CHECK(tp2.year == 1976);
|
||||
CHECK(tp2.month == TimePoint::NOVEMBER);
|
||||
CHECK(tp2.day == 17);
|
||||
}
|
||||
|
||||
TEST_CASE("Past - 1976 Nov 17, 12:27:05", "[DateTimeFeaturizer][Transformer]") {
|
||||
TEST_CASE("Past - 1976 Nov 17, 12:27:05", "[DateTimeFeaturizer][DateTimeTransformer]") {
|
||||
const time_t date = 217081625;
|
||||
SysClock::time_point stp = SysClock::from_time_t(date);
|
||||
|
||||
Transformer dt;
|
||||
TimePoint tp = dt.transform(stp);
|
||||
|
||||
DateTimeTransformer dt;
|
||||
TimePoint tp = dt.execute(stp);
|
||||
CHECK(tp.year == 1976);
|
||||
CHECK(tp.month == TimePoint::NOVEMBER);
|
||||
CHECK(tp.day == 17);
|
||||
|
@ -63,12 +68,12 @@ TEST_CASE("Past - 1976 Nov 17, 12:27:05", "[DateTimeFeaturizer][Transformer]") {
|
|||
|
||||
}
|
||||
|
||||
TEST_CASE("Future - 2025 June 30", "[DateTimeFeaturizer][Transformer]") {
|
||||
TEST_CASE("Future - 2025 June 30", "[DateTimeFeaturizer][DateTimeTransformer]") {
|
||||
const time_t date = 1751241600;
|
||||
SysClock::time_point stp = SysClock::from_time_t(date);
|
||||
|
||||
Transformer dt;
|
||||
TimePoint tp = dt.transform(stp);
|
||||
DateTimeTransformer dt;
|
||||
TimePoint tp = dt.execute(stp);
|
||||
CHECK(tp.year == 2025);
|
||||
CHECK(tp.month == TimePoint::JUNE);
|
||||
CHECK(tp.day == 30);
|
||||
|
@ -84,12 +89,12 @@ TEST_CASE("Future - 2025 June 30", "[DateTimeFeaturizer][Transformer]") {
|
|||
#ifdef _MSC_VER
|
||||
// others define system_clock::time_point as nanoseconds (64-bit),
|
||||
// which rolls over somewhere around 2260. Still a couple hundred years!
|
||||
TEST_CASE("Far Future - 2998 March 2, 14:03:02", "[DateTimeFeaturizer][Transformer]") {
|
||||
TEST_CASE("Far Future - 2998 March 2, 14:03:02", "[DateTimeFeaturizer][DateTimeTransformer]") {
|
||||
const time_t date = 32445842582;
|
||||
SysClock::time_point stp = SysClock::from_time_t(date);
|
||||
|
||||
Transformer dt;
|
||||
TimePoint tp = dt.transform(stp);
|
||||
DateTimeTransformer dt;
|
||||
TimePoint tp = dt.execute(stp);
|
||||
CHECK(tp.year == 2998);
|
||||
CHECK(tp.month == TimePoint::MARCH);
|
||||
CHECK(tp.day == 2);
|
||||
|
@ -105,19 +110,19 @@ TEST_CASE("Far Future - 2998 March 2, 14:03:02", "[DateTimeFeaturizer][Transform
|
|||
#else
|
||||
|
||||
// msvcrt doesn't support negative time_t, so nothing before 1970
|
||||
TEST_CASE("Pre-Epoch - 1776 July 4", "[DateTimeFeaturizer][Transformer]")
|
||||
TEST_CASE("Pre-Epoch - 1776 July 4", "[DateTimeFeaturizer][DateTimeTransformer]")
|
||||
{
|
||||
const time_t date = -6106060800;
|
||||
SysClock::time_point stp = SysClock::from_time_t(date);
|
||||
|
||||
// Constructor
|
||||
Transformer dt;
|
||||
TimePoint tp = dt.transform(stp);
|
||||
DateTimeTransformer dt;
|
||||
TimePoint tp = dt.execute(stp);
|
||||
CHECK(tp.year == 1776);
|
||||
CHECK(tp.month == TimePoint::JULY);
|
||||
CHECK(tp.day == 4);
|
||||
}
|
||||
#endif /* _MSVCRT */
|
||||
} // namespace DateTimeFeaturizer
|
||||
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
||||
|
|
|
@ -8,15 +8,29 @@
|
|||
|
||||
#include "../SampleAdd.h"
|
||||
|
||||
TEST_CASE("Transformer") {
|
||||
CHECK(Microsoft::Featurizer::SampleAdd::Transformer(10).transform(20) == 30);
|
||||
CHECK(Microsoft::Featurizer::SampleAdd::Transformer(20).transform(1) == 21);
|
||||
// ----------------------------------------------------------------------
|
||||
using Microsoft::Featurizer::SampleAddFeaturizer;
|
||||
using Microsoft::Featurizer::CreateTestAnnotationMapsPtr;
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
SampleAddFeaturizer::TransformerPtr Train(std::vector<std::uint16_t> const &input) {
|
||||
SampleAddFeaturizer featurizer(CreateTestAnnotationMapsPtr(2));
|
||||
|
||||
featurizer.fit(input.data(), input.size());
|
||||
featurizer.complete_training();
|
||||
|
||||
return featurizer.create_transformer();
|
||||
}
|
||||
|
||||
TEST_CASE("Estimator") {
|
||||
CHECK(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).commit()->transform(20) == 30);
|
||||
CHECK(Microsoft::Featurizer::SampleAdd::Estimator().fit(20).commit()->transform(1) == 21);
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
CHECK(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).fit(20).commit()->transform(20) == 50);
|
||||
CHECK(Microsoft::Featurizer::SampleAdd::Estimator().fit(10).fit(20).fit(30).commit()->transform(20) == 80);
|
||||
TEST_CASE("SampleAddFeaturizer") {
|
||||
CHECK(Train({10})->execute(20) == 30);
|
||||
CHECK(Train({20})->execute(1) == 21);
|
||||
|
||||
CHECK(Train({10, 20})->execute(20) == 50);
|
||||
CHECK(Train({10, 20, 30})->execute(20) == 80);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "../StringFeaturizer.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// Tests are lighter for this featurizer because we are directly
|
||||
/// leveraging functionality in the traits class
|
||||
///
|
||||
using namespace Microsoft::Featurizer;
|
||||
|
||||
TEST_CASE("StringFeaturizer") {
|
||||
CHECK(StringFeaturizer<int>(CreateTestAnnotationMapsPtr(2)).Name == "StringFeaturizer");
|
||||
CHECK(StringFeaturizer<int>(CreateTestAnnotationMapsPtr(2)).is_training_complete());
|
||||
CHECK(std::dynamic_pointer_cast<StringTransformer<int>>(StringFeaturizer<int>(CreateTestAnnotationMapsPtr(2)).create_transformer()));
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Binary") {
|
||||
CHECK(StringTransformer<bool>().execute(false) == "False");
|
||||
CHECK(StringTransformer<bool>().execute(true) == "True");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Strings") {
|
||||
std::string const arg_s("isstring");
|
||||
|
||||
CHECK(StringTransformer<std::string>().execute(arg_s) == arg_s);
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Integers") {
|
||||
std::int8_t arg_8 = 20;
|
||||
std::int16_t arg_16 = -250;
|
||||
std::int32_t arg_32 = 480;
|
||||
std::int64_t arg_64 = -7799;
|
||||
|
||||
std::uint8_t arg_u8 = 20;
|
||||
std::uint16_t arg_u16 = 250;
|
||||
std::uint32_t arg_u32 = 480;
|
||||
std::uint64_t arg_u64 = 7799;
|
||||
|
||||
CHECK(StringTransformer<std::int8_t>().execute(arg_8) == "20");
|
||||
CHECK(StringTransformer<std::int16_t>().execute(arg_16) == "-250");
|
||||
CHECK(StringTransformer<std::int32_t>().execute(arg_32) == "480");
|
||||
CHECK(StringTransformer<std::int64_t>().execute(arg_64) == "-7799");
|
||||
|
||||
CHECK(StringTransformer<std::uint8_t>().execute(arg_u8) == "20");
|
||||
CHECK(StringTransformer<std::uint16_t>().execute(arg_u16) == "250");
|
||||
CHECK(StringTransformer<std::uint32_t>().execute(arg_u32) == "480");
|
||||
CHECK(StringTransformer<std::uint64_t>().execute(arg_u64) == "7799");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Numbers") {
|
||||
std::float_t arg_f = 123;
|
||||
std::double_t arg_d1 = 123.45;
|
||||
std::double_t arg_d2 = 135453984983490.5473;
|
||||
|
||||
CHECK(StringTransformer<std::float_t>().execute(arg_f) == "123");
|
||||
CHECK(StringTransformer<std::double_t>().execute(arg_d1) == "123.45");
|
||||
CHECK(StringTransformer<std::double_t>().execute(arg_d2) == "1.35454e+14");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Array") {
|
||||
std::array<std::double_t, 4> arr{ 1.3,2,-306.2,0.04 };
|
||||
std::string arr_s{ "[1.3,2,-306.2,0.04]" };
|
||||
CHECK(StringTransformer<std::array<std::double_t, 4>>().execute(arr) == arr_s);
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Vector") {
|
||||
std::vector<std::double_t> vect{ 1.03, -20.1, 305.8 };
|
||||
std::string vect_s{ "[1.03,-20.1,305.8]" };
|
||||
CHECK(StringTransformer<std::vector<std::double_t>>().execute(vect) == vect_s);
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Maps") {
|
||||
std::map<std::int16_t, std::double_t> m;
|
||||
m.insert(std::pair<std::int16_t, std::double_t>(5, 35.8));
|
||||
m.insert(std::pair<std::int16_t, std::double_t>(93, 0.147));
|
||||
std::string map_s{ "{5:35.8,93:0.147}" };
|
||||
CHECK(StringTransformer<std::map<std::int16_t, std::double_t>>().execute(m) == map_s);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
|
||||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch.hpp"
|
||||
#include "../StringTransformer.h"
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// Tests are lighter for this featurizer because we are directly
|
||||
/// leveraging functionality in the traits class
|
||||
///
|
||||
|
||||
TEST_CASE("Transformer_Binary") {
|
||||
bool arg_false = false;
|
||||
bool arg_true = true;
|
||||
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<bool>().transform(arg_false) == "False");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<bool>().transform(arg_true) == "True");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Strings") {
|
||||
std::string arg_s = "isstring";
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::string>().transform(arg_s) == "isstring");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Integers") {
|
||||
std::int8_t arg_8 = 20;
|
||||
std::int16_t arg_16 = -250;
|
||||
std::int32_t arg_32 = 480;
|
||||
std::int64_t arg_64 = -7799;
|
||||
|
||||
std::uint8_t arg_u8 = 20;
|
||||
std::uint16_t arg_u16 = 250;
|
||||
std::uint32_t arg_u32 = 480;
|
||||
std::uint64_t arg_u64 = 7799;
|
||||
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::int8_t>().transform(arg_8) == "20");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::int16_t>().transform(arg_16) == "-250");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::int32_t>().transform(arg_32) == "480");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::int64_t>().transform(arg_64) == "-7799");
|
||||
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::uint8_t>().transform(arg_u8) == "20");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::uint16_t>().transform(arg_u16) == "250");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::uint32_t>().transform(arg_u32) == "480");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::uint64_t>().transform(arg_u64) == "7799");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Numbers") {
|
||||
std::float_t arg_f = 123;
|
||||
std::double_t arg_d1 = 123.45;
|
||||
std::double_t arg_d2 = 135453984983490.5473;
|
||||
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::float_t>().transform(arg_f) == "123");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::double_t>().transform(arg_d1) == "123.45");
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::double_t>().transform(arg_d2) == "1.35454e+14");
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Array") {
|
||||
std::array<std::double_t, 4> arr{ 1.3,2,-306.2,0.04 };
|
||||
std::string arr_s{ "[1.3,2,-306.2,0.04]" };
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::array<std::double_t, 4>>().transform(arr) == arr_s);
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Vector") {
|
||||
std::vector<std::double_t> vect{ 1.03, -20.1, 305.8 };
|
||||
std::string vect_s{ "[1.03,-20.1,305.8]" };
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::vector<std::double_t>>().transform(vect) == vect_s);
|
||||
}
|
||||
|
||||
TEST_CASE("Transformer_Maps") {
|
||||
std::map<std::int16_t, std::double_t> m;
|
||||
m.insert(std::pair<std::int16_t, std::double_t>(5, 35.8));
|
||||
m.insert(std::pair<std::int16_t, std::double_t>(93, 0.147));
|
||||
std::string map_s{ "{5:35.8,93:0.147}" };
|
||||
CHECK(Microsoft::Featurizer::StringTransformer::Transformer<std::map<std::int16_t, std::double_t>>().transform(m) == map_s);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
#pragma once
|
||||
|
||||
#include "Featurizer.h"
|
||||
|
||||
namespace Microsoft {
|
||||
namespace Featurizer {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
/// \class InferenceOnlyFeaturizerImpl
|
||||
/// \brief Featurizer that only participates in inferencing
|
||||
/// activities - no training is required. This class implement
|
||||
/// the scaffolding necessary to produce a transformer.
|
||||
///
|
||||
template <
|
||||
typename TransformerT,
|
||||
typename InputT=typename TransformerT::InputType,
|
||||
typename TransformedT=typename TransformerT::TransformedType
|
||||
>
|
||||
class InferenceOnlyFeaturizerImpl : public TransformerEstimator<InputT, TransformedT> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Types
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
using TransformerType = TransformerT;
|
||||
using InputType = InputT;
|
||||
using TransformedType = TransformedT;
|
||||
|
||||
using ThisType = InferenceOnlyFeaturizerImpl<TransformerType, InputType, TransformedType>;
|
||||
using BaseType = TransformerEstimator<InputType, TransformedType>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Public Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
InferenceOnlyFeaturizerImpl(std::string name, AnnotationMapsPtr pAllColumnAnnotations);
|
||||
~InferenceOnlyFeaturizerImpl(void) override = default;
|
||||
|
||||
InferenceOnlyFeaturizerImpl(InferenceOnlyFeaturizerImpl const &) = delete;
|
||||
InferenceOnlyFeaturizerImpl & operator =(InferenceOnlyFeaturizerImpl const &) = delete;
|
||||
|
||||
InferenceOnlyFeaturizerImpl(InferenceOnlyFeaturizerImpl &&) = default;
|
||||
InferenceOnlyFeaturizerImpl & operator =(InferenceOnlyFeaturizerImpl &&) = delete;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Private Methods
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
typename BaseType::FitResult fit_impl(InputType const *pBuffer, size_t cBuffer, boost::optional<std::uint64_t> const &optionalNumTrailingNulls) override;
|
||||
void complete_training_impl(void) override;
|
||||
typename BaseType::TransformerPtr create_transformer_impl(void) override;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// |
|
||||
// | Implementation
|
||||
// |
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename TransformerT, typename InputT, typename TransformedT>
|
||||
InferenceOnlyFeaturizerImpl<TransformerT, InputT, TransformedT>::InferenceOnlyFeaturizerImpl(std::string name, AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType(std::move(name), std::move(pAllColumnAnnotations), true) {
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------
|
||||
template <typename TransformerT, typename InputT, typename TransformedT>
|
||||
typename InferenceOnlyFeaturizerImpl<TransformerT, InputT, TransformedT>::BaseType::FitResult InferenceOnlyFeaturizerImpl<TransformerT, InputT, TransformedT>::fit_impl(InputType const *, size_t, boost::optional<std::uint64_t> const &) /*override*/ {
|
||||
throw std::runtime_error("This should never be called");
|
||||
}
|
||||
|
||||
template <typename TransformerT, typename InputT, typename TransformedT>
|
||||
void InferenceOnlyFeaturizerImpl<TransformerT, InputT, TransformedT>::complete_training_impl(void) /*override*/ {
|
||||
throw std::runtime_error("This should never be called");
|
||||
}
|
||||
|
||||
template <typename TransformerT, typename InputT, typename TransformedT>
|
||||
typename InferenceOnlyFeaturizerImpl<TransformerT, InputT, TransformedT>::BaseType::TransformerPtr InferenceOnlyFeaturizerImpl<TransformerT, InputT, TransformedT>::create_transformer_impl(void) /*override*/ {
|
||||
return std::make_shared<TransformerT>();
|
||||
}
|
||||
|
||||
} // namespace Featurizer
|
||||
} // namespace Microsoft
|
|
@ -20,17 +20,17 @@ namespace Traits {
|
|||
/// \brief We have a range of of types we are dealing with. Many types
|
||||
/// have different ways to represent what a `NULL` value is
|
||||
/// (float has NAN for example) as well as different ways to
|
||||
/// convert the value to a string representation. By using
|
||||
/// convert the value to a string representation. By using
|
||||
/// templates combined with partial template specialization
|
||||
/// we can handle scenarios like these that vary based on the data type.
|
||||
///
|
||||
///
|
||||
/// Example: This allows us to do things like `Traits<std::int8_t>::IsNull()`
|
||||
/// and `Traits<float>::IsNull()` and let the trait itself deal with the
|
||||
/// actual implementation and allows us as developers to not worry about that.
|
||||
///
|
||||
/// This benefit is magnified because we are also using templates for our
|
||||
///
|
||||
/// This benefit is magnified because we are also using templates for our
|
||||
/// transformers. When we declare that a transformer has type T = std::int8_t,
|
||||
/// we can then also use `Traits<T>::IsNull()` and the compiler will know that
|
||||
/// we can then also use `Traits<T>::IsNull()` and the compiler will know that
|
||||
/// `T` is a `std::int8_t` and call the appropate template specialization.
|
||||
///
|
||||
template <typename T>
|
||||
|
@ -44,7 +44,7 @@ struct Traits {};
|
|||
/// defined there. If you have methods defined in that base template,
|
||||
/// it makes it very difficult to debug what is going on. By
|
||||
/// putting no implementation in the `Traits<>` template and
|
||||
/// having the real base struct be `TraitsImpl<>`, if you try and
|
||||
/// having the real base struct be `TraitsImpl<>`, if you try and
|
||||
/// specify a trait that doesn't have a specilization, the compiler
|
||||
/// can detect that and throw an error during compilation.
|
||||
///
|
||||
|
@ -64,11 +64,14 @@ struct TraitsImpl {
|
|||
template <>
|
||||
struct Traits<bool> : public TraitsImpl<bool> {
|
||||
static std::string const & ToString(bool const& value) {
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wexit-time-destructors"
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wexit-time-destructors"
|
||||
|
||||
static std::string const _TRUE_VALUE("True");
|
||||
static std::string const _FALSE_VALUE("False");
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
return value != 0 ? _TRUE_VALUE : _FALSE_VALUE;
|
||||
}
|
||||
};
|
||||
|
@ -178,7 +181,7 @@ struct Traits <std::array<T, size>> : public TraitsImpl<std::array<T, size>> {
|
|||
static std::string ToString(std::array<T, size> const& value) {
|
||||
std::ostringstream streamObj;
|
||||
streamObj << "[";
|
||||
|
||||
|
||||
for (unsigned int count = 0; count < size - 1; ++count)
|
||||
{
|
||||
streamObj << Traits<T>::ToString (value[count]) << ",";
|
||||
|
@ -194,7 +197,7 @@ struct Traits<std::vector<T, AllocatorT>> : public TraitsImpl<std::vector<T, All
|
|||
static std::string ToString(std::vector<T, AllocatorT> const& value) {
|
||||
std::ostringstream streamObj;
|
||||
streamObj << "[";
|
||||
|
||||
|
||||
for (unsigned int count = 0; count < value.size() - 1; ++count)
|
||||
{
|
||||
streamObj << Traits<T>::ToString(value.at(count)) << ",";
|
||||
|
@ -209,11 +212,11 @@ struct Traits<std::map<KeyT, T, CompareT, AllocatorT>> : public TraitsImpl<std::
|
|||
static std::string ToString(std::map<KeyT, T, CompareT, AllocatorT> const& value) {
|
||||
std::ostringstream streamObj;
|
||||
streamObj << "{";
|
||||
|
||||
|
||||
for (auto it = value.cbegin(); it != value.cend(); ++it)
|
||||
{
|
||||
streamObj << Traits<KeyT>::ToString(it->first) << ":" << Traits<T>::ToString(it->second);
|
||||
if (std::next(it) != value.cend())
|
||||
if (std::next(it) != value.cend())
|
||||
{
|
||||
streamObj << ",";
|
||||
}
|
||||
|
@ -229,7 +232,7 @@ struct Traits <boost::optional<T>> : public TraitsImpl<boost::optional<T>> {
|
|||
static std::string ToString(nullable_type const& value) {
|
||||
if (value) {
|
||||
return Traits<T>::ToString(value.get());
|
||||
}
|
||||
}
|
||||
return "NULL";
|
||||
}
|
||||
};
|
||||
|
|
|
@ -28,6 +28,7 @@ enable_testing()
|
|||
|
||||
foreach(_test_name IN ITEMS
|
||||
Featurizer_UnitTest
|
||||
InferenceOnlyFeaturizerImpl_UnitTest
|
||||
Traits_UnitTests
|
||||
)
|
||||
add_executable(${_test_name} ${_test_name}.cpp)
|
||||
|
|
|
@ -8,52 +8,50 @@
|
|||
|
||||
#include "../Featurizer.h"
|
||||
|
||||
class MyTransformer : public Microsoft::Featurizer::Transformer<bool, int> {
|
||||
// ----------------------------------------------------------------------
|
||||
using Microsoft::Featurizer::AnnotationPtr;
|
||||
using Microsoft::Featurizer::AnnotationMaps;
|
||||
using Microsoft::Featurizer::AnnotationMapsPtr;
|
||||
using Microsoft::Featurizer::CreateTestAnnotationMapsPtr;
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
class MyAnnotation : public Microsoft::Featurizer::Annotation {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Data
|
||||
int const State;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
MyTransformer(bool true_on_odd=false) :
|
||||
_true_on_odd(true_on_odd) {
|
||||
MyAnnotation(EstimatorUniqueId id, int state, bool valid_construction=true) :
|
||||
Microsoft::Featurizer::Annotation(valid_construction ? id : nullptr),
|
||||
State(std::move(state)) {
|
||||
}
|
||||
|
||||
~MyTransformer(void) override = default;
|
||||
~MyAnnotation(void) override = default;
|
||||
|
||||
MyTransformer(MyTransformer const &) = delete;
|
||||
MyTransformer & operator =(MyTransformer const &) = delete;
|
||||
MyAnnotation(MyAnnotation const &) = delete;
|
||||
MyAnnotation & operator =(MyAnnotation const &) = delete;
|
||||
|
||||
MyTransformer(MyTransformer &&) = default;
|
||||
MyTransformer & operator =(MyTransformer &&) = delete;
|
||||
|
||||
return_type transform(arg_type const &arg) const override {
|
||||
bool const is_odd(arg & 1);
|
||||
|
||||
return _true_on_odd ? is_odd : !is_odd;
|
||||
}
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Data
|
||||
bool const _true_on_odd;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &ar, unsigned int const /*version*/) {
|
||||
ar & boost::serialization::base_object<transformer_type>(*this);
|
||||
ar & boost::serialization::make_nvp("true_on_odd", const_cast<bool &>(_true_on_odd));
|
||||
}
|
||||
MyAnnotation(MyAnnotation &&) = default;
|
||||
MyAnnotation & operator =(MyAnnotation &&) = delete;
|
||||
};
|
||||
|
||||
class MyEstimator : public Microsoft::Featurizer::Estimator<bool, int> {
|
||||
TEST_CASE("Annotation") {
|
||||
Microsoft::Featurizer::Annotation::EstimatorUniqueId const id(reinterpret_cast<void *>(10));
|
||||
MyAnnotation const annotation(id, 10);
|
||||
|
||||
CHECK(annotation.CreatorId == id);
|
||||
CHECK(annotation.State == 10);
|
||||
CHECK_THROWS_WITH(MyAnnotation(id, 10, false), Catch::Matches("Invalid id"));
|
||||
}
|
||||
|
||||
class MyEstimator : public Microsoft::Featurizer::Estimator {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
MyEstimator(bool return_invalid_transformer=false) :
|
||||
_return_invalid_transformer(return_invalid_transformer) {
|
||||
MyEstimator(std::string name, AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
Microsoft::Featurizer::Estimator(std::move(name), std::move(pAllColumnAnnotations)) {
|
||||
}
|
||||
|
||||
~MyEstimator(void) override = default;
|
||||
|
@ -64,65 +62,175 @@ public:
|
|||
MyEstimator(MyEstimator &&) = default;
|
||||
MyEstimator & operator =(MyEstimator &&) = delete;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Relationships
|
||||
friend class boost::serialization::access;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Data
|
||||
bool const _return_invalid_transformer;
|
||||
bool _true_on_odd_state;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
MyEstimator & fit_impl(apache_arrow const &data) override {
|
||||
_true_on_odd_state = static_cast<bool>(data);
|
||||
return *this;
|
||||
void add_annotation(size_t col_index, int state, bool invalid_add=false) const {
|
||||
Microsoft::Featurizer::Estimator::add_annotation(invalid_add ? AnnotationPtr() : std::make_shared<MyAnnotation>(this, state), col_index);
|
||||
}
|
||||
|
||||
TransformerUniquePtr commit_impl(void) override {
|
||||
if(_return_invalid_transformer)
|
||||
return TransformerUniquePtr();
|
||||
|
||||
return std::make_unique<MyTransformer>(_true_on_odd_state);
|
||||
}
|
||||
|
||||
template <typename ArchiveT>
|
||||
void serialize(ArchiveT &ar, unsigned int const /*version*/) {
|
||||
ar & boost::serialization::base_object<estimator_type>(*this);
|
||||
ar & boost::serialization::make_nvp("return_invalid_transformer", const_cast<bool &>(_return_invalid_transformer));
|
||||
ar & boost::serialization::make_nvp("true_on_odd_state", const_cast<bool &>(_true_on_odd_state));
|
||||
boost::optional<MyAnnotation &> get_annotation(size_t col_index) const {
|
||||
return get_annotation_impl<MyAnnotation>(col_index);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("Transformer: Functionality") {
|
||||
CHECK(MyTransformer(true).transform(1) == true);
|
||||
CHECK(MyTransformer(false).transform(1) == false);
|
||||
CHECK(MyTransformer(true).transform(2) == false);
|
||||
CHECK(MyTransformer(false).transform(2) == true);
|
||||
TEST_CASE("Estimator - Invalid construction") {
|
||||
CHECK_THROWS_WITH(MyEstimator(std::string(), CreateTestAnnotationMapsPtr(1)), "Invalid name");
|
||||
CHECK_THROWS_WITH(MyEstimator("Name", AnnotationMapsPtr()), "Empty annotations");
|
||||
CHECK_THROWS_WITH(MyEstimator("Name", CreateTestAnnotationMapsPtr(0)), "Empty annotations");
|
||||
|
||||
}
|
||||
|
||||
TEST_CASE("Estimator: Functionality") {
|
||||
CHECK(MyEstimator().fit(1).commit()->transform(1) == true);
|
||||
CHECK(MyEstimator().fit(0).commit()->transform(1) == false);
|
||||
CHECK(MyEstimator().fit(1).commit()->transform(2) == false);
|
||||
CHECK(MyEstimator().fit(0).commit()->transform(2) == true);
|
||||
TEST_CASE("Estimator") {
|
||||
AnnotationMapsPtr const pAllAnnotations(CreateTestAnnotationMapsPtr(2));
|
||||
MyEstimator const estimator("MyNewEstimator", pAllAnnotations);
|
||||
|
||||
CHECK(estimator.Name == "MyNewEstimator");
|
||||
CHECK(&estimator.get_column_annotations() == pAllAnnotations.get());
|
||||
|
||||
CHECK(!estimator.get_annotation(0));
|
||||
CHECK(!estimator.get_annotation(1));
|
||||
|
||||
estimator.add_annotation(1, 100);
|
||||
|
||||
CHECK(!estimator.get_annotation(0));
|
||||
CHECK(estimator.get_annotation(1)->State == 100);
|
||||
|
||||
// Annotation-related errors
|
||||
CHECK_THROWS_WITH(estimator.add_annotation(0, 200, true), "Invalid annotation");
|
||||
CHECK_THROWS_WITH(estimator.add_annotation(99999, 200), "Invalid annotation index");
|
||||
|
||||
CHECK_THROWS_WITH(estimator.get_annotation(99999), "Invalid annotation index");
|
||||
}
|
||||
|
||||
TEST_CASE("Estimator: Errors") {
|
||||
MyEstimator e;
|
||||
class MyFitEstimator : public Microsoft::Featurizer::FitEstimatorImpl<int> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
MyFitEstimator(bool return_complete_from_fit, bool is_training_complete=false) :
|
||||
Microsoft::Featurizer::FitEstimatorImpl<int>("Name", CreateTestAnnotationMapsPtr(2), is_training_complete),
|
||||
_return_complete_from_fit(return_complete_from_fit) {
|
||||
}
|
||||
|
||||
CHECK(e.commit());
|
||||
CHECK_THROWS_WITH(e.fit(1), Catch::Contains("has already been committed"));
|
||||
CHECK_THROWS_WITH(e.commit(), Catch::Contains("has already been committed"));
|
||||
~MyFitEstimator(void) override = default;
|
||||
|
||||
CHECK_THROWS_WITH(MyEstimator(true).commit(), Catch::Matches("Invalid result"));
|
||||
MyFitEstimator(MyFitEstimator const &) = delete;
|
||||
MyFitEstimator & operator =(MyFitEstimator const &) = delete;
|
||||
|
||||
MyFitEstimator(MyFitEstimator &&) = default;
|
||||
MyFitEstimator & operator =(MyFitEstimator &&) = delete;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Data
|
||||
bool const _return_complete_from_fit;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
FitResult fit_impl(InputType const *, size_t, boost::optional<std::uint64_t> const &) override {
|
||||
return _return_complete_from_fit ? FitResult::Complete : FitResult::Continue;
|
||||
}
|
||||
|
||||
void complete_training_impl(void) override {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("FitEstimatorImpl") {
|
||||
CHECK(MyFitEstimator(true).is_training_complete() == false);
|
||||
CHECK(MyFitEstimator(true, true).is_training_complete());
|
||||
|
||||
MyFitEstimator completed(false, true);
|
||||
|
||||
CHECK(completed.is_training_complete());
|
||||
CHECK_THROWS_WITH(completed.fit(reinterpret_cast<int *>(&completed), 1), Catch::Contains("should not be invoked on an estimator that is already complete"));
|
||||
|
||||
// Note that in all these cases, the actual value set to fit doesn't matter because the object doesn't
|
||||
// do anything with the data.
|
||||
MyFitEstimator manual_complete(false);
|
||||
|
||||
// Invalid invocation
|
||||
CHECK_THROWS_WITH(manual_complete.fit(reinterpret_cast<int *>(&manual_complete), 0), "Invalid buffer");
|
||||
CHECK_THROWS_WITH(manual_complete.fit(nullptr, 10), "Invalid buffer");
|
||||
CHECK_THROWS_WITH(manual_complete.fit(nullptr, 0), "Invalid invocation");
|
||||
CHECK_THROWS_WITH(manual_complete.fit(reinterpret_cast<int *>(&manual_complete), 1, 0), "Invalid number of nulls");
|
||||
|
||||
CHECK(manual_complete.fit(reinterpret_cast<int *>(&manual_complete), 1) == MyFitEstimator::FitResult::Continue);
|
||||
CHECK(manual_complete.fit(nullptr, 0, 1) == MyFitEstimator::FitResult::Continue);
|
||||
|
||||
CHECK(manual_complete.is_training_complete() == false);
|
||||
manual_complete.complete_training();
|
||||
CHECK(manual_complete.is_training_complete());
|
||||
CHECK_THROWS_WITH(manual_complete.fit(reinterpret_cast<int *>(&manual_complete), 1), Catch::Contains("should not be invoked on an estimator that is already complete"));
|
||||
CHECK_THROWS_WITH(manual_complete.complete_training(), Catch::Contains("should not be invoked on an estimator that is already complete"));
|
||||
|
||||
MyFitEstimator auto_complete(true);
|
||||
|
||||
CHECK(auto_complete.is_training_complete() == false);
|
||||
CHECK(auto_complete.fit(reinterpret_cast<int *>(&auto_complete), 1) == MyFitEstimator::FitResult::Complete);
|
||||
CHECK(auto_complete.is_training_complete());
|
||||
CHECK_THROWS_WITH(auto_complete.fit(reinterpret_cast<int *>(&manual_complete), 1), Catch::Contains("should not be invoked on an estimator that is already complete"));
|
||||
CHECK_THROWS_WITH(auto_complete.complete_training(), Catch::Contains("should not be invoked on an estimator that is already complete"));
|
||||
}
|
||||
|
||||
TEST_CASE("fit_and_commit") {
|
||||
CHECK(Microsoft::Featurizer::fit_and_commit<MyEstimator>(1, false)->transform(1) == true);
|
||||
CHECK(Microsoft::Featurizer::fit_and_commit<MyEstimator>(0, false)->transform(1) == false);
|
||||
CHECK(Microsoft::Featurizer::fit_and_commit<MyEstimator>(1, false)->transform(2) == false);
|
||||
CHECK(Microsoft::Featurizer::fit_and_commit<MyEstimator>(0, false)->transform(2) == true);
|
||||
class MyTransformerEstimator : public Microsoft::Featurizer::TransformerEstimator<int, bool> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Types
|
||||
struct MyTransformer : public Transformer {
|
||||
MyTransformer(void) = default;
|
||||
~MyTransformer(void) override = default;
|
||||
|
||||
MyTransformer(MyTransformer const &) = delete;
|
||||
MyTransformer & operator =(MyTransformer const &) = delete;
|
||||
|
||||
MyTransformer(MyTransformer &&) = default;
|
||||
MyTransformer & operator =(MyTransformer &&) = delete;
|
||||
|
||||
bool execute(int value) override {
|
||||
return value & 1;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
MyTransformerEstimator(bool return_invalid_transformer, bool is_training_completed=true) :
|
||||
Microsoft::Featurizer::TransformerEstimator<int, bool>("Name", CreateTestAnnotationMapsPtr(2), is_training_completed),
|
||||
_return_invalid_transformer(return_invalid_transformer) {
|
||||
}
|
||||
|
||||
~MyTransformerEstimator(void) override = default;
|
||||
|
||||
MyTransformerEstimator(MyTransformerEstimator const &) = delete;
|
||||
MyTransformerEstimator & operator =(MyTransformerEstimator const &) = delete;
|
||||
|
||||
MyTransformerEstimator(MyTransformerEstimator &&) = default;
|
||||
MyTransformerEstimator & operator =(MyTransformerEstimator &&) = delete;
|
||||
|
||||
private:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Data
|
||||
bool const _return_invalid_transformer;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Private Methods
|
||||
FitResult fit_impl(InputType const *, size_t, boost::optional<std::uint64_t> const &) override {
|
||||
throw std::runtime_error("This should never be called");
|
||||
}
|
||||
|
||||
void complete_training_impl(void) override {
|
||||
throw std::runtime_error("This should never be called");
|
||||
}
|
||||
|
||||
TransformerPtr create_transformer_impl(void) override {
|
||||
return _return_invalid_transformer ? TransformerPtr() : std::make_shared<MyTransformer>();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("TransformerEstimator") {
|
||||
CHECK_THROWS_WITH(MyTransformerEstimator(false, false).create_transformer(), Catch::Contains("should not be invoked on an estimator that is not yet complete"));
|
||||
|
||||
MyTransformerEstimator estimator(false);
|
||||
|
||||
CHECK(estimator.is_training_complete());
|
||||
CHECK(estimator.create_transformer());
|
||||
CHECK_THROWS_WITH(estimator.create_transformer(), Catch::Contains("should not be invoked on an estimator that has been used to create a"));
|
||||
|
||||
CHECK_THROWS_WITH(MyTransformerEstimator(true).create_transformer(), "Invalid result");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// ----------------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "../InferenceOnlyFeaturizerImpl.h"
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
using Microsoft::Featurizer::CreateTestAnnotationMapsPtr;
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
class MyTransformer : public Microsoft::Featurizer::TransformerEstimator<int, bool>::Transformer {
|
||||
public:
|
||||
MyTransformer(void) = default;
|
||||
~MyTransformer(void) override = default;
|
||||
|
||||
MyTransformer(MyTransformer const &) = delete;
|
||||
MyTransformer & operator =(MyTransformer const &) = delete;
|
||||
|
||||
MyTransformer(MyTransformer &&) = default;
|
||||
MyTransformer & operator =(MyTransformer &&) = delete;
|
||||
|
||||
TransformedType execute(InputType input) override {
|
||||
return input & 1;
|
||||
}
|
||||
};
|
||||
|
||||
class MyFeaturizer : public Microsoft::Featurizer::InferenceOnlyFeaturizerImpl<MyTransformer> {
|
||||
public:
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Types
|
||||
using BaseType = Microsoft::Featurizer::InferenceOnlyFeaturizerImpl<MyTransformer>;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// | Public Methods
|
||||
MyFeaturizer(AnnotationMapsPtr pAllColumnAnnotations) :
|
||||
BaseType("MyFeaturizer", std::move(pAllColumnAnnotations)) {
|
||||
}
|
||||
|
||||
~MyFeaturizer(void) override = default;
|
||||
|
||||
MyFeaturizer(MyFeaturizer const &) = delete;
|
||||
MyFeaturizer & operator =(MyFeaturizer const &) = delete;
|
||||
|
||||
MyFeaturizer(MyFeaturizer &&) = default;
|
||||
MyFeaturizer & operator =(MyFeaturizer &&) = delete;
|
||||
};
|
||||
|
||||
TEST_CASE("MyFeaturizer") {
|
||||
MyFeaturizer featurizer(CreateTestAnnotationMapsPtr(2));
|
||||
|
||||
CHECK(featurizer.Name == "MyFeaturizer");
|
||||
CHECK(featurizer.is_training_complete());
|
||||
CHECK(featurizer.has_created_transformer() == false);
|
||||
|
||||
auto const pTransformer(featurizer.create_transformer());
|
||||
|
||||
CHECK(featurizer.has_created_transformer());
|
||||
|
||||
CHECK(pTransformer->execute(3));
|
||||
CHECK(pTransformer->execute(4) == false);
|
||||
}
|
Загрузка…
Ссылка в новой задаче