support to override some parameters in Dataset (#1876)

* add warnings for override parameters of Dataset

* fix pep8

* add feature_penalty

* refactor

* add R's code

* Update basic.py

* Update basic.py

* fix parameter bug

* Update lgb.Dataset.R

* fix a bug
This commit is contained in:
Guolin Ke 2019-01-23 16:07:50 +08:00 коммит произвёл GitHub
Родитель f30809670a
Коммит b37065dbd5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 116 добавлений и 1 удалений

Просмотреть файл

@ -492,6 +492,10 @@ Dataset <- R6::R6Class(
update_params = function(params) {
# Parameter updating
if (!lgb.is.null.handle(private$handle)) {
lgb.call("LGBM_DatasetUpdateParam_R", ret = NULL, private$handle, lgb.params2str(params))
return(invisible(self))
}
private$params <- modifyList(private$params, params)
return(invisible(self))

Просмотреть файл

@ -309,6 +309,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
const void** out_ptr,
int* out_type);
/*!
* \brief Update parameters for a Dataset
* \param handle a instance of data matrix
* \param parameters parameters
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters);
/*!
* \brief get number of data.
* \param handle the handle to the dataset

Просмотреть файл

@ -575,6 +575,8 @@ public:
return bufs;
}
void ResetConfig(const char* parameters);
/*! \brief Get Number of data */
inline data_size_t num_data() const { return num_data_; }
@ -615,6 +617,11 @@ private:
std::vector<int8_t> monotone_types_;
std::vector<double> feature_penalty_;
bool is_finish_load_;
int max_bin_;
int bin_construct_sample_cnt_;
int min_data_in_bin_;
bool use_missing_;
bool zero_as_missing_;
};
} // namespace LightGBM

Просмотреть файл

@ -171,6 +171,16 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
LGBM_SE field_data,
LGBM_SE call_state);
/*!
* \brief Update parameters for a Dataset
* \param handle a instance of data matrix
* \param parameters parameters
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
LGBM_SE params,
LGBM_SE call_state);
/*!
* \brief get number of data.
* \param handle the handle to the dataset

Просмотреть файл

@ -1070,6 +1070,8 @@ class Dataset(object):
return self
def _update_params(self, params):
if self.handle is not None and params is not None:
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(params))))
if not self.params:
self.params = params
else:
@ -1080,6 +1082,8 @@ class Dataset(object):
def _reverse_update_params(self):
self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None
if self.handle is not None and self.params is not None:
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(self.params))))
return self
def set_field(self, field_name, data):

Просмотреть файл

@ -859,6 +859,13 @@ int LGBM_DatasetGetField(DatasetHandle handle,
API_END();
}
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->ResetConfig(parameters);
API_END();
}
int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out) {
API_BEGIN();

Просмотреть файл

@ -317,6 +317,58 @@ void Dataset::Construct(
feature_penalty_.clear();
}
}
max_bin_ = io_config.max_bin;
min_data_in_bin_ = io_config.min_data_in_bin;
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
use_missing_ = io_config.use_missing;
zero_as_missing_ = io_config.zero_as_missing;
}
void Dataset::ResetConfig(const char* parameters) {
auto param = Config::Str2Map(parameters);
Config io_config;
io_config.Set(param);
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
}
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
}
if (param.count("min_data_in_bin") && io_config.min_data_in_bin != min_data_in_bin_) {
Log::Warning("Cannot change min_data_in_bin after constructed Dataset handle.");
}
if (param.count("use_missing") && io_config.use_missing != use_missing_) {
Log::Warning("Cannot change use_missing after constructed Dataset handle.");
}
if (param.count("zero_as_missing") && io_config.zero_as_missing != zero_as_missing_) {
Log::Warning("Cannot change zero_as_missing after constructed Dataset handle.");
}
if (!io_config.monotone_constraints.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.monotone_constraints.size());
monotone_types_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
monotone_types_[inner_fidx] = io_config.monotone_constraints[i];
}
}
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
}
if (!io_config.feature_contri.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.feature_contri.size());
feature_penalty_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
feature_penalty_[inner_fidx] = std::max(0.0, io_config.feature_contri[i]);
}
}
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
}
}
void Dataset::FinishLoad() {
@ -571,7 +623,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
+ sizeof(double) * num_features_;
+ sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
// size of feature names
for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int);
@ -582,6 +634,11 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
writer->Write(&num_features_, sizeof(num_features_));
writer->Write(&num_total_features_, sizeof(num_total_features_));
writer->Write(&label_idx_, sizeof(label_idx_));
writer->Write(&max_bin_, sizeof(max_bin_));
writer->Write(&bin_construct_sample_cnt_, sizeof(bin_construct_sample_cnt_));
writer->Write(&min_data_in_bin_, sizeof(min_data_in_bin_));
writer->Write(&use_missing_, sizeof(use_missing_));
writer->Write(&zero_as_missing_, sizeof(zero_as_missing_));
writer->Write(used_feature_map_.data(), sizeof(int) * num_total_features_);
writer->Write(&num_groups_, sizeof(num_groups_));
writer->Write(real_feature_idx_.data(), sizeof(int) * num_features_);

Просмотреть файл

@ -316,6 +316,16 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
mem_ptr += sizeof(dataset->num_total_features_);
dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->label_idx_);
dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->max_bin_);
dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->min_data_in_bin_);
dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
mem_ptr += sizeof(dataset->use_missing_);
dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
mem_ptr += sizeof(dataset->zero_as_missing_);
const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
dataset->used_feature_map_.clear();
for (int i = 0; i < dataset->num_total_features_; ++i) {

Просмотреть файл

@ -270,6 +270,14 @@ LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
R_API_END();
}
LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
LGBM_SE params,
LGBM_SE call_state) {
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetUpdateParam(R_GET_PTR(handle), R_CHAR_PTR(params)));
R_API_END();
}
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
LGBM_SE call_state) {
int nrow;