1416 строки
60 KiB
R
1416 строки
60 KiB
R
# agnostic to objective function or data splits
|
|
|
|
|
|
# grid_partition -----------------
|
|
|
|
#' Grid Partition
|
|
#'
|
|
#' A \code{\link{grid_partition}} is defines a grid over a feature-space. It can be built by composing \code{\link{partition_split}}s.
|
|
#'
|
|
#' The partition is typically built by a search algorithm as \code{\link{fit_partition}}.
|
|
#' @name GridPartition
|
|
NULL
|
|
#> NULL
|
|
|
|
#' Create a null \code{grid_partition}
|
|
#'
|
|
#' Create a empty partition. Splits can be added using \code{\link{add_partition_split}}.
|
|
#' Information about a split can be retrieved using \code{\link{num_cells}}, \code{\link{get_desc_df}} and \code{\link{print}}
|
|
#' With data, one can determine the cell for each observation using \code{\link{predict}}
|
|
#'
|
|
#' @param X_range Such as from \code{\link{get_X_range}}
|
|
#' @param varnames Names of the X-variables
|
|
#'
|
|
#' @return Grid Partition
|
|
#' @export
|
|
grid_partition <- function(X_range, varnames=NULL) {
|
|
K = length(X_range)
|
|
s_by_dim = vector("list", length=K) #splits_by_dim(s_seq) #stores Xk_val's
|
|
dim_cat = c()
|
|
for (k in 1:K) {
|
|
if(mode(X_range[[k]])=="character") {
|
|
dim_cat = c(dim_cat, k)
|
|
s_by_dim[[k]] = list()
|
|
}
|
|
else {
|
|
s_by_dim[[k]] = vector("numeric")
|
|
}
|
|
}
|
|
nsplits_by_dim = rep(0, K)
|
|
|
|
return(structure(list(s_by_dim = s_by_dim, nsplits_by_dim = nsplits_by_dim, varnames=varnames, dim_cat=dim_cat,
|
|
X_range=X_range), class = c("grid_partition")))
|
|
}
|
|
|
|
#' Is grid_partition
|
|
#'
|
|
#' Test whether an object is an \code{grid_function}
|
|
#'
|
|
#' @param x an R object
|
|
#'
|
|
#' @return True if x is a grid_partition
|
|
#' @export
|
|
#' @describeIn grid_partition is grid_partition
|
|
is_grid_partition <- function(x) {
|
|
inherits(x, "grid_partition")
|
|
}
|
|
|
|
|
|
#' Get X_range
|
|
#'
|
|
#' Gets the "range" of each variable in X. For numeric variables this is (min, max).
|
|
#' For factors this means vector of levels.
|
|
#'
|
|
#' @param X data
|
|
#'
|
|
#' @return list of length K with each element being the "range" along that dimension
|
|
#' @export
|
|
get_X_range <- function(X) {
|
|
if(is_sep_sample(X))
|
|
X = do.call("rbind", X)
|
|
if(is.matrix(X)) {
|
|
are_equal(mode(X), "numeric")
|
|
}
|
|
else {
|
|
assert_that(is.data.frame(X), msg="X is not a matrix or data.frame")
|
|
if(inherits(X, "tbl")) X = as.data.frame(X) #tibble's return tibble (rather than vector) for X[,k], making is.factor(X[,k]) and others fail. Could switch to doing X[[k]] for df-like objects
|
|
for(k in seq_len(ncol(X))) are_equal(mode(X[[k]]), "numeric")
|
|
}
|
|
assert_that(ncol(X)>=1, msg="X has no columns")
|
|
|
|
X_range = list()
|
|
K = ncol(X)
|
|
for(k in 1:K) {
|
|
X_k = X[, k]
|
|
X_range[[k]] = if(is.factor(X_k)) levels(X_k) else range(X_k) #c(min, max)
|
|
}
|
|
return(X_range)
|
|
}
|
|
|
|
|
|
#' Get factor describing cell number fo each observation
|
|
#'
|
|
#' Note that currently if X has values more extreme (e.g., for numeric or factor levels ) than was used to generate the partition
|
|
#' then we will return NA unless you provide and updated X_range.
|
|
#'
|
|
#' @param object partition
|
|
#' @param X X data or list of X
|
|
#' @param X_range (Optional) overrides the partition$X_range
|
|
#' @param ... Additional arguments. Unused.
|
|
#'
|
|
#' @return Factor
|
|
#' @export
|
|
predict.grid_partition <- function(object, X, X_range=NULL, ...) {
|
|
facts = get_factors_from_partition(object, X, X_range=X_range)
|
|
return(interaction_m(facts, is_sep_sample(X)))
|
|
}
|
|
|
|
|
|
#' @describeIn num_cells grid_partition
|
|
#' @export
|
|
num_cells.grid_partition <- function(obj) {
|
|
return(prod(obj$nsplits_by_dim+1))
|
|
}
|
|
|
|
#' Print grid_partition
|
|
#'
|
|
#' Prints a data.frame with options
|
|
#'
|
|
#' @param x partition object
|
|
#' @param do_str If True, use a string like "(a, b]", otherwise have two separate columns with a and b
|
|
#' @param drop_unsplit If True, drop columns for variables overwhich the partition did not split
|
|
#' @param digits digits Option
|
|
#' @param ... Additional arguments. Passed to data.frame
|
|
#'
|
|
#' @return string (and displayed)
|
|
#' @export
|
|
print.grid_partition <- function(x, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, ...) {
|
|
#To check: digits
|
|
assert_that(is.flag(do_str), is.flag(drop_unsplit), msg="One of do_str or drop_unsplit are not flags")
|
|
return(print(get_desc_df(x, do_str=do_str, drop_unsplit=drop_unsplit, digits=digits),
|
|
digits=digits, ...))
|
|
}
|
|
|
|
|
|
#' Get descriptive data.frame
|
|
#'
|
|
#' Get information for each cell
|
|
#'
|
|
#' @inheritParams get_desc_df
|
|
#'
|
|
#'
|
|
#' @return data.frame with columns: partitioning columns
|
|
#' @export
|
|
get_desc_df.grid_partition <- function(obj, cont_bounds_inf=TRUE, do_str=FALSE, drop_unsplit=FALSE,
|
|
digits=NULL, unsplit_cat_star=TRUE, ...) {
|
|
#To check: digits
|
|
assert_that(is.flag(cont_bounds_inf), is.flag(do_str), is.flag(drop_unsplit), is.flag(unsplit_cat_star), msg="One (cont_bounds_inf, do_str, drop_unsplit, unsplit_cat_star)of are not flags.")
|
|
# A split at x_k means that we split to those <= and >
|
|
|
|
n_segs = obj$nsplits_by_dim+1
|
|
n_cells = prod(n_segs)
|
|
|
|
if(n_cells==1 & drop_unsplit) return(as.data.frame(matrix(NA, nrow=1, ncol=0)))
|
|
|
|
#Old code
|
|
#library(tidyverse)
|
|
#desc_df = data.frame(labels=levels(grid_fit$cell_stats$cell_factor),
|
|
# stringsAsFactors = FALSE) %>% separate(labels, names(X), "(?<=]).(?=[(])", PERL=TRUE)
|
|
|
|
K = length(obj$nsplits_by_dim)
|
|
X_range = obj$X_range
|
|
if(cont_bounds_inf) {
|
|
for(k in 1:K) {
|
|
if(!k %in% obj$dim_cat) X_range[[k]] = c(-Inf, Inf)
|
|
}
|
|
}
|
|
colnames=obj$varnames
|
|
if(is.null(colnames)) colnames = paste("X", 1:K, sep="")
|
|
|
|
list_of_windows = list()
|
|
for(k in 1:K) {
|
|
list_of_windows[[k]] = if(k %in% obj$dim_cat) get_windows_cat(obj$s_by_dim[[k]], X_range[[k]]) else get_window_cont(obj$s_by_dim[[k]], X_range[[k]])
|
|
}
|
|
|
|
format_cell_cat <- function(win, unsplit_cat_star, n_tot_dim, sep=", ") {
|
|
if(unsplit_cat_star && n_tot_dim==1) return("*")
|
|
return(paste(win, collapse=sep))
|
|
}
|
|
format_cell_cont <- function(win) {
|
|
if(is.infinite(win[1]) && is.infinite(win[2])) return("*")
|
|
if(is.infinite(win[1])) return(paste0("<=", format(win[2], digits=digits)))
|
|
if(is.infinite(win[2])) return(paste0(">", format(win[1], digits=digits)))
|
|
return(paste0("(", format(win[1], digits=digits), ", ", format(win[2], digits=digits), "]"))
|
|
}
|
|
|
|
raw_data = data.frame(row.names=1:n_cells)
|
|
str_data = data.frame(row.names=1:n_cells)
|
|
for(k in 1:K) {
|
|
raw_data_k = list()
|
|
str_data_k = c()
|
|
for(cell_i in 1:n_cells) {
|
|
segment_indexes = segment_indexes_from_cell_i(cell_i, n_segs)
|
|
win = list_of_windows[[k]][[segment_indexes[k]]]
|
|
raw_data_k[[cell_i]] = win
|
|
str_data_k[cell_i] = if(k %in% obj$dim_cat) format_cell_cat(win, unsplit_cat_star, length(list_of_windows[[k]])) else format_cell_cont(win)
|
|
}
|
|
raw_data[[colnames[k]]] = cbind(raw_data_k) #make a list-column: https://stackoverflow.com/a/51308306
|
|
str_data[[colnames[k]]] = factor(str_data_k, levels=unique(str_data_k)) #will be in low-high order
|
|
}
|
|
desc_df = if(do_str) str_data else raw_data
|
|
if(drop_unsplit) desc_df = desc_df[n_segs>1]
|
|
|
|
|
|
return(desc_df)
|
|
}
|
|
|
|
#' Adds partition_split to grid_partition
|
|
#'
|
|
#' Update the partition with an additional split.
|
|
#'
|
|
#' @param obj Grid Partition object
|
|
#' @param s Partition Split object
|
|
#'
|
|
#' @return updated Grid Partition
|
|
#' @export
|
|
add_partition_split <- function(obj, s) {
|
|
k = s[[1]]
|
|
X_k_cut = s[[2]]
|
|
|
|
if(k %in% obj$dim_cat) obj$s_by_dim[[k]][[obj$nsplits_by_dim[k]+1]] = X_k_cut
|
|
else obj$s_by_dim[[k]] = sort(c(X_k_cut, obj$s_by_dim[[k]]))
|
|
obj$nsplits_by_dim[k] = obj$nsplits_by_dim[k]+1
|
|
|
|
return(obj)
|
|
}
|
|
|
|
get_factors_from_splits_dim <- function(X_k, X_k_range, s_by_dim_k) {
|
|
if(mode(X_k_range)=="character") {
|
|
windows = get_windows_cat(s_by_dim_k, X_k_range)
|
|
fac = X_k
|
|
new_name_map = levels(fac)
|
|
new_names = c()
|
|
for(window in windows) {
|
|
new_name = if(length(window)>1) paste0("{", paste(window, collapse=","), "}") else window[1]
|
|
new_name_map[levels(fac) %in% window] = new_name
|
|
new_names = c(new_names, new_name)
|
|
}
|
|
levels(fac) <- new_name_map
|
|
fac = factor(fac, new_names)
|
|
}
|
|
else {
|
|
bottom_break = X_k_range[1]
|
|
top_break = X_k_range[2]
|
|
#if(nsplits_by_dim_k>0) {
|
|
#bottom_split = s_by_dim_k[1]
|
|
#if(bottom_split==bottom_break)
|
|
bottom_break = bottom_break-1 #not needed
|
|
#}
|
|
top_break = top_break+1
|
|
breaks = c(bottom_break, s_by_dim_k, top_break)
|
|
fac = cut(X_k, breaks, labels=NULL, include.lower=TRUE) #right=FALSE makes [a,b) segments. labels=FALSE makes just numeric vector
|
|
}
|
|
return(fac)
|
|
}
|
|
|
|
get_factors_from_splits_dim_m <- function(X, X_k_range, s_by_dim_k, k) {
|
|
M_mult = is_sep_sample(X)
|
|
if(!M_mult)
|
|
return(get_factors_from_splits_dim(X[,k], X_k_range, s_by_dim_k))
|
|
return(lapply(X, function(X_s) get_factors_from_splits_dim(X_s[,k], X_k_range, s_by_dim_k)))
|
|
}
|
|
|
|
# for a continuous variables, splits are just values
|
|
# for a factor variable, a split is a vector of levels (strings)
|
|
|
|
dummy_X_range <- function(K) {
|
|
X_range = list()
|
|
for(k in 1:K) {
|
|
X_range[[k]] = c(-Inf, Inf)
|
|
}
|
|
return(X_range)
|
|
}
|
|
|
|
#First element is most insignificant (fastest changing), rather than lexicographic
|
|
#cell_i and return value are 1-indexed
|
|
segment_indexes_from_cell_i <- function(cell_i, n_segments) {
|
|
K = length(n_segments)
|
|
size = cumprod(n_segments)
|
|
if(cell_i > size[K])
|
|
print("Error: too big")
|
|
index = rep(0, K)
|
|
cell_i_rem = cell_i-1 #convert to 0-indexing
|
|
for(k in 1:K) {
|
|
index[k] = cell_i_rem %% n_segments[k]
|
|
cell_i_rem = cell_i_rem %/% n_segments[k]
|
|
}
|
|
index = index+1 #convert from 0-indexing
|
|
return(index)
|
|
}
|
|
|
|
partition_from_split_seq <- function(split_seq, X_range, varnames=NULL, max_include=Inf) {
|
|
part = grid_partition(X_range, varnames)
|
|
for(i in seq_len(min(length(split_seq), max_include))) part = add_partition_split(part, split_seq[[i]])
|
|
return(part)
|
|
}
|
|
|
|
get_factors_from_partition <- function(partition, X, X_range=NULL) {
|
|
X_range = if(is.null(X_range)) partition$X_range else X_range
|
|
factors_by_dim = list()
|
|
if(is_sep_sample(X)) {
|
|
K = ncol(X[[1]])
|
|
for(m in 1:length(X)) {
|
|
factors_by_dim_m = list()
|
|
for(k in 1:K) {
|
|
factors_by_dim_m[[k]] = get_factors_from_splits_dim(X[[m]][, k], X_range[[k]], partition$s_by_dim[[k]])
|
|
}
|
|
factors_by_dim[[m]] = factors_by_dim_m
|
|
}
|
|
}
|
|
else {
|
|
K = ncol(X)
|
|
for(k in 1:K) {
|
|
factors_by_dim[[k]] = get_factors_from_splits_dim(X[, k], X_range[[k]], partition$s_by_dim[[k]])
|
|
}
|
|
}
|
|
return(factors_by_dim)
|
|
}
|
|
|
|
# partition_split ---------------------
|
|
|
|
#' Create partition_split
|
|
#'
|
|
#' Describes a single partition split. Used with \code{\link{add_partition_split}}.
|
|
#'
|
|
#' @param k dimension
|
|
#' @param X_k_cut cut value
|
|
#'
|
|
#' @return Partition Split
|
|
#' @export
|
|
partition_split <- function(k, X_k_cut) {
|
|
return(structure(list(k=k, X_k_cut=X_k_cut), class=c("partition_split")))
|
|
}
|
|
|
|
#' Is \code{partition_split}
|
|
#'
|
|
#' Tests whether or not an object is a \code{partition_split}.
|
|
#'
|
|
#' @param x an R object
|
|
#'
|
|
#' @return Boolean
|
|
#' @export
|
|
#' @describeIn partition_split is partition_split
|
|
is_partition_split <- function(x){
|
|
inherits(x, "partition_split")
|
|
}
|
|
|
|
#' Print partition_split
|
|
#'
|
|
#' Prints information for a \code{partition_split}
|
|
#'
|
|
#' @param x Object
|
|
#' @param ... Additional arguments. Unused.
|
|
#'
|
|
#' @return None
|
|
#' @export
|
|
print.partition_split <- function(x, ...) {
|
|
cat(paste0(x[[1]], ": ", x[[2]], "\n"))
|
|
}
|
|
|
|
# Search algo --------------------
|
|
|
|
|
|
#' Fit grid_partition
|
|
#'
|
|
#' Fit partition on some data, optionally finding best lambda using CV and then re-fiting on full data.
|
|
#'
|
|
#' Returns the partition and information about the fitting process
|
|
#'
|
|
#' @section Multiple estimates:
|
|
#' With multiple core estimates (M) there are 3 options (the first two have the same sample across treatment effects).\enumerate{
|
|
#' \item DS.MULTI_SAMPLE: Multiple pairs of (Y_{m},W_{m}). y,X,d are then lists of length M. Each element then has the typical size
|
|
#' The N_m may differ across m. The number of columns of X will be the same across m.
|
|
#' \item DS.MULTI_D: Multiple treatments and a single outcome. d is then a NxM matrix.
|
|
#' \item DS.MULTI_Y: A single treatment and multiple outcomes. y is then a NXM matrix.
|
|
#' }
|
|
#'
|
|
#' @param y Nx1 matrix of outcome (label/target) data. With multiple core estimates see Details below.
|
|
#' @param X NxK matrix of features (covariates). With multiple core estimates see Details below.
|
|
#' @param d (Optional) NxP matrix (with colnames) of treatment data. If all equally important they
|
|
#' should be normalized to have the same variance. With multiple core estimates see Details below.
|
|
#'
|
|
#' @param X_aux aux X sample to compute statistics on (OOS data)
|
|
#' @param d_aux aux d sample to compute statistics on (OOS data)
|
|
#' @param max_splits Maximum number of splits even if splits continue to improve OOS fit
|
|
#' @param max_cells Maximum number of cells even if more splits continue to improve OOS fit
|
|
#' @param min_size Minimum cell size when building full grid, cv_tr will use (F-1)/F*min_size, cv_te doesn't use any.
|
|
#' @param cv_folds Number of CV Folds or a vector of foldids.
|
|
#' If m_mode==DS.MULTI_SAMPLE, then a list with foldids per Dataset.
|
|
#' @param verbosity 0 print no message.
|
|
#' 1 prints progress bar for high-level loops.
|
|
#' 2 prints detailed output for high-level loops.
|
|
#' Nested operations decrease verbosity by 1.
|
|
#' @param breaks_per_dim NULL (for all possible breaks);
|
|
#' K-length vector with # of break (chosen by quantiles); or
|
|
#' K-dim list of vectors giving potential split points for non-categorical
|
|
#' variables (can put c(0) for categorical).
|
|
#' Similar to 'discrete splitting' in CausalTree though their they do separate split-points
|
|
#' for treated and controls.
|
|
#' @param potential_lambdas potential lambdas to search through in CV
|
|
#' @param X_range list of min/max for each dimension (e.g., from \code{\link{get_X_range}})
|
|
#' @param bucket_min_n Minimum number of observations needed between different split checks
|
|
#' @param bucket_min_d_var Ensure positive variance of d for the observations between different split checks
|
|
#' @param obj_fn Default is \code{\link{eval_mse_hat}}. User-provided must allow same signature.
|
|
#' @param est_plan \link{EstimatorPlan}.
|
|
#' @param partition_i Default NA. Use this to avoid CV
|
|
#' @param pr_cl Default NULL. Parallel cluster. Used for:\enumerate{
|
|
#' \item CVing the optimal lambda,
|
|
#' \item fitting full tree (at each split going across dimensions),
|
|
#' \item fitting trees over the bumped samples
|
|
#' }
|
|
#' @param bump_samples Number of bump bootstraps (default 0), or list of such length where each items is a bootstrap sample.
|
|
#' If m_mode==DS.MULTI_SAMPLE then each item is a sublist with such bootstrap samples over each dataset.
|
|
#' @param bump_ratio For bootstraps the ratio of sample size to sample (between 0 and 1, default 1)
|
|
#' @param ... Additional params.
|
|
#'
|
|
#' @return An object.
|
|
#' \item{partition}{Grid Partition (type=\code{\link{grid_partition}})}
|
|
#' \item{is_obj_val_seq}{Full sequence of in-sample objective function values}
|
|
#' \item{complexity_seq}{Full sequence of partition complexities (num_cells - 1)}
|
|
#' \item{partition_i}{Index of partition chosen}
|
|
#' \item{partition_seq}{Full sequence of Grid Partitions}
|
|
#' \item{split_seq}{Full sequence of splits (type=\code{\link{partition_split}})}
|
|
#' \item{lambda}{lambda chosen}
|
|
#' \item{folds_index_out}{List of the held-out observations for each fold (e.g., we might have generated them)}
|
|
#' @export
|
|
fit_partition <- function(y, X, d=NULL, X_aux=NULL, d_aux=NULL, max_splits=Inf, max_cells=Inf,
|
|
min_size=3, cv_folds=2, verbosity=0, breaks_per_dim=NULL, potential_lambdas=NULL,
|
|
X_range=NULL, bucket_min_n=NA, bucket_min_d_var=FALSE, obj_fn,
|
|
est_plan, partition_i=NA, pr_cl=NULL, bump_samples=0, bump_ratio=1, ...) {
|
|
#Hidden params:
|
|
# - @param lambda_1se Use the 1se rule to pick the best lambda
|
|
# - @param valid_fn Function to quickly check if partition could be valid. User can override.
|
|
# - @param split_check_fn Alternative split-check function
|
|
# - @param N_est N of samples in the Estimation dataset
|
|
# - @param nsplits_k_warn_limit
|
|
# - @param bump_complexity, method 1 is c(FALSE, FALSE), method 2 is c(FALSE, TRUE), and method 3 is c(TRUE)
|
|
extra_params = list(...)
|
|
valid_fn = split_check_fn = NULL
|
|
lambda_1se=FALSE
|
|
N_est=NA
|
|
nsplits_k_warn_limit=200
|
|
bump_complexity=list(doCV=FALSE, incl_comp_in_pick=FALSE)
|
|
if(length(extra_params)>0) {
|
|
if("valid_fn" %in% names(extra_params)) valid_fn = extra_params[['valid_fn']]
|
|
if("split_check_fn" %in% names(extra_params)) split_check_fn = extra_params[['split_check_fn']]
|
|
if("lambda_1se" %in% names(extra_params)) lambda_1se = extra_params[['lambda_1se']]
|
|
if("N_est" %in% names(extra_params)) N_est = extra_params[['N_est']]
|
|
if("nsplits_k_warn_limit" %in% names(extra_params)) nsplits_k_warn_limit = extra_params[['nsplits_k_warn_limit']]
|
|
if("bump_complexity" %in% names(extra_params)) bump_complexity = extra_params[['bump_complexity']]
|
|
good_args = c("valid_fn", "split_check_fn", "lambda_1se", "N_est","nsplits_k_warn_limit", "bump_complexity")
|
|
bad_names = names(extra_params)[!(names(extra_params) %in% good_args)]
|
|
assert_that(length(bad_names)==0, msg=paste(c(list("Illegal arguments:"), bad_names), collapse = " "))
|
|
}
|
|
|
|
#To check: y, X, d, N_est, X_aux, d_aux, breaks_per_dim, potential_lambdas, X_range, bucket_min_n
|
|
assert_that(max_splits>0, max_cells>0, min_size>0, msg="max_splits, max_cells, min_size need to be positive")
|
|
assert_that(is.flag(lambda_1se), is.flag(bucket_min_d_var), msg="One of (lambda_1se, bucket_min_d_var) are not flags.")
|
|
assert_that(inherits(est_plan, "estimator_plan") || (is.list(est_plan) && inherits(est_plan[[1]], "estimator_plan")), msg="estimator_plan argument (or it's first element) doesn't inherit from estimator_plan class")
|
|
#verbosity can be negative if decrementd from a fit_estimate call
|
|
list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE)
|
|
if(is_sep_sample(X) && length(cv_folds)>1) {
|
|
assert_that(is.list(cv_folds) && length(cv_folds)==M, msg="When separate samples and length(cv_folds)>1, need is.list(cv_folds) && length(cv_folds)==M.")
|
|
}
|
|
check_M_K(M, m_mode, K, X_aux, d_aux)
|
|
do_cv = is.na(partition_i) && (is.null(potential_lambdas) || length(potential_lambdas)>0)
|
|
do_bump = length(bump_samples)>1 || bump_samples > 0
|
|
if(!do_cv) assert_that(bump_complexity$doCV==FALSE, msg="When not doing CV, can't including bumping in CV.")
|
|
if(do_bump && bump_complexity$doCV) {
|
|
if(length(bump_samples==1)) bump_samples = list(bump_samples, bump_samples)
|
|
cv_bump_samples = bump_samples[[1]]
|
|
bump_samples = bump_samples[[2]]
|
|
}
|
|
else cv_bump_samples=0
|
|
|
|
if(is.null(X_range)) X_range = get_X_range(X)
|
|
if(!is.list(breaks_per_dim) && length(breaks_per_dim)==1) breaks_per_dim = get_quantile_breaks(X, X_range, g=breaks_per_dim)
|
|
if(is.null(valid_fn)) valid_fn = valid_partition
|
|
|
|
if(is.null(split_check_fn) && (!is.na(bucket_min_n) | bucket_min_d_var)) {
|
|
split_check_fn = purrr::partial(rolling_split_check, bucket_min_n=bucket_min_n, bucket_min_d_var=bucket_min_d_var)
|
|
}
|
|
else{
|
|
split_check_fn = NULL
|
|
}
|
|
|
|
if(verbosity>0) cat("Grid: Started.\n")
|
|
|
|
|
|
if(verbosity>0) cat("Grid: Fitting grid structure on full set\n")
|
|
fit_ret = fit_partition_full(y, X, d, X_aux, d_aux, X_range=X_range, max_splits=max_splits,
|
|
max_cells=max_cells, min_size=min_size, verbosity=verbosity-1,
|
|
breaks_per_dim=breaks_per_dim, N_est, split_check_fn=split_check_fn,
|
|
obj_fn=obj_fn, allow_empty_aux=FALSE,
|
|
allow_est_errors_aux=FALSE, min_size_aux=1, est_plan=est_plan,
|
|
pr_cl=pr_cl, valid_fn=valid_fn, nsplits_k_warn_limit=nsplits_k_warn_limit)
|
|
list[partition_seq, is_obj_val_seq, split_seq] = fit_ret
|
|
complexity_seq = sapply(partition_seq, num_cells) - 1
|
|
|
|
foldids = NA
|
|
if(!is.na(partition_i)) {
|
|
lambda = NA
|
|
max_splits = partition_i-1
|
|
if(length(partition_seq)< partition_i) {
|
|
cat("Note: Couldn't build grid to desired granularity. Using most granular")
|
|
partition_i = length(partition_seq)
|
|
}
|
|
assert_that(bump_complexity$incl_comp_in_pick==FALSE, msg="When no complexity penalization used, can't include complexity cost in bumping calculation.")
|
|
}
|
|
else {
|
|
if(do_cv) {
|
|
list[nfolds, folds_ret, foldids] = expand_fold_info(y, cv_folds, m_mode)
|
|
list[lambda,lambda_oos, n_cell_table] = cv_pick_lambda(y=y, X=X, d=d, folds_ret=folds_ret, nfolds=nfolds, potential_lambdas=potential_lambdas, N_est=N_est, max_splits=max_splits, max_cells=max_cells,
|
|
min_size=min_size, verbosity=verbosity, breaks_per_dim=breaks_per_dim, X_range=X_range, lambda_1se=lambda_1se,
|
|
split_check_fn=split_check_fn, obj_fn=obj_fn,
|
|
est_plan=est_plan, pr_cl=pr_cl, valid_fn=valid_fn, cv_bump_samples=cv_bump_samples, bump_ratio=bump_ratio)
|
|
}
|
|
else {
|
|
lambda = potential_lambdas[1]
|
|
}
|
|
partition_i = which.min(is_obj_val_seq + lambda*complexity_seq)
|
|
}
|
|
|
|
|
|
if(do_bump) {
|
|
if(verbosity>0) cat("Grid > Bumping: Started.\n")
|
|
|
|
if(bump_complexity$incl_comp_in_pick) {
|
|
best_val = is_obj_val_seq[partition_i] + lambda*complexity_seq[partition_i]
|
|
}
|
|
else {
|
|
best_val = is_obj_val_seq[partition_i]
|
|
}
|
|
|
|
b_rets = gen_bumped_partitions(bump_samples, bump_ratio, N, m_mode, verbosity, pr_cl, min_size=min_size*bump_ratio,
|
|
y=y, X_d=X, d=d, X_aux=X_aux, d_aux=d_aux, X_range=X_range, max_splits=max_splits,
|
|
max_cells=max_cells,
|
|
breaks_per_dim=breaks_per_dim, N_est=N_est, split_check_fn=split_check_fn, obj_fn=obj_fn,
|
|
min_size_aux=min_size, est_plan=est_plan,
|
|
valid_fn=valid_fn)
|
|
bump_B = length(b_rets)
|
|
|
|
best_b = NA
|
|
for(b in seq_len(bump_B)) {
|
|
b_ret = b_rets[[b]]
|
|
if(do_cv || bump_complexity$incl_comp_in_pick) b_complexity_seq = sapply(b_ret$partition_seq, num_cells) - 1
|
|
if(do_cv) {
|
|
partition_i_b = which.min(b_ret$is_obj_val_seq + lambda*b_complexity_seq)
|
|
}
|
|
else {
|
|
partition_i_b = partition_i #default
|
|
if(length(b_ret$partition_seq)<partition_i_b) {
|
|
cat("Note: Couldn't build grid to desired granularity. Using most granular\n")
|
|
partition_i_b = length(b_ret$partition_seq)
|
|
}
|
|
}
|
|
partition_b = b_ret$partition_seq[[partition_i_b]]
|
|
|
|
obj_ret = obj_fn(y, X, d, N_est=N_est, partition=partition_b, est_plan=est_plan, sample="trtr")
|
|
if(obj_ret[2]>0 | obj_ret[3]>0) next #N_cell_empty, N_cell_error
|
|
if(bump_complexity$incl_comp_in_pick) {
|
|
bump_val = obj_ret[1] + lambda*b_complexity_seq[partition_i_b]
|
|
}
|
|
else {
|
|
bump_val = obj_ret[1]
|
|
}
|
|
|
|
if(bump_val < best_val){
|
|
best_val = bump_val
|
|
best_b = b
|
|
}
|
|
}
|
|
if(!is.na(best_b)) {
|
|
if(verbosity>0) {
|
|
cat(paste("Grid > Bumping: Finished. Picking bumped partition."))
|
|
cat(paste(" Old (unbumped) is_obj_val_seq=[", paste(is_obj_val_seq, collapse=" "), "]."))
|
|
cat(paste(" Old (unbumped) complexity_seq=[", paste(complexity_seq, collapse=" "), "].\n"))
|
|
}
|
|
list[partition_seq, is_obj_val_seq_best_b, split_seq] = b_rets[[best_b]]
|
|
if(do_cv) {
|
|
b_complexity_seq = sapply(b_rets[[best_b]]$partition_seq, num_cells) - 1
|
|
partition_i = which.min(b_rets[[best_b]]$is_obj_val_seq + lambda*b_complexity_seq)
|
|
}
|
|
complexity_seq = sapply(partition_seq, num_cells) - 1
|
|
is_obj_val_seq = sapply(partition_seq, function(p){
|
|
obj_fn(y, X, d, N_est=N_est, partition=p, est_plan=est_plan, sample="trtr")[1]
|
|
})
|
|
}
|
|
else {
|
|
if(verbosity>0) cat(paste("Grid > Bumping: Finished. No bumped partitions better than original.\n"))
|
|
}
|
|
}
|
|
|
|
if(verbosity>0) {
|
|
#print(partition_seq)
|
|
cat(paste("Grid: Finished. is_obj_val_seq=[", paste(is_obj_val_seq, collapse=" "), "]."))
|
|
if(do_cv) {
|
|
cat(paste(" complexity_seq=[", paste(complexity_seq, collapse=" "), "]."))
|
|
cat(paste(" best partition=", paste(partition_i, collapse=" "), "."))
|
|
}
|
|
cat("\n")
|
|
}
|
|
partition = partition_seq[[partition_i]]
|
|
return(list(partition=partition, is_obj_val_seq=is_obj_val_seq, complexity_seq=complexity_seq,
|
|
partition_i=partition_i, partition_seq=partition_seq, split_seq=split_seq, lambda=lambda,
|
|
foldids=foldids))
|
|
}
|
|
|
|
|
|
#' Get break-points by looking at quantiles
|
|
#'
|
|
#' Provides a set of potential split points for data according to quantiles (if possible)
|
|
#'
|
|
#' @param X Features
|
|
#' @param X_range X-range
|
|
#' @param g # of quantiles
|
|
#' @param type Quantile type (see ?quantile and https://mathworld.wolfram.com/Quantile.html).
|
|
#' Types1-3 are discrete and this is good for passing to unique() when there are clumps
|
|
#'
|
|
#' @return list of potential breaks
|
|
get_quantile_breaks <- function(X, X_range, g=20, type=3) {
|
|
if(is.null(g)) g=20 #fit_estimate has a different default that might get passed in.
|
|
if(is_sep_sample(X)) X = X[[1]]
|
|
X = ensure_good_X(X)
|
|
|
|
breaks_per_dim = list()
|
|
K = ncol(X)
|
|
for(k in 1:K) {
|
|
X_k = X[,k]
|
|
if(is.factor(X_k)) {
|
|
breaks_per_dim[[k]] = c(0) #Dummy
|
|
}
|
|
else {
|
|
if(storage.mode(X_k)=="integer" && (X_range[[k]][2]-X_range[[k]][1])<=g) {
|
|
vals = sort(unique(X_k))
|
|
breaks_per_dim[[k]] = vals[-c(length(vals), 1)]
|
|
}
|
|
else {
|
|
#unique(sort(X[,k])) #we will automatically skip the top point
|
|
#if you want g cuts, then there are g+2 outer nodes
|
|
qs = quantile(X_k, seq(0, 1, length.out=g+2), names=FALSE, type=type)
|
|
qs = unique(qs)
|
|
breaks_per_dim[[k]] = qs[-c(length(qs), 1)]
|
|
}
|
|
}
|
|
}
|
|
return(breaks_per_dim)
|
|
}
|
|
|
|
# if d vectors are empty doesn't return fail
|
|
valid_partition <- function(cell_factor, d=NULL, cell_factor_aux=NULL, d_aux=NULL, min_size=0) {
|
|
#check none of the cells are too small
|
|
if(min_size>0) {
|
|
if(length(cell_factor)==0) return(list(fail=TRUE, min_size=0))
|
|
lowest_size = min(table(cell_factor))
|
|
if(lowest_size<min_size) return(list(fail=TRUE, min_size=lowest_size))
|
|
|
|
if(!is.null(cell_factor_aux)) {
|
|
if(length(cell_factor_aux)==0) return(list(fail=TRUE, min_size_aux=0))
|
|
lowest_size_aux = min(table(cell_factor_aux))
|
|
if(lowest_size_aux<min_size) return(list(fail=TRUE, min_size_aux=lowest_size_aux))
|
|
}
|
|
}
|
|
|
|
if(!is.null(d)) {
|
|
if(!is_vec(d)) {
|
|
for(m in 1:ncol(d)) {
|
|
if(any(by(d[,m], cell_factor, FUN=const_vect))) {
|
|
return(list(fail=TRUE, always_d_var=FALSE))
|
|
}
|
|
}
|
|
|
|
}
|
|
else {
|
|
if(any(by(as.vector(d), cell_factor, FUN=const_vect))) {
|
|
return(list(fail=TRUE, always_d_var=FALSE))
|
|
}
|
|
}
|
|
}
|
|
if(!is.null(d_aux)) {
|
|
if(!is_vec(d_aux)) {
|
|
for(m in 1:ncol(d_aux)) {
|
|
if(any(by(d_aux[,m], cell_factor, FUN=const_vect))) {
|
|
return(list(fail=TRUE, always_d_var=FALSE))
|
|
}
|
|
}
|
|
|
|
}
|
|
else {
|
|
if(any(by(as.vector(d_aux), cell_factor_aux, FUN=const_vect))) {
|
|
return(list(fail=TRUE, always_d_var_aux=FALSE))
|
|
}
|
|
}
|
|
}
|
|
return(list(fail=FALSE))
|
|
}
|
|
|
|
|
|
gen_bumped_partitions <- function(bump_samples, bump_ratio, N, m_mode, verbosity, pr_cl, allow_empty_aux=FALSE, allow_est_errors_aux=FALSE, ...) {
|
|
assert_that(bump_ratio>0, bump_ratio<=1, msg="bump_ration needs to be in (0,1]")
|
|
bump_samples = expand_bump_samples(bump_samples, bump_ratio, N, m_mode)
|
|
bump_B = length(bump_samples)
|
|
|
|
params = c(list(samples=bump_samples, verbosity=verbosity-1, allow_empty_aux=FALSE, allow_est_errors_aux=FALSE, pr_cl=NULL, m_mode=m_mode),
|
|
list(...))
|
|
|
|
b_rets = my_apply(1:bump_B, fit_partition_bump_b, verbosity==1 || !is.null(pr_cl), pr_cl, params)
|
|
return(b_rets)
|
|
}
|
|
|
|
#if not mid-point then the all but the last are the splits
|
|
get_usable_break_points <- function(breaks_per_dim, X, X_range, dim_cat, mid_point=TRUE) {
|
|
if(is_sep_sample(X)) X = X[[1]]
|
|
K = ncol(X)
|
|
#old code
|
|
if(is.null(breaks_per_dim)) {
|
|
breaks_per_dim = list()
|
|
for(k in 1:K) {
|
|
if(!k %in% dim_cat) {
|
|
u = unique(sort(X[, k]))
|
|
if(mid_point) {
|
|
breaks_per_dim[[k]] = u[-length(u)] + diff(u) / 2
|
|
}
|
|
else {
|
|
breaks_per_dim[[k]] = u[-length(u)] #skip last point
|
|
}
|
|
}
|
|
else {
|
|
breaks_per_dim[[k]] = c(0) #Dummy just for place=holder
|
|
}
|
|
}
|
|
}
|
|
else { #make sure they didn't include the lowest point
|
|
for(k in 1:K) {
|
|
if(!k %in% dim_cat) {
|
|
n_k = length(breaks_per_dim[[k]])
|
|
if(breaks_per_dim[[k]][n_k]==X_range[[k]][2]) {
|
|
breaks_per_dim[[k]] = breaks_per_dim[[k]][-n_k]
|
|
}
|
|
}
|
|
breaks_per_dim[[k]] = unname(breaks_per_dim[[k]]) #names messed up the get_desc_df() (though not in debugSource)
|
|
}
|
|
}
|
|
return(breaks_per_dim)
|
|
}
|
|
|
|
|
|
#Typically is_obj_val_seq trends negative. If first element is min, then return c()
|
|
get_lambda_ties <- function(is_obj_val_seq, complexity_seq) {
|
|
n_seq = length(is_obj_val_seq)
|
|
slopes = c() #will go from strongly negative and increases and we stop before reaching 0
|
|
hull_i = 1
|
|
while(hull_i < n_seq) {
|
|
i_slopes = rep(NA, n_seq)
|
|
for(i in (hull_i+1):n_seq) {
|
|
i_slopes[i] = (is_obj_val_seq[i] - is_obj_val_seq[hull_i])/(complexity_seq[i]- complexity_seq[hull_i])
|
|
}
|
|
best_slope = min(i_slopes, na.rm=TRUE)
|
|
if(best_slope>=0) break
|
|
slopes = c(slopes, best_slope)
|
|
hull_i = which.min(i_slopes)
|
|
}
|
|
if(length(slopes)>1) {
|
|
lambda_ties = abs(slopes) #slightly bigger will go will pick the index earlier, slightly bigger later
|
|
}
|
|
else {
|
|
lambda_ties = c()
|
|
}
|
|
return(lambda_ties)
|
|
}
|
|
|
|
gen_cat_window_splits <- function(chr_vec) {
|
|
n = length(chr_vec)
|
|
splits=list()
|
|
for(m in seq_len(floor(n/2))) {
|
|
cs = combn(chr_vec, m, simplify=F)
|
|
if(m==n/2) cs = cs[1:(length(cs)/2)] #or just filter by those that contain chr_vec[1]
|
|
splits = c(splits, cs)
|
|
}
|
|
return(splits)
|
|
}
|
|
|
|
n_cat_window_splits <- function(window_len) {
|
|
n_splits = 0
|
|
for(m in seq_len(floor(window_len/2))) {
|
|
n_choose = choose(window_len, m)
|
|
n_splits = n_splits + if(m==window_len/2) n_choose/2 else n_choose
|
|
}
|
|
return(n_splits)
|
|
}
|
|
|
|
n_cat_splits <- function(s_by_dim_k, X_range_k) {
|
|
windows = get_windows_cat(s_by_dim_k, X_range_k)
|
|
n_splits = 0
|
|
for(window in windows) n_splits = n_splits + n_cat_window_splits(length(window))
|
|
return(n_splits)
|
|
}
|
|
|
|
get_windows_cat <- function(s_by_dim_k, X_k_range) {
|
|
windows = s_by_dim_k
|
|
windows[[length(windows)+1]] = X_k_range[!X_k_range %in% unlist(c(windows))]
|
|
return(windows)
|
|
}
|
|
|
|
get_window_cont <- function(s_by_dim_k, X_k_range) {
|
|
windows=list()
|
|
n_w = length(s_by_dim_k)+1
|
|
for(w in 1:n_w) {
|
|
wmin = if(w==1) X_k_range[1] else s_by_dim_k[w-1]
|
|
wmax = if(w==n_w) X_k_range[2] else s_by_dim_k[w]
|
|
windows[[w]] = c(wmin, wmax)
|
|
}
|
|
return(windows)
|
|
}
|
|
|
|
gen_holdout_interaction <- function(factors_by_dim, k) {
|
|
if(length(factors_by_dim)>1)
|
|
return(interaction(factors_by_dim[-k]))
|
|
return(factor(rep("|", length(factors_by_dim[[1]]))))
|
|
}
|
|
|
|
n_breaks_k <- function(breaks_per_dim, k, partition, X_range) {
|
|
if(k %in% partition$dim_cat) return(n_cat_splits(partition$s_by_dim[[k]], X_range[[k]]))
|
|
return(length(breaks_per_dim[[k]]))
|
|
}
|
|
|
|
|
|
rolling_split_check <- function(shifted_N, shifted_d=NULL, shifted_cell_factor_nk, m_mode, bucket_min_n=NA, bucket_min_d_var=FALSE) {
|
|
if(!is.na(bucket_min_n) && min(shifted_N)<bucket_min_n){
|
|
#cat("Skipped: increment not big enough\n")
|
|
return(FALSE)
|
|
}
|
|
|
|
if(bucket_min_d_var && !is.null(shifted_d) && any_const_m(shifted_d, shifted_cell_factor_nk, m_mode)) {
|
|
return(FALSE)
|
|
}
|
|
|
|
return(TRUE)
|
|
}
|
|
|
|
fit_partition_full_k <- function(k, y, X_d, d, X_range, pb, debug, valid_breaks, factors_by_dim, X_aux,
|
|
factors_by_dim_aux, partition, verbosity, allow_empty_aux=TRUE, d_aux,
|
|
allow_est_errors_aux, min_size, min_size_aux, obj_fn, N_est, est_plan,
|
|
split_check_fn = NULL, breaks_per_dim, valid_fn=NULL) { #, n_cut
|
|
assert_that(is.flag(allow_empty_aux), msg="allow_empty_aux needs to be logical flags.")
|
|
list[M, m_mode, N, K] = get_sample_type(y, X_d, d, checks=FALSE)
|
|
if(is.null(valid_fn)) valid_fn = valid_partition
|
|
search_ret = list()
|
|
best_new_val = Inf
|
|
valid_breaks_k = valid_breaks[[k]]
|
|
cell_factor_nk = gen_holdout_interaction_m(factors_by_dim, k, is_sep_sample(X_d))
|
|
if(!is.null(X_aux)) {
|
|
cell_factor_nk_aux = gen_holdout_interaction_m(factors_by_dim_aux, k, is_sep_sample(X_aux))
|
|
}
|
|
|
|
if(!is_factor_dim_k_m(X_d, k, m_mode==DS.MULTI_SAMPLE)) {
|
|
n_pot_break_points_k = length(breaks_per_dim[[k]])
|
|
vals = rep(NA, n_pot_break_points_k)
|
|
prev_split_checked = X_range[[k]][1]
|
|
win_LB = X_range[[k]][1]-1
|
|
win_UB = if(length(partition$s_by_dim[[k]])>0) partition$s_by_dim[[k]][1] else X_range[[k]][2]
|
|
win_mask = gen_cont_window_mask_m(X_d, k, win_LB, win_UB)
|
|
win_mask_aux = gen_cont_window_mask_m(X_aux, k, win_LB, win_UB)
|
|
for(X_k_cut_i in seq_len(n_pot_break_points_k)) { #cut-point is top end of segment,
|
|
if (verbosity>0 && !is.null(pb)) setTxtProgressBar(pb, getTxtProgressBar(pb)+1)
|
|
X_k_cut = breaks_per_dim[[k]][X_k_cut_i]
|
|
if(X_k_cut %in% partition$s_by_dim[[k]]) {
|
|
prev_split_checked = X_k_cut
|
|
win_LB = X_k_cut
|
|
higher_prev_split = partition$s_by_dim[[k]][partition$s_by_dim[[k]]>X_k_cut]
|
|
win_UB = if(length(higher_prev_split)>0) min(higher_prev_split) else X_range[[k]][2]
|
|
win_mask = gen_cont_window_mask_m(X_d, k, win_LB, win_UB)
|
|
win_mask_aux = gen_cont_window_mask_m(X_aux, k, win_LB, win_UB)
|
|
next
|
|
}
|
|
if(!valid_breaks_k[[1]][X_k_cut_i]) next
|
|
new_split = partition_split(k, X_k_cut)
|
|
tent_partition = add_partition_split(partition, new_split)
|
|
|
|
tent_split_fac_k = get_factors_from_splits_dim_m(X_d, X_range[[k]], tent_partition$s_by_dim[[k]], k)
|
|
tent_cell_factor = interaction2_m(cell_factor_nk, tent_split_fac_k, m_mode==DS.MULTI_SAMPLE)
|
|
if(!is.null(X_aux)) {
|
|
tent_split_fac_k_aux = get_factors_from_splits_dim_m(X_aux, X_range[[k]], tent_partition$s_by_dim[[k]], k)
|
|
tent_cell_factor_aux = interaction2_m(cell_factor_nk_aux, tent_split_fac_k_aux, is_sep_sample(X_aux))
|
|
}
|
|
|
|
if(!is.null(split_check_fn)){
|
|
shifted_mask = gen_cont_window_mask_m(X_d, k, prev_split_checked, X_k_cut)
|
|
shifted_N = sum_m(shifted_mask, m_mode==DS.MULTI_SAMPLE)
|
|
shifted_cell_factor_nk = droplevels_m(apply_mask_m(cell_factor_nk, shifted_mask, m_mode==DS.MULTI_SAMPLE), m_mode==DS.MULTI_SAMPLE)
|
|
shifted_d = if(is.null(d)) NULL else apply_mask_m(d, shifted_mask, m_mode==DS.MULTI_SAMPLE)
|
|
split_OK = split_check_fn(shifted_N, shifted_d, shifted_cell_factor_nk, m_mode)
|
|
if(!split_OK) {
|
|
valid_breaks_k[[1]][X_k_cut_i] = FALSE
|
|
next
|
|
}
|
|
}
|
|
|
|
# do_window_approach
|
|
#The bucket checks don't help much.
|
|
#- Though I do check for non-zero var of D, that's just on the left so to check on right side too
|
|
#- Note that though not min_size as different than bucket_min_n)
|
|
win_split_cond = gen_cont_win_split_cond_m(X_d, win_mask, k, X_k_cut)
|
|
win_cell_factor_nk = apply_mask_m(cell_factor_nk, win_mask, m_mode==DS.MULTI_SAMPLE)
|
|
win_cell_factor = interaction2_m(win_cell_factor_nk, win_split_cond, m_mode==DS.MULTI_SAMPLE)
|
|
win_d = apply_mask_m(d, win_mask, m_mode==DS.MULTI_SAMPLE)
|
|
valid_ret = valid_partition_m(m_mode==DS.MULTI_SAMPLE, valid_fn, win_cell_factor, d=win_d, min_size=min_size)
|
|
if(!valid_ret$fail) {
|
|
if(!allow_empty_aux && !is.null(X_aux)) {
|
|
win_split_cond_aux = gen_cont_win_split_cond_m(X_aux, win_mask_aux, k, X_k_cut)
|
|
win_cell_factor_aux = interaction2_m(apply_mask_m(cell_factor_nk_aux, win_mask_aux, is_sep_sample(X_aux)),
|
|
win_split_cond_aux, is_sep_sample(X_aux), drop=allow_empty_aux)
|
|
win_d_aux = if(!allow_est_errors_aux) apply_mask_m(d_aux, win_mask_aux, is_sep_sample(X_aux)) else NULL
|
|
valid_ret = valid_partition_m(is_sep_sample(X_aux), valid_fn, win_cell_factor_aux, d=win_d_aux, min_size=min_size_aux)
|
|
}
|
|
}
|
|
# Global approach
|
|
# valid_ret = valid_fn(tent_cell_factor, d=d, min_size=min_size)
|
|
# if(!valid_ret$fail) {
|
|
# valid_ret = valid_fn(tent_cell_factor_aux, d=d_aux, min_size=2)
|
|
# }
|
|
if(valid_ret$fail) {
|
|
#cat("Invalid partition\n")
|
|
valid_breaks_k[[1]][X_k_cut_i] = FALSE
|
|
next
|
|
}
|
|
if(debug) cat(paste("k", k, ". X_k", X_k_cut, "\n"))
|
|
obj_ret = obj_fn(y, X_d, d, N_est=N_est, cell_factor_tr = tent_cell_factor, debug=debug, est_plan=est_plan,
|
|
sample="trtr")
|
|
if(obj_ret[3]>0) { #don't need to check [2] (empty cells) as we already did that
|
|
#cat("Estimation errors\n")
|
|
valid_breaks_k[[1]][X_k_cut_i] = FALSE
|
|
next
|
|
}
|
|
val = obj_ret[1]
|
|
stopifnot(is.finite(val))
|
|
prev_split_checked = X_k_cut
|
|
|
|
if(val<best_new_val) {
|
|
#if(verbosity>0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val))
|
|
best_new_val = val
|
|
new_factors_by_dim = replace_k_factor_m(factors_by_dim, k, tent_split_fac_k, is_sep_sample(X_d))
|
|
if(!is.null(X_aux)) {
|
|
new_factors_by_dim_aux = replace_k_factor_m(factors_by_dim_aux, k, tent_split_fac_k_aux, is_sep_sample(X_aux))
|
|
}
|
|
else new_factors_by_dim_aux = NULL
|
|
search_ret = list(val=val, new_split=new_split, new_factors_by_dim=new_factors_by_dim,
|
|
new_factors_by_dim_aux=new_factors_by_dim_aux)
|
|
}
|
|
}
|
|
|
|
}
|
|
else { #categorical variable
|
|
windows = get_windows_cat(partition$s_by_dim[[k]], X_range[[k]])
|
|
for(window_i in seq_len(length(windows))) {
|
|
window = windows[[window_i]]
|
|
win_mask = gen_cat_window_mask_m(X_d, k, window)
|
|
win_mask_aux = gen_cat_window_mask_m(X_aux, k, window)
|
|
pot_splits = gen_cat_window_splits(window)
|
|
for(win_split_i in seq_len(length(pot_splits))) {
|
|
win_split_val = pot_splits[[win_split_i]]
|
|
#TODO: Refactor with continuous case
|
|
if (verbosity>0 && !is.null(pb)) setTxtProgressBar(pb, getTxtProgressBar(pb)+1)
|
|
if(!valid_breaks_k[[window_i]][win_split_i]) next
|
|
|
|
new_split = partition_split(k, win_split_val)
|
|
tent_partition = add_partition_split(partition, new_split)
|
|
|
|
tent_split_fac_k = get_factors_from_splits_dim_m(X_d, X_range[[k]], tent_partition$s_by_dim[[k]], k)
|
|
tent_cell_factor = interaction2_m(cell_factor_nk, tent_split_fac_k, m_mode==DS.MULTI_SAMPLE)
|
|
if(!is.null(X_aux)) {
|
|
tent_split_fac_k_aux = get_factors_from_splits_dim_m(X_aux, X_range[[k]], tent_partition$s_by_dim[[k]], k)
|
|
tent_cell_factor_aux = interaction2_m(cell_factor_nk_aux, tent_split_fac_k_aux, is_sep_sample(X_aux))
|
|
}
|
|
|
|
# do_window_approach
|
|
win_split_cond = gen_cat_win_split_cond_m(X_d, win_mask, k, win_split_val)
|
|
win_cell_factor = interaction2_m(apply_mask_m(cell_factor_nk, win_mask, m_mode==DS.MULTI_SAMPLE), win_split_cond, m_mode==DS.MULTI_SAMPLE)
|
|
win_d = apply_mask_m(d, win_mask, m_mode==DS.MULTI_SAMPLE)
|
|
valid_ret = valid_partition_m(m_mode==DS.MULTI_SAMPLE, valid_fn, win_cell_factor, d=win_d, min_size=min_size)
|
|
if(!valid_ret$fail) {
|
|
if(!is.null(X_aux) && !allow_empty_aux) {
|
|
win_split_cond_aux = factor(gen_cat_win_split_cond_m(X_aux, win_mask_aux, k, win_split_val), levels=c(FALSE, TRUE))
|
|
win_cell_factor_aux = interaction2_m(apply_mask_m(cell_factor_nk_aux, win_mask_aux, is_sep_sample(X_aux)),
|
|
win_split_cond_aux, is_sep_sample(X_aux), drop=allow_empty_aux)
|
|
win_d_aux = if(!allow_est_errors_aux) apply_mask_m(d_aux, win_mask_aux, is_sep_sample(X_aux)) else NULL
|
|
valid_ret = valid_partition_m(is_sep_sample(X_aux), valid_fn, win_cell_factor_aux, d=win_d_aux, min_size=min_size_aux)
|
|
}
|
|
}
|
|
if(valid_ret$fail) {
|
|
#cat("Invalid partition\n")
|
|
valid_breaks_k[[window_i]][win_split_i] = FALSE
|
|
next
|
|
}
|
|
if(debug) cat(paste("k", k, ". X_k", win_split_val, "\n"))
|
|
obj_ret = obj_fn(y, X_d, d, N_est=N_est, cell_factor_tr = tent_cell_factor, debug=debug, est_plan=est_plan,
|
|
sample="trtr")
|
|
if(obj_ret[3]>0) { #don't need to check [2] (empty cells) as we already did that
|
|
#cat("Estimation errors\n")
|
|
valid_breaks_k[[window_i]][win_split_i] = FALSE
|
|
next
|
|
}
|
|
val = obj_ret[1]
|
|
stopifnot(is.finite(val))
|
|
|
|
|
|
if(val<best_new_val) {
|
|
#if(verbosity>0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val))
|
|
best_new_val = val
|
|
new_factors_by_dim = replace_k_factor_m(factors_by_dim, k, tent_split_fac_k, is_sep_sample(X_d))
|
|
if(!is.null(X_aux)) {
|
|
new_factors_by_dim_aux = replace_k_factor_m(factors_by_dim_aux, k, tent_split_fac_k_aux, is_sep_sample(X_aux))
|
|
}
|
|
else new_factors_by_dim_aux = NULL
|
|
search_ret = list(val=val, new_split=new_split, new_factors_by_dim=new_factors_by_dim,
|
|
new_factors_by_dim_aux=new_factors_by_dim_aux)
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
return(list(search_ret, valid_breaks_k))
|
|
}
|
|
|
|
# There are three general problems with a partition.
|
|
# 1) Empty cells
|
|
# 2) Non-empty cells where objective can't be calculated
|
|
# 3) Cells where it can be calulcated but due to small sizes we don't want
|
|
# Main sample: We assume that a valid partition removes #1 and #2. Use min_size for 3
|
|
# For Aux: Use allow_empty_aux, allow_est_errors_aux, min_size_aux
|
|
# Include d_aux if you want to make sure that non-empty cells in aux have positive variance in d
|
|
# FOr CV:allow_empty_aux=TRUE, allow_est_errors_aux=FALSE, min_size_aux=1 (weaker check than removing estimation errors)
|
|
# Can set nsplits_k_warn_limit=Inf to disable
|
|
fit_partition_full <- function(y, X, d=NULL, X_aux=NULL, d_aux=NULL, X_range, max_splits=Inf, max_cells=Inf,
|
|
min_size=2, verbosity=0, breaks_per_dim, N_est, obj_fn, allow_est_errors_aux=TRUE,
|
|
min_size_aux=2, est_plan, partition=NULL, nsplits_k_warn_limit=200, pr_cl=NULL,
|
|
...) {
|
|
assert_that(max_splits>=0, max_cells>=1, min_size>=1, msg="Need max_splits>=0, max_cells>=1, min_size>=1.")
|
|
assert_that(is.flag(allow_est_errors_aux), msg="allow_est_errors_aux needs to be a flag")
|
|
assert_that(is.na(nsplits_k_warn_limit) || nsplits_k_warn_limit>=1, msg="nsplits_k_warn_limit not understood")
|
|
list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE)
|
|
est_min = ifelse(is.null(d), 2, 3) #If don't always need variance calc: ifelse(is.null(d), ifelse(honest, 2, 1), ifelse(honest, 3, 2))
|
|
min_size = max(min_size, est_min)
|
|
if(!allow_est_errors_aux) min_size_aux = max(min_size_aux, est_min)
|
|
debug = FALSE
|
|
if(is.null(partition)) partition = grid_partition(X_range, colnames(X))
|
|
breaks_per_dim = get_usable_break_points(breaks_per_dim, X, X_range, partition$dim_cat)
|
|
valid_breaks = vector("list", length=K) #splits_by_dim(s_seq) #stores Xk_val's
|
|
|
|
for(k in 1:K) {
|
|
n_split_breaks_k = n_breaks_k(breaks_per_dim, k, partition, X_range)
|
|
valid_breaks[[k]] = list(rep(TRUE, n_split_breaks_k))
|
|
if(!is.na(nsplits_k_warn_limit) && n_split_breaks_k>nsplits_k_warn_limit) warning(paste("Warning: Many splits (", n_split_breaks_k, ") along dimension", k, "\n"))
|
|
}
|
|
factors_by_dim = get_factors_from_partition(partition, X)
|
|
if(!is.null(X_aux)) {
|
|
factors_by_dim_aux = get_factors_from_partition(partition, X_aux)
|
|
}
|
|
|
|
if(verbosity>0){
|
|
cat("Grid > Fitting: Started.\n")
|
|
t0 = Sys.time()
|
|
}
|
|
split_i = 1
|
|
seq_val = c()
|
|
obj_ret = obj_fn(y, X, d, N_est=N_est, cell_factor_tr = interaction_m(factors_by_dim, is_sep_sample(X)), est_plan=est_plan, sample="trtr")
|
|
if(obj_ret[3]>0 || !is.finite(obj_ret[1])) {
|
|
stop("Estimation error with initial partition")
|
|
}
|
|
seq_val[1] = obj_ret[1]
|
|
partition_seq = list()
|
|
split_seq = list()
|
|
partition_seq[[1]] = partition
|
|
tent_cell_factor_aux = NULL
|
|
style = if(summary(stdout())$class=="terminal") 3 else 1
|
|
if(!is.null(pr_cl) & !requireNamespace("parallel", quietly = TRUE)) {
|
|
stop("Package \"parallel\" needed for this function to work. Please install it.", call. = FALSE)
|
|
}
|
|
do_pbapply = requireNamespace("pbapply", quietly = TRUE) & (verbosity>0) & (is.null(pr_cl) || length(pr_cl)<K)
|
|
|
|
while(TRUE) {
|
|
if(split_i>max_splits) break
|
|
if(num_cells(partition)==max_cells) break
|
|
n_cuts_k = rep(0, K)
|
|
for(k in 1:K) {
|
|
n_cuts_k[k] = n_breaks_k(breaks_per_dim, k, partition, X_range)
|
|
}
|
|
n_cuts_total = sum(n_cuts_k)
|
|
if(n_cuts_total==0) break
|
|
if(verbosity>0) {
|
|
cat(paste("Grid > Fitting > split ", split_i, ": Started\n"))
|
|
t1 = Sys.time()
|
|
if(is.null(pr_cl)) pb = txtProgressBar(0, n_cuts_total, style = style)
|
|
}
|
|
|
|
params = c(list(y=y, X_d=X, d=d, X_range=X_range, pb=NULL, debug=debug, valid_breaks=valid_breaks,
|
|
factors_by_dim=factors_by_dim, X_aux=X_aux, factors_by_dim_aux=factors_by_dim_aux, partition=partition,
|
|
verbosity=verbosity, d_aux=d_aux, allow_est_errors_aux=allow_est_errors_aux,
|
|
min_size=min_size, min_size_aux=min_size_aux, obj_fn=obj_fn, N_est=N_est, est_plan=est_plan,
|
|
breaks_per_dim=breaks_per_dim), list(...))
|
|
|
|
col_rets = my_apply(1:K, fit_partition_full_k, verbosity, pr_cl, params)
|
|
|
|
best_new_val = Inf
|
|
best_new_split = NULL
|
|
for(k in 1:K) {
|
|
col_ret = col_rets[[k]]
|
|
search_ret = col_ret[[1]]
|
|
valid_breaks[[k]] = col_ret[[2]]
|
|
if(length(search_ret)>0 && search_ret$val<best_new_val) {
|
|
#if(verbosity>0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val))
|
|
best_new_val = search_ret$val
|
|
best_new_split = search_ret$new_split
|
|
best_new_factors_by_dim = search_ret$new_factors_by_dim
|
|
if(!is.null(X_aux)) {
|
|
best_new_factors_by_dim_aux = search_ret$new_factors_by_dim_aux
|
|
}
|
|
}
|
|
}
|
|
|
|
if (verbosity>0) {
|
|
t2 = Sys.time() #can us as.numeric(t1) to convert to seconds
|
|
td = t2-t1
|
|
if(is.null(pr_cl)) close(pb)
|
|
}
|
|
if(is.null(best_new_split)) {
|
|
if (verbosity>0) cat(paste("Grid > Fitting > split ", split_i, ": Finished. Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ". No valid splits\n"))
|
|
break
|
|
}
|
|
best_new_partition = add_partition_split(partition, best_new_split)
|
|
if(num_cells(best_new_partition)>max_cells) {
|
|
if (verbosity>0) cat(paste("Grid > Fitting > split ", split_i, ": Finished. Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ". Best split has results in too many cells\n"))
|
|
break
|
|
}
|
|
partition = best_new_partition
|
|
factors_by_dim = best_new_factors_by_dim
|
|
if(!is.null(X_aux)) {
|
|
factors_by_dim_aux = best_new_factors_by_dim_aux
|
|
}
|
|
split_i = split_i + 1
|
|
seq_val[split_i] = best_new_val
|
|
partition_seq[[split_i]] = partition
|
|
split_seq[[split_i-1]] = best_new_split
|
|
if(best_new_split[[1]] %in% partition$dim_cat) {
|
|
k = best_new_split[[1]]
|
|
windows = get_windows_cat(partition$s_by_dim[[k]], X_range[[k]])
|
|
nwindows = length(windows)
|
|
v_breaks = vector("list", length=nwindows)
|
|
for(window_i in seq_len(nwindows)) {
|
|
v_breaks[[window_i]] = rep(TRUE, n_cat_window_splits(length(windows[[window_i]])))
|
|
}
|
|
valid_breaks[[k]] = v_breaks
|
|
}
|
|
if (verbosity>0) {
|
|
cat(paste("Grid > Fitting > split ", split_i, ": Finished.",
|
|
" Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ".",
|
|
" New split: k=", best_new_split[[1]], ", cut=", best_new_split[[2]], ", val=", best_new_val, "\n"))
|
|
}
|
|
}
|
|
if (verbosity>0) {
|
|
tn = Sys.time()
|
|
td = tn-t0
|
|
cat("Grid > Fitting: Finished.")
|
|
cat(paste(" Entire Search Duration: ", format(as.numeric(td)), " ", attr(td, "units"), "\n"))
|
|
}
|
|
return(list(partition_seq=partition_seq, is_obj_val_seq=seq_val, split_seq=split_seq))
|
|
}
|
|
|
|
#Allows two lists or two datasets
|
|
add_samples <- function (X, X_aux, M_mult) {
|
|
if(is.null(X_aux)) return(X)
|
|
if(M_mult) return(c(X, X_aux))
|
|
return(list(X, X_aux))
|
|
}
|
|
|
|
|
|
fit_partition_bump_b <- function(b, samples, y, X_d, d=NULL, m_mode, X_aux, d_aux, verbosity, nsplits_k_warn_limit=NA, ...){
|
|
if(verbosity>0) cat(paste("Grid > Bumping > b = ", b, "\n"))
|
|
sample = samples[[b]]
|
|
list[y_b, X_b, d_b] = subsample_m(y, X_d, d, sample)
|
|
X_aux2 = add_samples(X_d, X_aux, is_sep_sample(X_d))
|
|
d_aux2 = add_samples(d, d_aux, is_sep_sample(X_d))
|
|
fit_partition_full(y=y_b, X=X_b, d=d_b, X_aux=X_aux2, d_aux=d_aux2, verbosity=verbosity, nsplits_k_warn_limit=NA, ...)
|
|
}
|
|
|
|
# These are bump wrappers
|
|
get_part_for_lambda <- function(obj, lambda, is_bumped=FALSE) {
|
|
if(is_bumped) {
|
|
is_obj_val_seq = unlist(lapply(obj, function(f) f$is_obj_val_seq))
|
|
complexity_seq = unlist(lapply(obj, function(f) sapply(f$partition_seq, num_cells) - 1))
|
|
partition_seq = unlist(lapply(obj, function(f) f$partition_seq ), recursive = FALSE)
|
|
}
|
|
else {
|
|
is_obj_val_seq = obj$is_obj_val_seq
|
|
complexity_seq = sapply(obj$partition_seq, num_cells) - 1
|
|
partition_seq = obj$partition_seq
|
|
}
|
|
partition_i = which.min(is_obj_val_seq + lambda*complexity_seq)
|
|
return(list(partition_i, partition_seq[[partition_i]]))
|
|
}
|
|
get_num_parts <- function(cvtr_fit, is_bumped=FALSE) {
|
|
if(is_bumped)
|
|
return(sum(sapply(cvtr_fit, function(part) length(part$is_obj_val_seq))))
|
|
return(length(cvtr_fit$is_obj_val_seq))
|
|
}
|
|
|
|
get_all_lambda_ties <- function(cvtr_fit, is_bumped=FALSE) {
|
|
if(is_bumped) {
|
|
return(unlist(lapply(cvtr_fit, function(f) get_lambda_ties(f$is_obj_val_seq, sapply(f$partition_seq, num_cells) - 1))))
|
|
}
|
|
return(get_lambda_ties(cvtr_fit$is_obj_val_seq, sapply(cvtr_fit$partition_seq, num_cells) - 1))
|
|
}
|
|
|
|
# ... params sent to fit_partition_full()
|
|
cv_pick_lambda_f <- function(f, y, X_d, d, folds_ret, nfolds, potential_lambdas, N_est,
|
|
verbosity, obj_fn, cv_tr_min_size, est_plan, cv_bump_samples, bump_ratio,
|
|
nsplits_k_warn_limit=NA, min_size_aux=1, allow_empty_aux=TRUE, allow_est_errors_aux=FALSE, recal_is_obj_b=TRUE, ...) { #catch some of the params that might still be in ...
|
|
if(verbosity>0) cat(paste("Grid > CV > Fold", f, "\n"))
|
|
supplied_lambda = !is.null(potential_lambdas)
|
|
if(supplied_lambda) n_lambda = length(potential_lambdas)
|
|
|
|
list[y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv] = split_sample_folds_m(y, X_d, d, folds_ret, f)
|
|
|
|
do_bump = (length(cv_bump_samples)>1 || cv_bump_samples>0)
|
|
cvtr_fit = fit_partition_full(y_f_tr, X_f_tr, d_f_tr, X_f_cv, d_f_cv,
|
|
min_size=cv_tr_min_size, verbosity=verbosity,
|
|
N_est=N_est, obj_fn=obj_fn, allow_empty_aux=TRUE,
|
|
allow_est_errors_aux=FALSE, min_size_aux=1, est_plan=est_plan,
|
|
nsplits_k_warn_limit=NA, ...) #min_size_aux is weaker than removing est errors
|
|
|
|
if(do_bump) {
|
|
if(length(cv_bump_samples)>1) cv_bump_samples = cv_bump_samples[[f]]
|
|
list[M, m_mode, N_tr, K] = get_sample_type(y_f_tr, X_f_tr, d_f_tr, checks=FALSE)
|
|
cvtr_fit_bumps = gen_bumped_partitions(bump_samples=cv_bump_samples, bump_ratio, N_tr, m_mode, verbosity=verbosity, pr_cl=NULL,
|
|
min_size=cv_tr_min_size*bump_ratio,
|
|
y=y_f_tr, X_d=X_f_tr, d=d_f_tr, X_aux=X_f_cv, d_aux=d_f_cv,
|
|
N_est=N_est, obj_fn=obj_fn,
|
|
min_size_aux=1, est_plan=est_plan, nsplits_k_warn_limit=NA,
|
|
allow_empty_aux=allow_empty_aux, allow_est_errors_aux=allow_est_errors_aux,
|
|
...)
|
|
if(recal_is_obj_b) {
|
|
#Use the updated values on the unbumped sample
|
|
for(b in 1:length(cvtr_fit_bumps)) {
|
|
#partition_seq=partition_seq, is_obj_val_seq
|
|
cvtr_fit_bumps[[b]]$is_obj_val_seq = sapply(cvtr_fit_bumps[[b]]$partition_seq, function(p){
|
|
obj_fn(y_f_tr, X_f_tr, d_f_tr, partition=p, est_plan=est_plan, sample="trtr")[1]
|
|
})
|
|
}
|
|
}
|
|
cvtr_fit = c(list(cvtr_fit), cvtr_fit_bumps)
|
|
}
|
|
|
|
if(!supplied_lambda) {
|
|
return(cvtr_fit)
|
|
}
|
|
#If we know the lambdas, eval data while we have it
|
|
return(eval_lambdas(obj_fn, est_plan, potential_lambdas, cvtr_fit, y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est, do_bump))
|
|
}
|
|
|
|
|
|
eval_lambdas <- function(obj_fn, est_plan, potential_lambdas, cvtr_fit, y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est, is_bumped) {
|
|
partition_oos_cache = rep(NA, get_num_parts(cvtr_fit, is_bumped))
|
|
n_lambda = length(potential_lambdas)
|
|
|
|
lambda_oos = rep(NA, n_lambda)
|
|
for(lambda_i in seq_len(n_lambda)) {
|
|
lambda = potential_lambdas[lambda_i]
|
|
list[partition_i, part] = get_part_for_lambda(cvtr_fit, lambda, is_bumped)
|
|
if(is.na(partition_oos_cache[partition_i])) {
|
|
debug = FALSE
|
|
if(debug) cat(paste("s_by_dim", paste(part$s_by_dim, collapse=" "), "\n"))
|
|
obj_ret = obj_fn(y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est=N_est, partition=part, debug=debug,
|
|
est_plan=est_plan, sample="trcv")
|
|
oos_obj_val = obj_ret[1]
|
|
partition_oos_cache[partition_i] = oos_obj_val
|
|
}
|
|
lambda_oos[lambda_i] = partition_oos_cache[partition_i]
|
|
}
|
|
return(lambda_oos)
|
|
}
|
|
|
|
lambda_1se_selector <- function(potential_lambdas, lambda_oos, min_obs_1se, max_oos_err_allowed, verbosity) {
|
|
n_lambda = ncol(lambda_oos)
|
|
nfolds = nrow(lambda_oos)
|
|
lambda_oos_means = colMeans(lambda_oos)
|
|
lambda_oos_min_i = which.min(lambda_oos_means)
|
|
#lambda_oos_sd = apply(X=lambda_oos, MARGIN=2, FUN=sd) #don't need all for now
|
|
obs = lambda_oos[, lambda_oos_min_i]
|
|
if(min_obs_1se>nfolds) {
|
|
for(delta in 1:min(n_lambda-lambda_oos_min_i, lambda_oos_min_i-1)) {
|
|
if(lambda_oos_min_i+delta<=n_lambda) {
|
|
obs = c(obs, lambda_oos[, lambda_oos_min_i+delta])
|
|
}
|
|
if(lambda_oos_min_i-delta>=1) {
|
|
obs = c(obs, lambda_oos[, lambda_oos_min_i-delta])
|
|
}
|
|
if(length(obs)>=min_obs_1se) break
|
|
}
|
|
print(obs)
|
|
}
|
|
max_oos_err_allowed = min(lambda_oos_means) + sd(obs)
|
|
if(verbosity>0) cat(paste("max_oos_err_allowed:", paste(max_oos_err_allowed, collapse=" "), "\n"))
|
|
lambda_star_i = min(which(lambda_oos_means <= max_oos_err_allowed))
|
|
lambda_star = potential_lambdas[lambda_star_i]
|
|
|
|
return(lambda_star)
|
|
}
|
|
|
|
# cv_tr_min_size: We don't want this too large as (since we have less data) otherwise we might not find the
|
|
# MSE-min lambda and if the most detailed partition on the full data is best we might have a
|
|
# lambda too large and choose one coarser. could choose 2
|
|
# On the other hand the types of partitions we generate when this param is too small will be different
|
|
# and incomparable. Could choose (nfolds-2)/nfolds
|
|
# Therefore I take the average of the above two approaches.
|
|
# Used to warn if best lamda was the smallest, but since there's not much to do about it (we already scale
|
|
# cv_tr_min_size), stopped reporting
|
|
# Note: We do not want to first fit the full grid and then take potential lambdas as one from each segment that
|
|
# picks another grid. Those lambdas aren't gauranteed to include the true lambda min. We basically
|
|
# roughly sampling the true CV lambda function (which is is a step-function) and we might miss it and
|
|
# wrongly evaluate the benefit of each subgrid and therefore pick the wrong one.
|
|
# ... params sent to cv_pick_lambda_f
|
|
# use cv_obj_fn if you want to a different obj fun for cv eval (rather than tr,tr training)
|
|
cv_pick_lambda <- function(y, X, d, folds_ret, nfolds, potential_lambdas=NULL, N_est=NA, min_size=5, verbosity=0, lambda_1se=FALSE,
|
|
min_obs_1se=5, obj_fn, cv_tr_min_size=NA, est_plan, pr_cl=NULL, cv_bump_samples=0, bump_ratio=1, cv_obj_fn=NULL, ...) {
|
|
#If potential_lambdas is NULL, then only have to iterate through lambda values that change partition_i (for any fold)
|
|
#If is_obj_val_seq is monotonic then this is easy and can do sequentially, but not sure if this is the case
|
|
supplied_lambda = !is.null(potential_lambdas)
|
|
if(supplied_lambda) {
|
|
n_lambda = length(potential_lambdas)
|
|
lambda_oos = matrix(NA, nrow=nfolds, ncol=n_lambda)
|
|
}
|
|
else {
|
|
lambda_ties = list()
|
|
cvtr_fits = list()
|
|
}
|
|
if(verbosity>0) cat("Grid > CV: Started.\n")
|
|
if(is.na(cv_tr_min_size)) cv_tr_min_size = as.integer(ifelse(nfolds==2, (2+min_size/2)/2, (nfolds-2)/nfolds)*min_size)
|
|
|
|
params = c(list(y=y, X_d=X, d=d, folds_ret=folds_ret, nfolds=nfolds, potential_lambdas=potential_lambdas,
|
|
N_est=N_est, verbosity=verbosity-1,
|
|
obj_fn=obj_fn, cv_tr_min_size=cv_tr_min_size,
|
|
est_plan=est_plan, cv_bump_samples=cv_bump_samples, bump_ratio=bump_ratio), list(...))
|
|
|
|
col_rets = my_apply(1:nfolds, cv_pick_lambda_f, verbosity==1 || !is.null(pr_cl), pr_cl, params)
|
|
do_bump = (length(cv_bump_samples)>1 || cv_bump_samples>0)
|
|
|
|
if(is.null(cv_obj_fn)) cv_obj_fn = obj_fn
|
|
|
|
# Process nfolds loop
|
|
if(!supplied_lambda) {
|
|
for(f in 1:nfolds) {
|
|
cvtr_fits[[f]] = col_rets[[f]]
|
|
lambda_ties[[f]] = get_all_lambda_ties(cvtr_fits[[f]], do_bump) #build lambdas. Assuming no slope ties
|
|
}
|
|
}
|
|
else {
|
|
for(f in 1:nfolds) {
|
|
lambda_oos[f, ] = col_rets[[f]]
|
|
}
|
|
n_cell_table = NULL
|
|
}
|
|
|
|
if(!supplied_lambda) {
|
|
union_lambda_ties = sort(unlist(lambda_ties), decreasing=TRUE)
|
|
mid_points = union_lambda_ties[-length(union_lambda_ties)] + diff(union_lambda_ties)/2
|
|
potential_lambdas = c(union_lambda_ties[1]+1, mid_points, mid_points[length(mid_points)]/2)
|
|
if(length(potential_lambdas)==0) {
|
|
if(verbosity>0) cat("Note: CV folds consistently picked initial model (complexity didn't improve in-sample objective). Defaulting to lambda=0.\n")
|
|
potential_lambdas=c(0)
|
|
}
|
|
n_lambda = length(potential_lambdas)
|
|
lambda_oos = matrix(NA, nrow=nfolds, ncol=n_lambda)
|
|
n_cell_table = matrix(NA, nrow=nfolds, ncol=n_lambda)
|
|
|
|
for(f in 1:nfolds) {
|
|
list[y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv] = split_sample_folds_m(y, X, d, folds_ret, f)
|
|
lambda_oos[f,] = eval_lambdas(cv_obj_fn, est_plan, potential_lambdas, cvtr_fits[[f]], y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est, do_bump)
|
|
if(FALSE) {
|
|
#ns = get_num_parts(cvtr_fits[[f]])
|
|
|
|
for(lambda_i in 1:length(potential_lambdas)) {
|
|
lambda = potential_lambdas[lambda_i]
|
|
list[partition_i, part] = get_part_for_lambda(cvtr_fits[[f]], lambda)
|
|
n_cell_table[f, lambda_i] = num_cells(part)
|
|
#print(paste("fit=good. lambda: ", lambda))
|
|
#good_obj_ret = cv_obj_fn(y_f_tr, X_f_tr, d_f_tr, y_f_cv, X_f_cv, d_f_cv, N_est=N_est, partition=part, debug=TRUE,
|
|
# est_plan=est_plan, sample="trcv")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
lambda_oos_means = colMeans(lambda_oos)
|
|
lambda_oos_min_i = which.min(lambda_oos_means)
|
|
#if(lambda_oos_min_i==length(lambda_oos_means)) cat(paste("Warning: MSE-min lambda is the smallest (of",length(lambda_oos_means),"potential lambdas)\n"))
|
|
if(lambda_1se) {
|
|
lambda_star = lambda_1se_selector(potential_lambdas, lambda_oos, min_obs_1se, max_oos_err_allowed, verbosity)
|
|
}
|
|
else {
|
|
lambda_star = potential_lambdas[lambda_oos_min_i]
|
|
}
|
|
if(verbosity>0){
|
|
cat("Grid > CV: Finished.")
|
|
cat(paste(" lambda_oos_means=[", paste(lambda_oos_means, collapse=" "), "]."))
|
|
if(length(lambda_star)==0) cat(" Couldn't find any suitable lambdas, returning 1.\n")
|
|
else {
|
|
cat(paste(" potential_lambdas=[", paste(potential_lambdas, collapse=" "), "]."))
|
|
cat(paste(" lambda_star=", lambda_star, ".\n"))
|
|
}
|
|
}
|
|
if(length(lambda_star)==0) {
|
|
lambda_star=1
|
|
}
|
|
|
|
return(list(lambda_star,lambda_oos, n_cell_table))
|
|
}
|