зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Speed-up lgb.importance() (#6364)
This commit is contained in:
Родитель
628e91a959
Коммит
5dfe7168d4
|
@ -90,6 +90,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
|
|||
|
||||
#' @importFrom data.table := data.table rbindlist
|
||||
.single_tree_parse <- function(lgb_tree) {
|
||||
tree_info_cols <- c(
|
||||
"split_index"
|
||||
, "split_feature"
|
||||
, "split_gain"
|
||||
, "threshold"
|
||||
, "decision_type"
|
||||
, "default_left"
|
||||
, "internal_value"
|
||||
, "internal_count"
|
||||
)
|
||||
|
||||
# Traverse tree function
|
||||
pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {
|
||||
|
@ -97,7 +107,8 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
|
|||
if (is.null(env)) {
|
||||
# Setup initial default data.table with default types
|
||||
env <- new.env(parent = emptyenv())
|
||||
env$single_tree_dt <- data.table::data.table(
|
||||
env$single_tree_dt <- list()
|
||||
env$single_tree_dt[[1L]] <- data.table::data.table(
|
||||
tree_index = integer(0L)
|
||||
, depth = integer(0L)
|
||||
, split_index = integer(0L)
|
||||
|
@ -127,19 +138,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
|
|||
if (!is.null(tree_node_leaf$split_index)) {
|
||||
|
||||
# update data.table
|
||||
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
|
||||
c(tree_node_leaf[c("split_index",
|
||||
"split_feature",
|
||||
"split_gain",
|
||||
"threshold",
|
||||
"decision_type",
|
||||
"default_left",
|
||||
"internal_value",
|
||||
"internal_count")],
|
||||
"depth" = current_depth,
|
||||
"node_parent" = parent_index)),
|
||||
use.names = TRUE,
|
||||
fill = TRUE)
|
||||
env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
|
||||
tree_node_leaf[tree_info_cols]
|
||||
, list("depth" = current_depth, "node_parent" = parent_index)
|
||||
)
|
||||
|
||||
# Traverse tree again both left and right
|
||||
pre_order_traversal(
|
||||
|
@ -154,31 +156,27 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
|
|||
, current_depth = current_depth + 1L
|
||||
, parent_index = tree_node_leaf$split_index
|
||||
)
|
||||
|
||||
} else if (!is.null(tree_node_leaf$leaf_index)) {
|
||||
|
||||
# update data.table
|
||||
env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
|
||||
c(tree_node_leaf[c("leaf_index",
|
||||
"leaf_value",
|
||||
"leaf_count")],
|
||||
"depth" = current_depth,
|
||||
"leaf_parent" = parent_index)),
|
||||
use.names = TRUE,
|
||||
fill = TRUE)
|
||||
|
||||
# update list
|
||||
env$single_tree_dt[[length(env$single_tree_dt) + 1L]] <- c(
|
||||
tree_node_leaf[c("leaf_index", "leaf_value", "leaf_count")]
|
||||
, list("depth" = current_depth, "leaf_parent" = parent_index)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
return(env$single_tree_dt)
|
||||
}
|
||||
|
||||
# Traverse structure
|
||||
single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
|
||||
# Traverse structure and rowbind everything
|
||||
single_tree_dt <- data.table::rbindlist(
|
||||
pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)
|
||||
, use.names = TRUE
|
||||
, fill = TRUE
|
||||
)
|
||||
|
||||
# Store index
|
||||
single_tree_dt[, tree_index := lgb_tree$tree_index]
|
||||
|
||||
return(single_tree_dt)
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
NROUNDS <- 10L
|
||||
MAX_DEPTH <- 3L
|
||||
N <- nrow(iris)
|
||||
X <- data.matrix(iris[2L:4L])
|
||||
FEAT <- colnames(X)
|
||||
NCLASS <- nlevels(iris[, 5L])
|
||||
|
||||
model_reg <- lgb.train(
|
||||
params = list(
|
||||
objective = "regression"
|
||||
, num_threads = .LGB_MAX_THREADS
|
||||
, max.depth = MAX_DEPTH
|
||||
)
|
||||
, data = lgb.Dataset(X, label = iris[, 1L])
|
||||
, verbose = .LGB_VERBOSITY
|
||||
, nrounds = NROUNDS
|
||||
)
|
||||
|
||||
model_binary <- lgb.train(
|
||||
params = list(
|
||||
objective = "binary"
|
||||
, num_threads = .LGB_MAX_THREADS
|
||||
, max.depth = MAX_DEPTH
|
||||
)
|
||||
, data = lgb.Dataset(X, label = iris[, 5L] == "setosa")
|
||||
, verbose = .LGB_VERBOSITY
|
||||
, nrounds = NROUNDS
|
||||
)
|
||||
|
||||
model_multiclass <- lgb.train(
|
||||
params = list(
|
||||
objective = "multiclass"
|
||||
, num_threads = .LGB_MAX_THREADS
|
||||
, max.depth = MAX_DEPTH
|
||||
, num_classes = NCLASS
|
||||
)
|
||||
, data = lgb.Dataset(X, label = as.integer(iris[, 5L]) - 1L)
|
||||
, verbose = .LGB_VERBOSITY
|
||||
, nrounds = NROUNDS
|
||||
)
|
||||
|
||||
model_rank <- lgb.train(
|
||||
params = list(
|
||||
objective = "lambdarank"
|
||||
, num_threads = .LGB_MAX_THREADS
|
||||
, max.depth = MAX_DEPTH
|
||||
, lambdarank_truncation_level = 3L
|
||||
)
|
||||
, data = lgb.Dataset(
|
||||
X
|
||||
, label = as.integer(iris[, 1L] > 5.8)
|
||||
, group = rep(10L, times = 15L)
|
||||
)
|
||||
, verbose = .LGB_VERBOSITY
|
||||
, nrounds = NROUNDS
|
||||
)
|
||||
|
||||
models <- list(
|
||||
reg = model_reg
|
||||
, bin = model_binary
|
||||
, multi = model_multiclass
|
||||
, rank = model_rank
|
||||
)
|
||||
|
||||
for (model_name in names(models)) {
|
||||
model <- models[[model_name]]
|
||||
expected_n_trees <- NROUNDS
|
||||
if (model_name == "multi") {
|
||||
expected_n_trees <- NROUNDS * NCLASS
|
||||
}
|
||||
df <- as.data.frame(lgb.model.dt.tree(model))
|
||||
df_list <- split(df, f = df$tree_index, drop = TRUE)
|
||||
|
||||
df_leaf <- df[!is.na(df$leaf_index), ]
|
||||
df_internal <- df[is.na(df$leaf_index), ]
|
||||
|
||||
test_that("lgb.model.dt.tree() returns the right number of trees", {
|
||||
expect_equal(length(unique(df$tree_index)), expected_n_trees)
|
||||
})
|
||||
|
||||
test_that("num_iteration can return less trees", {
|
||||
expect_equal(
|
||||
length(unique(lgb.model.dt.tree(model, num_iteration = 2L)$tree_index))
|
||||
, 2L * (if (model_name == "multi") NCLASS else 1L)
|
||||
)
|
||||
})
|
||||
|
||||
test_that("Tree index from lgb.model.dt.tree() is in 0:(NROUNS-1)", {
|
||||
expect_equal(unique(df$tree_index), (0L:(expected_n_trees - 1L)))
|
||||
})
|
||||
|
||||
test_that("Depth calculated from lgb.model.dt.tree() respects max.depth", {
|
||||
expect_true(max(df$depth) <= MAX_DEPTH)
|
||||
})
|
||||
|
||||
test_that("Each tree from lgb.model.dt.tree() has single root node", {
|
||||
expect_equal(
|
||||
unname(sapply(df_list, function(df) sum(df$depth == 0L)))
|
||||
, rep(1L, expected_n_trees)
|
||||
)
|
||||
})
|
||||
|
||||
test_that("Each tree from lgb.model.dt.tree() has two depth 1 nodes", {
|
||||
expect_equal(
|
||||
unname(sapply(df_list, function(df) sum(df$depth == 1L)))
|
||||
, rep(2L, expected_n_trees)
|
||||
)
|
||||
})
|
||||
|
||||
test_that("leaves from lgb.model.dt.tree() do not have split info", {
|
||||
internal_node_cols <- c(
|
||||
"split_index"
|
||||
, "split_feature"
|
||||
, "split_gain"
|
||||
, "threshold"
|
||||
, "decision_type"
|
||||
, "default_left"
|
||||
, "internal_value"
|
||||
, "internal_count"
|
||||
)
|
||||
expect_true(all(is.na(df_leaf[internal_node_cols])))
|
||||
})
|
||||
|
||||
test_that("leaves from lgb.model.dt.tree() have valid leaf info", {
|
||||
expect_true(all(df_leaf$leaf_index %in% 0L:(2.0^MAX_DEPTH - 1.0)))
|
||||
expect_true(all(is.finite(df_leaf$leaf_value)))
|
||||
expect_true(all(df_leaf$leaf_count > 0L & df_leaf$leaf_count <= N))
|
||||
})
|
||||
|
||||
test_that("non-leaves from lgb.model.dt.tree() do not have leaf info", {
|
||||
leaf_node_cols <- c(
|
||||
"leaf_index", "leaf_parent", "leaf_value", "leaf_count"
|
||||
)
|
||||
expect_true(all(is.na(df_internal[leaf_node_cols])))
|
||||
})
|
||||
|
||||
test_that("non-leaves from lgb.model.dt.tree() have valid split info", {
|
||||
expect_true(
|
||||
all(
|
||||
sapply(
|
||||
split(df_internal, df_internal$tree_index),
|
||||
function(x) all(x$split_index %in% 0L:(nrow(x) - 1L))
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
expect_true(all(df_internal$split_feature %in% FEAT))
|
||||
|
||||
num_cols <- c("split_gain", "threshold", "internal_value")
|
||||
expect_true(all(is.finite(unlist(df_internal[, num_cols]))))
|
||||
|
||||
# range of decision type?
|
||||
expect_true(all(df_internal$default_left %in% c(TRUE, FALSE)))
|
||||
|
||||
counts <- df_internal$internal_count
|
||||
expect_true(all(counts > 1L & counts <= N))
|
||||
})
|
||||
}
|
Загрузка…
Ссылка в новой задаче