From 5dfe7168d42898b66da3513eb8cab68ef2b23eeb Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Wed, 10 Apr 2024 21:31:40 +0200 Subject: [PATCH] [R-package] Speed-up lgb.importance() (#6364) --- R-package/R/lgb.model.dt.tree.R | 56 +++---- .../tests/testthat/test_lgb.model.dt.tree.R | 158 ++++++++++++++++++ 2 files changed, 185 insertions(+), 29 deletions(-) create mode 100644 R-package/tests/testthat/test_lgb.model.dt.tree.R diff --git a/R-package/R/lgb.model.dt.tree.R b/R-package/R/lgb.model.dt.tree.R index bf4562e41..be877c40d 100644 --- a/R-package/R/lgb.model.dt.tree.R +++ b/R-package/R/lgb.model.dt.tree.R @@ -90,6 +90,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { #' @importFrom data.table := data.table rbindlist .single_tree_parse <- function(lgb_tree) { + tree_info_cols <- c( + "split_index" + , "split_feature" + , "split_gain" + , "threshold" + , "decision_type" + , "default_left" + , "internal_value" + , "internal_count" + ) # Traverse tree function pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) { @@ -97,7 +107,8 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { if (is.null(env)) { # Setup initial default data.table with default types env <- new.env(parent = emptyenv()) - env$single_tree_dt <- data.table::data.table( + env$single_tree_dt <- list() + env$single_tree_dt[[1L]] <- data.table::data.table( tree_index = integer(0L) , depth = integer(0L) , split_index = integer(0L) @@ -127,19 +138,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { if (!is.null(tree_node_leaf$split_index)) { # update data.table - env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, - c(tree_node_leaf[c("split_index", - "split_feature", - "split_gain", - "threshold", - "decision_type", - "default_left", - "internal_value", - "internal_count")], - "depth" = current_depth, - "node_parent" = parent_index)), - use.names = TRUE, - fill = TRUE) + env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c( + tree_node_leaf[tree_info_cols] + , list("depth" = current_depth, "node_parent" = parent_index) + ) # Traverse tree again both left and right pre_order_traversal( @@ -154,31 +156,27 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { , current_depth = current_depth + 1L , parent_index = tree_node_leaf$split_index ) - } else if (!is.null(tree_node_leaf$leaf_index)) { - # update data.table - env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt, - c(tree_node_leaf[c("leaf_index", - "leaf_value", - "leaf_count")], - "depth" = current_depth, - "leaf_parent" = parent_index)), - use.names = TRUE, - fill = TRUE) - + # update list + env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c( + tree_node_leaf[c("leaf_index", "leaf_value", "leaf_count")] + , list("depth" = current_depth, "leaf_parent" = parent_index) + ) } - } return(env$single_tree_dt) } - # Traverse structure - single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure) + # Traverse structure and rowbind everything + single_tree_dt <- data.table::rbindlist( + pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure) + , use.names = TRUE + , fill = TRUE + ) # Store index single_tree_dt[, tree_index := lgb_tree$tree_index] return(single_tree_dt) - } diff --git a/R-package/tests/testthat/test_lgb.model.dt.tree.R b/R-package/tests/testthat/test_lgb.model.dt.tree.R new file mode 100644 index 000000000..2c26474af --- /dev/null +++ b/R-package/tests/testthat/test_lgb.model.dt.tree.R @@ -0,0 +1,158 @@ +NROUNDS <- 10L +MAX_DEPTH <- 3L +N <- nrow(iris) +X <- data.matrix(iris[2L:4L]) +FEAT <- colnames(X) +NCLASS <- nlevels(iris[, 5L]) + +model_reg <- lgb.train( + params = list( + objective = "regression" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + ) + , data = lgb.Dataset(X, label = iris[, 1L]) + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +model_binary <- lgb.train( + params = list( + objective = "binary" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + ) + , data = lgb.Dataset(X, label = iris[, 5L] == "setosa") + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +model_multiclass <- lgb.train( + params = list( + objective = "multiclass" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + , num_classes = NCLASS + ) + , data = lgb.Dataset(X, label = as.integer(iris[, 5L]) - 1L) + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +model_rank <- lgb.train( + params = list( + objective = "lambdarank" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + , lambdarank_truncation_level = 3L + ) + , data = lgb.Dataset( + X + , label = as.integer(iris[, 1L] > 5.8) + , group = rep(10L, times = 15L) + ) + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +models <- list( + reg = model_reg + , bin = model_binary + , multi = model_multiclass + , rank = model_rank +) + +for (model_name in names(models)) { + model <- models[[model_name]] + expected_n_trees <- NROUNDS + if (model_name == "multi") { + expected_n_trees <- NROUNDS * NCLASS + } + df <- as.data.frame(lgb.model.dt.tree(model)) + df_list <- split(df, f = df$tree_index, drop = TRUE) + + df_leaf <- df[!is.na(df$leaf_index), ] + df_internal <- df[is.na(df$leaf_index), ] + + test_that("lgb.model.dt.tree() returns the right number of trees", { + expect_equal(length(unique(df$tree_index)), expected_n_trees) + }) + + test_that("num_iteration can return less trees", { + expect_equal( + length(unique(lgb.model.dt.tree(model, num_iteration = 2L)$tree_index)) + , 2L * (if (model_name == "multi") NCLASS else 1L) + ) + }) + + test_that("Tree index from lgb.model.dt.tree() is in 0:(NROUNS-1)", { + expect_equal(unique(df$tree_index), (0L:(expected_n_trees - 1L))) + }) + + test_that("Depth calculated from lgb.model.dt.tree() respects max.depth", { + expect_true(max(df$depth) <= MAX_DEPTH) + }) + + test_that("Each tree from lgb.model.dt.tree() has single root node", { + expect_equal( + unname(sapply(df_list, function(df) sum(df$depth == 0L))) + , rep(1L, expected_n_trees) + ) + }) + + test_that("Each tree from lgb.model.dt.tree() has two depth 1 nodes", { + expect_equal( + unname(sapply(df_list, function(df) sum(df$depth == 1L))) + , rep(2L, expected_n_trees) + ) + }) + + test_that("leaves from lgb.model.dt.tree() do not have split info", { + internal_node_cols <- c( + "split_index" + , "split_feature" + , "split_gain" + , "threshold" + , "decision_type" + , "default_left" + , "internal_value" + , "internal_count" + ) + expect_true(all(is.na(df_leaf[internal_node_cols]))) + }) + + test_that("leaves from lgb.model.dt.tree() have valid leaf info", { + expect_true(all(df_leaf$leaf_index %in% 0L:(2.0^MAX_DEPTH - 1.0))) + expect_true(all(is.finite(df_leaf$leaf_value))) + expect_true(all(df_leaf$leaf_count > 0L & df_leaf$leaf_count <= N)) + }) + + test_that("non-leaves from lgb.model.dt.tree() do not have leaf info", { + leaf_node_cols <- c( + "leaf_index", "leaf_parent", "leaf_value", "leaf_count" + ) + expect_true(all(is.na(df_internal[leaf_node_cols]))) + }) + + test_that("non-leaves from lgb.model.dt.tree() have valid split info", { + expect_true( + all( + sapply( + split(df_internal, df_internal$tree_index), + function(x) all(x$split_index %in% 0L:(nrow(x) - 1L)) + ) + ) + ) + + expect_true(all(df_internal$split_feature %in% FEAT)) + + num_cols <- c("split_gain", "threshold", "internal_value") + expect_true(all(is.finite(unlist(df_internal[, num_cols])))) + + # range of decision type? + expect_true(all(df_internal$default_left %in% c(TRUE, FALSE))) + + counts <- df_internal$internal_count + expect_true(all(counts > 1L & counts <= N)) + }) +}