This commit is contained in:
Tong Wu 2022-01-27 16:27:02 +08:00
Родитель 21f8f49d19
Коммит 9d6695016f
3 изменённых файлов: 7 добавлений и 6 удалений

Просмотреть файл

@ -157,6 +157,7 @@ struct EncodeResult {
class CategoryFeatureEncoderManager {
public:
// NOLINTNEXTLINE
CategoryFeatureEncoderManager(std::vector<std::unordered_map<int, std::vector<std::unique_ptr<CategoryFeatureEncoder>>>>& train_category_feature_encoders, std::unordered_map<int, std::vector<std::unique_ptr<CategoryFeatureEncoder>>>& category_feature_encoders)
: train_category_feature_encoders_(std::move(train_category_feature_encoders)), category_feature_encoders_(std::move(category_feature_encoders)) { }

Просмотреть файл

@ -152,9 +152,9 @@ namespace LightGBM {
std::unique_ptr<CategoryFeatureEncoderManager> CategoryFeatureEncoderManager::Create(json11::Json settings, const CategoryFeatureTargetInformationCollector& informationCollector) {
const std::vector<int>& categorical_features = informationCollector.GetCategoricalFeatures();
const std::vector<std::unordered_map<int, CategoryFeatureTargetInformation>>& category_target_information = informationCollector.GetCategoryTargetInformation();
const std::unordered_map<int, CategoryFeatureTargetInformation>& global_category_target_information = informationCollector.GetGlobalCategoryTargetInformation();
int fold_count = category_target_information.size();
const std::vector<std::unordered_map<int, CategoryFeatureTargetInformation>>& category_target_information = informationCollector.GetCategoryTargetInformation();
const std::unordered_map<int, CategoryFeatureTargetInformation>& global_category_target_information = informationCollector.GetGlobalCategoryTargetInformation();
size_t fold_count = category_target_information.size();
std::vector<std::unordered_map<int, std::vector<std::unique_ptr<CategoryFeatureEncoder>>>> train_category_feature_encoders(fold_count);
std::unordered_map<int, std::vector<std::unique_ptr<CategoryFeatureEncoder>>> category_feature_encoders;

Просмотреть файл

@ -37,16 +37,16 @@ namespace LightGBM {
count_.reserve(count_.size() + target_count_record.size());
count_.insert(count_.end(), target_count_record.begin(), target_count_record.end());
const std::vector<double>& target_sum_record = collector.GetLabelSum();
const std::vector<double>& target_sum_record = collector.GetLabelSum();
label_sum_.reserve(label_sum_.size() + target_sum_record.size());
label_sum_.insert(label_sum_.end(), target_sum_record.begin(), target_sum_record.end());
const std::vector<std::unordered_map<int, CategoryFeatureTargetInformation>>& target_category_target_information = collector.GetCategoryTargetInformation();
const std::vector<std::unordered_map<int, CategoryFeatureTargetInformation>>& target_category_target_information = collector.GetCategoryTargetInformation();
for (auto& entry : target_category_target_information) {
category_target_information_.push_back(entry);
}
const std::unordered_map<int, CategoryFeatureTargetInformation>& global_category_target_information_record = collector.GetGlobalCategoryTargetInformation();
const std::unordered_map<int, CategoryFeatureTargetInformation>& global_category_target_information_record = collector.GetGlobalCategoryTargetInformation();
for (auto& feature_information : global_category_target_information_record) {
for (auto& category_count : feature_information.second.category_count) {
global_category_target_information_[feature_information.first].category_count[category_count.first] += category_count.second;