зеркало из https://github.com/microsoft/LightGBM.git
support to override some parameters in Dataset (#1876)
* add warnings for override parameters of Dataset * fix pep8 * add feature_penalty * refactor * add R's code * Update basic.py * Update basic.py * fix parameter bug * Update lgb.Dataset.R * fix a bug
This commit is contained in:
Родитель
f30809670a
Коммит
b37065dbd5
|
@ -492,6 +492,10 @@ Dataset <- R6::R6Class(
|
|||
update_params = function(params) {
|
||||
|
||||
# Parameter updating
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
lgb.call("LGBM_DatasetUpdateParam_R", ret = NULL, private$handle, lgb.params2str(params))
|
||||
return(invisible(self))
|
||||
}
|
||||
private$params <- modifyList(private$params, params)
|
||||
return(invisible(self))
|
||||
|
||||
|
|
|
@ -309,6 +309,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
|
|||
const void** out_ptr,
|
||||
int* out_type);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Update parameters for a Dataset
|
||||
* \param handle a instance of data matrix
|
||||
* \param parameters parameters
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters);
|
||||
|
||||
/*!
|
||||
* \brief get number of data.
|
||||
* \param handle the handle to the dataset
|
||||
|
|
|
@ -575,6 +575,8 @@ public:
|
|||
return bufs;
|
||||
}
|
||||
|
||||
void ResetConfig(const char* parameters);
|
||||
|
||||
/*! \brief Get Number of data */
|
||||
inline data_size_t num_data() const { return num_data_; }
|
||||
|
||||
|
@ -615,6 +617,11 @@ private:
|
|||
std::vector<int8_t> monotone_types_;
|
||||
std::vector<double> feature_penalty_;
|
||||
bool is_finish_load_;
|
||||
int max_bin_;
|
||||
int bin_construct_sample_cnt_;
|
||||
int min_data_in_bin_;
|
||||
bool use_missing_;
|
||||
bool zero_as_missing_;
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -171,6 +171,16 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
|
|||
LGBM_SE field_data,
|
||||
LGBM_SE call_state);
|
||||
|
||||
/*!
|
||||
* \brief Update parameters for a Dataset
|
||||
* \param handle a instance of data matrix
|
||||
* \param parameters parameters
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
|
||||
LGBM_SE params,
|
||||
LGBM_SE call_state);
|
||||
|
||||
/*!
|
||||
* \brief get number of data.
|
||||
* \param handle the handle to the dataset
|
||||
|
|
|
@ -1070,6 +1070,8 @@ class Dataset(object):
|
|||
return self
|
||||
|
||||
def _update_params(self, params):
|
||||
if self.handle is not None and params is not None:
|
||||
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(params))))
|
||||
if not self.params:
|
||||
self.params = params
|
||||
else:
|
||||
|
@ -1080,6 +1082,8 @@ class Dataset(object):
|
|||
def _reverse_update_params(self):
|
||||
self.params = copy.deepcopy(self.params_back_up)
|
||||
self.params_back_up = None
|
||||
if self.handle is not None and self.params is not None:
|
||||
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(self.params))))
|
||||
return self
|
||||
|
||||
def set_field(self, field_name, data):
|
||||
|
|
|
@ -859,6 +859,13 @@ int LGBM_DatasetGetField(DatasetHandle handle,
|
|||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
|
||||
API_BEGIN();
|
||||
auto dataset = reinterpret_cast<Dataset*>(handle);
|
||||
dataset->ResetConfig(parameters);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_DatasetGetNumData(DatasetHandle handle,
|
||||
int* out) {
|
||||
API_BEGIN();
|
||||
|
|
|
@ -317,6 +317,58 @@ void Dataset::Construct(
|
|||
feature_penalty_.clear();
|
||||
}
|
||||
}
|
||||
max_bin_ = io_config.max_bin;
|
||||
min_data_in_bin_ = io_config.min_data_in_bin;
|
||||
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
|
||||
use_missing_ = io_config.use_missing;
|
||||
zero_as_missing_ = io_config.zero_as_missing;
|
||||
}
|
||||
|
||||
void Dataset::ResetConfig(const char* parameters) {
|
||||
auto param = Config::Str2Map(parameters);
|
||||
Config io_config;
|
||||
io_config.Set(param);
|
||||
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
|
||||
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
|
||||
}
|
||||
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
|
||||
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
|
||||
}
|
||||
if (param.count("min_data_in_bin") && io_config.min_data_in_bin != min_data_in_bin_) {
|
||||
Log::Warning("Cannot change min_data_in_bin after constructed Dataset handle.");
|
||||
}
|
||||
if (param.count("use_missing") && io_config.use_missing != use_missing_) {
|
||||
Log::Warning("Cannot change use_missing after constructed Dataset handle.");
|
||||
}
|
||||
if (param.count("zero_as_missing") && io_config.zero_as_missing != zero_as_missing_) {
|
||||
Log::Warning("Cannot change zero_as_missing after constructed Dataset handle.");
|
||||
}
|
||||
if (!io_config.monotone_constraints.empty()) {
|
||||
CHECK(static_cast<size_t>(num_total_features_) == io_config.monotone_constraints.size());
|
||||
monotone_types_.resize(num_features_);
|
||||
for (int i = 0; i < num_total_features_; ++i) {
|
||||
int inner_fidx = InnerFeatureIndex(i);
|
||||
if (inner_fidx >= 0) {
|
||||
monotone_types_[inner_fidx] = io_config.monotone_constraints[i];
|
||||
}
|
||||
}
|
||||
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
|
||||
monotone_types_.clear();
|
||||
}
|
||||
}
|
||||
if (!io_config.feature_contri.empty()) {
|
||||
CHECK(static_cast<size_t>(num_total_features_) == io_config.feature_contri.size());
|
||||
feature_penalty_.resize(num_features_);
|
||||
for (int i = 0; i < num_total_features_; ++i) {
|
||||
int inner_fidx = InnerFeatureIndex(i);
|
||||
if (inner_fidx >= 0) {
|
||||
feature_penalty_[inner_fidx] = std::max(0.0, io_config.feature_contri[i]);
|
||||
}
|
||||
}
|
||||
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
|
||||
feature_penalty_.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Dataset::FinishLoad() {
|
||||
|
@ -571,7 +623,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
|
|||
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
|
||||
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
|
||||
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
|
||||
+ sizeof(double) * num_features_;
|
||||
+ sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
|
||||
// size of feature names
|
||||
for (int i = 0; i < num_total_features_; ++i) {
|
||||
size_of_header += feature_names_[i].size() + sizeof(int);
|
||||
|
@ -582,6 +634,11 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
|
|||
writer->Write(&num_features_, sizeof(num_features_));
|
||||
writer->Write(&num_total_features_, sizeof(num_total_features_));
|
||||
writer->Write(&label_idx_, sizeof(label_idx_));
|
||||
writer->Write(&max_bin_, sizeof(max_bin_));
|
||||
writer->Write(&bin_construct_sample_cnt_, sizeof(bin_construct_sample_cnt_));
|
||||
writer->Write(&min_data_in_bin_, sizeof(min_data_in_bin_));
|
||||
writer->Write(&use_missing_, sizeof(use_missing_));
|
||||
writer->Write(&zero_as_missing_, sizeof(zero_as_missing_));
|
||||
writer->Write(used_feature_map_.data(), sizeof(int) * num_total_features_);
|
||||
writer->Write(&num_groups_, sizeof(num_groups_));
|
||||
writer->Write(real_feature_idx_.data(), sizeof(int) * num_features_);
|
||||
|
|
|
@ -316,6 +316,16 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
|
|||
mem_ptr += sizeof(dataset->num_total_features_);
|
||||
dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
|
||||
mem_ptr += sizeof(dataset->label_idx_);
|
||||
dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
|
||||
mem_ptr += sizeof(dataset->max_bin_);
|
||||
dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
|
||||
mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
|
||||
dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
|
||||
mem_ptr += sizeof(dataset->min_data_in_bin_);
|
||||
dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
|
||||
mem_ptr += sizeof(dataset->use_missing_);
|
||||
dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
|
||||
mem_ptr += sizeof(dataset->zero_as_missing_);
|
||||
const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
|
||||
dataset->used_feature_map_.clear();
|
||||
for (int i = 0; i < dataset->num_total_features_; ++i) {
|
||||
|
|
|
@ -270,6 +270,14 @@ LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
|
|||
R_API_END();
|
||||
}
|
||||
|
||||
LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
|
||||
LGBM_SE params,
|
||||
LGBM_SE call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetUpdateParam(R_GET_PTR(handle), R_CHAR_PTR(params)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
|
||||
LGBM_SE call_state) {
|
||||
int nrow;
|
||||
|
|
Загрузка…
Ссылка в новой задаче