From 5647912dd2e94f6730f059ca986b0721756722c6 Mon Sep 17 00:00:00 2001 From: Brian Quistorff Date: Mon, 14 Dec 2020 14:54:36 -0500 Subject: [PATCH] v0.2 --- .Rbuildignore | 11 + .Rprofile | 1 + .gitignore | 8 + .lintr | 7 + DESCRIPTION | 28 + NAMESPACE | 59 + R/.gitignore | 2 + R/CausalGrid.R | 58 + R/Estimator_plans.R | 482 +++++++ R/RcppExports.R | 7 + R/fit_estimate.R | 790 +++++++++++ R/graphing.R | 32 + R/grid_partition.R | 1259 ++++++++++++++++++ R/utils.R | 371 ++++++ SUPPORT.md | 7 +- SubgroupAnalysis.Rproj | 18 + archive/utils.R | 54 + man/CausalGrid.Rd | 10 + man/Do_Residualize.Rd | 25 + man/Do_Residualize.grid_rf.Rd | 25 + man/Do_Residualize.lm_X_est.Rd | 25 + man/Do_Residualize.simple_est.Rd | 25 + man/Fit_InitTr.Rd | 37 + man/Fit_InitTr.grid_rf.Rd | 41 + man/Fit_InitTr.lm_X_est.Rd | 37 + man/Fit_InitTr.simple_est.Rd | 37 + man/Param_Est.Rd | 27 + man/Param_Est.grid_rf.Rd | 27 + man/Param_Est.lm_X_est.Rd | 27 + man/Param_Est.simple_est.Rd | 27 + man/any_sign_effect.Rd | 35 + man/change_complexity.Rd | 27 + man/est_full_stats.Rd | 43 + man/estimate_cell_stats.Rd | 45 + man/fit_estimate_partition.Rd | 112 ++ man/fit_partition.Rd | 89 ++ man/get_X_range.Rd | 17 + man/get_desc_df.estimated_partition.Rd | 31 + man/get_desc_df.grid_partition.Rd | 34 + man/get_factor_from_partition.Rd | 21 + man/grid_rf.Rd | 23 + man/is.estimated_partition.Rd | 17 + man/is.grid_partition.Rd | 17 + man/is.grid_rf.Rd | 17 + man/is.lm_X_est.Rd | 17 + man/is.simple_est.Rd | 17 + man/num_cells.Rd | 17 + man/num_cells.estimated_partition.Rd | 17 + man/num_cells.grid_partition.Rd | 17 + man/plot_2D_partition.estimated_partition.Rd | 19 + man/predict_te.estimated_partition.Rd | 19 + man/print.estimated_partition.Rd | 34 + man/print.grid_partition.Rd | 25 + man/print.partition_split.Rd | 19 + man/quantile_breaks.Rd | 24 + project/.gitignore | 1 + project/ct_utils.R | 175 +++ project/sims.R | 443 ++++++ renv.lock | 953 +++++++++++++ renv/.gitignore | 4 + renv/activate.R | 349 +++++ src/CausalGrid.cpp | 12 + src/RcppExports.cpp | 28 + tests/dgps.R | 96 ++ tests/testthat.R | 4 + tests/testthat/test_multi.R | 82 ++ tests/testthat/testres.R | 64 + tests/testthat/testrun.R | 146 ++ vignettes/.gitignore | 2 + vignettes/vignette.Rmd | 33 + writeups/oneline algo.lyx | 236 ++++ 71 files changed, 6942 insertions(+), 3 deletions(-) create mode 100644 .Rbuildignore create mode 100644 .Rprofile create mode 100644 .lintr create mode 100644 DESCRIPTION create mode 100644 NAMESPACE create mode 100644 R/.gitignore create mode 100644 R/CausalGrid.R create mode 100644 R/Estimator_plans.R create mode 100644 R/RcppExports.R create mode 100644 R/fit_estimate.R create mode 100644 R/graphing.R create mode 100644 R/grid_partition.R create mode 100644 R/utils.R create mode 100644 SubgroupAnalysis.Rproj create mode 100644 archive/utils.R create mode 100644 man/CausalGrid.Rd create mode 100644 man/Do_Residualize.Rd create mode 100644 man/Do_Residualize.grid_rf.Rd create mode 100644 man/Do_Residualize.lm_X_est.Rd create mode 100644 man/Do_Residualize.simple_est.Rd create mode 100644 man/Fit_InitTr.Rd create mode 100644 man/Fit_InitTr.grid_rf.Rd create mode 100644 man/Fit_InitTr.lm_X_est.Rd create mode 100644 man/Fit_InitTr.simple_est.Rd create mode 100644 man/Param_Est.Rd create mode 100644 man/Param_Est.grid_rf.Rd create mode 100644 man/Param_Est.lm_X_est.Rd create mode 100644 man/Param_Est.simple_est.Rd create mode 100644 man/any_sign_effect.Rd create mode 100644 man/change_complexity.Rd create mode 100644 man/est_full_stats.Rd create mode 100644 man/estimate_cell_stats.Rd create mode 100644 man/fit_estimate_partition.Rd create mode 100644 man/fit_partition.Rd create mode 100644 man/get_X_range.Rd create mode 100644 man/get_desc_df.estimated_partition.Rd create mode 100644 man/get_desc_df.grid_partition.Rd create mode 100644 man/get_factor_from_partition.Rd create mode 100644 man/grid_rf.Rd create mode 100644 man/is.estimated_partition.Rd create mode 100644 man/is.grid_partition.Rd create mode 100644 man/is.grid_rf.Rd create mode 100644 man/is.lm_X_est.Rd create mode 100644 man/is.simple_est.Rd create mode 100644 man/num_cells.Rd create mode 100644 man/num_cells.estimated_partition.Rd create mode 100644 man/num_cells.grid_partition.Rd create mode 100644 man/plot_2D_partition.estimated_partition.Rd create mode 100644 man/predict_te.estimated_partition.Rd create mode 100644 man/print.estimated_partition.Rd create mode 100644 man/print.grid_partition.Rd create mode 100644 man/print.partition_split.Rd create mode 100644 man/quantile_breaks.Rd create mode 100644 project/.gitignore create mode 100644 project/ct_utils.R create mode 100644 project/sims.R create mode 100644 renv.lock create mode 100644 renv/.gitignore create mode 100644 renv/activate.R create mode 100644 src/CausalGrid.cpp create mode 100644 src/RcppExports.cpp create mode 100644 tests/dgps.R create mode 100644 tests/testthat.R create mode 100644 tests/testthat/test_multi.R create mode 100644 tests/testthat/testres.R create mode 100644 tests/testthat/testrun.R create mode 100644 vignettes/.gitignore create mode 100644 vignettes/vignette.Rmd create mode 100644 writeups/oneline algo.lyx diff --git a/.Rbuildignore b/.Rbuildignore new file mode 100644 index 0000000..dfc6996 --- /dev/null +++ b/.Rbuildignore @@ -0,0 +1,11 @@ +^renv$ +^renv\.lock$ +^CausalGrid\.Rproj$ +^\.Rproj\.user$ +^\.vs +^archive +^writeups +^project +^tests/sim.RData$ +^\.lintr$ +^dev_notes\.md diff --git a/.Rprofile b/.Rprofile new file mode 100644 index 0000000..81b960f --- /dev/null +++ b/.Rprofile @@ -0,0 +1 @@ +source("renv/activate.R") diff --git a/.gitignore b/.gitignore index fae8299..7c57783 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,11 @@ +#Sims stuff +project/sim.RData +project/log.txt + +#Cpp builds +src/*.o +src/*.dll + # History files .Rhistory .Rapp.history diff --git a/.lintr b/.lintr new file mode 100644 index 0000000..3a7e9be --- /dev/null +++ b/.lintr @@ -0,0 +1,7 @@ +linters: with_defaults(line_length_linter(120), + assignment_linter = NULL, + spaces_left_parentheses_linter=NULL, + object_name_linter=NULL, + commented_code_linter=NULL, + infix_spaces_linter=NULL, + trailing_whitespace_linter=NULL) diff --git a/DESCRIPTION b/DESCRIPTION new file mode 100644 index 0000000..64b8e1d --- /dev/null +++ b/DESCRIPTION @@ -0,0 +1,28 @@ +Package: CausalGrid +Title: Analysis of Subgroups +Version: 0.2 +Authors@R: person("Brian", "Quistorff", email = "Brian.Quistorff@microsoft.com", + role = c("aut", "cre")) +Description: Analysis of Subgroups. +Depends: R (>= 3.1.0), + caret, + gsubfn, + assertthat +License: MIT +LazyData: true +RoxygenNote: 7.1.1 +Suggests: + ggplot2, + glmnet, + gglasso, + ranger, + parallel, + pbapply, + testthat, + knitr, + rmarkdown +BuildVignettes: false +Imports: Rcpp (>= 1.0.1) +LinkingTo: Rcpp +Encoding: UTF-8 +VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE new file mode 100644 index 0000000..16e8b22 --- /dev/null +++ b/NAMESPACE @@ -0,0 +1,59 @@ +# Generated by roxygen2: do not edit by hand + +S3method(Do_Residualize,grid_rf) +S3method(Do_Residualize,lm_X_est) +S3method(Do_Residualize,simple_est) +S3method(Fit_InitTr,grid_rf) +S3method(Fit_InitTr,lm_X_est) +S3method(Fit_InitTr,simple_est) +S3method(Param_Est,grid_rf) +S3method(Param_Est,lm_X_est) +S3method(Param_Est,simple_est) +S3method(num_cells,estimated_partition) +S3method(num_cells,grid_partition) +S3method(print,estimated_partition) +S3method(print,grid_partition) +S3method(print,partition_split) +export(Do_Residualize) +export(Fit_InitTr) +export(Param_Est) +export(any_sign_effect) +export(change_complexity) +export(est_full_stats) +export(estimate_cell_stats) +export(fit_estimate_partition) +export(fit_partition) +export(get_X_range) +export(get_desc_df.estimated_partition) +export(get_desc_df.grid_partition) +export(get_factor_from_partition) +export(grid_rf) +export(is.estimated_partition) +export(is.grid_partition) +export(is.grid_rf) +export(is.lm_X_est) +export(is.simple_est) +export(num_cells) +export(plot_2D_partition.estimated_partition) +export(predict_te.estimated_partition) +export(quantile_breaks) +import(Rcpp) +import(assertthat) +import(caret) +import(gsubfn) +importFrom(Rcpp,evalCpp) +importFrom(stats,coef) +importFrom(stats,formula) +importFrom(stats,lm) +importFrom(stats,model.matrix) +importFrom(stats,p.adjust) +importFrom(stats,predict) +importFrom(stats,pt) +importFrom(stats,qt) +importFrom(stats,quantile) +importFrom(stats,rnorm) +importFrom(stats,sd) +importFrom(stats,var) +importFrom(stats,vcov) +importFrom(utils,combn) +useDynLib(CausalGrid, .registration = TRUE) diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 0000000..d48adf2 --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,2 @@ +*.html +*.Rdata diff --git a/R/CausalGrid.R b/R/CausalGrid.R new file mode 100644 index 0000000..7b2bb04 --- /dev/null +++ b/R/CausalGrid.R @@ -0,0 +1,58 @@ +#' CausalGrid: A package for subgroup effects +#' +#' Intervals are (a,b], and [a,b] for the lowest. A split at x means <= and > +#' We randomize in generating train/est and trtr/trcv splits. Possibly cv.glmnet and cv.gglasso as well. +#' +#' +#' @useDynLib CausalGrid, .registration = TRUE +#' @importFrom Rcpp evalCpp +#' @importFrom stats coef formula lm model.matrix p.adjust pt qt quantile sd vcov var predict rnorm +#' @importFrom utils combn +#' @import caret +#' @import gsubfn +#' @import Rcpp +#' @import assertthat +#' @docType package +#' @name CausalGrid +NULL +#> NULL + + +#TODO: +# Correctness: +# - Ensure case where estimation might not have any observations in a cell +# Cleanup: +# - Encapsulate valid_partition() + bucket-splits with est_plan (deal with what happens with est error). + Doc. +# - Styler and lintr; https://style.tidyverse.org/ +# - cleanup the _m functions (with their bare counterparts) +# Functionality: +# - Allow for picking paritition with # cells closest to an 'ideal' number +# - Allow for integer types with range <=g to pick those values rather than the quantiles. +# - Allow initial splits to be pre-determined. +# - Cleanup Predict function and allow y_hat (how do other packages distinguish y_hat from d_hat?). Allow mult. te +# - Like GRF, When considering a split, require that each child node have min.node.size samples with treatment value +# less than the average, and at least that many samples with treatment value greater than or equal to the average. +# - summary method? +# - Warn if CV picks end-point (and maybe reduce minsize of cv_tr to 2 in the case we need more complex) +# - Check that tr+est and trtr's each had all the categories in sufficient quantity +# - Provide a double-selection version of lm_X_est (think through multiple D case) +# - update importance weights in change_complexity +# - Provide partial dependency functions/graphs +# - graphs: Show a rug (indicators for data points on the x-axis) or a histogram along axis. +# Usability: +# - msg for assertions +# - Have nicer factor labels (especially if split at bottom point, make [T,T] rather than [T,T+1], and redo top +# (to not have -1)) +# - ?? switch to have min_size apply only to train_folds*splits rather than smallest (test) fold * splits. User can +# work around with math. +# Performance: Low-priority as doing pretty well so far +# - For each dim, save stats for each existing section and only update the stats for the section being split. +# - Additionally, Use incremental update formulas to recompute the metrics after moving split point over a +# few observations (for affected cells). Can look at (fromo)[https://cran.r-project.org/web/packages/fromo/] and +# (twextras)[https://github.com/twolodzko/twextras/blob/master/R/cumxxx.R] +# - Pre-compute the allowable range for new splits in each slice given the cell min size. +# - see if can swap findInterval for cut() (do I need the labels) +# Checks: Check all user input types of exported functions +# Tests: More! +# R check (currently ignoring): License, top-level dev_notes.md, checking dependencies in R code, +# Undefined global functions or variables, tests \ No newline at end of file diff --git a/R/Estimator_plans.R b/R/Estimator_plans.R new file mode 100644 index 0000000..82c1bbc --- /dev/null +++ b/R/Estimator_plans.R @@ -0,0 +1,482 @@ + +# Estimator Fns ----------- +cont_te_estimator <- function(y, d, ...) { + if(is_vec(d)) { + # Straight formulas is much faster than OLS + #formula reference: http://cameron.econ.ucdavis.edu/e240a/reviewbivariate.pdf + y_avg = mean(y) + d_avg = mean(d) + d_demean = d-d_avg + sum_d_dev = sum(d_demean^2) + param_est = sum((y-y_avg)*d_demean)/sum_d_dev + } + else { + ols_fit = lm(y~d) + param_est = coef(ols_fit)[-1] + } + return(list(param_est=param_est)) +} + +cont_te_var_estimator <- function(y, d, ...) { + if(is_vec(d)) { + # Straight formulas is much faster than OLS + #formula reference: http://cameron.econ.ucdavis.edu/e240a/reviewbivariate.pdf + y_avg = mean(y) + d_avg = mean(d) + d_demean = d-d_avg + sum_d_dev = sum(d_demean^2) + param_est = sum((y-y_avg)*d_demean)/sum_d_dev + b0 = y_avg - param_est*d_avg + y_hat = b0+param_est*d + err = y - y_hat + var_est = (sum(err^2)/(length(y)-2))/sum_d_dev + } + else { + if(length(y)==0) { + print("Ahh") + } + ols_fit = lm(y~d) + param_est = coef(ols_fit)[-1] + var_est = diag(vcov(lm(y~d)))[-1] + } + return(list(param_est=param_est, var_est=var_est)) +} + +#Handles removing factors with only 1 level +robust_lm_d <- function(y, d, X, ctrl_names) { + ctrl_str = if(length(ctrl_names)>0) paste0("+", paste(ctrl_names, collapse="+")) else "" + tryCatch(ols_fit <- lm(formula(paste0("y~d", ctrl_str)), data=as.data.frame(X)), + error=function(e) { + ctrl_names2 <- ctrl_names[sapply(ctrl_names, function(ctrl_name){length(unique(X[, ctrl_name]))}) > 1] + ctrl_str2 <- if(length(ctrl_names2)>0) paste0("+", paste(ctrl_names2, collapse="+")) else "" + ols_fit <<- lm(formula(paste0("y~d", ctrl_str2)), data=as.data.frame(X)) + }) + return(ols_fit) +} + +cont_te_X_estimator <- function(y, d, X, ctrl_names) { + d_ncols = if(is_vec(d)) 1 else ncol(d) + ols_fit = robust_lm_d(y, d, X, ctrl_names) + param_est=coef(ols_fit)[2:(1+d_ncols)] + return(list(param_est=param_est)) +} + + + +cont_te_var_X_estimator <- function(y, d, X, ctrl_names) { + d_ncols = if(is_vec(d)) 1 else ncol(d) + ols_fit = robust_lm_d(y, d, X, ctrl_names) + param_est=coef(ols_fit)[2:(1+d_ncols)] + var_est=diag(vcov(ols_fit))[2:(1+d_ncols)] + return(list(param_est=param_est, var_est=var_est)) +} + +lcl_colMeans <- function(y) { + if(is.list(y)) #list of dataframe + return(sapply(y, mean)) + if(is_vec(y)) #vector + return(mean(y)) + #matrix + return(colMeans(y)) +} + +lcl_colVars_est <- function(y) { + if(is.list(y)) #list of dataframe + return(sapply(y, function(c) var(c)/(length(c)-1))) + if(is_vec(y)) #vector + return(var(y)/(length(y)-1)) + #matrix + return(apply(y, 2, function(c) var(c)/(length(c)-1))) +} + +mean_var_estimator <- function(y, ...) { + #int_str = "(Intercept)" #"const" + #ols_fit <- lm(y~1) + #param_est=coef(ols_fit)[int_str] + #var_est=vcov(ols_fit)[int_str, int_str] + # The below is much faster + + return(list(param_est=lcl_colMeans(y), var_est=lcl_colVars_est(y))) +} + +mean_estimator <- function(y, ...) { + return(list(param_est=lcl_colMeans(y))) +} + + +# Generics --------------- + +#Aside from these generics, subclasses must have $dof scalar + +#' Fit_InitTr +#' +#' @param obj Object +#' @param X_tr X +#' @param y_tr y +#' @param d_tr d_tr +#' @param cv_folds CV folds +#' @param verbosity verbosity +#' @param dim_cat vector of dimensions that are categorical +#' +#' @return Updated Object +#' @export +Fit_InitTr <- function(obj, X_tr, y_tr, d_tr=NULL, cv_folds, verbosity=0, dim_cat=c()) { UseMethod("Fit_InitTr", obj)} + + +#' Do_Residualize +#' +#' @param obj Object +#' @param y y +#' @param X X +#' @param d d +#' @param d d (Default=NULL) +#' @param sample one of 'tr' or 'est' +#' +#' @return list(y=) or list(y=, d=) +#' @export +Do_Residualize <- function(obj, y, X, d, sample) { UseMethod("Do_Residualize", obj)} + +#' Param_Est +#' +#' @param obj Object +#' @param y y A N-vector +#' @param d d A N-vector or Nxm matrix (so that they can be estimated jointly) +#' @param X X A NxK matrix or data.frame +#' @param sample Sample: "trtr", "trcv", "est" +#' @param ret_var Return Variance in the return list +#' +#' @return list(param_est=...) +#' @export +Param_Est <- function(obj, y, d=NULL, X, sample="est", ret_var=FALSE) { UseMethod("Param_Est", obj)} + +# lm_X_est --------------- + +lm_X_est <- function(lasso=FALSE, control_est=TRUE) { + return(structure(list(lasso=lasso, control_est=control_est), class = c("Estimator_plan","lm_X_est"))) +} + +#' is.lm_X_est +#' +#' @param x Object +#' +#' @return Boolean +#' @export +is.lm_X_est <- function(x) {inherits(x, "lm_X_est")} + +dummyVar_common <- function(X, dim_cat) { + X_new = NULL + groups = c() + #Since regularizing, be careful about the reference class (so can't use dummyVars or one_hot easily) + for(k in 1:ncol(X)) { + n_cols=1 + X_k = X[[k]] + if(k %in% dim_cat) { + n_l = nlevels(X_k) + level_common = dimnames(sort(table(factor(X_k)), decreasing=T))[[1]][1] + level_common_int = match(level_common, levels(X_k)) + X_k = model.matrix(~1+C(X_k, contr.treatment(n_l, base=level_common_int)))[, 2:n_l] #col=1 is intercept + n_cols = ncol(X_k) + } + if(is.null(X_new)) X_new = X_k + else X_new = cbind(X_new, X_k) + groups = c(groups, rep(k, n_cols)) + } + return(list(X_new, groups)) +} + +lasso_select <- function(obj, X_tr, y_tr, cv_folds, verbosity, dim_cat) { + if(length(dim_cat)<1) { + if (!requireNamespace("glmnet", quietly = TRUE)) { + stop("Package \"glmnet\" needed for this function to work. Please install it.", + call. = FALSE) + } + if(is.data.frame(X_tr)) X_tr = as.matrix(X_tr) + if(length(cv_folds)==1) + lasso_fit = glmnet::cv.glmnet(X_tr, y_tr, nfolds=cv_folds) + else + lasso_fit = glmnet::cv.glmnet(X_tr, y_tr, foldid=cv_folds) + c = coef(lasso_fit, s = "lambda.min") + sel = c[2:length(c), ]!=0 + } + else { + list[X_new, groups] = dummyVar_common(X_tr, dim_cat) + if (!requireNamespace("gglasso", quietly = TRUE)) { + stop("Package \"gglasso\" needed for this function to work. Please install it.", + call. = FALSE) + } + if(length(cv_folds)==1) { + gg_fit = gglasso::cv.gglasso(X_new, y_tr, nfolds=cv_folds, loss="ls", groups) + } + else { + gg_fit = gglasso::cv.gglasso(X_new, y_tr, foldid=cv_folds, loss="ls", groups) + } + c = coef(gg_fit, s="lambda.min") + sel = sort(unique(groups[c[2:length(c)]!=0])) + } + if(verbosity>0) print(c) + return(colnames(X_tr)[sel]) +} + +#' Fit_InitTr.lm_X_est +#' +#' @param obj lm_X_est object +#' @param X_tr X_tr +#' @param y_tr y_tr +#' @param d_tr d_tr +#' @param cv_folds cv_folds +#' @param verbosity verbosity +#' @param dim_cat dim_cat +#' +#' @return Updated object +#' @export +#' @method Fit_InitTr lm_X_est +Fit_InitTr.lm_X_est <- function(obj, X_tr, y_tr, d_tr=NULL, cv_folds, verbosity=0, dim_cat=c()) { + assert_that(!is.null(d_tr)) + if(obj$lasso & length(dim_cat)<1) { + if (!requireNamespace("glmnet", quietly = TRUE)) { + stop("Package \"ranger\" needed for this function to work. Please install it.", call. = FALSE) + } + } + + list[M, m_mode, N, K] = get_sample_type(y_tr, X_tr, d_tr, checks=TRUE) + + if(m_mode==0 || m_mode==2) { + if(obj$lasso) + obj$ctrl_names = lasso_select(obj, X_tr, y_tr, cv_folds, verbosity, dim_cat) + else + obj$ctrl_names = colnames(X_tr) + obj$dof = 2+length(obj$ctrl_names) + } + else { + if(obj$lasso) { + if(m_mode==1) + obj$ctrl_names = mapply(function(X_s, y_s, cv_folds_s) lasso_select(obj, X_s, y_s, cv_folds_s, verbosity, dim_cat), X_tr, y_tr, cv_folds) + if(m_mode==3) + obj$ctrl_names = apply(y_tr, 2, function(y_col) lasso_select(obj, X_tr, y_col, cv_folds, verbosity, dim_cat)) + } + else { + obj$ctrl_names = rep(list(colnames(X_tr)), M) + } + obj$dof = 2+sapply(obj$ctrl_names, length) + } + if(verbosity>0) cat(paste("LassoCV-picked control variables: ", paste(obj$ctrl_names, collapse=" "), "\n")) + + return(obj) +} + +#' Do_Residualize.lm_X_est +#' +#' @param obj obj +#' @param y y +#' @param X X +#' @param d d +#' @param sample one of 'tr' or 'est' +#' +#' @return list(y=...) or list(y=..., d=...) +#' @export +#' @method Do_Residualize lm_X_est +Do_Residualize.lm_X_est <- function(obj, y, X, d, sample) {return(list(y=y, d=d))} + +#' Param_Est.lm_X_est +#' +#' @param obj obj +#' @param y y +#' @param d d +#' @param X X +#' @param sample Sample: "trtr", "trcv", "est" +#' @param ret_var Return variance in return list +#' +#' @return list(param_est=...) or list(param_est=..., var_est=...) +#' @export +#' @method Param_Est lm_X_est +Param_Est.lm_X_est <- function(obj, y, d=NULL, X, sample="est", ret_var=FALSE) { + assert_that(!is.null(d)) + + if(sample=="trtr" || (sample=="est" && obj$control_est)) { + if(ret_var) return(cont_te_var_X_estimator(y, d, X, obj$ctrl_names)) + return(cont_te_X_estimator(y, d, X, obj$ctrl_names)) + } + if(ret_var) return(cont_te_var_estimator(y, d, X)) + return(cont_te_estimator(y, d, X)) +} + + +# simple_est --------------- + +simple_est <- function(te_fn, te_var_fn, dof=2) { + return(structure(list(te_fn=te_fn, te_var_fn=te_var_fn, dof=dof), class = c("Estimator_plan", "simple_est"))) +} + +#' is.simple_est +#' +#' @param x Object +#' +#' @return Boolean +#' @export +is.simple_est <- function(x) {inherits(x, "simple_est")} + +gen_simple_est_plan <- function(has_d=TRUE) { + if(has_d) { + return(simple_est(cont_te_estimator, cont_te_var_estimator)) + } + return(simple_est(mean_estimator, mean_var_estimator, dof=1)) +} + +#' Fit_InitTr.simple_est +#' +#' @param obj obj +#' @param X_tr X_tr +#' @param y_tr y_tr +#' @param d_tr d_tr +#' @param cv_folds cv_folds +#' @param verbosity verbosity +#' @param dim_cat dim_cat +#' +#' @return Updated object +#' @export +#' @method Fit_InitTr simple_est +Fit_InitTr.simple_est <- function(obj, X_tr, y_tr, d_tr=NULL, cv_folds, verbosity=0, dim_cat=c()) {return(obj)} + +#' Do_Residualize.simple_est +#' +#' @param obj obj +#' @param y y +#' @param X X +#' @param d d +#' @param sample one of 'tr' or 'est' +#' +#' @return list(y=...) and list(y=..., d=...) +#' @export +#' @method Do_Residualize simple_est +Do_Residualize.simple_est <- function(obj, y, X, d, sample) {return(list(y=y, d=d))} + +#' Param_Est.simple_est +#' +#' @param obj obj +#' @param y y +#' @param d d +#' @param X X +#' @param sample Sample: "trtr", "trcv", "est" +#' @param ret_var Return variance in return list +#' +#' @return list(param_est=...) +#' @export +#' @method Param_Est simple_est +Param_Est.simple_est <- function(obj, y, d=NULL, X, sample="est", ret_var=FALSE) { + if(ret_var) return(obj$te_var_fn(y, d, X)) + return(obj$te_fn(y, d, X)) +} + +# grid_rf --------------- + +#' grid_rf +#' +#' @param num.trees number of trees in the random forest +#' @param num.threads num.threads +#' @param dof degrees-of-freedom +#' @param resid_est Residualize the Estimation sample (using fit from training) +#' +#' @return grid_rf object +#' @export +grid_rf <- function(num.trees=500, num.threads=NULL, dof=2, resid_est=TRUE) { + return(structure(list(num.trees=num.trees, num.threads=num.threads, dof=dof, resid_est=resid_est), + class = c("Estimator_plan","grid_rf"))) +} + +#' is.grid_rf +#' +#' @param x Object +#' +#' @return Boolean +#' @export +is.grid_rf <- function(x) {inherits(x, "grid_rf")} + +rf_fit_data <- function(obj, target, X) { + if(is_vec(target)) + return(ranger::ranger(y=target, x=X, + num.trees = obj$num.trees, num.threads = obj$num.threads)) + fits = list() + for(m in 1:ncol(target)){ + fits[[m]] = ranger::ranger(y=target[,m], x=X, + num.trees = obj$num.trees, num.threads = obj$num.threads) + } + return(fits) +} + +#' Fit_InitTr.grid_rf +#' Note that for large data, the rf_y_fit and potentially rf_d_fit objects may be large. +#' They can be null'ed out after fitting +#' +#' @param obj Object +#' @param X_tr X +#' @param y_tr y +#' @param d_tr d_tr +#' @param cv_folds CV folds +#' @param verbosity verbosity +#' @param dim_cat vector of dimensions that are categorical +#' +#' @return Updated Object +#' @export +#' @method Fit_InitTr grid_rf +Fit_InitTr.grid_rf <- function(obj, X_tr, y_tr, d_tr=NULL, cv_folds, verbosity=0, dim_cat=c()) { + assert_that(!is.null(d_tr)) #Only residualize when having treatment + if (!requireNamespace("ranger", quietly = TRUE)) { + stop("Package \"ranger\" needed for this function to work. Please install it.", call. = FALSE) + } + obj$rf_y_fit = rf_fit_data(obj, y_tr, X_tr) + + if(!is.null(d_tr)) { + obj$rf_d_fit = rf_fit_data(obj, d_tr, X_tr) + } + return(obj) +} + +rf_predict_data <- function(fit, target, X) { + if(is_vec(target)) + return(predict(fit, X, type="response")$predictions) + preds = matrix(NA, nrow=nrow(X), ncol=ncol(X)) + for(m in 1:ncol(target)){ + preds[,m] = predict(fit[[m]], X, type="response")$predictions + } + return(preds) +} + +#' Do_Residualize.grid_rf +#' +#' @param obj Object +#' @param y y +#' @param X X +#' @param d d (Default=NULL) +#' @param sample one of 'tr' or 'est' +#' +#' @return list(y=) or list(y=, d=) +#' @export +#' @method Do_Residualize grid_rf +Do_Residualize.grid_rf <- function(obj, y, X, d, sample) { + if(sample=="est" && !obj$resid_est) return(list(y=y, d=d)) + if (!requireNamespace("ranger", quietly = TRUE)) { + stop("Package \"ranger\" needed for this function to work. Please install it.", call. = FALSE) + } + y_res = y - rf_predict_data(obj$rf_y_fit, y, X) + d_res = if(is.null(d)) NULL else d - rf_predict_data(obj$rf_d_fit, d, X) + return(list(y=y_res, d=d_res)) +} + +#' Param_Est.grid_rf +#' +#' @param obj Object +#' @param y y +#' @param d d +#' @param X X +#' @param sample Sample: "trtr", "trcv", "est" +#' @param ret_var Return Variance in the return list +#' +#' @return list(param_est=...) +#' @export +#' @method Param_Est grid_rf +Param_Est.grid_rf <- function(obj, y, d=NULL, X, sample="est", ret_var=FALSE) { + assert_that(is.flag(ret_var), sample %in% c("est", "trtr", "trcv"), !is.null(d)) + + if(ret_var) return(cont_te_var_estimator(y, d, X)) + return(cont_te_estimator(y, d, X)) +} \ No newline at end of file diff --git a/R/RcppExports.R b/R/RcppExports.R new file mode 100644 index 0000000..99f63d3 --- /dev/null +++ b/R/RcppExports.R @@ -0,0 +1,7 @@ +# Generated by using Rcpp::compileAttributes() -> do not edit by hand +# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 + +const_vect <- function(var) { + .Call(`_CausalGrid_const_vect`, var) +} + diff --git a/R/fit_estimate.R b/R/fit_estimate.R new file mode 100644 index 0000000..71532bd --- /dev/null +++ b/R/fit_estimate.R @@ -0,0 +1,790 @@ +#Notes: +# - In order to work with both X as matrix and data.frame I used X[,k], but this is messed up +# with incoming Tibbles so convert those. + +# If empty and err cells will be removed from the calculation, but counts of these returned +# estimator_var: takes sample for cell and returns coefficient estimate and estimated variance of estimate +emse_hat_obj <-function(y, X , d, N_est, partition=NULL, cell_factor=NULL, estimator_var=NULL, debug=FALSE, + warn_on_error=FALSE, alpha=NULL, est_plan=NULL, sample="trtr") { + if(is.null(est_plan)) { + if(is.null(estimator_var)) est_plan = gen_simple_est_plan(has_d=!is.null(d)) + else est_plan = simple_est(estimator_var, estimator_var) + } + if(is.null(cell_factor)) { + cell_factor = get_factor_from_partition(partition, X) + } + if(!is.null(alpha)) { + stopifnot(alpha>=0 & alpha<=1) + } + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=FALSE) + list[lvls, n_cells] = lcl_levels(cell_factor) + cell_contribs1 = rep(0, n_cells) + cell_contribs2 = rep(0, n_cells) + N_eff = 0 + N_cell_empty = 0 + N_cell_err = 0 + for(cell_i in 1:n_cells) { + list[y_cell, d_cell, X_cell, N_l] <- get_cell(y, X, d, cell_factor, cell_i, lvls) + if(any(N_l==0)) { + N_cell_empty = N_cell_empty+1 + next + } + + list[param_est, var_est] = Param_Est_m(est_plan, y_cell, d_cell, X_cell, sample=sample, ret_var=TRUE, m_mode) + if(!all(is.finite(param_est)) || !all(is.finite(var_est))) { + N_cell_err = N_cell_err+1 + msg = paste("Failed estimation: (N_l=", N_l, ", param_est=", param_est, ", var_est=", var_est, + ifelse(!is.null(d), paste(", var_d=", var(d_cell)) , ""), + ")\n") + if(warn_on_error) warning(msg) + next + } + t1_mse = sum(N_l*param_est^2) + t2_var = sum(N_l*var_est) + N_eff = N_eff + sum(N_l) + cell_contribs1[cell_i] = t1_mse + cell_contribs2[cell_i] = t2_var + if(debug) print(paste("N=", N, "; cell_i=", cell_i, "; N_l=", N_l, "; param_est=", param_est, + "; var_est=", var_est)) + } + t1_mse = -1/N_eff*sum(cell_contribs1) + t2_var = (1/N_eff + 1/N_est)*sum(cell_contribs2) + val = if(is.null(alpha)) t1_mse + t2_var else alpha*t1_mse + (1-alpha)*t2_var + if(debug) print(paste("cell sums", val)) + if(!is.finite(val)) stop("Non-finite val") + return(c(val, N_cell_empty, N_cell_err)) +} + +get_cell <- function(y, X, d, cell_factor, cell_i, lvls) { + list[M, m_mode, N, K] = get_sample_type(y, X, d) + if(is_sep_sample(X)) { + y_cell = d_cell = X_cell = list() + N_l = rep(0, M) + for(m in 1:M) { + cell_ind = cell_factor[[m]]==lvls[[m]][cell_i] + y_cell[[m]] = y[[m]][cell_ind] + d_cell[[m]] = d[[m]][cell_ind] + X_cell[[m]] = X[[m]][cell_ind, , drop=FALSE] + N_l[m] = sum(cell_ind) + } + } + else { + cell_ind = cell_factor==lvls[cell_i] + N_l = if(M==1) sum(cell_ind) else rep(sum(cell_ind), M) + y_cell = if(is_vec(y)) y[cell_ind] else y[cell_ind, , drop=FALSE] + d_cell = if(is_vec(d)) d[cell_ind] else d[cell_ind, , drop=FALSE] + X_cell = X[cell_ind, , drop=FALSE] + } + return(list(y_cell, d_cell, X_cell, N_l)) +} + +lcl_levels <- function(cell_factor) { + if(!is.list(cell_factor)) { + lvls = levels(cell_factor) + return(list(lvls, length(lvls))) + } + lvls = lapply(cell_factor, levels) + return(list(lvls, length(lvls[[1]]))) +} + +Param_Est_m <- function(est_plan, y_cell, d_cell, X_cell, sample=sample, ret_var=FALSE, m_mode) { + if(!is_sep_estimators(m_mode)) { #single estimation + return(Param_Est(est_plan, y_cell, d_cell, X_cell, sample=sample, ret_var)) + } + if(m_mode==1){ + if(length(est_plan)!=3) { + print("ahh") + } + if(ret_var) { + rets = mapply(function(est_plan_s, y_cell_s, d_cell_s, X_cell_s) + unlist(Param_Est(est_plan_s, y_cell_s, d_cell_s, X_cell_s, sample=sample, ret_var)), + est_plan, y_cell, d_cell, X_cell, SIMPLIFY = TRUE) + return(list(param_ests=rets[1,], var_ests=rets[2,])) + } + rets = mapply(function(est_plan_s, y_cell_s, d_cell_s, X_cell_s) + Param_Est(est_plan_s, y_cell_s, d_cell_s, X_cell_s, sample=sample, ret_var)[[1]], + est_plan, y_cell, d_cell, X_cell, SIMPLIFY = TRUE) + return(list(param_ests = rets)) + } + + M = ncol(y_cell) + if(ret_var) { + rets = sapply(1:M, function(m) unlist(Param_Est(est_plan[[m]], y_cell[,m], d_cell, X_cell, sample=sample, ret_var))) + return(list(param_ests=rets[1,], var_ests=rets[2,])) + } + rets = sapply(1:M, function(m) Param_Est(est_plan[[m]], y_cell[,m], d_cell, X_cell, sample=sample, ret_var)[[1]]) + return(list(param_ests = rets)) +} + +mse_hat_obj <-function(y, X, d, partition=NULL, cell_factor=NULL, estimator=NULL, debug=FALSE, + warn_on_error=FALSE, est_plan=NULL, sample="trtr", ...) { + if(is.null(est_plan)) { + if(is.null(estimator)) est_plan = gen_simple_est_plan(has_d=!is.null(d)) + else est_plan = simple_est(estimator, estimator) + } + if(is.null(cell_factor)) { + cell_factor = get_factor_from_partition(partition, X) + } + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=FALSE) + list[lvls, n_cells] = lcl_levels(cell_factor) + cell_contribs = rep(0, n_cells) + N_eff = 0 + N_cell_empty = 0 + N_cell_error = 0 + for(cell_i in 1:n_cells) { + list[y_cell, d_cell, X_cell, N_l] <- get_cell(y, X, d, cell_factor, cell_i, lvls) + if(any(N_l==0)) { + N_cell_empty = N_cell_empty+1 + next + } + list[param_est] = Param_Est_m(est_plan, y_cell, d_cell, X_cell, sample=sample, ret_var=FALSE, m_mode) + if(!all(is.finite(param_est))) { + N_cell_error = N_cell_error+1 + msg = paste("Failed estimation: (N_l=", N_l, + ifelse(!is.null(d), paste(", var_d=", var(d_cell)), ""), + ")\n") + if(warn_on_error) warning(msg) + next + } + t1_mse = sum(N_l*param_est^2) + N_eff = N_eff + sum(N_l) + cell_contribs[cell_i] = t1_mse + if(debug) print(paste("cell_i=", cell_i, "; N_l=", N_l, "; param_est=", param_est)) + } + val = -1/N_eff*sum(cell_contribs) # Use N_eff to remove from average given errors + if(debug) print(paste("cell sums", val)) + return(c(val, N_cell_empty, N_cell_error)) +} + + +#' estimate_cell_stats +#' +#' @param y Nx1 matrix of outcome (label/target) data +#' @param X NxK matrix of features (covariates) +#' @param d (Optional) NxP matrix (with colnames) of treatment data. If all equally important they should +#' be normalized to have the same variance. +#' @param partition (Optional, need this or cell_factor) partitioning returned from fit_estimate_partition +#' @param cell_factor (Optional, need this or partition) +#' @param estimator_var (Optional) a function with signature list(param_est, var_est) = function(y, d) +#' (where if no d then can pass in null). If NULL then will choose between built-in +#' mean-estimator and scalar_te_estimator +#' @param est_plan Estimation plan +#' @param alpha Alpha +#' +#' @return list +#' \item{cell_factor}{Factor with levels for each cell for X. Length N.} +#' \item{stats}{data.frame(cell_i, N_est, param_ests, var_ests, tstats, pval, ci_u, ci_l, p_fwer, p_fdr)} +#' @export +estimate_cell_stats <- function(y, X, d=NULL, partition=NULL, cell_factor=NULL, estimator_var=NULL, + est_plan=NULL, alpha=0.05) { + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE) + X = ensure_good_X(X) + + if(is.null(est_plan)) { + if(is.null(estimator_var)) est_plan = gen_simple_est_plan(has_d=!is.null(d)) + else est_plan = simple_est(NULL, estimator_var) + } + if(is.null(cell_factor)) { + cell_factor = get_factor_from_partition(partition, X) + } + list[lvls, n_cells] = lcl_levels(cell_factor) + param_ests = matrix(NA, nrow=n_cells, ncol=M) + var_ests = matrix(NA, nrow=n_cells, ncol=M) + cell_sizes = matrix(NA, nrow=n_cells, ncol=M) + for(cell_i in 1:n_cells) { + list[y_cell, d_cell, X_cell, N_l] <- get_cell(y, X, d, cell_factor, cell_i, lvls) + cell_sizes[cell_i,] = N_l + list[param_ests_c, var_ests_c] = Param_Est_m(est_plan, y_cell, d_cell, X_cell, sample="est", ret_var=TRUE, m_mode=m_mode) + param_ests[cell_i,] = param_ests_c + var_ests[cell_i,] = var_ests_c + } + dofs = t(t(cell_sizes) - get_dofs(est_plan, M, m_mode)) #subtract dofs from each row + colnames(cell_sizes) = if(M==1) "N_est" else paste("N_est", 1:M, sep="") + colnames(param_ests) = if(M==1) "param_ests" else paste("param_ests", 1:M, sep="") + colnames(var_ests) = if(M==1) "var_ests" else paste("var_ests", 1:M, sep="") + base_df = cbind(data.frame(cell_i=1:n_cells), cell_sizes, param_ests, var_ests) + list[stat_df, pval] = exp_stats(base_df, param_ests, var_ests, dofs, alpha=alpha, M) + p_fwer = matrix(p.adjust(pval, "hommel"), ncol=ncol(pval)) #slightly more powerful than "hochberg". Given indep these are better than "bonferroni" and "holm" + p_fdr = matrix(p.adjust(pval, "BH"), ncol=ncol(pval)) #ours are independent so don't need "BY" + colnames(p_fwer) = if(M==1) "p_fwer" else paste("p_fwer", 1:M, sep="") + colnames(p_fdr) = if(M==1) "p_fdr" else paste("p_fdr", 1:M, sep="") + stat_df = cbind(stat_df, p_fwer, p_fdr) + return(list(cell_factor=cell_factor, stats=stat_df)) +} + +get_dofs <- function(est_plan, M, m_mode) { + if(m_mode==0) + return(est_plan$dof) + + if(is_sep_estimators(m_mode)) + return(sapply(est_plan, function(plan) plan$dof)) + + #Only 1 estimator but multiple d, so make sure right length + dof = est_plan$dof + if(length(dof)==1) dof=rep(dof, M) + return(dof) +} + +#' est_full_stats +#' +#' @param y y +#' @param d d +#' @param X X +#' @param est_plan est_plan +#' @param y_es y_es +#' @param d_es d_es +#' @param X_es X_es +#' @param index_tr index_tr +#' @param alpha alpha +#' +#' @return Stats df +#' @export +est_full_stats <- function(y, d, X, est_plan, y_es=NULL, d_es=NULL, X_es=NULL, index_tr=NULL, alpha=0.05) { + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE) + X = ensure_good_X(X) + + if(is.null(y_es)) { + list[y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est] = split_sample_m(y, X, d, index_tr) + } + N_es = nrow_m(X_es, M) + full_Ns = rbind(N, N_es) + colnames(full_Ns) = if(M==1) "N_est" else paste("N_est", 1:M, sep="") + list[full_param_ests_all, full_var_ests_all] = Param_Est_m(est_plan, y, d, X, sample="est", ret_var=TRUE, m_mode=m_mode) + list[full_param_ests_es, full_var_ests_es] = Param_Est_m(est_plan, y_es, d_es, X_es, sample="est", ret_var=TRUE, m_mode=m_mode) + M = length(full_param_ests_all) + full_param_ests = rbind(full_param_ests_all, full_param_ests_es) + colnames(full_param_ests) = if(M==1) "param_ests" else paste("param_ests", 1:M, sep="") + full_var_ests = rbind(full_var_ests_all, full_var_ests_es) + colnames(full_var_ests) = if(M==1) "var_ests" else paste("var_ests", 1:M, sep="") + base_df = cbind(data.frame(sample=c("all", "est")), full_Ns, full_param_ests, full_var_ests) + dofs = t(t(full_Ns) - get_dofs(est_plan, M, m_mode)) #subtract dofs from each row + list[full_stat_df, pval] = exp_stats(base_df, full_param_ests, full_var_ests, dofs, alpha=alpha, M) + return(full_stat_df) +} + +exp_stats <- function(stat_df, param_ests, var_ests, dofs, alpha=0.05, M) { + tstats = param_ests/sqrt(var_ests) + colnames(tstats) = if(M==1) "tstats" else paste("tstats", 1:M, sep="") + t.half.alpha = qt(1-alpha/2, df=dofs)*sqrt(var_ests) + ci_u = param_ests + t.half.alpha + colnames(ci_u) = if(M==1) "ci_u" else paste("ci_u", 1:M, sep="") + ci_l = param_ests - t.half.alpha + colnames(ci_l) = if(M==1) "ci_l" else paste("ci_l", 1:M, sep="") + pval = 2*pt(abs(tstats), df=dofs, lower.tail=FALSE) + colnames(pval) = if(M==1) "pval" else paste("pval", 1:M, sep="") + stat_df = cbind(stat_df, tstats, ci_u, ci_l, pval) + #pval_right= pt(tstats, df=dofs, lower.tail=FALSE) #right-tailed. Checking for just a positive effect (H_a is "greater") + #pval_left = pt(tstats, df=dofs, lower.tail=TRUE) #left-tailed. Checking for just a negative effect (H_a is "less") + return(list(stat_df, pval)) +} + +#' predict_te.estimated_partition +#' +#' Predicted unit-level treatment effect +#' +#' @param obj estimated_partition object +#' @param new_X new X +#' +#' @return predicted treatment effect +#' @export +predict_te.estimated_partition <- function(obj, new_X) { + #TODO: for mode 1 &2 maybe return a matrix rather than list + new_X = ensure_good_X(new_X) + new_X_range = get_X_range(new_X) + + cell_factor = get_factor_from_partition(obj$partition, new_X, new_X_range) + if(obj$M==1) { + N=nrow(new_X) + cell_factor_df = data.frame(id=1:N, cell_i = as.integer(cell_factor)) + m_df = merge(cell_factor_df, obj$cell_stats$stats) + m_df = m_df[order(m_df[["id"]]), ] + return(m_df[["param_ests"]]) + } + N = nrow_m(X, M) + rets = list() + for(m in 1:M) { + cell_factor_df = data.frame(id=1:N[m], cell_i = as.integer(cell_factor[[m]])) + m_df = merge(cell_factor_df, obj$cell_stats$stats) + m_df = m_df[order(m_df[["id"]]), ] + rets[[m]] = m_df[["param_ests"]] + } + return(rets) +} + +get_importance_weights_full_k <- function(k_i, to_compute, X_d, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, pot_break_points, verbosity, ...) { + if(verbosity>0) cat(paste("Feature weight > ", k_i, "of", length(to_compute),"\n")) + k = to_compute[k_i] + X_k = drop_col_k_m(X_d, k) + X_tr_k = drop_col_k_m(X_tr, k) + X_es_k = drop_col_k_m(X_es, k) + main_ret = fit_estimate_partition_int(X_k, y, d, X_tr_k, y_tr, d_tr, y_es, X_es_k, d_es, X_range[-k], pot_break_points=pot_break_points[-k], verbosity=verbosity, ...) + nk_val = mse_hat_obj(y_es, X_es_k, d=d_es, partition=main_ret$partition, est_plan=main_ret$est_plan, sample="est")[1] #use oos version instead of main_ret$is_obj_val_seq[partition_i] + return(list(nk_val, main_ret$partition$nsplits_by_dim)) +} + +# Just use mse_hat as we're working not on the Tr sample, but the est sample +# The ... params are passed to get_importance_weights_full_k -> fit_estimate_partition_int +# There's an undocumented "fast" version. Not very great as assings 0 to any feature not split on +get_importance_weights <- function(X, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, pot_break_points, partition, est_plan, type, verbosity, pr_cl, ...) { + if(verbosity>0) cat("Feature weights: Started.\n") + K = length(X_range) + if(sum(partition$nsplits_by_dim)==0) return(rep(0, K)) + full_val = mse_hat_obj(y_es, X_es, d=d_es, partition = partition, est_plan=est_plan, sample="est")[1] + + if(K==1) { + null_val = mse_hat_obj(y_es, X_es, d=d_es, partition = grid_partition(partition$X_range, partition$varnames), est_plan=est_plan, sample="est")[1] + if(verbosity>0) cat("Feature weights: Finished.\n") + return(null_val - full_val) + } + + if("fast"==type) { + new_vals = rep(0, K) + factors_by_dim = get_factors_from_partition(partition, X_es) + for(k in 1:K) { + if(partition$nsplits_by_dim[k]>0) { + cell_factor_nk = gen_holdout_interaction_m(factors_by_dim, k, is_sep_sample(X_tr)) + new_vals[k] = mse_hat_obj(y_es, X_es, d=d_es, cell_factor = cell_factor_nk, est_plan=est_plan, sample="est")[1] + } + } + if(verbosity>0) cat("Feature weights: Finished.\n") + return(new_vals - full_val) + } + + #if("full"==type) + new_vals = rep(full_val, K) + to_compute = which(partition$nsplits_by_dim>0) + + params = c(list(to_compute, X_d=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, pot_break_points=pot_break_points, est_plan=est_plan, verbosity=verbosity-1), + list(...)) + + rets = my_apply(1:length(to_compute), get_importance_weights_full_k, verbosity==1 || !is.null(pr_cl), pr_cl, params) + for(k_i in 1:length(to_compute)) { + k = to_compute[k_i] + new_vals[k] = rets[[k_i]][[1]] + } + if(verbosity>0) cat("Feature weights: Finished.\n") + return(new_vals - full_val) +} + +get_feature_interactions_k12 <- function(ks_i, to_compute, X_d, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, pot_break_points, verbosity, ...) { + if(verbosity>0) cat(paste("Feature interaction weight > ", ks_i, "of", length(to_compute),"\n")) + ks = to_compute[[ks_i]] + X_k = drop_col_k_m(X_d, ks) + X_tr_k = drop_col_k_m(X_tr, ks) + X_es_k = drop_col_k_m(X_es, ks) + main_ret = fit_estimate_partition_int(X_k, y, d, X_tr_k, y_tr, d_tr, y_es, X_es_k, d_es, X_range[-ks], pot_break_points=pot_break_points[-ks], verbosity=verbosity, ...) + nk_val = mse_hat_obj(y_es, X_es_k, d=d_es, partition=main_ret$partition, est_plan=main_ret$est_plan, sample="est")[1] #use oos version instead of main_ret$is_obj_val_seq[partition_i] + return(nk_val) + +} + +get_feature_interactions <- function(X, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, pot_break_points, partition, est_plan, verbosity, pr_cl, ...) { + + if(verbosity>0) cat("Feature weights: Started.\n") + K = length(X_range) + delta_k12 = matrix(as.integer(diag(rep(NA, K))), ncol=K) #dummy for K<3 cases + if(sum(partition$nsplits_by_dim)==0){ + if(verbosity>0) cat("Feature weights: Finished.\nFeature interaction weights: Started.\nFeature interaction interactions: Finished.\n") + return(list(delta_k=rep(0, K), delta_k12=delta_k12)) + } + full_val = mse_hat_obj(y_es, X_es, d=d_es, partition = partition, est_plan=est_plan, sample="est")[1] + + if(K==1) { + null_val = mse_hat_obj(y_es, X_es, d=d_es, partition = grid_partition(partition$X_range, partition$varnames), est_plan=est_plan, sample="est")[1] + if(verbosity>0) cat("Feature weights: Finished.\nFeature interaction weights: Started.\nFeature interaction interactions: Finished.\n") + return(list(delta_k=null_val - full_val, delta_k12=delta_k12)) + } + + #compute the single-removed values (and keep around the nsplits from each new partition) + new_val_k = rep(full_val, K) + to_compute_k = which(partition$nsplits_by_dim>0) + params = c(list(to_compute=to_compute_k, X_d=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, pot_break_points=pot_break_points, est_plan=est_plan, verbosity=verbosity-1), + list(...)) + rets_k = my_apply(1:length(to_compute_k), get_importance_weights_full_k, verbosity==1 || !is.null(pr_cl), pr_cl, params) + for(k_i in 1:length(to_compute_k)) { + k = to_compute_k[k_i] + new_val_k[k] = rets_k[[k_i]][[1]] + } + delta_k = new_val_k - full_val + if(K==2) { + null_val = mse_hat_obj(y_es, X_es, d=d_es, partition = grid_partition(partition$X_range, partition$varnames), est_plan=est_plan, sample="est")[1] + delta_k12 = matrix(null_val - full_val, ncol=2) + diag(rep(NA, K)) + if(verbosity>0) cat("Feature weights: Finished.\nFeature interaction weights: Started.\nFeature interaction interactions: Finished.\n") + return(list(delta_k=delta_k, delta_k12=delta_k12)) + } + if(verbosity>0) cat("Feature weights: Finished.\n") + + + #Compute the pair-removed values + if(verbosity>0) cat("Feature interaction weights: Started.\n") + new_val_k12 = matrix(full_val, ncol=K, nrow=K) + to_compute = list() + for(k1 in 1:(K-1)) { + if(partition$nsplits_by_dim[k1]==0) { + new_val_k12[k1,] = new_val_k + new_val_k12[,k1] = new_val_k + } + else { + k1_i = which(to_compute_k==k1) + nsplits_by_dim_k1= rets_k[[k1_i]][[2]] + for(k2 in (k1+1):K) { + if(nsplits_by_dim_k1[k2-1]==0) { #nsplits_by_dim_k1 is missing k1 so drop k2 back one + new_val_k12[k1,k2] = new_val_k[k1] + new_val_k12[k2,k1] = new_val_k[k1] + } + else { + to_compute = c(list(c(k1, k2)), to_compute) + } + } + } + } + params = c(list(to_compute=to_compute, X_d=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, pot_break_points=pot_break_points, est_plan=est_plan, verbosity=verbosity-1), + list(...)) + rets_k12 = my_apply(1:length(to_compute), get_feature_interactions_k12, verbosity==1 || !is.null(pr_cl), pr_cl, params) + for(ks_i in 1:length(to_compute)) { + k1 = to_compute[[ks_i]][1] + k2 = to_compute[[ks_i]][2] + new_val = rets_k12[[ks_i]] + new_val_k12[k1, k2] = new_val + new_val_k12[k2, k1] = new_val + } + delta_k12 = t(t((new_val_k12 - full_val) - delta_k) - delta_k) + diag(rep(NA, K)) + if(verbosity>0) cat("Feature interaction interactions: Finished.\n") + return(list(delta_k=delta_k, delta_k12=delta_k12)) +} + + +fit_and_residualize <- function(est_plan, X_tr, y_tr, d_tr, cv_folds, y_es, X_es, d_es, verbosity, dim_cat) { + est_plan = Fit_InitTr(est_plan, X_tr, y_tr, d_tr, cv_folds, verbosity=verbosity, dim_cat=dim_cat) + list[y_tr, d_tr] = Do_Residualize(est_plan, y_tr, X_tr, d_tr, sample="tr") + list[y_es, d_es] = Do_Residualize(est_plan, y_es, X_es, d_es, sample="tr") + return(list(est_plan, y_tr, d_tr, y_es, d_es)) +} + +# ... params are passed to fit_partition() +fit_estimate_partition_int <- function(X, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, dim_cat, X_range, est_plan, honest, cv_folds, verbosity, M, m_mode, + alpha, partition_i, ...) { + K = length(X_range) + obj_fn = if(honest) emse_hat_obj else mse_hat_obj + + list[est_plan, y_tr, d_tr, y_es, d_es] = fit_and_residualize_m(est_plan, X_tr, y_tr, d_tr, cv_folds, y_es, X_es, d_es, m_mode, M, verbosity, dim_cat) + + if(verbosity>0) cat("Training partition on training set\n") + fit_ret = fit_partition(y=y_tr, X=X_tr, d=d_tr, X_aux=X_es, d_aux=d_es, cv_folds=cv_folds, verbosity=verbosity, + X_range=X_range, obj_fn=obj_fn, est_plan=est_plan, valid_fn=NULL, ...) + list[partition, is_obj_val_seq, complexity_seq, partition_i, partition_seq, split_seq, lambda, cv_foldid] = fit_ret + + if(verbosity>0) cat("Estimating cell statistics on estimation set\n") + cell_stats = estimate_cell_stats(y_es, X_es, d_es, partition, est_plan=est_plan, alpha=alpha) + + full_stat_df = est_full_stats(y, d, X, est_plan, y_es=y_es, d_es=d_es, X_es=X_es) + + return(list(partition=partition, is_obj_val_seq=is_obj_val_seq, complexity_seq=complexity_seq, partition_i=partition_i, partition_seq=partition_seq, split_seq=split_seq, lambda=lambda, cv_foldid=cv_foldid, cell_stats=cell_stats, full_stat_df=full_stat_df, est_plan=est_plan)) +} + + +#' fit_estimate_partition +#' +#' Split the data, one one side train/fit the partition and then on the other estimate subgroup effects. +#' With multiple treatment effects (M) there are 3 options (the first two have the same sample across treatment effects). +#' 1) Multiple pairs of (Y_{m},W_{m}). y,X,d are then lists of length M. Each element then has the typical size +#' The N_m may differ across m. The number of columns of X will be the same across m. +#' 2) Multiple treatments and a single outcome. d is then a NxM matrix. +#' 3) A single treatment and multiple outcomes. y is then a NXM matrix. +#' +#' @param y N vector of outcome (label/target) data +#' @param X NxK matrix of features (covariates). Must be numerical (unordered categorical variables must be +#' 1-hot encoded.) +#' @param d (Optional) N vector of treatment data. +#' @param max_splits Maximum number of splits even if splits continue to improve OOS fit +#' @param max_cells Maximum number of cells +#' @param min_size Minimum size of cells +#' @param cv_folds Number of CV Folds or foldids. If Multiple effect #3 and using vector, then pass in list of vectors. +#' @param potential_lambdas potential lambdas to search through in CV +#' @param lambda.1se Use the 1se rule to pick the best lambda +#' @param partition_i Default is NA. Use this to avoid CV automated selection of the partition +#' @param tr_split - can be ratio or vector of indexes. If Multiple effect #3 and using vector then pass in list of vectors. +#' @param verbosity If >0 prints out progress bar for each split +#' @param pot_break_points NULL or a k-dim list of vectors giving potential split points for non-categorical +#' variables (can put c(0) for categorical). Similar to 'discrete splitting' in +#' CausalTree though their they do separate split-points for treated and controls. +#' @param bucket_min_n Minimum number of observations needed between different split checks for continuous features +#' @param bucket_min_d_var Ensure positive variance of d for the observations between different split checks +#' for continuous features +#' @param honest Whether to use the emse_hat or mse_hat. Use emse for outcome mean. For treatment effect, +#' use if want predictive accuracy, but if only want to identify true dimensions of heterogeneity +#' then don't use. +#' @param ctrl_method Method for determining additional control variables. Empty ("") for nothing, "all" or "lasso" +#' @param pr_cl Parallel Cluster (If NULL, default, then will be single-processor) +#' @param alpha Default=0.05 +#' @param bump_B Number of bump bootstraps +#' @param bump_ratio For bootstraps the ratio of sample size to sample (between 0 and 1, default 1) +#' @param importance_type Options: +#' single - (smart) redo full fitting removing each possible dimension +#' interaction - (smart) redo full fitting removing each pair of dimensions +#' "" - Nothing +#' +#' @return An object with class \code{"estimated_partition"}. +#' \item{partition}{Parition obj defining cuts} +#' \item{cell_stats}{list(cell_factor=cell_factor, stats=stat_df) from estimate_cell_stats() using est sample} +#' \item{importance_weights}{importance_weights} +#' \item{interaction_weights}{interaction_weights} +#' \item{has_d}{has_d} +#' \item{lambda}{lambda used} +#' \item{is_obj_val_seq}{In-sample objective function values for sequence of partitions} +#' \item{complexity_seq}{Complexity #s (# cells-1) for sequence of partitions} +#' \item{partition_i}{Index of Partition selected in sequence} +#' \item{split_seq}{Sequence of splits. Note that split i corresponds to partition i+1} +#' \item{index_tr}{Index of training sample (Size of N)} +#' \item{cv_foldid}{CV foldids for the training sample (Size of N_tr)} +#' \item{varnames}{varnames (or c("X1", "X2",...) if X doesn't have colnames)} +#' \item{honest}{honest target} +#' \item{est_plan}{Estimation plan} +#' \item{full_stat_df}{full_stat_df} +#' @export +fit_estimate_partition <- function(y, X, d=NULL, max_splits=Inf, max_cells=Inf, min_size=3, + cv_folds=2, potential_lambdas=NULL, lambda.1se=FALSE, partition_i=NA, + tr_split = 0.5, verbosity=0, pot_break_points=NULL, + bucket_min_n=NA, bucket_min_d_var=FALSE, honest=FALSE, + ctrl_method="", pr_cl=NULL, alpha=0.05, bump_B=0, bump_ratio=1, + importance_type="") { + + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE) + if(is_sep_sample(X) && length(tr_split)>1) { + assert_that(is.list(tr_split) && length(tr_split)==M) + } + X = ensure_good_X(X) + X = update_names_m(X) + + dim_cat = which(get_dim_cat_m(X)) + + #Split the sample + if(length(tr_split)==1) + index_tr = gen_split_m(N, tr_split, m_mode==1) + else + index_tr = tr_split + + list[y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est] = split_sample_m(y, X, d, index_tr) + + + #Setup est_plan + if(is.character(ctrl_method)) { + if(ctrl_method=="") { + est_plan = gen_simple_est_plan(has_d=!is.null(d)) + } + else { + assert_that(ctrl_method %in% c("all", "LassoCV", "rf"), !is.null(d)) + est_plan = if(ctrl_method=="rf") grid_rf() else lm_X_est(lasso = ctrl_method=="LassoCV") + } + } + else { + assert_that(inherits(ctrl_method, "Estimator_plan")) + est_plan = ctrl_method + } + if(is_sep_estimators(m_mode)) { + est_plan1 = est_plan + est_plan = list() + for(m in 1:M) est_plan[[m]] = est_plan1 + } + + X_range = get_X_range(X) + t0 = Sys.time() + main_ret = fit_estimate_partition_int(X=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, + X_range=X_range, max_splits=max_splits, max_cells=max_cells, min_size=min_size, + cv_folds=cv_folds, potential_lambdas=potential_lambdas, lambda.1se=lambda.1se, + partition_i=partition_i, verbosity=verbosity, pot_break_points=pot_break_points, + bucket_min_n=bucket_min_n, bucket_min_d_var=bucket_min_d_var, honest=honest, + pr_cl=pr_cl, alpha=alpha, bump_B=bump_B, bump_ratio=bump_ratio, M=M, m_mode=m_mode, + dim_cat=dim_cat, est_plan=est_plan, N_est=N_est) + list[partition, is_obj_val_seq, complexity_seq, partition_i, partition_seq, split_seq, lambda, cv_foldid, cell_stats, full_stat_df, est_plan] = main_ret + + importance_weights <- interaction_weights <- NULL + if(importance_type=="interaction") { + import_ret = get_feature_interactions(X=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, + pot_break_points=pot_break_points, partition=partition, est_plan=est_plan, verbosity=verbosity, pr_cl=pr_cl, + max_splits=max_splits, max_cells=max_cells, min_size=min_size, cv_folds=cv_folds, potential_lambdas=potential_lambdas, lambda.1se=lambda.1se, partition_i=partition_i, + bucket_min_n=bucket_min_n, bucket_min_d_var=bucket_min_d_var, honest=honest, alpha=alpha, bump_B=bump_B, + bump_ratio=bump_ratio, M=M, m_mode=m_mode, dim_cat=dim_cat, N_est=N_est) + importance_weights = import_ret$delta_k + interaction_weights = import_ret$delta_k12 + } + else if(importance_type %in% c("single", "fast")) { + importance_weights = get_importance_weights(X=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, + pot_break_points=pot_break_points, partition=partition, est_plan=est_plan, type=importance_type, verbosity=verbosity, pr_cl=pr_cl, + max_splits=max_splits, max_cells=max_cells, min_size=min_size, cv_folds=cv_folds, potential_lambdas=potential_lambdas, lambda.1se=lambda.1se, partition_i=partition_i, + bucket_min_n=bucket_min_n, bucket_min_d_var=bucket_min_d_var, honest=honest, alpha=alpha, bump_B=bump_B, + bump_ratio=bump_ratio, M=M, m_mode=m_mode, dim_cat=dim_cat, N_est=N_est) + } + + tn = Sys.time() + td = tn-t0 + if(verbosity>0) cat(paste("Entire Fit-Estimation Duration: ", format(as.numeric(td)), " ", attr(td, "units"), "\n")) + + return(structure(list(partition=partition, + cell_stats=cell_stats, + importance_weights=importance_weights, + interaction_weights=interaction_weights, + has_d=!is.null(d), + lambda=lambda, + is_obj_val_seq=is_obj_val_seq, + complexity_seq=complexity_seq, + partition_i=partition_i, + split_seq=split_seq, + index_tr=index_tr, + cv_foldid=cv_foldid, + varnames=names(X), + honest=honest, + est_plan=est_plan, + full_stat_df=full_stat_df, + m_mode=m_mode, + M=M), + class = c("estimated_partition"))) +} + +#' is.estimated_partition +#' +#' @param x Object +#' +#' @return True if x is an estimated_partition +#' @export +is.estimated_partition <- function(x) { + inherits(x, "estimated_partition") +} + +#' num_cells.estimated_partition +#' +#' @param obj Estimated Partition +#' +#' @return Number of cells +#' @export +#' @method num_cells estimated_partition +num_cells.estimated_partition <- function(obj) { + return(num_cells(obj$partition)) +} + +#' change_complexity +#' +#' Doesn't update the importance weights +#' +#' @param fit estimated_partition +#' @param y Nx1 matrix of outcome (label/target) data +#' @param X NxK matrix of features (covariates). Must be numerical (unordered categorical +#' variables must be 1-hot encoded.) +#' @param d (Optional) NxP matrix (with colnames) or vector of treatment data. If all equally +#' important they should be normalized to have the same variance. +#' @param partition_i partition_i - 1 is the last include in split_seq included in new partition +#' +#' @return updated estimated_partition +#' @export +change_complexity <- function(fit, y, X, d=NULL, partition_i) { + #TODO: Refactor checks from fit_estimation_partition and put them here + list[y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est] = split_sample_m(y, X, d, fit$index_tr) + + fit$partition = partition_from_split_seq(fit$split_seq, fit$partition$X_range, + varnames=fit$partition$varnames, max_include=partition_i-1) + fit$cell_stats = estimate_cell_stats(y_es, X_es, d_es, fit$partition, est_plan=fit$est_plan) + + return(fit) +} + + +#' get_desc_df.estimated_partition +#' +#' @param obj estimated_partition object +#' @param do_str If True, use a string like "(a, b]", otherwise have two separate columns with a and b +#' @param drop_unsplit If True, drop columns for variables overwhich the partition did not split +#' @param digits digits Option (default is NULL) +#' @param import_order should we use importance ordering or input ordering (default) +#' +#' @return data.frame +#' @export +get_desc_df.estimated_partition <- function(obj, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, import_order=FALSE) { + M = obj$M + stats = obj$cell_stats$stats[c(F, rep(T,M), rep(T,M), rep(F,M),rep(F,M), rep(F,M), rep(F,M), rep(T,M), rep(F,M), rep(F,M))] + part_df = get_desc_df.grid_partition(obj$partition, do_str=do_str, drop_unsplit=drop_unsplit, digits=digits) + + imp_weights = obj$importance_weights + if(drop_unsplit) { + imp_weights = imp_weights[obj$partition$nsplits_by_dim>0] + } + if(import_order) part_df = part_df[, order(-1* imp_weights)] + + return(cbind(part_df, stats)) +} + +#' print.estimated_partition +#' +#' @param x estimated_partition object +#' @param do_str If True, use a string like "(a, b]", otherwise have two separate columns with a and b +#' @param drop_unsplit If True, drop columns for variables overwhich the partition did not split +#' @param digits digits options +#' @param import_order should we use importance ordering or input ordering (default) +#' @param ... Additional arguments. These won't be passed to print.data.frame +#' +#' @return string (and displayed) +#' @export +#' @method print estimated_partition +print.estimated_partition <- function(x, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, import_order=FALSE, ...) { + return(print(get_desc_df.estimated_partition(x, do_str, drop_unsplit, digits, import_order=import_order), + digits=digits, ...)) +} + +#predict.estimated_partition <- function(object, X, d=NULL, type="response") { +# TDDO: Have to store y_hat as well as tau_hat +#} + +#libs required and suggested. Use if sourcing directly. +#If you don't want to use the Rcpp versio of const_vect (`const_vect = const_vectr`) then you can skip Rcpp +#lapply(lib_list, require, character.only = TRUE) +CausalGrid_libs <- function(required=TRUE, suggested=TRUE, load_Rcpp=FALSE) { + lib_list = c() + if(required) lib_list = c(lib_list, "caret", "gsubfn", "assertthat") + if(load_Rcpp) lib_list = c(lib_list, "Rcpp") + if(suggested) lib_list = c(lib_list, "ggplot2", "glmnet", "gglasso", "parallel", "pbapply", "ranger") + #Build=Rcpp. Full dev=testthat, knitr, rmarkdown, renv, rprojroot + return(lib_list) +} + +#' any_sign_effect +#' fdr - conservative +#' sim_mom_ineq - Need samples sizes to sufficiently large so that the effects are normally distributed +#' +#' @param obj obj +#' @param check_negative If true, check for a negative. If false, check for positive. +#' @param method one of c("fdr", "sim_mom_ineq") +#' @param alpha alpha +#' @param n_sim n_sim +#' +#' @return list(are_any= boolean of whether effect is negative) +#' @export +any_sign_effect <- function(obj, check_negative=T, method="fdr", alpha=0.05, n_sim=500) { + #TODO: could also + assert_that(method %in% c("fdr", "sim_mom_ineq")) + if(method=="fdr") { + assert_that(obj$has_d, alpha>0, alpha<1) + dofs = obj$cell_stats$stats[["N_est"]] - obj$est_plan$dof + pval_right= pt(obj$cell_stats$stats$tstats, df=dofs, lower.tail=FALSE) #right-tailed. Checking for just a positive effect (H_a is "greater") + pval_left = pt(obj$cell_stats$stats$tstats, df=dofs, lower.tail=TRUE) #left-tailed. Checking for just a negative effect (H_a is "less") + pval1s = if(check_negative) pval_left else pval_right + pval1s_fdr = p.adjust(pval1s, "BH") + are_any = sum(pval1s_fdr 0 + return(list(are_any=are_any, pval1s=pval1s, pval1s_fdr=pval1s_fdr)) + } + else { + N_cell = nrow(obj$cell_stats$stats) + te_se = sqrt(obj$cell_stats$stats[["var_ests"]]) + tstat_ext = if(check_negative) min(obj$cell_stats$stats[["tstats"]]) else max(obj$cell_stats$stats[["tstats"]]) + sim_tstat_exts = rep(NA, n_sim) + for(s in 1:n_sim) { + sim_te = rnorm(N_cell, mean=0, sd=te_se) + sim_tstat_exts[s] = if(check_negative) min(sim_te/te_se) else max(sim_te/te_se) + } + if(check_negative) { + are_any = sum(sim_tstat_exts < quantile(sim_tstat_exts, alpha)) > 0 + } + else { + are_any = sum(sim_tstat_exts > quantile(sim_tstat_exts, 1-alpha)) > 0 + } + } + return(list(are_any=are_any)) +} diff --git a/R/graphing.R b/R/graphing.R new file mode 100644 index 0000000..becf7b5 --- /dev/null +++ b/R/graphing.R @@ -0,0 +1,32 @@ +# TODO: +# - Add marginal plots: https://www.r-graph-gallery.com/277-marginal-histogram-for-ggplot2.html +# - When more than 2-d, have the 2d graphs be the most important wones and split on the least + +#' plot_2D_partition.estimated_partition +#' +#' @param grid_fit grid_fit +#' @param X_names_2D X_names_2D +#' +#' @return ggplot2 object +#' @export +plot_2D_partition.estimated_partition <- function(grid_fit, X_names_2D) { + if (!requireNamespace("ggplot2", quietly = TRUE)) { + stop("Package \"ggplot2\" needed for this function to work. Please install it.", + call. = FALSE) + } + split_dims = (grid_fit$partition$nsplits_by_dim > 0) + n_split_dims = sum(split_dims) + if(n_split_dims<2) { + warning("Less than 2 dimensions of heterogeneity") + } + desc_range_df = get_desc_df(grid_fit$partition, drop_unsplit=T) + desc_range_df = do.call(cbind, lapply(desc_range_df, function(c) as.data.frame(t(matrix(unlist(c), nrow=2))))) + + colnames(desc_range_df)<-c("xmin", "xmax", "ymin", "ymax") + desc_range_df["fill"] = grid_fit$cell_stats$stats$param_ests + + plt = ggplot2::ggplot() + + ggplot2::scale_x_continuous(name=X_names_2D[1]) +ggplot2::scale_y_continuous(name=X_names_2D[2]) + + ggplot2::geom_rect(data=desc_range_df, mapping=aes(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, fill=fill), color="black") + return(plt) +} diff --git a/R/grid_partition.R b/R/grid_partition.R new file mode 100644 index 0000000..c78813c --- /dev/null +++ b/R/grid_partition.R @@ -0,0 +1,1259 @@ +# agnostic to objective function or data splits + +# Factor/vector Utils ----------- + +get_factors_from_splits_dim <- function(X_k, X_k_range, s_by_dim_k) { + if(mode(X_k_range)=="character") { + windows = get_windows_cat(s_by_dim_k, X_k_range) + fac = X_k + new_name_map = levels(fac) + new_names = c() + for(window in windows) { + new_name = if(length(window)>1) paste0("{", paste(window, collapse=","), "}") else window[1] + new_name_map[levels(fac) %in% window] = new_name + new_names = c(new_names, new_name) + } + levels(fac) <- new_name_map + fac = factor(fac, new_names) + } + else { + bottom_break = X_k_range[1] + top_break = X_k_range[2] + #if(nsplits_by_dim_k>0) { + #bottom_split = s_by_dim_k[1] + #if(bottom_split==bottom_break) + bottom_break = bottom_break-1 #not needed + #} + top_break = top_break+1 + breaks = c(bottom_break, s_by_dim_k, top_break) + fac = cut(X_k, breaks, labels=NULL, include.lower=TRUE) #right=FALSE makes [a,b) segments. labels=FALSE makes just numeric vector + } + return(fac) +} + +get_factors_from_splits_dim_m <- function(X, X_k_range, s_by_dim_k, k) { + M_mult = is_sep_sample(X) + if(!M_mult) + return(get_factors_from_splits_dim(X[,k], X_k_range, s_by_dim_k)) + return(lapply(X, function(X_s) get_factors_from_splits_dim(X_s[,k], X_k_range, s_by_dim_k))) +} + +factor_from_idxs <-function(N, nfolds, indexOut) { + folds = vector("numeric", N) + for(f in 1:nfolds) { + folds[indexOut[[f]]] = f + } + folds_f = as.factor(folds) + return(folds_f) +} + +#Standard way to check if vector is constant is const_vectr(), but is O(n). +# Checking element-by-element would often be faster, but this is inefficient in R and faster in C. +# const_vect1() and const_vect2() were two versions (first using 'inline', second just Rcpp), +# but couldn't get to work in building a package. The Rcpp version is now in a separate file. + +const_vectr <- function(x) { + if(length(x)==0) return(TRUE) + r = range(x) + return(r[1]==r[2]) +} + +#' get_X_range +#' +#' @param X data +#' +#' @return list of length K with each element being a c(min, max) along that dimension +#' @export +get_X_range <- function(X) { + if(is_sep_sample(X)) + X = do.call("rbind", X) + if(is.matrix(X)) { + are_equal(mode(X), "numeric") + } + else { + assert_that(is.data.frame(X)) + if(inherits(X, "tbl")) X = as.data.frame(X) #tibble's return tibble (rather than vector) for X[,k], making is.factor(X[,k]) and others fail. Could switch to doing X[[k]] for df-like objects + for(k in seq_len(ncol(X))) are_equal(mode(X[[k]]), "numeric") + } + assert_that(ncol(X)>=1) + + X_range = list() + K = ncol(X) + for(k in 1:K) { + X_k = X[, k] + X_range[[k]] = if(is.factor(X_k)) levels(X_k) else range(X_k) #c(min, max) + } + return(X_range) +} + +# grid_partition ----------------- + +# for a continuous variables, splits are just values +# for a factor variable, a split is a vector of levels (strings) + +dummy_X_range <- function(K) { + X_range = list() + for(k in 1:K) { + X_range[[k]] = c(-Inf, Inf) + } + return(X_range) +} + +grid_partition <- function(X_range, varnames=NULL) { + K = length(X_range) + s_by_dim = vector("list", length=K) #splits_by_dim(s_seq) #stores Xk_val's + dim_cat = c() + for (k in 1:K) { + if(mode(X_range[[k]])=="character") { + dim_cat = c(dim_cat, k) + s_by_dim[[k]] = list() + } + else { + s_by_dim[[k]] = vector("numeric") + } + } + nsplits_by_dim = rep(0, K) + + return(structure(list(s_by_dim = s_by_dim, nsplits_by_dim = nsplits_by_dim, varnames=varnames, dim_cat=dim_cat, + X_range=X_range), class = c("grid_partition"))) +} + +#' is.grid_partition +#' +#' @param x Object +#' +#' @return True if x is a grid_partition +#' @export +is.grid_partition <- function(x) { + inherits(x, "grid_partition") +} + + +#First element is most insignificant (fastest changing), rather than lexicographic +#cell_i and return value are 1-indexed +segment_indexes_from_cell_i <- function(cell_i, n_segments) { + K = length(n_segments) + size = cumprod(n_segments) + if(cell_i > size[K]) + print("Error: too big") + index = rep(0, K) + cell_i_rem = cell_i-1 #convert to 0-indexing + for(k in 1:K) { + index[k] = cell_i_rem %% n_segments[k] + cell_i_rem = cell_i_rem %/% n_segments[k] + } + index = index+1 #convert from 0-indexing + return(index) +} + +#' get_desc_df.grid_partition +#' +#' @param partition Partition +#' @param cont_bounds_inf If True, will put continuous bounds as -Inf/Inf. Otherwise will use X_range bounds +#' @param do_str If True, use a string like "(a, b]", otherwise have two separate columns with a and b +#' @param drop_unsplit If True, drop columns for variables overwhich the partition did not split +#' @param digits digits option +#' @param unsplit_cat_star if we don't split on a categorical var, should we show as "*" (otherwise list all levels) +#' +#' @return data.frame +#' @export +get_desc_df.grid_partition <- function(partition, cont_bounds_inf=TRUE, do_str=FALSE, drop_unsplit=FALSE, + digits=NULL, unsplit_cat_star=TRUE) { + #To check: digits + assert_that(is.flag(cont_bounds_inf), is.flag(do_str), is.flag(drop_unsplit), is.flag(unsplit_cat_star)) + # A split at x_k means that we split to those <= and > + + n_segs = partition$nsplits_by_dim+1 + n_cells = prod(n_segs) + + if(n_cells==1 & drop_unsplit) return(as.data.frame(matrix(NA, nrow=1, ncol=0))) + + #Old code + #library(tidyverse) + #desc_df = data.frame(labels=levels(grid_fit$cell_stats$cell_factor), + # stringsAsFactors = FALSE) %>% separate(labels, names(X), "(?<=]).(?=[(])", PERL=TRUE) + + K = length(partition$nsplits_by_dim) + X_range = partition$X_range + if(cont_bounds_inf) { + for(k in 1:K) { + if(!k %in% partition$dim_cat) X_range[[k]] = c(-Inf, Inf) + } + } + colnames=partition$varnames + if(is.null(colnames)) colnames = paste("X", 1:K, sep="") + + list_of_windows = list() + for(k in 1:K) { + list_of_windows[[k]] = if(k %in% partition$dim_cat) get_windows_cat(partition$s_by_dim[[k]], X_range[[k]]) else get_window_cont(partition$s_by_dim[[k]], X_range[[k]]) + } + + format_cell_cat <- function(win, unsplit_cat_star, n_tot_dim, sep=", ") { + if(unsplit_cat_star && n_tot_dim==1) return("*") + return(paste(win, collapse=sep)) + } + format_cell_cont <- function(win) { + if(is.infinite(win[1]) && is.infinite(win[2])) return("*") + if(is.infinite(win[1])) return(paste0("<=", format(win[2], digits=digits))) + if(is.infinite(win[2])) return(paste0(">", format(win[1], digits=digits))) + return(paste0("(", format(win[1], digits=digits), ", ", format(win[2], digits=digits), "]")) + } + + raw_data = data.frame(row.names=1:n_cells) + str_data = data.frame(row.names=1:n_cells) + for(k in 1:K) { + raw_data_k = list() + str_data_k = c() + for(cell_i in 1:n_cells) { + segment_indexes = segment_indexes_from_cell_i(cell_i, n_segs) + win = list_of_windows[[k]][[segment_indexes[k]]] + raw_data_k[[cell_i]] = win + str_data_k[cell_i] = if(k %in% partition$dim_cat) format_cell_cat(win, unsplit_cat_star, length(list_of_windows[[k]])) else format_cell_cont(win) + } + raw_data[[colnames[k]]] = cbind(raw_data_k) #make a list-column: https://stackoverflow.com/a/51308306 + str_data[[colnames[k]]] = str_data_k + } + desc_df = if(do_str) str_data else raw_data + if(drop_unsplit) desc_df = desc_df[n_segs>1] + + + return(desc_df) +} + +add_split.grid_partition <- function(obj, s) { + k = s[[1]] + X_k_cut = s[[2]] + + if(k %in% obj$dim_cat) obj$s_by_dim[[k]][[obj$nsplits_by_dim[k]+1]] = X_k_cut + else obj$s_by_dim[[k]] = sort(c(X_k_cut, obj$s_by_dim[[k]])) + obj$nsplits_by_dim[k] = obj$nsplits_by_dim[k]+1 + + return(obj) +} + +partition_from_split_seq <- function(split_seq, X_range, varnames=NULL, max_include=Inf) { + part = grid_partition(X_range, varnames) + for(i in seq_len(min(length(split_seq), max_include))) part = add_split.grid_partition(part, split_seq[[i]]) + return(part) +} + +#' num_cells +#' +#' @param obj Object +#' +#' @return Number of cells in partition (at least 1) +#' @export +num_cells <- function(obj) { + UseMethod("num_cells", obj) +} + +#' num_cells.grid_partition +#' +#' @param obj Object +#' +#' @return Number of cells +#' @export +#' @method num_cells grid_partition +num_cells.grid_partition <- function(obj) { + return(prod(obj$nsplits_by_dim+1)) +} + +#' print.grid_partition +#' +#' @param x partition object +#' @param do_str If True, use a string like "(a, b]", otherwise have two separate columns with a and b +#' @param drop_unsplit If True, drop columns for variables overwhich the partition did not split +#' @param digits digits Option +#' @param ... Additional arguments. Passed to data.frame +#' +#' @return string (and displayed) +#' @export +#' @method print grid_partition +print.grid_partition <- function(x, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, ...) { + #To check: digits + assert_that(is.flag(do_str), is.flag(drop_unsplit)) + return(print(get_desc_df.grid_partition(x, do_str=do_str, drop_unsplit=drop_unsplit, digits=digits), + digits=digits, ...)) +} + +get_factors_from_partition <- function(partition, X, X_range=NULL) { + X_range = if(is.null(X_range)) partition$X_range else X_range + factors_by_dim = list() + if(is_sep_sample(X)) { + K = ncol(X[[1]]) + for(m in 1:length(X)) { + factors_by_dim_m = list() + for(k in 1:K) { + factors_by_dim_m[[k]] = get_factors_from_splits_dim(X[[m]][, k], X_range[[k]], partition$s_by_dim[[k]]) + } + factors_by_dim[[m]] = factors_by_dim_m + } + } + else { + K = ncol(X) + for(k in 1:K) { + factors_by_dim[[k]] = get_factors_from_splits_dim(X[, k], X_range[[k]], partition$s_by_dim[[k]]) + } + } + return(factors_by_dim) +} + +# partition_split --------------------- + +partition_split <- function(k, X_k_cut) { + return(structure(list(k=k, X_k_cut=X_k_cut), class=c("partition_split"))) +} +is.grid_partition_split <- function(x) inherits(x, "partition_split") + +#' print.partition_split +#' +#' @param x Object +#' @param ... Additional arguments. Unused. +#' +#' @return None +#' @export +#' @method print partition_split +print.partition_split <- function(x, ...) { + cat(paste0(x[[1]], ": ", x[[2]], "\n")) +} + +#' get_factor_from_partition +#' +#' @param partition partition +#' @param X X data or list of X +#' @param X_range (Optional) overrides the partition$X_range +#' +#' @return Factor +#' @export +get_factor_from_partition <- function(partition, X, X_range=NULL) { + facts = get_factors_from_partition(partition, X, X_range=X_range) + return(interaction_m(facts, is_sep_sample(X))) +} + +# Search algo -------------------- + +#if not mid-point then the all but the last are the splits +get_usable_break_points <- function(pot_break_points, X, X_range, dim_cat, mid_point=TRUE) { + if(is_sep_sample(X)) X = X[[1]] + K = ncol(X) + #old code + if(is.null(pot_break_points)) { + pot_break_points = list() + for(k in 1:K) { + if(!k %in% dim_cat) { + u = unique(sort(X[, k])) + if(mid_point) { + pot_break_points[[k]] = u[-length(u)] + diff(u) / 2 + } + else { + pot_break_points[[k]] = u[-length(u)] #skip last point + } + } + else { + pot_break_points[[k]] = c(0) #Dummy just for place=holder + } + } + } + else { #make sure they didn't include the lowest point + for(k in 1:K) { + if(!k %in% dim_cat) { + n_k = length(pot_break_points[[k]]) + if(pot_break_points[[k]][n_k]==X_range[[k]][2]) { + pot_break_points[[k]] = pot_break_points[[k]][-n_k] + } + } + pot_break_points[[k]] = unname(pot_break_points[[k]]) #names messed up the get_desc_df() (though not in debugSource) + } + } + return(pot_break_points) +} + +#' quantile_breaks +#' +#' @param X Features +#' @param binary_k vector of dimensions that are binary +#' @param g # of quantiles +#' @param type Quantile type (see ?quantile and https://mathworld.wolfram.com/Quantile.html). +#' Types1-3 are discrete and this is good for passing to unique() when there are clumps +#' +#' @return list of potential breaks +#' @export +quantile_breaks <- function(X, binary_k=c(), g=20, type=3) { + if(is_sep_sample(X)) X = X[[1]] + X = ensure_good_X(X) + + pot_break_points = list() + K = ncol(X) + for(k in 1:K) { + if(k %in% binary_k) { + pot_break_points[[k]] = c(0) + } + else { + X_k = X[[k]] + if(is.factor(X_k)) { + pot_break_points[[k]] = c(0) #Dummy + } + else { + #unique(sort(X[,k])) #we will automatically skip the top point + #if you want g segments, there are g-1 internal nodes, then there are g+1 outer nodes + qs = quantile(X_k, seq(0, 1, length.out=g+1), names=FALSE, type=type) + qs = unique(qs) + pot_break_points[[k]] = qs[-c(length(qs), 1)] + } + + } + } + return(pot_break_points) +} + +# if d vectors are empty doesn't return fail +valid_partition <- function(cell_factor, d=NULL, cell_factor_aux=NULL, d_aux=NULL, min_size=0) { + #check none of the cells are too small + if(min_size>0) { + if(length(cell_factor)==0) return(list(fail=TRUE, min_size=0)) + lowest_size = min(table(cell_factor)) + if(lowest_size=0) break + slopes = c(slopes, best_slope) + hull_i = which.min(i_slopes) + } + if(length(slopes)>1) { + lambda_ties = abs(slopes) #slightly bigger will go will pick the index earlier, slightly bigger later + } + else { + lambda_ties = c() + } + return(lambda_ties) +} + +gen_cat_window_splits <- function(chr_vec) { + n = length(chr_vec) + splits=list() + for(m in seq_len(floor(n/2))) { + cs = combn(chr_vec, m, simplify=F) + if(m==n/2) cs = cs[1:(length(cs)/2)] #or just filter by those that contain chr_vec[1] + splits = c(splits, cs) + } + return(splits) +} + +n_cat_window_splits <- function(window_len) { + n_splits = 0 + for(m in seq_len(floor(window_len/2))) { + n_choose = choose(window_len, m) + n_splits = n_splits + if(m==window_len/2) n_choose/2 else n_choose + } + return(n_splits) +} + +n_cat_splits <- function(s_by_dim_k, X_range_k) { + windows = get_windows_cat(s_by_dim_k, X_range_k) + n_splits = 0 + for(window in windows) n_splits = n_splits + n_cat_window_splits(length(window)) + return(n_splits) +} + +get_windows_cat <- function(s_by_dim_k, X_k_range) { + windows = s_by_dim_k + windows[[length(windows)+1]] = X_k_range[!X_k_range %in% unlist(c(windows))] + return(windows) +} + +get_window_cont <- function(s_by_dim_k, X_k_range) { + windows=list() + n_w = length(s_by_dim_k)+1 + for(w in 1:n_w) { + wmin = if(w==1) X_k_range[1] else s_by_dim_k[w-1] + wmax = if(w==n_w) X_k_range[2] else s_by_dim_k[w] + windows[[w]] = c(wmin, wmax) + } + return(windows) +} + +gen_holdout_interaction <- function(factors_by_dim, k) { + if(length(factors_by_dim)>1) + return(interaction(factors_by_dim[-k])) + return(factor(rep("|", length(factors_by_dim[[1]])))) +} + +n_breaks_k <- function(pot_break_points, k, partition, X_range) { + if(k %in% partition$dim_cat) return(n_cat_splits(partition$s_by_dim[[k]], X_range[[k]])) + return(length(pot_break_points[[k]])) +} + +fit_partition_full_k <- function(k, y, X_d, d, X_range, pb, debug, valid_breaks, factors_by_dim, X_aux, + factors_by_dim_aux, partition, verbosity, allow_empty_aux=TRUE, d_aux, + allow_est_errors_aux, min_size, min_size_aux, obj_fn, N_est, est_plan, + bucket_min_n, bucket_min_d_var, pot_break_points, valid_fn) { #, n_cut + assert_that(is.flag(bucket_min_d_var), is.flag(allow_empty_aux)) + list[M, m_mode, N, K] = get_sample_type(y, X_d, d, checks=FALSE) + search_ret = list() + best_new_val = Inf + valid_breaks_k = valid_breaks[[k]] + cell_factor_nk = gen_holdout_interaction_m(factors_by_dim, k, is_sep_sample(X_d)) + if(!is.null(X_aux)) { + cell_factor_nk_aux = gen_holdout_interaction_m(factors_by_dim_aux, k, is_sep_sample(X_aux)) + } + if(!is_factor_dim_k(X_d, k, m_mode==1)) { + n_pot_break_points_k = length(pot_break_points[[k]]) + vals = rep(NA, n_pot_break_points_k) + prev_split_checked = X_range[[k]][1] + win_LB = X_range[[k]][1]-1 + win_UB = if(length(partition$s_by_dim[[k]])>0) partition$s_by_dim[[k]][1] else X_range[[k]][2] + win_mask = gen_cont_window_mask_m(X_d, k, win_LB, win_UB) + win_mask_aux = gen_cont_window_mask_m(X_aux, k, win_LB, win_UB) + for(X_k_cut_i in seq_len(n_pot_break_points_k)) { #cut-point is top end of segment, + if (verbosity>0 && !is.null(pb)) utils::setTxtProgressBar(pb, utils::getTxtProgressBar(pb)+1) + X_k_cut = pot_break_points[[k]][X_k_cut_i] + if(X_k_cut %in% partition$s_by_dim[[k]]) { + prev_split_checked = X_k_cut + win_LB = X_k_cut + higher_prev_split = partition$s_by_dim[[k]][partition$s_by_dim[[k]]>X_k_cut] + win_UB = if(length(higher_prev_split)>0) min(higher_prev_split) else X_range[[k]][2] + win_mask = gen_cont_window_mask_m(X_d, k, win_LB, win_UB) + win_mask_aux = gen_cont_window_mask_m(X_aux, k, win_LB, win_UB) + next + } + if(!valid_breaks_k[[1]][X_k_cut_i]) next + new_split = partition_split(k, X_k_cut) + tent_partition = add_split.grid_partition(partition, new_split) + + tent_split_fac_k = get_factors_from_splits_dim_m(X_d, X_range[[k]], tent_partition$s_by_dim[[k]], k) + tent_cell_factor = interaction2_m(cell_factor_nk, tent_split_fac_k, m_mode==1) + if(!is.null(X_aux)) { + tent_split_fac_k_aux = get_factors_from_splits_dim_m(X_aux, X_range[[k]], tent_partition$s_by_dim[[k]], k) + tent_cell_factor_aux = interaction2_m(cell_factor_nk_aux, tent_split_fac_k_aux, is_sep_sample(X_aux)) + } + if(!is.na(bucket_min_n) | bucket_min_d_var) { + shifted = gen_cont_window_mask_m(X_d, k, prev_split_checked, X_k_cut) + } + + if(!is.na(bucket_min_n) && min_sum(shifted, m_mode==1)0) { #don't need to check [2] (empty cells) as we already did that + #cat("Estimation errors\n") + valid_breaks_k[[1]][X_k_cut_i] = FALSE + next + } + val = obj_ret[1] + stopifnot(is.finite(val)) + prev_split_checked = X_k_cut + + if(val0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val)) + best_new_val = val + new_factors_by_dim = replace_k_factor_m(factors_by_dim, k, tent_split_fac_k, is_sep_sample(X_d)) + if(!is.null(X_aux)) { + new_factors_by_dim_aux = replace_k_factor_m(factors_by_dim_aux, k, tent_split_fac_k_aux, is_sep_sample(X_aux)) + } + else new_factors_by_dim_aux = NULL + search_ret = list(val=val, new_split=new_split, new_factors_by_dim=new_factors_by_dim, + new_factors_by_dim_aux=new_factors_by_dim_aux) + } + } + + } + else { + windows = get_windows_cat(partition$s_by_dim[[k]], X_range[[k]]) + for(window_i in seq_len(length(windows))) { + window = windows[[window_i]] + win_mask = gen_cat_window_mask_m(X_d, k, window) + win_mask_aux = gen_cat_window_mask_m(X_aux, k, window) + pot_splits = gen_cat_window_splits(window) + for(win_split_i in seq_len(length(pot_splits))) { + win_split_val = pot_splits[[win_split_i]] + #TODO: Refactor with continuous case + if (verbosity>0 && !is.null(pb)) utils::setTxtProgressBar(pb, utils::getTxtProgressBar(pb)+1) + if(!valid_breaks_k[[window_i]][win_split_i]) next + + new_split = partition_split(k, win_split_val) + tent_partition = add_split.grid_partition(partition, new_split) + + tent_split_fac_k = get_factors_from_splits_dim_m(X_d, X_range[[k]], tent_partition$s_by_dim[[k]], k) + tent_cell_factor = interaction2_m(cell_factor_nk, tent_split_fac_k, m_mode==1) + if(!is.null(X_aux)) { + tent_split_fac_k_aux = get_factors_from_splits_dim_m(X_aux, X_range[[k]], tent_partition$s_by_dim[[k]], k) + tent_cell_factor_aux = interaction2_m(cell_factor_nk_aux, tent_split_fac_k_aux, is_sep_sample(X_aux)) + } + + # do_window_approach + win_split_cond = gen_cat_win_split_cond_m(X_d, win_mask, k, win_split_val) + win_cell_factor = interaction2_m(apply_mask_m(cell_factor_nk, win_mask, m_mode==1), win_split_cond, m_mode==1) + win_d = apply_mask_m(d, win_mask, m_mode==1) + valid_ret = valid_partition_m(m_mode==1, valid_fn, win_cell_factor, d=win_d, min_size=min_size) + if(!valid_ret$fail) { + if(!is.null(X_aux) && !allow_empty_aux) { + win_split_cond_aux = factor(gen_cat_win_split_cond_m(X_aux, win_mask_aux, k, win_split_val), levels=c(FALSE, TRUE)) + win_cell_factor_aux = interaction2_m(apply_mask_m(cell_factor_nk_aux, win_mask_aux, is_sep_sample(X_aux)), + win_split_cond_aux, is_sep_sample(X_aux), drop=allow_empty_aux) + win_d_aux = if(!allow_est_errors_aux) apply_mask_m(d_aux, win_mask_aux, is_sep_sample(X_aux)) else NULL + valid_ret = valid_partition_m(is_sep_sample(X_aux), valid_fn, win_cell_factor_aux, d=win_d_aux, min_size=min_size_aux) + } + } + if(valid_ret$fail) { + #cat("Invalid partition\n") + valid_breaks_k[[window_i]][win_split_i] = FALSE + next + } + if(debug) cat(paste("k", k, ". X_k", win_split_val, "\n")) + obj_ret = obj_fn(y, X_d, d, N_est=N_est, cell_factor = tent_cell_factor, debug=debug, est_plan=est_plan, + sample="trtr") + if(obj_ret[3]>0) { #don't need to check [2] (empty cells) as we already did that + #cat("Estimation errors\n") + valid_breaks_k[[window_i]][win_split_i] = FALSE + next + } + val = obj_ret[1] + stopifnot(is.finite(val)) + + + if(val0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val)) + best_new_val = val + new_factors_by_dim = replace_k_factor_m(factors_by_dim, k, tent_split_fac_k, is_sep_sample(X_d)) + if(!is.null(X_aux)) { + new_factors_by_dim_aux = replace_k_factor_m(factors_by_dim_aux, k, tent_split_fac_k_aux, is_sep_sample(X_aux)) + } + else new_factors_by_dim_aux = NULL + search_ret = list(val=val, new_split=new_split, new_factors_by_dim=new_factors_by_dim, + new_factors_by_dim_aux=new_factors_by_dim_aux) + } + + } + } + } + + return(list(search_ret, valid_breaks_k)) +} + +# There are three general problems with a partition. +# 1) Empty cells +# 2) Non-empty cells where objective can't be calculated +# 3) Cells where it can be calulcated but due to small sizes we don't want +# Main sample: We assume that a valid partition removes #1 and #2. Use min_size for 3 +# For Aux: Use allow_empty_aux, allow_est_errors_aux, min_size_aux +# Include d_aux if you want to make sure that non-empty cells in aux have positive variance in d +# FOr CV:allow_empty_aux=TRUE, allow_est_errors_aux=FALSE, min_size_aux=1 (weaker check than removing estimation errors) +# Can set nsplits_k_warn_limit=Inf to disable +fit_partition_full <- function(y, X, d=NULL, X_aux=NULL, d_aux=NULL, X_range, max_splits=Inf, max_cells=Inf, + min_size=2, verbosity=0, pot_break_points, N_est, obj_fn, allow_est_errors_aux=TRUE, + min_size_aux=2, est_plan, partition=NULL, nsplits_k_warn_limit=1000, pr_cl=NULL, + ...) { + assert_that(max_splits>=0, max_cells>=1, min_size>=1, + is.flag(allow_est_errors_aux), nsplits_k_warn_limit>=1) + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE) + est_min = ifelse(is.null(d), 2, 3) #If don't always need variance calc: ifelse(is.null(d), ifelse(honest, 2, 1), ifelse(honest, 3, 2)) + min_size = max(min_size, est_min) + if(!allow_est_errors_aux) min_size_aux = max(min_size_aux, est_min) + debug = FALSE + if(is.null(partition)) partition = grid_partition(X_range, names(X)) + pot_break_points = get_usable_break_points(pot_break_points, X, X_range, partition$dim_cat) + valid_breaks = vector("list", length=K) #splits_by_dim(s_seq) #stores Xk_val's + + for(k in 1:K) { + n_split_breaks_k = n_breaks_k(pot_break_points, k, partition, X_range) + valid_breaks[[k]] = list(rep(TRUE, n_split_breaks_k)) + if(!is.na(nsplits_k_warn_limit) && n_split_breaks_k>nsplits_k_warn_limit) warning(paste("Warning: Many splits (", n_split_breaks_k, ") along dimension", k, "\n")) + } + factors_by_dim = get_factors_from_partition(partition, X) + if(!is.null(X_aux)) { + factors_by_dim_aux = get_factors_from_partition(partition, X_aux) + } + + if(verbosity>0){ + cat("Grid > Fitting: Started.\n") + t0 = Sys.time() + } + split_i = 1 + seq_val = c() + obj_ret = obj_fn(y, X, d, N_est=N_est, cell_factor = interaction_m(factors_by_dim, is_sep_sample(X)), est_plan=est_plan, sample="trtr") + if(obj_ret[3]>0 || !is.finite(obj_ret[1])) { + stop("Estimation error with initial partition") + } + seq_val[1] = obj_ret[1] + partition_seq = list() + split_seq = list() + partition_seq[[1]] = partition + tent_cell_factor_aux = NULL + style = if(summary(stdout())$class=="terminal") 3 else 1 + if(!is.null(pr_cl) & !requireNamespace("parallel", quietly = TRUE)) { + stop("Package \"parallel\" needed for this function to work. Please install it.", call. = FALSE) + } + do_pbapply = requireNamespace("pbapply", quietly = TRUE) & (verbosity>0) & (is.null(pr_cl) || length(pr_cl)max_splits) break + if(num_cells(partition)==max_cells) break + n_cuts_k = rep(0, K) + for(k in 1:K) { + n_cuts_k[k] = n_breaks_k(pot_break_points, k, partition, X_range) + } + n_cuts_total = sum(n_cuts_k) + if(n_cuts_total==0) break + if(verbosity>0) { + cat(paste("Grid > Fitting > split ", split_i, ": Started\n")) + t1 = Sys.time() + if(is.null(pr_cl)) pb = utils::txtProgressBar(0, n_cuts_total, style = style) + } + + params = c(list(y=y, X_d=X, d=d, X_range=X_range, pb=NULL, debug=debug, valid_breaks=valid_breaks, + factors_by_dim=factors_by_dim, X_aux=X_aux, factors_by_dim_aux=factors_by_dim_aux, partition=partition, + verbosity=verbosity, d_aux=d_aux, allow_est_errors_aux=allow_est_errors_aux, + min_size=min_size, min_size_aux=min_size_aux, obj_fn=obj_fn, N_est=N_est, est_plan=est_plan, + pot_break_points=pot_break_points), list(...)) + + col_rets = my_apply(1:K, fit_partition_full_k, verbosity, pr_cl, params) + + best_new_val = Inf + best_new_split = NULL + for(k in 1:K) { + col_ret = col_rets[[k]] + search_ret = col_ret[[1]] + valid_breaks[[k]] = col_ret[[2]] + if(length(search_ret)>0 && search_ret$val0) print(paste("Testing split at ", X_k_cut, ". Val=", split_res$val)) + best_new_val = search_ret$val + best_new_split = search_ret$new_split + best_new_factors_by_dim = search_ret$new_factors_by_dim + if(!is.null(X_aux)) { + best_new_factors_by_dim_aux = search_ret$new_factors_by_dim_aux + } + } + } + + if (verbosity>0) { + t2 = Sys.time() #can us as.numeric(t1) to convert to seconds + td = t2-t1 + if(is.null(pr_cl)) close(pb) + } + if(is.null(best_new_split)) { + if (verbosity>0) cat(paste("Grid > Fitting > split ", split_i, ": Finished. Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ". No valid splits\n")) + break + } + best_new_partition = add_split.grid_partition(partition, best_new_split) + if(num_cells(best_new_partition)>max_cells) { + if (verbosity>0) cat(paste("Grid > Fitting > split ", split_i, ": Finished. Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ". Best split has results in too many cells\n")) + break + } + partition = best_new_partition + factors_by_dim = best_new_factors_by_dim + if(!is.null(X_aux)) { + factors_by_dim_aux = best_new_factors_by_dim_aux + } + split_i = split_i + 1 + seq_val[split_i] = best_new_val + partition_seq[[split_i]] = partition + split_seq[[split_i-1]] = best_new_split + if(best_new_split[[1]] %in% partition$dim_cat) { + k = best_new_split[[1]] + windows = get_windows_cat(partition$s_by_dim[[k]], X_range[[k]]) + nwindows = length(windows) + v_breaks = vector("list", length=nwindows) + for(window_i in seq_len(nwindows)) { + v_breaks[[window_i]] = rep(TRUE, n_cat_window_splits(length(windows[[window_i]]))) + } + valid_breaks[[k]] = v_breaks + } + if (verbosity>0) { + cat(paste("Grid > Fitting > split ", split_i, ": Finished.", + " Duration: ", format(as.numeric(td)), " ", attr(td, "units"), ".", + " New split: k=", best_new_split[[1]], ", cut=", best_new_split[[2]], ", val=", best_new_val, "\n")) + } + } + if (verbosity>0) { + tn = Sys.time() + td = tn-t0 + cat("Grid > Fitting: Finished.") + cat(paste(" Entire Search Duration: ", format(as.numeric(td)), " ", attr(td, "units"), "\n")) + } + return(list(partition_seq=partition_seq, is_obj_val_seq=seq_val, split_seq=split_seq)) +} + +#Allows two lists or two datasets +add_samples <- function (X, X_aux, M_mult) { + if(is.null(X_aux)) return(X) + if(M_mult) return(c(X, X_aux)) + return(list(X, X_aux)) +} + + +fit_partition_bump_b <- function(b, samples, y, X_d, d=NULL, m_mode, X_aux, d_aux, verbosity,...){ + if(verbosity>0) cat(paste("Grid > Bumping > b = ", b, "\n")) + sample = samples[[b]] + list[y_b, X_b, d_b] = subsample_m(y, X_d, d, sample) + X_aux2 = add_samples(X_d, X_aux, is_sep_sample(X_d)) + d_aux2 = add_samples(d, d_aux, is_sep_sample(X_d)) + fit_partition_full(y=y_b, X=X_b, d=d_b, X_aux=X_aux2, d_aux=d_aux2, verbosity=verbosity, ...) +} + +# ... params sent to fit_partition_full() +cv_pick_lambda_f <- function(f, y, X_d, d, folds_ret, nfolds, potential_lambdas, N_est, + verbosity, obj_fn, cv_tr_min_size, est_plan, ...) { + if(verbosity>0) cat(paste("Grid > CV > Fold", f, "\n")) + supplied_lambda = !is.null(potential_lambdas) + if(supplied_lambda) n_lambda = length(potential_lambdas) + + list[y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv] = split_sample_folds_m(y, X_d, d, folds_ret, f) + + #if(verbosity>0) cat(paste("- Fitting grid structure on fold", f, "of", nfolds, "\n")) + fit_ret = fit_partition_full(y_f_tr, X_f_tr, d_f_tr, X_f_cv, d_f_cv, + min_size=cv_tr_min_size, verbosity=verbosity, + N_est=N_est, obj_fn=obj_fn, allow_empty_aux=TRUE, + allow_est_errors_aux=FALSE, min_size_aux=1, est_plan=est_plan, ...) #min_size_aux is weaker than removing est errors + list[partition_seq, is_obj_val_seq, split_seq] = fit_ret + complexity_seq = sapply(partition_seq, num_cells) - 1 + + if(!supplied_lambda) { #build lambdas. Assuming no slope ties + lambda_ties_f = get_lambda_ties(is_obj_val_seq, complexity_seq) + col_ret = list(lambda_ties_f=lambda_ties_f, partition_seq=partition_seq, is_obj_val_seq=is_obj_val_seq, + complexity_seq=complexity_seq) + } + else { + col_ret = rep(NA, n_lambda) + for(lambda_i in seq_len(n_lambda)) { + lambda = potential_lambdas[lambda_i] + partition_i = which.min(is_obj_val_seq + lambda*complexity_seq) + obj_ret = obj_fn(y_f_cv, X_f_cv, d_f_cv, N_est=N_est, partition=partition_seq[[partition_i]], + est_plan=est_plan, sample="trcv") + oos_obj_val = obj_ret[1] + col_ret[lambda_i] = oos_obj_val + } + } + return(col_ret) +} + +# cv_tr_min_size: We don't want this too large as (since we have less data) otherwise we might not find the +# MSE-min lambda and if the most detailed partition on the full data is best we might have a +# lambda too large and choose one coarser. could choose 2 +# On the other hand the types of partitions we generate when this param is too small will be different +# and incomparable. Could choose (nfolds-2)/nfolds +# Therefore I take the average of the above two approaches. +# Used to warn if best lamda was the smallest, but since there's not much to do about it (we already scale +# cv_tr_min_size), stopped reporting +# Note: We do not want to first fit the full grid and then take potential lambdas as one from each segment that +# picks another grid. Those lambdas aren't gauranteed to include the true lambda min. We basically +# roughly sampling the true CV lambda function (which is is a step-function) and we might miss it and +# wrongly evaluate the benefit of each subgrid and therefore pick the wrong one. +# ... params sent to cv_pick_lambda_f +cv_pick_lambda <- function(y, X, d, folds_ret, nfolds, potential_lambdas, N_est, min_size, verbosity, lambda.1se=FALSE, + min_obs_1se=5, obj_fn, cv_tr_min_size=NA, est_plan, pr_cl=NULL, ...) { + #If potential_lambdas is NULL, then only have to iterate through lambda values that change partition_i (for any fold) + #If is_obj_val_seq is monotonic then this is easy and can do sequentially, but not sure if this is the case + supplied_lambda = !is.null(potential_lambdas) + if(supplied_lambda) { + n_lambda = length(potential_lambdas) + lambda_oos = matrix(NA, nrow=nfolds, ncol=n_lambda) + } + else { + lambda_ties = list() + partition_seqs = list() + is_obj_val_seqs = list() + complexity_seqs = list() + } + if(verbosity>0) cat("Grid > CV: Started.\n") + if(is.na(cv_tr_min_size)) cv_tr_min_size = as.integer(ifelse(nfolds==2, (2+min_size/2)/2, (nfolds-2)/nfolds)*min_size) + + params = c(list(y=y, X_d=X, d=d, folds_ret=folds_ret, nfolds=nfolds, potential_lambdas=potential_lambdas, + N_est=N_est, verbosity=verbosity-1, + obj_fn=obj_fn, cv_tr_min_size=cv_tr_min_size, + est_plan=est_plan), list(...)) + + col_rets = my_apply(1:nfolds, cv_pick_lambda_f, verbosity==1 || !is.null(pr_cl), pr_cl, params) + + # Process nfolds loop + if(!supplied_lambda) { + for(f in 1:nfolds) { + lambda_ties[[f]] = col_rets[[f]]$lambda_ties_f + partition_seqs[[f]] = col_rets[[f]]$partition_seq + is_obj_val_seqs[[f]] = col_rets[[f]]$is_obj_val_seq + complexity_seqs[[f]] = col_rets[[f]]$complexity_seq + } + } + else { + for(f in 1:nfolds) { + lambda_oos[f, ] = col_rets[[f]] + } + } + + if(!supplied_lambda) { + union_lambda_ties = sort(unlist(lambda_ties), decreasing=TRUE) + mid_points = union_lambda_ties[-length(union_lambda_ties)] + diff(union_lambda_ties)/2 + potential_lambdas = c(union_lambda_ties[1]+1, mid_points, mid_points[length(mid_points)]/2) + if(length(potential_lambdas)==0) { + if(verbosity>0) cat("Note: CV folds consistently picked initial model (complexity didn't improve in-sample objective). Defaulting to lambda=0.\n") + potential_lambdas=c(0) + } + n_lambda = length(potential_lambdas) + lambda_oos = matrix(NA, nrow=nfolds, ncol=n_lambda) + lambda_which_partition = matrix(NA, nrow=nfolds, ncol=n_lambda) + + for(f in 1:nfolds) { + list[y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv] = split_sample_folds_m(y, X, d, folds_ret, f) + for(lambda_i in 1:n_lambda) { + lambda = potential_lambdas[lambda_i] + partition_i = which.min(is_obj_val_seqs[[f]] + lambda*complexity_seqs[[f]]) + if(is.na(lambda_oos[f, lambda_i])) { + debug = FALSE + part = partition_seqs[[f]][[partition_i]] + if(debug) cat(paste("s_by_dim", paste(part$s_by_dim, collapse=" "), "\n")) + obj_ret = obj_fn(y_f_cv, X_f_cv, d_f_cv, N_est=N_est, partition=part, debug=debug, + est_plan=est_plan, sample="trcv") + oos_obj_val = obj_ret[1] + lambda_oos[f, lambda_i] = oos_obj_val + } + lambda_which_partition[f, lambda_i] = partition_i + } + + } + } + lambda_oos_means = colMeans(lambda_oos) + lambda_oos_min_i = which.min(lambda_oos_means) + #if(lambda_oos_min_i==length(lambda_oos_means)) cat(paste("Warning: MSE-min lambda is the smallest (of",length(lambda_oos_means),"potential lambdas)\n")) + if(lambda.1se) { + #lambda_oos_sd = apply(X=lambda_oos, MARGIN=2, FUN=sd) #don't need all for now + obs = lambda_oos[, lambda_oos_min_i] + if(min_obs_1se>nfolds) { + for(delta in 1:min(n_lambda-lambda_oos_min_i, lambda_oos_min_i-1)) { + if(lambda_oos_min_i+delta<=n_lambda) { + obs = c(obs, lambda_oos[, lambda_oos_min_i+delta]) + } + if(lambda_oos_min_i-delta>=1) { + obs = c(obs, lambda_oos[, lambda_oos_min_i-delta]) + } + if(length(obs)>=min_obs_1se) break + } + print(obs) + } + max_oos_err_allowed = min(lambda_oos_means) + sd(obs) + if(verbosity>0) cat(paste("max_oos_err_allowed:", paste(max_oos_err_allowed, collapse=" "), "\n")) + lambda_star_i = min(which(lambda_oos_means <= max_oos_err_allowed)) + lambda_star = potential_lambdas[lambda_star_i] + } + else { + lambda_star = potential_lambdas[lambda_oos_min_i] + } + if(verbosity>0){ + cat("Grid > CV: Finished.") + cat(paste(" lambda_oos_means=[", paste(lambda_oos_means, collapse=" "), "].")) + if(length(lambda_star)==0) cat(" Couldn't find any suitable lambdas, returning 1.\n") + else { + cat(paste(" potential_lambdas=[", paste(potential_lambdas, collapse=" "), "].")) + cat(paste(" lambda_star=", lambda_star, ".\n")) + } + } + if(length(lambda_star)==0) { + lambda_star=1 + } + + return(lambda_star) +} + +#' fit_partition +#' +#' CV Fit partition on some data, finds best lambda and then re-fits on full data. +#' +#' @param y Nx1 matrix of outcome (label/target) data +#' @param X NxK matrix of features (covariates) +#' @param d (Optional) NxP matrix (with colnames) of treatment data. If all equally important they +#' should be normalized to have the same variance. +#' @param N_est N of samples in the Estimation dataset +#' @param X_aux aux X sample to compute statistics on (OOS data) +#' @param d_aux aux d sample to compute statistics on (OOS data) +#' @param max_splits Maximum number of splits even if splits continue to improve OOS fit +#' @param max_cells Maximum number of cells even if more splits continue to improve OOS fit +#' @param min_size Minimum cell size when building full grid, cv_tr will use (F-1)/F*min_size, cv_te doesn't use any. +#' @param cv_folds Number of Folds +#' @param verbosity If >0 prints out progress bar for each split +#' @param pot_break_points k-dim list of vectors giving potential split points +#' @param potential_lambdas potential lambdas to search through in CV +#' @param X_range list of min/max for each dimension +#' @param lambda.1se Use the 1se rule to pick the best lambda +#' @param bucket_min_n Minimum number of observations needed between different split checks +#' @param bucket_min_d_var Ensure positive variance of d for the observations between different split checks +#' @param obj_fn Whether to use the emse_hat or mse_hat +#' @param est_plan Estimation plan. +#' @param partition_i Default NA. Use this to avoid CV +#' @param pr_cl Default NULL. Parallel cluster (used for fit_partition_full) +#' @param valid_fn Function to quickly check if partition could be valid. User can override. +#' @param bump_B Number of bump bootstraps +#' @param bump_ratio For bootstraps the ratio of sample size to sample (between 0 and 1, default 1) +#' +#' @return list(partition, lambda) +#' @export +fit_partition <- function(y, X, d=NULL, N_est=NA, X_aux=NULL, d_aux=NULL, max_splits=Inf, max_cells=Inf, + min_size=3, cv_folds=2, verbosity=0, pot_break_points=NULL, potential_lambdas=NULL, + X_range=NULL, lambda.1se=FALSE, bucket_min_n=NA, bucket_min_d_var=FALSE, obj_fn, + est_plan, partition_i=NA, pr_cl=NULL, valid_fn=NULL, bump_B=0, bump_ratio=1) { + #To check: y, X, d, N_est, X_aux, d_aux, pot_break_points, potential_lambdas, X_range, bucket_min_n + assert_that(max_splits>0, max_cells>0, min_size>0, is.flag(lambda.1se), is.flag(bucket_min_d_var), + inherits(est_plan, "Estimator_plan") || (is.list(est_plan) && inherits(est_plan[[1]], "Estimator_plan"))) #verbosity can be negative if decrementd from a fit_estimate call + list[M, m_mode, N, K] = get_sample_type(y, X, d, checks=TRUE) + if(is_sep_sample(X) && length(cv_folds)>1) { + assert_that(is.list(cv_folds) && length(cv_folds)==M) + } + check_M_K(M, m_mode, K, X_aux, d_aux) + do_cv = is.na(partition_i) + + if(is.null(X_range)) X_range = get_X_range(X) + if(is.null(valid_fn)) valid_fn = valid_partition + + if(verbosity>0) cat("Grid: Started.\n") + if(do_cv) { + #Check dimensions + if(m_mode==1) { #Different samples + if(length(cv_folds)==1) cv_folds = rep(list(cv_folds), M) #expand here + } + + #Get number of folds + focus_list = if(m_mode==1) cv_folds[[1]] else cv_folds + if(length(focus_list)==1) { + nfolds = cv_folds[[1]] + folds_ret = gen_folds_m(y, cv_folds, m_mode, M) + } + else { + nfolds = length(focus_list) + folds_ret = cv_folds + } + + if(is.null(potential_lambdas) | length(potential_lambdas)>1) { + lambda = cv_pick_lambda(y=y, X=X, d=d, folds_ret=folds_ret, nfolds=nfolds, potential_lambdas=potential_lambdas, N_est=N_est, max_splits=max_splits, max_cells=max_cells, + min_size=min_size, verbosity=verbosity, pot_break_points=pot_break_points, X_range=X_range, lambda.1se=lambda.1se, + bucket_min_n=bucket_min_n, bucket_min_d_var=bucket_min_d_var, obj_fn=obj_fn, + est_plan=est_plan, pr_cl=pr_cl, valid_fn=valid_fn) + } + else { + lambda = potential_lambdas[1] + } + folds_index_out = folds_ret$indexOut + } + else { + assert_that(partition_i>0) + lambda = NA + folds_index_out = NA + max_splits = partition_i-1 + } + + if(verbosity>0) cat("Grid: Fitting grid structure on full set\n") + fit_ret = fit_partition_full(y, X, d, X_aux, d_aux, X_range=X_range, max_splits=max_splits, + max_cells=max_cells, min_size=min_size, verbosity=verbosity-1, + pot_break_points=pot_break_points, N_est, bucket_min_n=bucket_min_n, + bucket_min_d_var=bucket_min_d_var, obj_fn=obj_fn, allow_empty_aux=FALSE, + allow_est_errors_aux=FALSE, min_size_aux=1, est_plan=est_plan, + pr_cl=pr_cl, valid_fn=valid_fn) + list[partition_seq, is_obj_val_seq, split_seq] = fit_ret + complexity_seq = sapply(partition_seq, num_cells) - 1 + + if(do_cv) { + #lambda = max(lambda, 1e-8*abs(min(is_obj_val_seq))) #min_lambda option + + partition_best = which.min(is_obj_val_seq + lambda*complexity_seq) + partition_i = partition_best + } + else { + if(length(partition_seq)< partition_i) { + cat("Note: Couldn't build grid to desired granularity. Using most granular") + partition_i = length(partition_seq) + } + } + + if(bump_B>0) { + if(verbosity>0) cat("Grid > Bumping: Started.\n") + assert_that(bump_ratio>0, bump_ratio<=1) + + if(do_cv) n_cv_cells = (complexity_seq + 1)[partition_i] + best_val = is_obj_val_seq[partition_i] + partition_i_b = partition_i + + + samples <- lapply(seq_len(bump_B), function(b){sample_m(bump_ratio, N, m_mode==1)}) + + params = list(samples=samples, y=y, X_d=X, d=d, m_mode=m_mode, X_aux=X_aux, d_aux=d_aux, X_range=X_range, max_splits=max_splits, + max_cells=max_cells, min_size=min_size*bump_ratio, verbosity=verbosity-1, + pot_break_points=pot_break_points, N_est=N_est, bucket_min_n=bucket_min_n, + bucket_min_d_var=bucket_min_d_var, obj_fn=obj_fn, allow_empty_aux=FALSE, + allow_est_errors_aux=FALSE, min_size_aux=min_size, est_plan=est_plan, + pr_cl=NULL, valid_fn=valid_fn) + + b_rets = my_apply(1:bump_B, fit_partition_bump_b, verbosity==1 || !is.null(pr_cl), pr_cl, params) + + best_b = NA + for(b in seq_len(bump_B)) { + b_ret = b_rets[[b]] + if(do_cv) { + partition_i_b = which.min(abs(sapply(b_ret$partition_seq, num_cells)-n_cv_cells)) + } + else { + if(length(b_ret$partition_seq)0 | obj_ret[3]>0) next #N_cell_empty, N_cell_error + if(obj_ret[1] < best_val){ + best_val = obj_ret[1] + best_b = b + } + } + if(!is.na(best_b)) { + if(verbosity>0) { + cat(paste("Grid > Bumping: Finished. Picking bumped partition.")) + cat(paste(" Old (unbumped) is_obj_val_seq=[", paste(is_obj_val_seq, collapse=" "), "].")) + cat(paste(" Old (unbumped) complexity_seq=[", paste(complexity_seq, collapse=" "), "].\n")) + } + list[partition_seq, is_obj_val_seq_best_b, split_seq] = b_rets[[best_b]] + if(do_cv) partition_i = which.min(abs(sapply(partition_seq, num_cells)-n_cv_cells)) + complexity_seq = sapply(partition_seq, num_cells) - 1 + is_obj_val_seq = sapply(partition_seq, function(p){ + obj_fn(y, X, d, N_est=N_est, partition=p, est_plan=est_plan, sample="trtr")[1] + }) + } + else { + if(verbosity>0) cat(paste("Grid > Bumping: Finished. No bumped partitions better than original.\n")) + } + } + + if(verbosity>0) { + #print(partition_seq) + cat(paste("Grid: Finished. is_obj_val_seq=[", paste(is_obj_val_seq, collapse=" "), "].")) + if(exists("partition_best")) { + cat(paste(" complexity_seq=[", paste(complexity_seq, collapse=" "), "].")) + cat(paste(" best partition=", paste(partition_best, collapse=" "), ".")) + } + cat("\n") + } + partition = partition_seq[[partition_i]] + return(list(partition=partition, is_obj_val_seq=is_obj_val_seq, complexity_seq=complexity_seq, partition_i=partition_i, + partition_seq=partition_seq, split_seq=split_seq, lambda=lambda, folds_index_out=folds_index_out)) +} diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 0000000..76c93d8 --- /dev/null +++ b/R/utils.R @@ -0,0 +1,371 @@ +# Utils + + +#handles vectors and 2D structures +row_sample <- function(data, sample) { + if(is.null(data)) return(NULL) + if(is.null(ncol(data))) + return(data[sample]) + return(data[sample, , drop=FALSE]) +} + +is_vec <- function(X) { + return(is.null(ncol(X)) || ncol(X)==1) +} + +get_dim_cat <- function(X) { + if(is.data.frame(X)) { + return(sapply(X, is.factor) & !sapply(X, is.ordered)) + } + return(rep(F, ncol(X))) +} + +update_names <- function(X) { + if(is.null(colnames(X))){ + colnames(X) = paste("X", 1:ncol(X), sep="") + } + return(X) +} + +#Note that the verbosity passed in here could be different than a member of the params +my_apply <- function(X, fn_k, apply_verbosity, pr_cl, params) { + K = length(X) + if(requireNamespace("pbapply", quietly = TRUE) & (apply_verbosity>0) & (is.null(pr_cl) || length(pr_cl)= 1) + return(X) +} + + +get_sample_type <- function(y, X, d=NULL, checks=FALSE) { + if(is_sep_sample(X)) { #Different samples + m_mode=1 + M = length(X) + N = sapply(X, nrow) + K = ncol(X[[1]]) + + if(checks) { + check_list_dims <- function(new_type) { + assert_that(is.list(new_type), length(new_type)==M) + for(m in 1:M) assert_that(length(new_type[[m]])==N[[m]]) + } + check_list_dims(y) + if(!is.null(d)) check_list_dims(d) + + for(m in 1:M) { + assert_that(ncol(X[[m]])==K) + } + } + + } + else { #Same sample + N = nrow(X) + K = ncol(X) + + if(!is.null(d) && is.matrix(d) && ncol(d)>1) { + m_mode= 2 + M = ncol(d) + if(checks){ + assert_that(!inherits(d, "tbl")) #TODO: Could silently conver + assert_that(nrow(d)==N, length(y)==N) + } + } + else if(!is.null(d) && is.matrix(y) && ncol(y)>1) { + m_mode= 3 + M = ncol(y) + N = nrow(X) + if(checks){ + assert_that(!inherits(y, "tbl")) #TODO: Could silently conver + assert_that(is.null(d) || length(d)==N, nrow(y)==N) + } + } + else { + m_mode= 0 + M=1 + if(checks) + assert_that(is.null(d) || length(d)==N, length(y)==N) + } + + if(M>1) N= rep(N, M) + } + return(list(M, m_mode, N, K)) +} + +check_M_K <- function(M, m_mode, K, X_aux, d_aux) { + if(m_mode==1) { + assert_that(length(X_aux)==M, is.null(d_aux) || length(d_aux)==M) + for(m in 1:M) assert_that(ncol(X_aux[[m]])==K) + } + else { + assert_that(ncol(X_aux)==K) + if(m_mode==2) assert_that(ncol(d_aux)==M) + } +} + +# Return M-list if mode_m==1 else sample +sample_m <- function(ratio, N, M_mult) { + if(!M_mult) { + if(length(N)>1) N=N[1] #for modes 2 & 3 + return(sample(N, N*ratio, replace=TRUE)) + } + return(lapply(N, function(N_s) sample(N_s, N_s*ratio, replace=TRUE))) +} + +#assumes separate samples if m_mode==1 +subsample_m <- function(y, X, d, sample) { + M_mult = is_sep_sample(X) + if(!M_mult) { + return(list(row_sample(y,sample), X[sample,,drop=FALSE], row_sample(d,sample))) + } + return(list(mapply(function(y_s, sample_s) y_s[sample_s], y, sample, SIMPLIFY=FALSE), + mapply(function(X_s, sample_s) X_s[sample_s,,drop=FALSE], X, sample, SIMPLIFY=FALSE), + mapply(function(d_s, sample_s) d_s[sample_s], d, sample, SIMPLIFY=FALSE))) +} + + +gen_split_m <- function(N, tr_split, M_mult) { + if(!M_mult) { + if(length(N>1)) N=N[1] # mode 2 & 3 + return(base::sample(N, tr_split*N)) + } + return(lapply(N, function(n) base::sample(n, tr_split*n))) +} + +split_sample_m <- function(y, X, d, index_tr) { + if(!is_sep_sample(X)) { + list[y_tr, y_es] = list(row_sample(y, index_tr), row_sample(y, -index_tr)) + list[d_tr, d_es] = list(row_sample(d, index_tr), row_sample(d, -index_tr)) + X_tr = X[index_tr, , drop=FALSE] + X_es = X[-index_tr, , drop=FALSE] + N_est = nrow(X_es) + } + else { + y_tr = y_es = X_tr = X_es = d_tr = d_es = list() + N_est = rep(0, length(X)) + for(m in 1:length(X)) + list[y_tr[[m]], y_es[[m]], X_tr[[m]], X_es[[m]], d_tr[[m]], d_es[[m]], N_est[m]] = split_sample_m(y[[m]], X[[m]], d[[m]], index_tr[[m]]) + N_est = sapply(X_es, nrow) + } + return(list(y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est)) +} + + +gen_folds_m <-function(y, folds, m_mode, M) { + if(m_mode!=1) { + if(is.list(y)) y = y[[1]] + if(is.matrix(y)) y = y[,1] + + return(gen_folds(y, folds)) + } + return(lapply(1:M, function(m) gen_folds(y[[m]], folds[[m]]))) +} + +split_sample_folds_m <- function(y, X, d, folds_ret, f) { + if(!is_sep_sample(X)) { + list[y_f_tr, y_f_cv] = list(row_sample(y, folds_ret$index[[f]]), row_sample(y, folds_ret$indexOut[[f]])) + list[d_f_tr, d_f_cv] = list(row_sample(d, folds_ret$index[[f]]), row_sample(d, folds_ret$indexOut[[f]])) + X_f_tr = X[folds_ret$index[[f]], , drop=FALSE] + X_f_cv = X[folds_ret$indexOut[[f]], , drop=FALSE] + } + else { + y_f_tr = y_f_cv = X_f_tr = X_f_cv = d_f_tr = d_f_cv = list() + for(m in 1:length(X)) + list[y_f_tr[[m]], y_f_cv[[m]], X_f_tr[[m]], X_f_cv[[m]], d_f_tr[[m]], d_f_cv[[m]]] = split_sample_folds_m(y[[m]], X[[m]], d[[m]], folds_ret[[m]], f) + } + return(list(y_f_tr, y_f_cv, X_f_tr, X_f_cv, d_f_tr, d_f_cv)) +} + +fit_and_residualize_m <- function(est_plan, X_tr, y_tr, d_tr, cv_folds, y_es, X_es, d_es, m_mode, M, verbosity, dim_cat) { + if(!is_sep_estimators(m_mode)) + return(fit_and_residualize(est_plan, X_tr, y_tr, d_tr, cv_folds, y_es, X_es, d_es, verbosity, dim_cat)) + + if(m_mode==1) { + for(m in 1:M) + list[est_plan[[m]], y_tr[[m]], d_tr[[m]], y_es[[m]], d_es[[m]]] = fit_and_residualize(est_plan[[m]], X_tr[[m]], y_tr[[m]], d_tr[[m]], cv_folds, y_es[[m]], X_es[[m]], d_es[[m]], verbosity, dim_cat) + return(list(est_plan, y_tr, d_tr, y_es, d_es)) + } + + #We overwrite the d's + for(m in 1:M) + list[est_plan[[m]], y_tr[,m], d_tr, y_es[,m], d_es] = fit_and_residualize(est_plan[[m]], X_tr, y_tr[,m], d_tr, cv_folds, y_es[,m], X_es, d_es, verbosity, dim_cat) + return(list(est_plan, y_tr, d_tr, y_es, d_es)) +} + + +interaction_m <- function(facts, M_mult=FALSE, drop=FALSE) { + if(!M_mult) { + return(interaction(facts, drop=drop)) + } + return(lapply(facts, function(f) interaction(f, drop=drop))) +} + +interaction2_m <- function(f1, f2, M_mult=FALSE, drop=FALSE) { + if(!M_mult) { + return(interaction(f1, f2, drop=drop)) + } + return(mapply(function(f1_s, f2_s) interaction(f1_s, f2_s, drop=drop), f1, f2, SIMPLIFY=FALSE)) +} + + +gen_holdout_interaction_m <- function(factors_by_dim, k, M_mult) { + if(!M_mult) + return(gen_holdout_interaction(factors_by_dim, k)) + + return(lapply(factors_by_dim, function(f_by_dim) gen_holdout_interaction(f_by_dim, k))) +} + +is_factor_dim_k <- function(X, k, M_mult) { + if(!M_mult) + return(is.factor(X[, k])) + return(return(is.factor(X[[1]][, k]))) +} + +droplevels_m <- function(factor, M_mult) { + if(!M_mult) return(droplevels(factor)) + return(lapply(factor, droplevels)) +} + +min_sum <- function(data, M_mult) { + if(!M_mult) return(sum(data)) + return(min(sapply(data, sum))) +} + +apply_mask_m <- function(data, mask, M_mult) { + if(is.null(data)) return(NULL) + if(!M_mult) return(row_sample(data, mask)) + return(mapply(function(data_s, mask_s) row_sample(data_s, mask_s), data, mask, SIMPLIFY=FALSE)) +} +any_const_m <- function(d, shifted, shifted_cell_factor_nk) { + if(m_mode==0 || m_mode==3) + return(any(by(d[shifted], shifted_cell_factor_nk, FUN=const_vect))) + if(m_mode==1) + return( any(mapply(function(d_s, shifted_s, shifted_cell_factor_nk_s) + any(by(d_s[shifted_s], shifted_cell_factor_nk_s, FUN=const_vect)) + , d, shifted, shifted_cell_factor_nk )) ) + #m_mode==3 + return( any(apply(d, 2, function(d_s) any(by(d_s[shifted], shifted_cell_factor_nk, FUN=const_vect)) )) ) +} +gen_cat_window_mask_m <- function(X, k, window) { + if(is.null(X)) return(NULL) + M_mult = is_sep_sample(X) + if(!M_mult) return(X[, k] %in% window) + return(lapply(X, function(X_s) X_s[, k] %in% window)) +} +gen_cat_win_split_cond_m <- function(X, win_mask, k, win_split_val) { + M_mult = is_sep_sample(X) + if(!M_mult) + return(factor(X[win_mask, k] %in% win_split_val, levels=c(FALSE, TRUE))) + return(mapply(function(X_s, win_mask_s) factor(X_s[win_mask_s, k] %in% win_split_val, levels=c(FALSE, TRUE)), X, win_mask, SIMPLIFY=FALSE)) +} + +gen_cont_window_mask_m <- function(X, k, win_LB, win_UB) { + if(is.null(X)) return(NULL) + M_mult = is_sep_sample(X) + if(!M_mult) return(win_LB +We randomize in generating train/est and trtr/trcv splits. Possibly cv.glmnet and cv.gglasso as well. +} diff --git a/man/Do_Residualize.Rd b/man/Do_Residualize.Rd new file mode 100644 index 0000000..8c144d4 --- /dev/null +++ b/man/Do_Residualize.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Do_Residualize} +\alias{Do_Residualize} +\title{Do_Residualize} +\usage{ +Do_Residualize(obj, y, X, d, sample) +} +\arguments{ +\item{obj}{Object} + +\item{y}{y} + +\item{X}{X} + +\item{d}{d (Default=NULL)} + +\item{sample}{one of 'tr' or 'est'} +} +\value{ +list(y=) or list(y=, d=) +} +\description{ +Do_Residualize +} diff --git a/man/Do_Residualize.grid_rf.Rd b/man/Do_Residualize.grid_rf.Rd new file mode 100644 index 0000000..875dd25 --- /dev/null +++ b/man/Do_Residualize.grid_rf.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Do_Residualize.grid_rf} +\alias{Do_Residualize.grid_rf} +\title{Do_Residualize.grid_rf} +\usage{ +\method{Do_Residualize}{grid_rf}(obj, y, X, d, sample) +} +\arguments{ +\item{obj}{Object} + +\item{y}{y} + +\item{X}{X} + +\item{d}{d (Default=NULL)} + +\item{sample}{one of 'tr' or 'est'} +} +\value{ +list(y=) or list(y=, d=) +} +\description{ +Do_Residualize.grid_rf +} diff --git a/man/Do_Residualize.lm_X_est.Rd b/man/Do_Residualize.lm_X_est.Rd new file mode 100644 index 0000000..12ae050 --- /dev/null +++ b/man/Do_Residualize.lm_X_est.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Do_Residualize.lm_X_est} +\alias{Do_Residualize.lm_X_est} +\title{Do_Residualize.lm_X_est} +\usage{ +\method{Do_Residualize}{lm_X_est}(obj, y, X, d, sample) +} +\arguments{ +\item{obj}{obj} + +\item{y}{y} + +\item{X}{X} + +\item{d}{d} + +\item{sample}{one of 'tr' or 'est'} +} +\value{ +list(y=...) or list(y=..., d=...) +} +\description{ +Do_Residualize.lm_X_est +} diff --git a/man/Do_Residualize.simple_est.Rd b/man/Do_Residualize.simple_est.Rd new file mode 100644 index 0000000..671c30f --- /dev/null +++ b/man/Do_Residualize.simple_est.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Do_Residualize.simple_est} +\alias{Do_Residualize.simple_est} +\title{Do_Residualize.simple_est} +\usage{ +\method{Do_Residualize}{simple_est}(obj, y, X, d, sample) +} +\arguments{ +\item{obj}{obj} + +\item{y}{y} + +\item{X}{X} + +\item{d}{d} + +\item{sample}{one of 'tr' or 'est'} +} +\value{ +list(y=...) and list(y=..., d=...) +} +\description{ +Do_Residualize.simple_est +} diff --git a/man/Fit_InitTr.Rd b/man/Fit_InitTr.Rd new file mode 100644 index 0000000..4ffb792 --- /dev/null +++ b/man/Fit_InitTr.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Fit_InitTr} +\alias{Fit_InitTr} +\title{Fit_InitTr} +\usage{ +Fit_InitTr( + obj, + X_tr, + y_tr, + d_tr = NULL, + cv_folds, + verbosity = 0, + dim_cat = c() +) +} +\arguments{ +\item{obj}{Object} + +\item{X_tr}{X} + +\item{y_tr}{y} + +\item{d_tr}{d_tr} + +\item{cv_folds}{CV folds} + +\item{verbosity}{verbosity} + +\item{dim_cat}{vector of dimensions that are categorical} +} +\value{ +Updated Object +} +\description{ +Fit_InitTr +} diff --git a/man/Fit_InitTr.grid_rf.Rd b/man/Fit_InitTr.grid_rf.Rd new file mode 100644 index 0000000..0b0209f --- /dev/null +++ b/man/Fit_InitTr.grid_rf.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Fit_InitTr.grid_rf} +\alias{Fit_InitTr.grid_rf} +\title{Fit_InitTr.grid_rf +Note that for large data, the rf_y_fit and potentially rf_d_fit objects may be large. +They can be null'ed out after fitting} +\usage{ +\method{Fit_InitTr}{grid_rf}( + obj, + X_tr, + y_tr, + d_tr = NULL, + cv_folds, + verbosity = 0, + dim_cat = c() +) +} +\arguments{ +\item{obj}{Object} + +\item{X_tr}{X} + +\item{y_tr}{y} + +\item{d_tr}{d_tr} + +\item{cv_folds}{CV folds} + +\item{verbosity}{verbosity} + +\item{dim_cat}{vector of dimensions that are categorical} +} +\value{ +Updated Object +} +\description{ +Fit_InitTr.grid_rf +Note that for large data, the rf_y_fit and potentially rf_d_fit objects may be large. +They can be null'ed out after fitting +} diff --git a/man/Fit_InitTr.lm_X_est.Rd b/man/Fit_InitTr.lm_X_est.Rd new file mode 100644 index 0000000..0352e41 --- /dev/null +++ b/man/Fit_InitTr.lm_X_est.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Fit_InitTr.lm_X_est} +\alias{Fit_InitTr.lm_X_est} +\title{Fit_InitTr.lm_X_est} +\usage{ +\method{Fit_InitTr}{lm_X_est}( + obj, + X_tr, + y_tr, + d_tr = NULL, + cv_folds, + verbosity = 0, + dim_cat = c() +) +} +\arguments{ +\item{obj}{lm_X_est object} + +\item{X_tr}{X_tr} + +\item{y_tr}{y_tr} + +\item{d_tr}{d_tr} + +\item{cv_folds}{cv_folds} + +\item{verbosity}{verbosity} + +\item{dim_cat}{dim_cat} +} +\value{ +Updated object +} +\description{ +Fit_InitTr.lm_X_est +} diff --git a/man/Fit_InitTr.simple_est.Rd b/man/Fit_InitTr.simple_est.Rd new file mode 100644 index 0000000..1bb4437 --- /dev/null +++ b/man/Fit_InitTr.simple_est.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Fit_InitTr.simple_est} +\alias{Fit_InitTr.simple_est} +\title{Fit_InitTr.simple_est} +\usage{ +\method{Fit_InitTr}{simple_est}( + obj, + X_tr, + y_tr, + d_tr = NULL, + cv_folds, + verbosity = 0, + dim_cat = c() +) +} +\arguments{ +\item{obj}{obj} + +\item{X_tr}{X_tr} + +\item{y_tr}{y_tr} + +\item{d_tr}{d_tr} + +\item{cv_folds}{cv_folds} + +\item{verbosity}{verbosity} + +\item{dim_cat}{dim_cat} +} +\value{ +Updated object +} +\description{ +Fit_InitTr.simple_est +} diff --git a/man/Param_Est.Rd b/man/Param_Est.Rd new file mode 100644 index 0000000..1cd7eac --- /dev/null +++ b/man/Param_Est.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Param_Est} +\alias{Param_Est} +\title{Param_Est} +\usage{ +Param_Est(obj, y, d = NULL, X, sample = "est", ret_var = FALSE) +} +\arguments{ +\item{obj}{Object} + +\item{y}{y A N-vector} + +\item{d}{d A N-vector or Nxm matrix (so that they can be estimated jointly)} + +\item{X}{X A NxK matrix or data.frame} + +\item{sample}{Sample: "trtr", "trcv", "est"} + +\item{ret_var}{Return Variance in the return list} +} +\value{ +list(param_est=...) +} +\description{ +Param_Est +} diff --git a/man/Param_Est.grid_rf.Rd b/man/Param_Est.grid_rf.Rd new file mode 100644 index 0000000..b01e016 --- /dev/null +++ b/man/Param_Est.grid_rf.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Param_Est.grid_rf} +\alias{Param_Est.grid_rf} +\title{Param_Est.grid_rf} +\usage{ +\method{Param_Est}{grid_rf}(obj, y, d = NULL, X, sample = "est", ret_var = FALSE) +} +\arguments{ +\item{obj}{Object} + +\item{y}{y} + +\item{d}{d} + +\item{X}{X} + +\item{sample}{Sample: "trtr", "trcv", "est"} + +\item{ret_var}{Return Variance in the return list} +} +\value{ +list(param_est=...) +} +\description{ +Param_Est.grid_rf +} diff --git a/man/Param_Est.lm_X_est.Rd b/man/Param_Est.lm_X_est.Rd new file mode 100644 index 0000000..fca30ff --- /dev/null +++ b/man/Param_Est.lm_X_est.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Param_Est.lm_X_est} +\alias{Param_Est.lm_X_est} +\title{Param_Est.lm_X_est} +\usage{ +\method{Param_Est}{lm_X_est}(obj, y, d = NULL, X, sample = "est", ret_var = FALSE) +} +\arguments{ +\item{obj}{obj} + +\item{y}{y} + +\item{d}{d} + +\item{X}{X} + +\item{sample}{Sample: "trtr", "trcv", "est"} + +\item{ret_var}{Return variance in return list} +} +\value{ +list(param_est=...) or list(param_est=..., var_est=...) +} +\description{ +Param_Est.lm_X_est +} diff --git a/man/Param_Est.simple_est.Rd b/man/Param_Est.simple_est.Rd new file mode 100644 index 0000000..48a15d3 --- /dev/null +++ b/man/Param_Est.simple_est.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{Param_Est.simple_est} +\alias{Param_Est.simple_est} +\title{Param_Est.simple_est} +\usage{ +\method{Param_Est}{simple_est}(obj, y, d = NULL, X, sample = "est", ret_var = FALSE) +} +\arguments{ +\item{obj}{obj} + +\item{y}{y} + +\item{d}{d} + +\item{X}{X} + +\item{sample}{Sample: "trtr", "trcv", "est"} + +\item{ret_var}{Return variance in return list} +} +\value{ +list(param_est=...) +} +\description{ +Param_Est.simple_est +} diff --git a/man/any_sign_effect.Rd b/man/any_sign_effect.Rd new file mode 100644 index 0000000..9f632c4 --- /dev/null +++ b/man/any_sign_effect.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{any_sign_effect} +\alias{any_sign_effect} +\title{any_sign_effect +fdr - conservative +sim_mom_ineq - Need samples sizes to sufficiently large so that the effects are normally distributed} +\usage{ +any_sign_effect( + obj, + check_negative = T, + method = "fdr", + alpha = 0.05, + n_sim = 500 +) +} +\arguments{ +\item{obj}{obj} + +\item{check_negative}{If true, check for a negative. If false, check for positive.} + +\item{method}{one of c("fdr", "sim_mom_ineq")} + +\item{alpha}{alpha} + +\item{n_sim}{n_sim} +} +\value{ +list(are_any= boolean of whether effect is negative) +} +\description{ +any_sign_effect +fdr - conservative +sim_mom_ineq - Need samples sizes to sufficiently large so that the effects are normally distributed +} diff --git a/man/change_complexity.Rd b/man/change_complexity.Rd new file mode 100644 index 0000000..17ad2da --- /dev/null +++ b/man/change_complexity.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{change_complexity} +\alias{change_complexity} +\title{change_complexity} +\usage{ +change_complexity(fit, y, X, d = NULL, partition_i) +} +\arguments{ +\item{fit}{estimated_partition} + +\item{y}{Nx1 matrix of outcome (label/target) data} + +\item{X}{NxK matrix of features (covariates). Must be numerical (unordered categorical +variables must be 1-hot encoded.)} + +\item{d}{(Optional) NxP matrix (with colnames) or vector of treatment data. If all equally +important they should be normalized to have the same variance.} + +\item{partition_i}{partition_i - 1 is the last include in split_seq included in new partition} +} +\value{ +updated estimated_partition +} +\description{ +Doesn't update the importance weights +} diff --git a/man/est_full_stats.Rd b/man/est_full_stats.Rd new file mode 100644 index 0000000..dcda8a6 --- /dev/null +++ b/man/est_full_stats.Rd @@ -0,0 +1,43 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{est_full_stats} +\alias{est_full_stats} +\title{est_full_stats} +\usage{ +est_full_stats( + y, + d, + X, + est_plan, + y_es = NULL, + d_es = NULL, + X_es = NULL, + index_tr = NULL, + alpha = 0.05 +) +} +\arguments{ +\item{y}{y} + +\item{d}{d} + +\item{X}{X} + +\item{est_plan}{est_plan} + +\item{y_es}{y_es} + +\item{d_es}{d_es} + +\item{X_es}{X_es} + +\item{index_tr}{index_tr} + +\item{alpha}{alpha} +} +\value{ +Stats df +} +\description{ +est_full_stats +} diff --git a/man/estimate_cell_stats.Rd b/man/estimate_cell_stats.Rd new file mode 100644 index 0000000..c7c8144 --- /dev/null +++ b/man/estimate_cell_stats.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{estimate_cell_stats} +\alias{estimate_cell_stats} +\title{estimate_cell_stats} +\usage{ +estimate_cell_stats( + y, + X, + d = NULL, + partition = NULL, + cell_factor = NULL, + estimator_var = NULL, + est_plan = NULL, + alpha = 0.05 +) +} +\arguments{ +\item{y}{Nx1 matrix of outcome (label/target) data} + +\item{X}{NxK matrix of features (covariates)} + +\item{d}{(Optional) NxP matrix (with colnames) of treatment data. If all equally important they should +be normalized to have the same variance.} + +\item{partition}{(Optional, need this or cell_factor) partitioning returned from fit_estimate_partition} + +\item{cell_factor}{(Optional, need this or partition)} + +\item{estimator_var}{(Optional) a function with signature list(param_est, var_est) = function(y, d) +(where if no d then can pass in null). If NULL then will choose between built-in +mean-estimator and scalar_te_estimator} + +\item{est_plan}{Estimation plan} + +\item{alpha}{Alpha} +} +\value{ +list +\item{cell_factor}{Factor with levels for each cell for X. Length N.} +\item{stats}{data.frame(cell_i, N_est, param_ests, var_ests, tstats, pval, ci_u, ci_l, p_fwer, p_fdr)} +} +\description{ +estimate_cell_stats +} diff --git a/man/fit_estimate_partition.Rd b/man/fit_estimate_partition.Rd new file mode 100644 index 0000000..e804f28 --- /dev/null +++ b/man/fit_estimate_partition.Rd @@ -0,0 +1,112 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{fit_estimate_partition} +\alias{fit_estimate_partition} +\title{fit_estimate_partition} +\usage{ +fit_estimate_partition( + y, + X, + d = NULL, + max_splits = Inf, + max_cells = Inf, + min_size = 3, + cv_folds = 2, + potential_lambdas = NULL, + lambda.1se = FALSE, + partition_i = NA, + tr_split = 0.5, + verbosity = 0, + pot_break_points = NULL, + bucket_min_n = NA, + bucket_min_d_var = FALSE, + honest = FALSE, + ctrl_method = "", + pr_cl = NULL, + alpha = 0.05, + bump_B = 0, + bump_ratio = 1, + importance_type = "" +) +} +\arguments{ +\item{y}{N vector of outcome (label/target) data} + +\item{X}{NxK matrix of features (covariates). Must be numerical (unordered categorical variables must be +1-hot encoded.)} + +\item{d}{(Optional) N vector of treatment data.} + +\item{max_splits}{Maximum number of splits even if splits continue to improve OOS fit} + +\item{max_cells}{Maximum number of cells} + +\item{min_size}{Minimum size of cells} + +\item{cv_folds}{Number of CV Folds or foldids. If Multiple effect #3 and using vector, then pass in list of vectors.} + +\item{potential_lambdas}{potential lambdas to search through in CV} + +\item{lambda.1se}{Use the 1se rule to pick the best lambda} + +\item{partition_i}{Default is NA. Use this to avoid CV automated selection of the partition} + +\item{tr_split}{- can be ratio or vector of indexes. If Multiple effect #3 and using vector then pass in list of vectors.} + +\item{verbosity}{If >0 prints out progress bar for each split} + +\item{pot_break_points}{NULL or a k-dim list of vectors giving potential split points for non-categorical +variables (can put c(0) for categorical). Similar to 'discrete splitting' in +CausalTree though their they do separate split-points for treated and controls.} + +\item{bucket_min_n}{Minimum number of observations needed between different split checks for continuous features} + +\item{bucket_min_d_var}{Ensure positive variance of d for the observations between different split checks +for continuous features} + +\item{honest}{Whether to use the emse_hat or mse_hat. Use emse for outcome mean. For treatment effect, +use if want predictive accuracy, but if only want to identify true dimensions of heterogeneity +then don't use.} + +\item{ctrl_method}{Method for determining additional control variables. Empty ("") for nothing, "all" or "lasso"} + +\item{pr_cl}{Parallel Cluster (If NULL, default, then will be single-processor)} + +\item{alpha}{Default=0.05} + +\item{bump_B}{Number of bump bootstraps} + +\item{bump_ratio}{For bootstraps the ratio of sample size to sample (between 0 and 1, default 1)} + +\item{importance_type}{Options: +single - (smart) redo full fitting removing each possible dimension +interaction - (smart) redo full fitting removing each pair of dimensions + "" - Nothing} +} +\value{ +An object with class \code{"estimated_partition"}. +\item{partition}{Parition obj defining cuts} +\item{cell_stats}{list(cell_factor=cell_factor, stats=stat_df) from estimate_cell_stats() using est sample} +\item{importance_weights}{importance_weights} +\item{interaction_weights}{interaction_weights} +\item{has_d}{has_d} +\item{lambda}{lambda used} +\item{is_obj_val_seq}{In-sample objective function values for sequence of partitions} +\item{complexity_seq}{Complexity #s (# cells-1) for sequence of partitions} +\item{partition_i}{Index of Partition selected in sequence} +\item{split_seq}{Sequence of splits. Note that split i corresponds to partition i+1} +\item{index_tr}{Index of training sample (Size of N)} +\item{cv_foldid}{CV foldids for the training sample (Size of N_tr)} +\item{varnames}{varnames (or c("X1", "X2",...) if X doesn't have colnames)} +\item{honest}{honest target} +\item{est_plan}{Estimation plan} +\item{full_stat_df}{full_stat_df} +} +\description{ +Split the data, one one side train/fit the partition and then on the other estimate subgroup effects. +With multiple treatment effects (M) there are 3 options (the first two have the same sample across treatment effects). + 1) Multiple pairs of (Y_{m},W_{m}). y,X,d are then lists of length M. Each element then has the typical size + The N_m may differ across m. The number of columns of X will be the same across m. + 2) Multiple treatments and a single outcome. d is then a NxM matrix. + 3) A single treatment and multiple outcomes. y is then a NXM matrix. +} diff --git a/man/fit_partition.Rd b/man/fit_partition.Rd new file mode 100644 index 0000000..f92a8c1 --- /dev/null +++ b/man/fit_partition.Rd @@ -0,0 +1,89 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{fit_partition} +\alias{fit_partition} +\title{fit_partition} +\usage{ +fit_partition( + y, + X, + d = NULL, + N_est = NA, + X_aux = NULL, + d_aux = NULL, + max_splits = Inf, + max_cells = Inf, + min_size = 3, + cv_folds = 2, + verbosity = 0, + pot_break_points = NULL, + potential_lambdas = NULL, + X_range = NULL, + lambda.1se = FALSE, + bucket_min_n = NA, + bucket_min_d_var = FALSE, + obj_fn, + est_plan, + partition_i = NA, + pr_cl = NULL, + valid_fn = NULL, + bump_B = 0, + bump_ratio = 1 +) +} +\arguments{ +\item{y}{Nx1 matrix of outcome (label/target) data} + +\item{X}{NxK matrix of features (covariates)} + +\item{d}{(Optional) NxP matrix (with colnames) of treatment data. If all equally important they +should be normalized to have the same variance.} + +\item{N_est}{N of samples in the Estimation dataset} + +\item{X_aux}{aux X sample to compute statistics on (OOS data)} + +\item{d_aux}{aux d sample to compute statistics on (OOS data)} + +\item{max_splits}{Maximum number of splits even if splits continue to improve OOS fit} + +\item{max_cells}{Maximum number of cells even if more splits continue to improve OOS fit} + +\item{min_size}{Minimum cell size when building full grid, cv_tr will use (F-1)/F*min_size, cv_te doesn't use any.} + +\item{cv_folds}{Number of Folds} + +\item{verbosity}{If >0 prints out progress bar for each split} + +\item{pot_break_points}{k-dim list of vectors giving potential split points} + +\item{potential_lambdas}{potential lambdas to search through in CV} + +\item{X_range}{list of min/max for each dimension} + +\item{lambda.1se}{Use the 1se rule to pick the best lambda} + +\item{bucket_min_n}{Minimum number of observations needed between different split checks} + +\item{bucket_min_d_var}{Ensure positive variance of d for the observations between different split checks} + +\item{obj_fn}{Whether to use the emse_hat or mse_hat} + +\item{est_plan}{Estimation plan.} + +\item{partition_i}{Default NA. Use this to avoid CV} + +\item{pr_cl}{Default NULL. Parallel cluster (used for fit_partition_full)} + +\item{valid_fn}{Function to quickly check if partition could be valid. User can override.} + +\item{bump_B}{Number of bump bootstraps} + +\item{bump_ratio}{For bootstraps the ratio of sample size to sample (between 0 and 1, default 1)} +} +\value{ +list(partition, lambda) +} +\description{ +CV Fit partition on some data, finds best lambda and then re-fits on full data. +} diff --git a/man/get_X_range.Rd b/man/get_X_range.Rd new file mode 100644 index 0000000..dc31ad3 --- /dev/null +++ b/man/get_X_range.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{get_X_range} +\alias{get_X_range} +\title{get_X_range} +\usage{ +get_X_range(X) +} +\arguments{ +\item{X}{data} +} +\value{ +list of length K with each element being a c(min, max) along that dimension +} +\description{ +get_X_range +} diff --git a/man/get_desc_df.estimated_partition.Rd b/man/get_desc_df.estimated_partition.Rd new file mode 100644 index 0000000..90fe7ff --- /dev/null +++ b/man/get_desc_df.estimated_partition.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{get_desc_df.estimated_partition} +\alias{get_desc_df.estimated_partition} +\title{get_desc_df.estimated_partition} +\usage{ +get_desc_df.estimated_partition( + obj, + do_str = TRUE, + drop_unsplit = TRUE, + digits = NULL, + import_order = FALSE +) +} +\arguments{ +\item{obj}{estimated_partition object} + +\item{do_str}{If True, use a string like "(a, b]", otherwise have two separate columns with a and b} + +\item{drop_unsplit}{If True, drop columns for variables overwhich the partition did not split} + +\item{digits}{digits Option (default is NULL)} + +\item{import_order}{should we use importance ordering or input ordering (default)} +} +\value{ +data.frame +} +\description{ +get_desc_df.estimated_partition +} diff --git a/man/get_desc_df.grid_partition.Rd b/man/get_desc_df.grid_partition.Rd new file mode 100644 index 0000000..b00dcde --- /dev/null +++ b/man/get_desc_df.grid_partition.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{get_desc_df.grid_partition} +\alias{get_desc_df.grid_partition} +\title{get_desc_df.grid_partition} +\usage{ +get_desc_df.grid_partition( + partition, + cont_bounds_inf = TRUE, + do_str = FALSE, + drop_unsplit = FALSE, + digits = NULL, + unsplit_cat_star = TRUE +) +} +\arguments{ +\item{partition}{Partition} + +\item{cont_bounds_inf}{If True, will put continuous bounds as -Inf/Inf. Otherwise will use X_range bounds} + +\item{do_str}{If True, use a string like "(a, b]", otherwise have two separate columns with a and b} + +\item{drop_unsplit}{If True, drop columns for variables overwhich the partition did not split} + +\item{digits}{digits option} + +\item{unsplit_cat_star}{if we don't split on a categorical var, should we show as "*" (otherwise list all levels)} +} +\value{ +data.frame +} +\description{ +get_desc_df.grid_partition +} diff --git a/man/get_factor_from_partition.Rd b/man/get_factor_from_partition.Rd new file mode 100644 index 0000000..8670887 --- /dev/null +++ b/man/get_factor_from_partition.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{get_factor_from_partition} +\alias{get_factor_from_partition} +\title{get_factor_from_partition} +\usage{ +get_factor_from_partition(partition, X, X_range = NULL) +} +\arguments{ +\item{partition}{partition} + +\item{X}{X data or list of X} + +\item{X_range}{(Optional) overrides the partition$X_range} +} +\value{ +Factor +} +\description{ +get_factor_from_partition +} diff --git a/man/grid_rf.Rd b/man/grid_rf.Rd new file mode 100644 index 0000000..2a3e46d --- /dev/null +++ b/man/grid_rf.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{grid_rf} +\alias{grid_rf} +\title{grid_rf} +\usage{ +grid_rf(num.trees = 500, num.threads = NULL, dof = 2, resid_est = TRUE) +} +\arguments{ +\item{num.trees}{number of trees in the random forest} + +\item{num.threads}{num.threads} + +\item{dof}{degrees-of-freedom} + +\item{resid_est}{Residualize the Estimation sample (using fit from training)} +} +\value{ +grid_rf object +} +\description{ +grid_rf +} diff --git a/man/is.estimated_partition.Rd b/man/is.estimated_partition.Rd new file mode 100644 index 0000000..1cfa6ea --- /dev/null +++ b/man/is.estimated_partition.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{is.estimated_partition} +\alias{is.estimated_partition} +\title{is.estimated_partition} +\usage{ +is.estimated_partition(x) +} +\arguments{ +\item{x}{Object} +} +\value{ +True if x is an estimated_partition +} +\description{ +is.estimated_partition +} diff --git a/man/is.grid_partition.Rd b/man/is.grid_partition.Rd new file mode 100644 index 0000000..eb0a66d --- /dev/null +++ b/man/is.grid_partition.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{is.grid_partition} +\alias{is.grid_partition} +\title{is.grid_partition} +\usage{ +is.grid_partition(x) +} +\arguments{ +\item{x}{Object} +} +\value{ +True if x is a grid_partition +} +\description{ +is.grid_partition +} diff --git a/man/is.grid_rf.Rd b/man/is.grid_rf.Rd new file mode 100644 index 0000000..70e4b8f --- /dev/null +++ b/man/is.grid_rf.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{is.grid_rf} +\alias{is.grid_rf} +\title{is.grid_rf} +\usage{ +is.grid_rf(x) +} +\arguments{ +\item{x}{Object} +} +\value{ +Boolean +} +\description{ +is.grid_rf +} diff --git a/man/is.lm_X_est.Rd b/man/is.lm_X_est.Rd new file mode 100644 index 0000000..3c98e8b --- /dev/null +++ b/man/is.lm_X_est.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{is.lm_X_est} +\alias{is.lm_X_est} +\title{is.lm_X_est} +\usage{ +is.lm_X_est(x) +} +\arguments{ +\item{x}{Object} +} +\value{ +Boolean +} +\description{ +is.lm_X_est +} diff --git a/man/is.simple_est.Rd b/man/is.simple_est.Rd new file mode 100644 index 0000000..1eb5092 --- /dev/null +++ b/man/is.simple_est.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Estimator_plans.R +\name{is.simple_est} +\alias{is.simple_est} +\title{is.simple_est} +\usage{ +is.simple_est(x) +} +\arguments{ +\item{x}{Object} +} +\value{ +Boolean +} +\description{ +is.simple_est +} diff --git a/man/num_cells.Rd b/man/num_cells.Rd new file mode 100644 index 0000000..f253d2e --- /dev/null +++ b/man/num_cells.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{num_cells} +\alias{num_cells} +\title{num_cells} +\usage{ +num_cells(obj) +} +\arguments{ +\item{obj}{Object} +} +\value{ +Number of cells in partition (at least 1) +} +\description{ +num_cells +} diff --git a/man/num_cells.estimated_partition.Rd b/man/num_cells.estimated_partition.Rd new file mode 100644 index 0000000..736fa31 --- /dev/null +++ b/man/num_cells.estimated_partition.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{num_cells.estimated_partition} +\alias{num_cells.estimated_partition} +\title{num_cells.estimated_partition} +\usage{ +\method{num_cells}{estimated_partition}(obj) +} +\arguments{ +\item{obj}{Estimated Partition} +} +\value{ +Number of cells +} +\description{ +num_cells.estimated_partition +} diff --git a/man/num_cells.grid_partition.Rd b/man/num_cells.grid_partition.Rd new file mode 100644 index 0000000..a4e8e3b --- /dev/null +++ b/man/num_cells.grid_partition.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{num_cells.grid_partition} +\alias{num_cells.grid_partition} +\title{num_cells.grid_partition} +\usage{ +\method{num_cells}{grid_partition}(obj) +} +\arguments{ +\item{obj}{Object} +} +\value{ +Number of cells +} +\description{ +num_cells.grid_partition +} diff --git a/man/plot_2D_partition.estimated_partition.Rd b/man/plot_2D_partition.estimated_partition.Rd new file mode 100644 index 0000000..ff10597 --- /dev/null +++ b/man/plot_2D_partition.estimated_partition.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/graphing.R +\name{plot_2D_partition.estimated_partition} +\alias{plot_2D_partition.estimated_partition} +\title{plot_2D_partition.estimated_partition} +\usage{ +plot_2D_partition.estimated_partition(grid_fit, X_names_2D) +} +\arguments{ +\item{grid_fit}{grid_fit} + +\item{X_names_2D}{X_names_2D} +} +\value{ +ggplot2 object +} +\description{ +plot_2D_partition.estimated_partition +} diff --git a/man/predict_te.estimated_partition.Rd b/man/predict_te.estimated_partition.Rd new file mode 100644 index 0000000..fb14694 --- /dev/null +++ b/man/predict_te.estimated_partition.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{predict_te.estimated_partition} +\alias{predict_te.estimated_partition} +\title{predict_te.estimated_partition} +\usage{ +predict_te.estimated_partition(obj, new_X) +} +\arguments{ +\item{obj}{estimated_partition object} + +\item{new_X}{new X} +} +\value{ +predicted treatment effect +} +\description{ +Predicted unit-level treatment effect +} diff --git a/man/print.estimated_partition.Rd b/man/print.estimated_partition.Rd new file mode 100644 index 0000000..73807b3 --- /dev/null +++ b/man/print.estimated_partition.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit_estimate.R +\name{print.estimated_partition} +\alias{print.estimated_partition} +\title{print.estimated_partition} +\usage{ +\method{print}{estimated_partition}( + x, + do_str = TRUE, + drop_unsplit = TRUE, + digits = NULL, + import_order = FALSE, + ... +) +} +\arguments{ +\item{x}{estimated_partition object} + +\item{do_str}{If True, use a string like "(a, b]", otherwise have two separate columns with a and b} + +\item{drop_unsplit}{If True, drop columns for variables overwhich the partition did not split} + +\item{digits}{digits options} + +\item{import_order}{should we use importance ordering or input ordering (default)} + +\item{...}{Additional arguments. These won't be passed to print.data.frame} +} +\value{ +string (and displayed) +} +\description{ +print.estimated_partition +} diff --git a/man/print.grid_partition.Rd b/man/print.grid_partition.Rd new file mode 100644 index 0000000..b5042fb --- /dev/null +++ b/man/print.grid_partition.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{print.grid_partition} +\alias{print.grid_partition} +\title{print.grid_partition} +\usage{ +\method{print}{grid_partition}(x, do_str = TRUE, drop_unsplit = TRUE, digits = NULL, ...) +} +\arguments{ +\item{x}{partition object} + +\item{do_str}{If True, use a string like "(a, b]", otherwise have two separate columns with a and b} + +\item{drop_unsplit}{If True, drop columns for variables overwhich the partition did not split} + +\item{digits}{digits Option} + +\item{...}{Additional arguments. Passed to data.frame} +} +\value{ +string (and displayed) +} +\description{ +print.grid_partition +} diff --git a/man/print.partition_split.Rd b/man/print.partition_split.Rd new file mode 100644 index 0000000..0d68f72 --- /dev/null +++ b/man/print.partition_split.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{print.partition_split} +\alias{print.partition_split} +\title{print.partition_split} +\usage{ +\method{print}{partition_split}(x, ...) +} +\arguments{ +\item{x}{Object} + +\item{...}{Additional arguments. Unused.} +} +\value{ +None +} +\description{ +print.partition_split +} diff --git a/man/quantile_breaks.Rd b/man/quantile_breaks.Rd new file mode 100644 index 0000000..428764c --- /dev/null +++ b/man/quantile_breaks.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/grid_partition.R +\name{quantile_breaks} +\alias{quantile_breaks} +\title{quantile_breaks} +\usage{ +quantile_breaks(X, binary_k = c(), g = 20, type = 3) +} +\arguments{ +\item{X}{Features} + +\item{binary_k}{vector of dimensions that are binary} + +\item{g}{# of quantiles} + +\item{type}{Quantile type (see ?quantile and https://mathworld.wolfram.com/Quantile.html). +Types1-3 are discrete and this is good for passing to unique() when there are clumps} +} +\value{ +list of potential breaks +} +\description{ +quantile_breaks +} diff --git a/project/.gitignore b/project/.gitignore new file mode 100644 index 0000000..2d19fc7 --- /dev/null +++ b/project/.gitignore @@ -0,0 +1 @@ +*.html diff --git a/project/ct_utils.R b/project/ct_utils.R new file mode 100644 index 0000000..1b7a568 --- /dev/null +++ b/project/ct_utils.R @@ -0,0 +1,175 @@ +library(ggplot2) +library(causalTree) + +rpart_label_component <- function(object) { + #TODO: Didn't copy the part with categorical variables, so this might not work then. Could do that. + #Copied from print.rpart + ff <- object$frame + n <- nrow(ff) + is.leaf <- (ff$var == "") + whichrow <- !is.leaf + index <- cumsum(c(1, ff$ncompete + ff$nsurrogate + !is.leaf)) + irow <- index[c(whichrow, FALSE)] + ncat <- object$splits[irow, 2L] + jrow <- irow[ncat < 2L] + cutpoint <- object$splits[jrow, 4L] + lsplit <- rsplit <- numeric(length(irow)) + lsplit[ncat<2L] <- cutpoint + rsplit[ncat<2L] <- cutpoint + + + vnames <- ff$var[whichrow] + varname <- (as.character(vnames)) + node <- as.numeric(row.names(ff)) + parent <- match(node %/% 2L, node[whichrow]) + odd <- (as.logical(node %% 2L)) + labels_var <- character(n) + labels_num <- numeric(n) + labels_num[odd] <- rsplit[parent[odd]] + labels_num[!odd] <- lsplit[parent[!odd]] + labels_num[1L] <- NA + labels_var[odd] <- varname[parent[odd]] + labels_var[!odd] <- varname[parent[!odd]] + labels_var[1L] <- "root" + list(labels_var, labels_num) +} + +plot_partition.rpart <- function(yvals, varnames, x_min, x_max, y_min, y_max, labels_var, labels_num, depth) { + nnode = length(depth) + if(nnode<=1) { + return(data.frame(xmin=c(x_min), xmax=c(x_max), ymin=c(y_min), ymax=c(y_max), fill=c(yvals[1]))) + } + yvals = yvals[2:nnode] + labels_var = labels_var[2:nnode] + labels_num = labels_num[2:nnode] + depth = depth[2:nnode] + nnode = length(depth) + i1 = which(depth==0)[1] + i2 = which(depth==0)[2] + varname = labels_var[1] + dim = which(varnames==varname)[1] + cutoff = labels_num[1] + if(dim==1) { + x_max1 = cutoff + x_min2 = cutoff + y_max1 = y_max + y_min2 = y_min + } + else { + x_max1 = x_max + x_min2 = x_min + y_max1 = cutoff + y_min2 = cutoff + } + + ret1 = plot_partition.rpart(yvals[1:(i2-1)], varnames, x_min, x_max1, y_min, y_max1, labels_var[1:(i2-1)], labels_num[1:(i2-1)], depth[1:(i2-1)]-1) + ret2 = plot_partition.rpart(yvals[i2:nnode], varnames, x_min2, x_max, y_min2, y_max, labels_var[i2:nnode], labels_num[i2:nnode], depth[i2:nnode]-1) + + return(rbind(ret1, ret2)) +} + + +#' plot_2D_partition.rpart +#' +#' @param cart_fit rpart fit object +#' @param X_range X_range +#' @param cs color_list +#' +#' @return ggplot fig +plot_2D_partition.rpart <- function(cart_fit, X_range) { + #Note: plotmo doesn't work well because it's just a grid and doesn't find the boundaries + varnames = names(cart_fit$ordered) + node <- as.numeric(row.names(cart_fit$frame)) + yvals = cart_fit$frame$yval + depth <- rpart:::tree.depth(node) + list[labels_var, labels_num] <- rpart_label_component(cart_fit) + nnode = length(depth) + rects = plot_partition.rpart(yvals, varnames, X_range[[1]][1], X_range[[1]][2], X_range[[2]][1], X_range[[2]][2], labels_var, labels_num, depth-1) + + plt = ggplot() + + scale_x_continuous(name=varnames[1]) +scale_y_continuous(name=varnames[2]) + + geom_rect(data=rects, mapping=aes(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, fill=fill), color="black") + return(plt) +} + +ct_cv_tree <- function(form, data, treatment, index_tr=NULL, tr_split=NA, split.Honest=TRUE, cv.Honest=TRUE, + minsize=2L, split.Bucket=FALSE, bucketNum=5, xval=10) { + N = nrow(data) + if(is.null(index_tr)) { + if(is.na(tr_split)) tr_split=0.5 + index_tr = sample(N, tr_split*N) + } + #could've done causalTree() and then estimate.causalTree + #fn_ret <- capture.output(ctree<-causalTree(form, data = data, treatment = treatment, + # split.Rule = "CT", cv.option = "CT", split.Honest = T, cv.Honest = T, split.Bucket = F, + # xval = 2, cp = 0, minsize=minsize), + # type="output") #does some random output + #print(fn_ret) + #opcp <- ctree$cptable[,1][which.min(ctree$cptable[,4])] + #ct_opfit <- prune(ctree, opcp) + + split.alpha = if(split.Honest) 0.5 else 1 + + fn_ret <- capture.output(honestTree <- honest.causalTree(form, data = data[index_tr,], treatment = treatment[index_tr], + est_data = data[-index_tr,], + est_treatment = treatment[-index_tr], + split.Rule = "CT", split.Honest = split.Honest, + HonestSampleSize = nrow(data[-index_tr,]), + split.Bucket = split.Bucket, bucketNum=bucketNum, + cv.option = "CT", cv.Honest = cv.Honest, minsize=minsize, + split.alpha=split.alpha, xval=xval)) + #print(fn_ret) + opcp <- honestTree$cptable[,1][which.min(honestTree$cptable[,4])] + opTree <- prune(honestTree, opcp) + + return(opTree) +} + +num_cells.rpart <- function(obj){ + sum(obj$frame[["var"]]=='') +} + +ct_nsplits_by_dim <- function(obj, ndim) { + library(stringr) + strs = paste(obj$frame$var[obj$frame$var!=""]) + int_tbl = table(as.integer(str_sub(strs, start=2))) + ret = rep(0, ndim) + for(k in 1:ndim) { + k_str = as.character(k) + if(k_str %in% names(int_tbl)) + ret[k] = int_tbl[[k_str]] + } + return(ret) +} + +#Just nodes and treatment effects +ct_desc <- function(ct_m, tex_table=TRUE, digits=3) { + ct_m_desc <- capture.output(print(ct_m)) + ct_m_desc = ct_m_desc[-c(1:5)] + new_str = c() + for(i in 1:length(ct_m_desc)) { + non_num_init = str_extract(ct_m_desc[i], paste0("^[ ]*[:digit:]+[)] (root|[:alnum:]*(< |>=))[ ]*")) + nums = as.numeric(str_split(str_sub(ct_m_desc[i], start=str_length(non_num_init)+1, end=str_length(ct_m_desc[i])-2), " ")[[1]]) + node_path = if(i==1) non_num_init else paste(non_num_init, format(nums[1], digits=digits)) + str_effect = format(nums[length(nums)], digits=digits) + is_leaf = str_sub(ct_m_desc[i],start=str_length(ct_m_desc[i]))=="*" + if(tex_table) { + n_spaces = str_length(str_extract(node_path, "^[ ]*")) + node_path = paste0(paste(replicate(n_spaces, "~"), collapse = ""), str_sub(node_path, start=n_spaces)) + if(is_leaf) + new_str[i] = paste0(node_path, " & ", str_effect, " \\\\") + else + new_str[i] = paste0(node_path, " & \\\\") + } + else { + new_str[i] = paste(node_path, str_effect) + if(is_leaf) new_str[i]= paste(new_str[i], "*") + } + } + #cat(file_cont, file=o_fname) + if(tex_table) { + new_str = c("\\begin{tabular}{lr}", " \\hline", "Node & Est. \\\\", " \\hline", new_str, " \\hline", "\\end{tabular}") + } + return(new_str) +} + diff --git a/project/sims.R b/project/sims.R new file mode 100644 index 0000000..80d5f3b --- /dev/null +++ b/project/sims.R @@ -0,0 +1,443 @@ +library(gsubfn) +library(devtools) +suppressPackageStartupMessages(library(data.table)) +do_load_all=F +if(!do_load_all){ + library(CausalGrid) +} else { + #Won't work in parallel + devtools::load_all(".", export_all=FALSE, helpers=FALSE) +} +library(causalTree) +library(doParallel) +library(foreach) +library(xtable) +library(stringr) +library(glmnet) +library(ranger) +library(gridExtra) +library(ggplot2) +source("project/ct_utils.R") +source("tests/dgps.R") + +#Paths +export_dir = "C:/Users/bquist/OneDrive - Microsoft/SubgroupAnalysis/writeup/" #can be turned off below +sim_rdata_fname = "project/sim.RData" +log_file = "project/log.txt" +tbl_export_path = paste0(export_dir, "tables/") +fig_export_path = paste0(export_dir, "figs/") + +#Estimation config +b = 4 +nfolds = 5 +minsize=25 + +#Sim parameters +S = 100 #3 100, TODO: Is this just 1!? +Ns = c(500, 1000) #c(500, 1000) +D = 3 #3 +Ks = c(2, 10, 20) +N_test = 8000 +NN = length(Ns) +NIters = NN*D*S + +good_features = list(c(T, F), c(T, T, rep(F, 8)), c(rep(T, 4), rep(F, 16))) + +#Execution config +n_parallel = 5 +my_seed = 1337 +set.seed(my_seed) +rf.num.threads = 1 #NULL will multi-treatd, doesn't seem to help much with small data + + + +# Helper functions -------- +yX_data <- function(y, X) { + yX = cbind(y, X) + colnames(yX) = c("Y", paste("X", 1:ncol(X), sep="")) + yX = as.data.frame(yX) +} + +sim_ct_fit <- function(y, X, w, tr_sample, honest=FALSE) { + yX = yX_data(y, X) + set.seed(my_seed) + fit = ct_cv_tree("Y~.", data=yX, treatment=w, index_tr=tr_sample, minsize=minsize, split.Bucket=TRUE, bucketNum=b, xval=nfolds, split.Honest=honest, cv.Honest=honest) + attr(fit$terms, ".Environment") <- NULL #save space other captures environment + return(fit) +} + +sim_ct_predict_te <- function(obj, y_te, X_te) { + yX_te = yX_data(y_te, X_te) + return(predict(obj, newdata=yX_te, type="vector")) +} + +sim_eval_ct <- function(data1, data2, good_mask, honest=FALSE) { + list[y, X, w, tau] = data1 + N = nrow(X)/2 + tr_sample = c(rep(TRUE, N), rep(FALSE, N)) + ct_fit = sim_ct_fit(y, X, w, tr_sample, honest=honest) + nl = num_cells(ct_fit) + nsplits_by_dim = ct_nsplits_by_dim(ct_fit, ncol(X)) + ngood = sum(nsplits_by_dim[good_mask]) + ntot = sum(nsplits_by_dim) + list[y_te, X_te, w_te, tau_te] = data2 + mse = mean((tau_te - sim_ct_predict_te(ct_fit, y_te, X_te))^2) + return(list(ct_fit, nl, mse, ngood, ntot)) +} + +sim_cg_fit <- function(y, X, w, tr_sample, verbosity=0, honest=FALSE, do_rf=FALSE, num.threads=rf.num.threads, ...) { + set.seed(my_seed) + if(do_rf) { + return(fit_estimate_partition(y, X, d=w, tr_split=tr_sample, cv_folds=nfolds, verbosity=verbosity, min_size=2*minsize, max_splits=10, bucket_min_d_var=TRUE, bucket_min_n=2*b, honest=honest, ctrl_method=grid_rf(num.threads=num.threads), ...)) + } + else { + return(fit_estimate_partition(y, X, d=w, tr_split=tr_sample, cv_folds=nfolds, verbosity=verbosity, min_size=2*minsize, max_splits=10, bucket_min_d_var=TRUE, bucket_min_n=2*b, honest=honest, ...)) + } +} + +sim_cg_vectors <- function(grid_fit, mask, X_te, tau_te) { + nsplits = grid_fit$partition$nsplits_by_dim + preds = predict_te.estimated_partition(grid_fit, new_X=X_te) + cg_mse = mean((preds - tau_te)^2) + return(c(num_cells(grid_fit), cg_mse, sum(nsplits[mask]), sum(nsplits))) +} + +ct_ndim_splits <- function(obj, K) { + return(sum(ct_nsplits_by_dim(obj, K)>0)) +} + +cg_ndim_splits <- function(obj) { + return(sum(obj$partition$nsplits_by_dim>0)) +} + +# Generate Data -------- + +cat("Generating Data\n") +data1 = list() +data2 = list() +for(d in 1:D) { + for(N_i in 1:NN) { + N = Ns[N_i] + for(s in 1:S){ + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + data1[[iter]] = AI_sim(n=2*N, design=d) + data2[[iter]] = AI_sim(n=N_test, design=d) + } + } +} + +# Eval CT ------ +cat("Eval CT\n") + +results_ct_h = matrix(0,nrow=0, ncol=7) +results_ct_a = matrix(0,nrow=0, ncol=7) +ct_h_fit_models = list() +ct_a_fit_models = list() +for(d in 1:D) { + for(N_i in 1:NN) { + for(s in 1:S){ + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + list[ct_h_fit, nl_h, mse_h, ngood_h, ntot_h] = sim_eval_ct(data1[[iter]], data2[[iter]], good_features[[d]], honest=T) + results_ct_h = rbind(results_ct_h, c(d, N_i, s, nl_h, mse_h, ngood_h, ntot_h)) + ct_h_fit_models[[iter]] = ct_h_fit + list[ct_a_fit, nl, mse, ngood, ntot] = sim_eval_ct(data1[[iter]], data2[[iter]], good_features[[d]]) + results_ct_a = rbind(results_ct_a, c(d, N_i, s, nl, mse, ngood, ntot)) + ct_a_fit_models[[iter]] = ct_a_fit + } + } +} +ct_a_nl = results_ct_a[,4] +ct_h_nl = results_ct_h[,4] + +# Eval CG ----- +cat("Eval CG\n") + + +if(n_parallel>1) { + if(file.exists(log_file)) file.remove(log_file) + cl <- makeCluster(n_parallel, outfile=log_file) + registerDoParallel(cl) +} + +bar_length = if(n_parallel>1) NN*D else S*NN*D +t1 = Sys.time() +cat(paste("Start time: ",t1,"\n")) +pb = utils::txtProgressBar(0, bar_length, style = 3) +run=1 + +outer_results = list() +for(d in 1:D) { + for(N_i in 1:NN) { + results_s = list() + if(n_parallel>1) { + utils::setTxtProgressBar(pb, run) + run = run+1 + } + #for(s in 1:S){ #Non-Foreach + results_s = foreach(s=1:S, .packages=c("proto","gsubfn","rpart", "rpart.plot", "data.table","causalTree", "ranger", "lattice", "ggplot2", "caret", "Matrix", "foreach", "CausalGrid"), .errorhandling = "pass") %dopar% { #, .combine=rbind + if(n_parallel==1) { + #utils::setTxtProgressBar(pb, run) + run = run+1 + } + res = c(s) + N = Ns[N_i] + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + list[y, X, w, tau] = data1[[iter]] + list[y_te, X_te, w_te, tau_te] = data2[[iter]] + tr_sample = c(rep(TRUE, N), rep(FALSE, N)) + + + grid_a_fit <- sim_cg_fit(y, X, w, tr_sample, honest=FALSE) + grid_a_LassoCV_fit <- sim_cg_fit(y, X, w, tr_sample, honest=FALSE, ctrl_method="LassoCV") + grid_a_RF_fit <- sim_cg_fit(y, X, w, tr_sample, honest=FALSE, do_rf=TRUE) + + grid_a_m_fit <- change_complexity(grid_a_fit, y, X, d=w, which.min(abs(ct_a_nl[iter] - (grid_a_fit$complexity_seq + 1)))) + grid_a_LassoCV_m_fit <- change_complexity(grid_a_LassoCV_fit, y, X, d=w, which.min(abs(ct_a_nl[iter] - (grid_a_LassoCV_fit$complexity_seq + 1)))) + grid_a_RF_m_fit <- change_complexity(grid_a_RF_fit, y, X, d=w, which.min(abs(ct_a_nl[iter] - (grid_a_RF_fit$complexity_seq + 1)))) + + res = c(res, sim_cg_vectors(grid_a_fit, good_features[[d]], X_te, tau_te)) + res = c(res, sim_cg_vectors(grid_a_LassoCV_fit, good_features[[d]], X_te, tau_te)) + res = c(res, sim_cg_vectors(grid_a_RF_fit, good_features[[d]], X_te, tau_te)) + res = c(res, sim_cg_vectors(grid_a_m_fit, good_features[[d]], X_te, tau_te)) + res = c(res, sim_cg_vectors(grid_a_LassoCV_m_fit, good_features[[d]], X_te, tau_te)) + res = c(res, sim_cg_vectors(grid_a_RF_m_fit, good_features[[d]], X_te, tau_te)) + + #Save space + grid_a_RF_m_fit$est_plan$rf_y_fit <- grid_a_RF_m_fit$est_plan$rf_d_fit <- NULL + res = list(grid_a_fit, grid_a_LassoCV_fit, grid_a_RF_m_fit, res) + + #results_s[[s]] = res #Non-Foreach + res #Foreach + } + outer_results[[(d-1)*NN + (N_i-1) + 1]] = results_s + } +} + +t2 = Sys.time() #can us as.numeric(t1) to convert to seconds +td = t2-t1 +close(pb) +cat(paste("Total time: ",format(as.numeric(td))," ", attr(td,"units"),"\n")) + +if(n_parallel>1) stopCluster(cl) + + +# Collect results ---- + +cg_a_fit_models = list() +cg_a_LassoCV_fit_models = list() +cg_a_RF_fit_models = list() +results_cg_a = matrix(0,nrow=0, ncol=7) +results_cg_a_LassoCV = matrix(0,nrow=0, ncol=7) +results_cg_a_RF = matrix(0,nrow=0, ncol=7) +results_cg_a_m = matrix(0,nrow=0, ncol=7) +results_cg_a_LassoCV_m = matrix(0,nrow=0, ncol=7) +results_cg_a_RF_m = matrix(0,nrow=0, ncol=7) +n_errors = matrix(0, nrow=D, ncol=NN) +for(d in 1:D) { + for(N_i in 1:NN) { + results_s = outer_results[[(d-1)*NN + (N_i-1) + 1]] + for(s in 1:S) { + res = results_s[[s]] + if(inherits(res, "error")) { + cat(paste("Error: d=", d, "N_i=", N_i, "s=", s, "\n")) + n_errors[d,N_i] = n_errors[d,N_i] + 1 + next + } + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + cg_a_fit_models[[iter]] = res[[1]] + cg_a_LassoCV_fit_models[[iter]] = res[[2]] + cg_a_RF_fit_models[[iter]] = res[[3]] + res = res[[4]] + #s = res[1] + results_cg_a = rbind(results_cg_a, c(d, N_i, s, res[2:5])) + results_cg_a_LassoCV = rbind(results_cg_a_LassoCV, c(d, N_i, s, res[6:9])) + results_cg_a_RF = rbind(results_cg_a_RF, c(d, N_i, s, res[10:13])) + results_cg_a_m = rbind(results_cg_a_m, c(d, N_i, s, res[14:17])) + results_cg_a_LassoCV_m = rbind(results_cg_a_LassoCV_m, c(d, N_i, s, res[18:21])) + results_cg_a_RF_m = rbind(results_cg_a_RF_m, c(d, N_i, s, res[22:25])) + } + } +} +if(sum(n_errors)>0){ + cat("N errors") + print(n_errors) +} + + +if(F){ #Output raw results + save(S, Ns, D, data1, data2, ct_h_fit_models, ct_a_fit_models, cg_a_fit_models, cg_a_LassoCV_fit_models, cg_a_RF_fit_models, + results_ct_h, results_ct_a, results_cg_a, results_cg_a_LassoCV, results_cg_a_RF, results_cg_a_m, results_cg_a_LassoCV_m, results_cg_a_RF_m, + file = sim_rdata_fname) +} +if(F){ + load(sim_rdata_fname) + ct_a_nl = results_ct_a[,4] + ct_h_nl = results_ct_h[,4] +} + +# Output Results ------ + + +sum_res = function(full_res) { + nl = matrix(0, nrow=D, ncol=NN) + mse = matrix(0, nrow=D, ncol=NN) + pct_good = matrix(0, nrow=D, ncol=NN) + for(d in 1:D) { + for(N_i in 1:NN) { + start = (d-1)*NN*S + (N_i-1)*S + 1 + end = (d-1)*NN*S + (N_i-1)*S + S + nl[d,N_i] = mean(full_res[start:end,4]) + mse[d,N_i] = mean(full_res[start:end,5]) + pct_good[d,N_i] = sum(full_res[start:end,6])/sum(full_res[start:end,7]) + } + } + return(list(nl, mse, pct_good)) +} + +list[nl_CT_h, mse_CT_h, pct_good_CT_h] = sum_res(results_ct_h) +list[nl_CT_a, mse_CT_a, pct_good_CT_a] = sum_res(results_ct_a) +list[nl_CG_a, mse_CG_a, pct_good_CG_a] = sum_res(results_cg_a) +list[nl_CG_a_LassoCV, mse_CG_a_LassoCV, pct_good_CG_a_LassoCV] = sum_res(results_cg_a_LassoCV) +list[nl_CG_a_RF, mse_CG_a_RF, pct_good_CG_a_RF] = sum_res(results_cg_a_RF) +list[nl_CG_a_m, mse_CG_a_m, pct_good_CG_a_m] = sum_res(results_cg_a_m) +list[nl_CG_a_LassoCV_m, mse_CG_a_LassoCV_m, pct_good_CG_a_LassoCV_m] = sum_res(results_cg_a_LassoCV_m) +list[nl_CG_a_RF_m, mse_CG_a_RF_m, pct_good_CG_a_RF_m] = sum_res(results_cg_a_RF_m) + + +flatten_table <- function(mat) { + new_mat = cbind(mat[1,, drop=F], mat[2,, drop=F], mat[3,, drop=F]) + colnames(new_mat) = c("N=500", "N=1000", "N=500", "N=1000", "N=500", "N=1000") + new_mat +} +compose_table <- function(mat_CT_h, mat_CT_a, mat_CG_a, mat_CG_a_LassoCV, mat_CG_a_RF) { + new_mat = rbind(flatten_table(mat_CT_h), flatten_table(mat_CT_a), flatten_table(mat_CG_a), flatten_table(mat_CG_a_LassoCV), flatten_table(mat_CG_a_RF)) + rownames(new_mat) = c("Causal Tree (CT)", "Causal Tree - Adaptive (CT-A)", "Causal Grid (CG)", "Causal Grid w/ Linear Controls (CG-X)", "Causal Grid w/ RF Controls (CG-RF)") + new_mat +} + +fmt_table <- function(xtbl, o_fname) { + capt_ret <- capture.output(file_cont <- print(xtbl, floating=F, comment = F)) + file_cont = paste0(str_sub(file_cont, end=35), " & \\multicolumn{2}{c}{Design 1} & \\multicolumn{2}{c}{Design 2} & \\multicolumn{2}{c}{Design 3}\\\\ \n", str_sub(file_cont, start=36)) + cat(file_cont, file=o_fname) +} + +n_cells_comp = compose_table(nl_CT_h, nl_CT_a, nl_CG_a, nl_CG_a_LassoCV, nl_CG_a_RF) +mse_comp = compose_table(mse_CT_h, mse_CT_a, mse_CG_a, mse_CG_a_LassoCV, mse_CG_a_RF) +ratio_good_comp = compose_table(pct_good_CT_h, pct_good_CT_a, pct_good_CG_a, pct_good_CG_a_LassoCV_m, pct_good_CG_a_RF_m) + +if(F){ #Output tables + fmt_table(xtable(n_cells_comp, digits=2), paste0(tbl_export_path, "n_cells.tex")) + fmt_table(xtable(mse_comp, digits=3), paste0(tbl_export_path, "mse.tex")) + fmt_table(xtable(ratio_good_comp), paste0(tbl_export_path, "ratio_good.tex")) +} + +# Output examples ------------ +ct_sim_plot <- function(d, N_i, s, honest=TRUE) { + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + list[y, X, w, tau] = data1[[iter]] + X_range = get_X_range(X) + if(honest) + plt = plot_2D_partition.rpart(ct_h_fit_models[[iter]], X_range=X_range) + else + plt = plot_2D_partition.rpart(ct_a_fit_models[[iter]], X_range=X_range) + return(plt + ggtitle("Causal Tree") + labs(fill = "tau(X)")) +} +cg_sim_plot <- function(d, N_i, s, lasso=FALSE) { + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + list[y, X, w, tau] = data1[[iter]] + N = Ns[N_i] + X_range = get_X_range(X) + if(lasso) + grid_fit = cg_a_LassoCV_fit_models[[iter]] + else + grid_fit = cg_a_fit_models[[iter]] + grid_a_m_fit <- change_complexity(grid_fit, y, X, d=w, which.min(abs(ct_h_nl[iter] - (grid_fit$complexity_seq + 1)))) + plt = plot_2D_partition.estimated_partition(grid_a_m_fit, c("X1", "X2")) + return(plt + ggtitle("Causal Tree") + labs(fill = "tau(X)")) +} +ct_cg_plot <- function(d, N_i, s, ct_honest=TRUE, cg_lasso=FALSE) { + ct_plt = ct_sim_plot(d, N_i, s, honest=ct_honest) + cg_plt = cg_sim_plot(d, N_i, s, lasso=cg_lasso) + grid.arrange(ct_plt + ggtitle("Causal Tree"), cg_plt + ggtitle("Causal Grid"), ncol=2) +} + +if(F) { + #Pick which one to show. + ct_a_nl = results_ct_a[,4] + ct_h_nl = results_ct_h[,4] + n_dim_splits = matrix(0,nrow=D*NN*S, ncol=6) + for(d in 1:D) { + k = Ks[d] + for(N_i in 1:NN) { + for(s in 1:S) { + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + N = Ns[N_i] + list[y, X, w, tau] = data1[[iter]] + #list[y_te, X_te, w_te, tau_te] = data2[[iter]] + tr_sample = c(rep(TRUE, N), rep(FALSE, N)) + + grid_a_fit <- cg_a_fit_models[[iter]] + grid_a_LassoCV_fit <- cg_a_LassoCV_fit_models[[iter]] + grid_a_m_fit <- change_complexity(grid_a_fit, y, X, d=w, which.min(abs(ct_h_nl[iter] - (grid_a_fit$complexity_seq + 1)))) + grid_a_LassoCV_m_fit <- change_complexity(grid_a_LassoCV_fit, y, X, d=w, which.min(abs(ct_h_nl[iter] - (grid_a_LassoCV_fit$complexity_seq + 1)))) + + n_dim_splits[iter, 1] = ct_ndim_splits(ct_h_fit_models[[iter]], k) + n_dim_splits[iter, 2] = ct_ndim_splits(ct_a_fit_models[[iter]], k) + n_dim_splits[iter, 3] = cg_ndim_splits(grid_a_fit) + n_dim_splits[iter, 4] = cg_ndim_splits(grid_a_LassoCV_fit) + n_dim_splits[iter, 5] = cg_ndim_splits(grid_a_m_fit) + n_dim_splits[iter, 6] = cg_ndim_splits(grid_a_LassoCV_m_fit) + } + } + } + #const_mask = results_ct_h[,1]==2 & results_ct_h[,4]>=4 & n_dim_splits[,1]==2 & n_dim_splits[,5]==2 #cg: normal + #grph_pick = cbind(results_ct_h[const_mask,c(1,2,3, 4)],results_cg_a_m[const_mask,c(4)]) + const_mask = results_ct_h[,1]==2 & results_ct_h[,4]>=4 & n_dim_splits[,1]==2 & n_dim_splits[,6]==2 #cg: Lasso + grph_pick = cbind(results_ct_h[const_mask,c(1,2,3, 4)],results_cg_a_LassoCV_m[const_mask,c(4)]) + grph_pick + ct_cg_plot(2, 1, 57, ct_honest=TRUE, cg_lasso=TRUE) +} + +cg_table <- function(obj, digits=3){ + stats = obj$cell_stats$stats[c("param_est")] + colnames(stats) <- c("Est.") + tbl = cbind(get_desc_df(obj$partition, do_str=TRUE, drop_unsplit=TRUE, digits=digits), stats) +} + +ggplot_to_pdf <-function(plt_obj, filename) { + pdf(file=filename) + print(plt_obj) + dev.off() +} + +if(F){ + d=2 #so we can hve a 2d model + N_i=1 + s=57 + iter = (d-1)*NN*S + (N_i-1)*S + (s-1) + 1 + list[y, X, w, tau] = data1[[iter]] + X_range = get_X_range(X) + + ct_m = ct_h_fit_models[[iter]] + ct_m_desc = ct_desc(ct_m) + cat(ct_m_desc, file=paste0(tbl_export_path, "ct_ex_2d.tex"), sep="\n") + ct_plt = plot_2D_partition.rpart(ct_m, X_range=X_range) + labs(fill = "tau(X)") + print(ct_plt + ggtitle("Causal Tree")) + + grid_fit = cg_a_LassoCV_fit_models[[iter]] #cg_a_fit_models[[iter]] + cg_m <- change_complexity(grid_fit, y, X, d=w, which.min(abs(ct_h_nl[iter] - (grid_fit$complexity_seq + 1)))) + print(cg_m) + cg_m_tbl = cg_table(cg_m) + temp <- capture.output(cg_tbl_file_cont <- print(xtable(cg_m_tbl, digits=3), floating=F, comment = F)) + cat(cg_tbl_file_cont, file=paste0(tbl_export_path, "cg_ex_2d.tex")) + cg_plt = plot_2D_partition.estimated_partition(cg_m, c("X1", "X2")) + labs(fill = "tau(X)") + print(cg_plt + ggtitle("Causal Grid")) + + ggplot_to_pdf(ct_plt + ggtitle("Causal Tree"), paste0(fig_export_path, "ct_ex_2d.pdf")) + ggplot_to_pdf(cg_plt + ggtitle("Causal Grid"), paste0(fig_export_path, "cg_ex_2d.pdf")) + + pdf(file=paste0(fig_export_path, "cg_ct_ex_2d.pdf"), width=8, height=4) + grid.arrange(ct_plt + ggtitle("Causal Tree"), cg_plt + ggtitle("Causal Grid"), ncol=2) + dev.off() +} + diff --git a/renv.lock b/renv.lock new file mode 100644 index 0000000..c825c7f --- /dev/null +++ b/renv.lock @@ -0,0 +1,953 @@ +{ + "R": { + "Version": "3.5.2", + "Repositories": [ + { + "Name": "CRAN", + "URL": "https://cran.r-project.org" + }, + { + "Name": "MRAN", + "URL": "https://mran.microsoft.com/snapshot/2019-04-15" + } + ] + }, + "Packages": { + "BH": { + "Package": "BH", + "Version": "1.69.0-1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "f4605d46264b35f53072fc9ee7ace15f" + }, + "DT": { + "Package": "DT", + "Version": "0.15", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "85738c69035e67ec4b484a5e02640ef6" + }, + "KernSmooth": { + "Package": "KernSmooth", + "Version": "2.23-15", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "081f417f4d6d55b7e8981433e8404a22" + }, + "MASS": { + "Package": "MASS", + "Version": "7.3-51.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c64a9edaef0658e36905934c5a7aa499" + }, + "Matrix": { + "Package": "Matrix", + "Version": "1.2-15", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ffbff17356922d442a4d6fab32e2bc96" + }, + "ModelMetrics": { + "Package": "ModelMetrics", + "Version": "1.2.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d85508dd2162bf34aaf15d6a022e42e5" + }, + "R6": { + "Package": "R6", + "Version": "2.4.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d4618c4cb36bc46831551c5d85815818" + }, + "RColorBrewer": { + "Package": "RColorBrewer", + "Version": "1.1-2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "e031418365a7f7a766181ab5a41a5716" + }, + "Rcpp": { + "Package": "Rcpp", + "Version": "1.0.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d5ae8445d4972caed1c5517ffae908d7" + }, + "RcppEigen": { + "Package": "RcppEigen", + "Version": "0.3.3.7.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c6faf038ba4346b1de19ad7c99b8f94a" + }, + "RcppRoll": { + "Package": "RcppRoll", + "Version": "0.3.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "84a03997fbb5acfb3c9b43bad88fea1f" + }, + "SQUAREM": { + "Package": "SQUAREM", + "Version": "2017.10-1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "5f2dab05aaaf51d7f87cf7ecbbe07541" + }, + "askpass": { + "Package": "askpass", + "Version": "1.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "e8a22846fff485f0be3770c2da758713" + }, + "assertthat": { + "Package": "assertthat", + "Version": "0.2.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "50c838a310445e954bc13f26f26a6ecf" + }, + "backports": { + "Package": "backports", + "Version": "1.1.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "b9a4db2667d7e43d197ed12e40781889" + }, + "base64enc": { + "Package": "base64enc", + "Version": "0.1-3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "543776ae6848fde2f48ff3816d0628bc" + }, + "brew": { + "Package": "brew", + "Version": "1.0-6", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "92a5f887f9ae3035ac7afde22ba73ee9" + }, + "callr": { + "Package": "callr", + "Version": "3.4.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "e56fe17ffeddfdcfcef40981e41e1c40" + }, + "caret": { + "Package": "caret", + "Version": "6.0-82", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "5792a55fc1f6adb4fe46924a449a579a" + }, + "causalTree": { + "Package": "causalTree", + "Version": "0.0", + "Source": "GitHub", + "RemoteType": "github", + "RemoteHost": "api.github.com", + "RemoteRepo": "causalTree", + "RemoteUsername": "susanathey", + "RemoteRef": "master", + "RemoteSha": "48604762b7db547f49e0e50460eb31a344933bba", + "Hash": "fa7f48ab9d73169a37244b40a307df8c" + }, + "class": { + "Package": "class", + "Version": "7.3-15", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "4fba6a022803b6c3f30fd023be3fa818" + }, + "cli": { + "Package": "cli", + "Version": "2.0.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ff0becff7bfdfe3f75d29aff8f3172dd" + }, + "clipr": { + "Package": "clipr", + "Version": "0.5.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "f470b4aeb573f770fea6ced401c7fb39" + }, + "codetools": { + "Package": "codetools", + "Version": "0.2-16", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "89cf4b8207269ccf82fbeb6473fd662b" + }, + "colorspace": { + "Package": "colorspace", + "Version": "1.4-1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6b436e95723d1f0e861224dd9b094dfb" + }, + "commonmark": { + "Package": "commonmark", + "Version": "1.7", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0f22be39ec1d141fd03683c06f3a6e67" + }, + "covr": { + "Package": "covr", + "Version": "3.5.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6d80a9fc3c0c8473153b54fa54719dfd" + }, + "crayon": { + "Package": "crayon", + "Version": "1.3.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0d57bc8e27b7ba9e45dba825ebc0de6b" + }, + "crosstalk": { + "Package": "crosstalk", + "Version": "1.1.0.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ae55f5d7c02f0ab43c58dd050694f2b4" + }, + "curl": { + "Package": "curl", + "Version": "3.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c71bf321a357db97242bc233a1f99a55" + }, + "data.table": { + "Package": "data.table", + "Version": "1.12.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "2acfeb805afc84b919dcbe1f32a23529" + }, + "desc": { + "Package": "desc", + "Version": "1.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6c8fe8fa26a23b79949375d372c7b395" + }, + "devtools": { + "Package": "devtools", + "Version": "2.3.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "415656f50722f5b6e6bcf80855ce11b9" + }, + "digest": { + "Package": "digest", + "Version": "0.6.18", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "496dd262e1ec64e452151479a74c972f" + }, + "doParallel": { + "Package": "doParallel", + "Version": "1.0.14", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "3335c17fb0a900813001058f1ce35fc4" + }, + "dplyr": { + "Package": "dplyr", + "Version": "0.8.5", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "57a42ddf80f429764ff7987128c3fd0a" + }, + "ellipsis": { + "Package": "ellipsis", + "Version": "0.3.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "fd2844b3a43ae2d27e70ece2df1b4e2a" + }, + "evaluate": { + "Package": "evaluate", + "Version": "0.13", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "45057c310ad47bb712f8b6c2cc72a0cd" + }, + "fansi": { + "Package": "fansi", + "Version": "0.4.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "b31d9e5d051553d1177083aeba04b5b9" + }, + "foreach": { + "Package": "foreach", + "Version": "1.4.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c179d1dd8abd4b888214d44f4de2359a" + }, + "fs": { + "Package": "fs", + "Version": "1.5.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "44594a07a42e5f91fac9f93fda6d0109" + }, + "generics": { + "Package": "generics", + "Version": "0.0.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "b8cff1d1391fd1ad8b65877f4c7f2e53" + }, + "gglasso": { + "Package": "gglasso", + "Version": "1.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "f41ac9fff68996bab6243ab2b07f9ca6" + }, + "ggplot2": { + "Package": "ggplot2", + "Version": "3.3.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "4ded8b439797f7b1693bd3d238d0106b" + }, + "gh": { + "Package": "gh", + "Version": "1.1.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "89ea5998938d1ad55f035c8a86f96b74" + }, + "git2r": { + "Package": "git2r", + "Version": "0.25.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "b0b62a21371dd846b4f790ebf279706f" + }, + "glmnet": { + "Package": "glmnet", + "Version": "2.0-16", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ac94187f0f9c5cf9634887f597726615" + }, + "glue": { + "Package": "glue", + "Version": "1.4.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "2aefa994e8df5da17dc09afd80f924d5" + }, + "gower": { + "Package": "gower", + "Version": "0.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a42dbfd0520f16ec2a9e513969656ead" + }, + "gridExtra": { + "Package": "gridExtra", + "Version": "2.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "7d7f283939f563670a697165b2cf5560" + }, + "gsubfn": { + "Package": "gsubfn", + "Version": "0.7", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a8ebb0bb0edcf041a7649ee43a0e1735" + }, + "gtable": { + "Package": "gtable", + "Version": "0.3.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ac5c6baf7822ce8732b343f14c072c4d" + }, + "highr": { + "Package": "highr", + "Version": "0.8", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "4dc5bb88961e347a0f4d8aad597cbfac" + }, + "htmltools": { + "Package": "htmltools", + "Version": "0.5.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "7d651b7131794fe007b1ad6f21aaa401" + }, + "htmlwidgets": { + "Package": "htmlwidgets", + "Version": "1.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0c8df16eba2c955487aad63a7e7051a6" + }, + "httr": { + "Package": "httr", + "Version": "1.4.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a525aba14184fec243f9eaec62fbed43" + }, + "ini": { + "Package": "ini", + "Version": "0.3.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6154ec2223172bce8162d4153cda21f7" + }, + "ipred": { + "Package": "ipred", + "Version": "0.9-8", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a89adbc5ac6bd1c9a1cb1ff5341fee4d" + }, + "isoband": { + "Package": "isoband", + "Version": "0.2.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "9b2f7cf1899f583a36d367702ecf49a3" + }, + "iterators": { + "Package": "iterators", + "Version": "1.0.10", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "02caec9a169f9344577950df8f70aaa8" + }, + "jsonlite": { + "Package": "jsonlite", + "Version": "1.7.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "1ec84e070b88b37ed169f19def40d47c" + }, + "knitr": { + "Package": "knitr", + "Version": "1.22", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d3085a2c6c75da96ad333143dcc35ce8" + }, + "labeling": { + "Package": "labeling", + "Version": "0.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "73832978c1de350df58108c745ed0e3e" + }, + "later": { + "Package": "later", + "Version": "1.1.0.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d0a62b247165aabf397fded504660d8a" + }, + "lattice": { + "Package": "lattice", + "Version": "0.20-38", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "848f8c593fd1050371042d18d152e3d7" + }, + "lava": { + "Package": "lava", + "Version": "1.6.5", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "816194cdf187672306a858a0e822350e" + }, + "lazyeval": { + "Package": "lazyeval", + "Version": "0.2.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d908914ae53b04d4c0c0fd72ecc35370" + }, + "lifecycle": { + "Package": "lifecycle", + "Version": "0.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "361811f31f71f8a617a9a68bf63f1f42" + }, + "lubridate": { + "Package": "lubridate", + "Version": "1.7.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "796afeea047cda6bdb308d374a33eeb6" + }, + "magrittr": { + "Package": "magrittr", + "Version": "1.5", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "1bb58822a20301cee84a41678e25d9b7" + }, + "markdown": { + "Package": "markdown", + "Version": "0.9", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6e05029d16b0df430ab2d31d151a3ac2" + }, + "memoise": { + "Package": "memoise", + "Version": "1.1.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "58baa74e4603fcfb9a94401c58c8f9b1" + }, + "mgcv": { + "Package": "mgcv", + "Version": "1.8-28", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "aa301a255aac625db12ee1793bd79265" + }, + "mime": { + "Package": "mime", + "Version": "0.6", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "65dd22e780565119a78036189cb3b885" + }, + "munsell": { + "Package": "munsell", + "Version": "0.5.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6dfe8bf774944bd5595785e3229d8771" + }, + "nlme": { + "Package": "nlme", + "Version": "3.1-137", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "4320ab595f7bbff5458bc6a044a57fc0" + }, + "nnet": { + "Package": "nnet", + "Version": "7.3-12", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "68287aec1f476c41d16ce1ace445800c" + }, + "numDeriv": { + "Package": "numDeriv", + "Version": "2016.8-1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "e3554c342c94ffc1095d6488e6521cd6" + }, + "openssl": { + "Package": "openssl", + "Version": "1.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "2a9e2d199c54f6061aba18976e958b1c" + }, + "pbapply": { + "Package": "pbapply", + "Version": "1.4-3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "52f8028b028076bc3b7ee5d6251abf0d" + }, + "pillar": { + "Package": "pillar", + "Version": "1.4.6", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "bdf26e55ccb7df3e49a490150277f002" + }, + "pkgbuild": { + "Package": "pkgbuild", + "Version": "1.1.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "404684bc4e3685007f9720adf13b06c1" + }, + "pkgconfig": { + "Package": "pkgconfig", + "Version": "2.0.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "f5940986fb19bcef52284068baeb3f29" + }, + "pkgload": { + "Package": "pkgload", + "Version": "1.1.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "b6b150cd4709e0c0c9b5d51ac4376282" + }, + "plogr": { + "Package": "plogr", + "Version": "0.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "09eb987710984fc2905c7129c7d85e65" + }, + "plyr": { + "Package": "plyr", + "Version": "1.8.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "23026346e3e0f023f326919320627a01" + }, + "praise": { + "Package": "praise", + "Version": "1.0.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a555924add98c99d2f411e37e7d25e9f" + }, + "prettyunits": { + "Package": "prettyunits", + "Version": "1.0.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "f3c960f0105f2ed179460864979fc37c" + }, + "processx": { + "Package": "processx", + "Version": "3.4.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "03446ed0b8129916f73676726cb3c48f" + }, + "prodlim": { + "Package": "prodlim", + "Version": "2018.04.18", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "dede5cafa9509f68d39368bd4526f36b" + }, + "promises": { + "Package": "promises", + "Version": "1.1.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a8730dcbdd19f9047774909f0ec214a4" + }, + "proto": { + "Package": "proto", + "Version": "1.0.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "5cb1623df69ee6102d011c7f78f5791d" + }, + "ps": { + "Package": "ps", + "Version": "1.3.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "919a32c940a25bc95fd464df9998a6ba" + }, + "purrr": { + "Package": "purrr", + "Version": "0.3.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "97def703420c8ab10d8f0e6c72101e02" + }, + "ranger": { + "Package": "ranger", + "Version": "0.12.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "561326df07a5bc5266ba17ce3b81cbf1" + }, + "rcmdcheck": { + "Package": "rcmdcheck", + "Version": "1.3.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ed95895886dab6d2a584da45503555da" + }, + "recipes": { + "Package": "recipes", + "Version": "0.1.5", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0e9821b4db35f76c5be32c7acc745972" + }, + "rematch2": { + "Package": "rematch2", + "Version": "2.1.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "76c9e04c712a05848ae7a23d2f170a40" + }, + "remotes": { + "Package": "remotes", + "Version": "2.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "430a0908aee75b1fcba0e62857cab0ce" + }, + "renv": { + "Package": "renv", + "Version": "0.12.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "7340c71f46a0fd16506cfa804e224e44" + }, + "reshape2": { + "Package": "reshape2", + "Version": "1.4.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "15a23ad30f51789188e439599559815c" + }, + "rex": { + "Package": "rex", + "Version": "1.1.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6d3dbb5d528c8f726861018472bc668c" + }, + "rlang": { + "Package": "rlang", + "Version": "0.4.7", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c06d2a6887f4b414f8e927afd9ee976a" + }, + "rmarkdown": { + "Package": "rmarkdown", + "Version": "2.5", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "20a0a94af9e8f7040510447763aab3e9" + }, + "roxygen2": { + "Package": "roxygen2", + "Version": "7.1.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "fcd94e00cc409b25d07ca50f7bf339f5" + }, + "rpart": { + "Package": "rpart", + "Version": "4.1-13", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6315535da80d5cc6c2e573966d8c8210" + }, + "rpart.plot": { + "Package": "rpart.plot", + "Version": "3.0.7", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ab5aa89ca83659bd61b649866af1c9e0" + }, + "rprojroot": { + "Package": "rprojroot", + "Version": "1.3-2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "f6a407ae5dd21f6f80a6708bbb6eb3ae" + }, + "rstudioapi": { + "Package": "rstudioapi", + "Version": "0.11", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "33a5b27a03da82ac4b1d43268f80088a" + }, + "rversions": { + "Package": "rversions", + "Version": "2.0.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0ec41191f744d0f5afad8c6f35cc36e4" + }, + "scales": { + "Package": "scales", + "Version": "1.0.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "2e46c8ab2c109085d5b3a775ea2df19c" + }, + "sessioninfo": { + "Package": "sessioninfo", + "Version": "1.1.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "308013098befe37484df72c39cf90d6e" + }, + "stringi": { + "Package": "stringi", + "Version": "1.4.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "74a50760af835563fb2c124e66aa134e" + }, + "stringr": { + "Package": "stringr", + "Version": "1.4.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0759e6b6c0957edb1311028a49a35e76" + }, + "survival": { + "Package": "survival", + "Version": "2.43-3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d6fc8c1de7e40274ff7bc53524cccd4b" + }, + "sys": { + "Package": "sys", + "Version": "3.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "22319ae218b22b9a14d5e7ecbf841703" + }, + "testthat": { + "Package": "testthat", + "Version": "2.3.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "0829b987b8961fb07f3b1b64a2fbc495" + }, + "tibble": { + "Package": "tibble", + "Version": "3.0.1", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "1c61e4cad000e03b1bd687db16a75926" + }, + "tidyr": { + "Package": "tidyr", + "Version": "1.0.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "fb73a010ace00d6c584c2b53a21b969c" + }, + "tidyselect": { + "Package": "tidyselect", + "Version": "1.0.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "7d4b0f1ab542d8cb7a40c593a4de2f36" + }, + "timeDate": { + "Package": "timeDate", + "Version": "3043.102", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "fde4fc571f5f61978652c229d4713845" + }, + "tinytex": { + "Package": "tinytex", + "Version": "0.26", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "db6477efcfbffcd9b3758c3c2882cf58" + }, + "usethis": { + "Package": "usethis", + "Version": "1.6.3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c541a7aed5f7fb3b487406bf92842e34" + }, + "utf8": { + "Package": "utf8", + "Version": "1.1.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "4a5081acfb7b81a572e4384a7aaf2af1" + }, + "vctrs": { + "Package": "vctrs", + "Version": "0.2.4", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6c839a149a30cb4ffc70443efa74c197" + }, + "viridisLite": { + "Package": "viridisLite", + "Version": "0.3.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ce4f6271baa94776db692f1cb2055bee" + }, + "whisker": { + "Package": "whisker", + "Version": "0.3-2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c944abf3f12a97b8369a6f6ba8186d23" + }, + "withr": { + "Package": "withr", + "Version": "2.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "ecd17882a0b4419545691e095b74ee89" + }, + "xfun": { + "Package": "xfun", + "Version": "0.19", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "a42372606cb76f34da9d090326e9f955" + }, + "xml2": { + "Package": "xml2", + "Version": "1.3.2", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "d4d71a75dd3ea9eb5fa28cc21f9585e2" + }, + "xopen": { + "Package": "xopen", + "Version": "1.0.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "6c85f015dee9cc7710ddd20f86881f58" + }, + "xtable": { + "Package": "xtable", + "Version": "1.8-3", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "7f777cd034efddabf07fcaf2f287ec43" + }, + "yaml": { + "Package": "yaml", + "Version": "2.2.0", + "Source": "Repository", + "Repository": "CRAN", + "Hash": "c78bdf1d16bd4ec7ecc86c6986d53309" + } + } +} diff --git a/renv/.gitignore b/renv/.gitignore new file mode 100644 index 0000000..643b5c5 --- /dev/null +++ b/renv/.gitignore @@ -0,0 +1,4 @@ +library/ +python/ +staging/ +settings.dcf diff --git a/renv/activate.R b/renv/activate.R new file mode 100644 index 0000000..ff7e655 --- /dev/null +++ b/renv/activate.R @@ -0,0 +1,349 @@ + +local({ + + # the requested version of renv + version <- "0.12.0" + + # the project directory + project <- getwd() + + # avoid recursion + if (!is.na(Sys.getenv("RENV_R_INITIALIZING", unset = NA))) + return(invisible(TRUE)) + + # signal that we're loading renv during R startup + Sys.setenv("RENV_R_INITIALIZING" = "true") + on.exit(Sys.unsetenv("RENV_R_INITIALIZING"), add = TRUE) + + # signal that we've consented to use renv + options(renv.consent = TRUE) + + # load the 'utils' package eagerly -- this ensures that renv shims, which + # mask 'utils' packages, will come first on the search path + library(utils, lib.loc = .Library) + + # check to see if renv has already been loaded + if ("renv" %in% loadedNamespaces()) { + + # if renv has already been loaded, and it's the requested version of renv, + # nothing to do + spec <- .getNamespaceInfo(.getNamespace("renv"), "spec") + if (identical(spec[["version"]], version)) + return(invisible(TRUE)) + + # otherwise, unload and attempt to load the correct version of renv + unloadNamespace("renv") + + } + + # load bootstrap tools + bootstrap <- function(version, library) { + + # read repos (respecting override if set) + repos <- Sys.getenv("RENV_CONFIG_REPOS_OVERRIDE", unset = NA) + if (is.na(repos)) + repos <- getOption("repos") + + # fix up repos + on.exit(options(repos = repos), add = TRUE) + repos[repos == "@CRAN@"] <- "https://cloud.r-project.org" + options(repos = repos) + + # attempt to download renv + tarball <- tryCatch(renv_bootstrap_download(version), error = identity) + if (inherits(tarball, "error")) + stop("failed to download renv ", version) + + # now attempt to install + status <- tryCatch(renv_bootstrap_install(version, tarball, library), error = identity) + if (inherits(status, "error")) + stop("failed to install renv ", version) + + } + + renv_bootstrap_download_impl <- function(url, destfile) { + + mode <- "wb" + + # https://bugs.r-project.org/bugzilla/show_bug.cgi?id=17715 + fixup <- + Sys.info()[["sysname"]] == "Windows" && + substring(url, 1L, 5L) == "file:" + + if (fixup) + mode <- "w+b" + + download.file( + url = url, + destfile = destfile, + mode = mode, + quiet = TRUE + ) + + } + + renv_bootstrap_download <- function(version) { + + methods <- list( + renv_bootstrap_download_cran_latest, + renv_bootstrap_download_cran_archive, + renv_bootstrap_download_github + ) + + for (method in methods) { + path <- tryCatch(method(version), error = identity) + if (is.character(path) && file.exists(path)) + return(path) + } + + stop("failed to download renv ", version) + + } + + renv_bootstrap_download_cran_latest <- function(version) { + + # check for renv on CRAN matching this version + db <- as.data.frame(available.packages(), stringsAsFactors = FALSE) + + entry <- db[db$Package %in% "renv" & db$Version %in% version, ] + if (nrow(entry) == 0) { + fmt <- "renv %s is not available from your declared package repositories" + stop(sprintf(fmt, version)) + } + + message("* Downloading renv ", version, " from CRAN ... ", appendLF = FALSE) + + info <- tryCatch( + download.packages("renv", destdir = tempdir()), + condition = identity + ) + + if (inherits(info, "condition")) { + message("FAILED") + return(FALSE) + } + + message("OK") + info[1, 2] + + } + + renv_bootstrap_download_cran_archive <- function(version) { + + name <- sprintf("renv_%s.tar.gz", version) + repos <- getOption("repos") + urls <- file.path(repos, "src/contrib/Archive/renv", name) + destfile <- file.path(tempdir(), name) + + message("* Downloading renv ", version, " from CRAN archive ... ", appendLF = FALSE) + + for (url in urls) { + + status <- tryCatch( + renv_bootstrap_download_impl(url, destfile), + condition = identity + ) + + if (identical(status, 0L)) { + message("OK") + return(destfile) + } + + } + + message("FAILED") + return(FALSE) + + } + + renv_bootstrap_download_github <- function(version) { + + enabled <- Sys.getenv("RENV_BOOTSTRAP_FROM_GITHUB", unset = "TRUE") + if (!identical(enabled, "TRUE")) + return(FALSE) + + # prepare download options + pat <- Sys.getenv("GITHUB_PAT") + if (nzchar(Sys.which("curl")) && nzchar(pat)) { + fmt <- "--location --fail --header \"Authorization: token %s\"" + extra <- sprintf(fmt, pat) + saved <- options("download.file.method", "download.file.extra") + options(download.file.method = "curl", download.file.extra = extra) + on.exit(do.call(base::options, saved), add = TRUE) + } else if (nzchar(Sys.which("wget")) && nzchar(pat)) { + fmt <- "--header=\"Authorization: token %s\"" + extra <- sprintf(fmt, pat) + saved <- options("download.file.method", "download.file.extra") + options(download.file.method = "wget", download.file.extra = extra) + on.exit(do.call(base::options, saved), add = TRUE) + } + + message("* Downloading renv ", version, " from GitHub ... ", appendLF = FALSE) + + url <- file.path("https://api.github.com/repos/rstudio/renv/tarball", version) + name <- sprintf("renv_%s.tar.gz", version) + destfile <- file.path(tempdir(), name) + + status <- tryCatch( + renv_bootstrap_download_impl(url, destfile), + condition = identity + ) + + if (!identical(status, 0L)) { + message("FAILED") + return(FALSE) + } + + message("Done!") + return(destfile) + + } + + renv_bootstrap_install <- function(version, tarball, library) { + + # attempt to install it into project library + message("* Installing renv ", version, " ... ", appendLF = FALSE) + dir.create(library, showWarnings = FALSE, recursive = TRUE) + + # invoke using system2 so we can capture and report output + bin <- R.home("bin") + exe <- if (Sys.info()[["sysname"]] == "Windows") "R.exe" else "R" + r <- file.path(bin, exe) + args <- c("--vanilla", "CMD", "INSTALL", "-l", shQuote(library), shQuote(tarball)) + output <- system2(r, args, stdout = TRUE, stderr = TRUE) + message("Done!") + + # check for successful install + status <- attr(output, "status") + if (is.numeric(status) && !identical(status, 0L)) { + header <- "Error installing renv:" + lines <- paste(rep.int("=", nchar(header)), collapse = "") + text <- c(header, lines, output) + writeLines(text, con = stderr()) + } + + status + + } + + renv_bootstrap_prefix <- function() { + + # construct version prefix + version <- paste(R.version$major, R.version$minor, sep = ".") + prefix <- paste("R", numeric_version(version)[1, 1:2], sep = "-") + + # include SVN revision for development versions of R + # (to avoid sharing platform-specific artefacts with released versions of R) + devel <- + identical(R.version[["status"]], "Under development (unstable)") || + identical(R.version[["nickname"]], "Unsuffered Consequences") + + if (devel) + prefix <- paste(prefix, R.version[["svn rev"]], sep = "-r") + + # build list of path components + components <- c(prefix, R.version$platform) + + # include prefix if provided by user + prefix <- Sys.getenv("RENV_PATHS_PREFIX") + if (nzchar(prefix)) + components <- c(prefix, components) + + # build prefix + paste(components, collapse = "/") + + } + + renv_bootstrap_library_root <- function(project) { + + path <- Sys.getenv("RENV_PATHS_LIBRARY", unset = NA) + if (!is.na(path)) + return(path) + + path <- Sys.getenv("RENV_PATHS_LIBRARY_ROOT", unset = NA) + if (!is.na(path)) + return(file.path(path, basename(project))) + + file.path(project, "renv/library") + + } + + renv_bootstrap_validate_version <- function(version) { + + loadedversion <- utils::packageDescription("renv", fields = "Version") + if (version == loadedversion) + return(TRUE) + + # assume four-component versions are from GitHub; three-component + # versions are from CRAN + components <- strsplit(loadedversion, "[.-]")[[1]] + remote <- if (length(components) == 4L) + paste("rstudio/renv", loadedversion, sep = "@") + else + paste("renv", loadedversion, sep = "@") + + fmt <- paste( + "renv %1$s was loaded from project library, but renv %2$s is recorded in lockfile.", + "Use `renv::record(\"%3$s\")` to record this version in the lockfile.", + "Use `renv::restore(packages = \"renv\")` to install renv %2$s into the project library.", + sep = "\n" + ) + + msg <- sprintf(fmt, loadedversion, version, remote) + warning(msg, call. = FALSE) + + FALSE + + } + + renv_bootstrap_load <- function(project, libpath, version) { + + # try to load renv from the project library + if (!requireNamespace("renv", lib.loc = libpath, quietly = TRUE)) + return(FALSE) + + # warn if the version of renv loaded does not match + renv_bootstrap_validate_version(version) + + # load the project + renv::load(project) + + TRUE + + } + + # construct path to library root + root <- renv_bootstrap_library_root(project) + + # construct library prefix for platform + prefix <- renv_bootstrap_prefix() + + # construct full libpath + libpath <- file.path(root, prefix) + + # attempt to load + if (renv_bootstrap_load(project, libpath, version)) + return(TRUE) + + # load failed; attempt to bootstrap + bootstrap(version, libpath) + + # exit early if we're just testing bootstrap + if (!is.na(Sys.getenv("RENV_BOOTSTRAP_INSTALL_ONLY", unset = NA))) + return(TRUE) + + # try again to load + if (requireNamespace("renv", lib.loc = libpath, quietly = TRUE)) { + message("Successfully installed and loaded renv ", version, ".") + return(renv::load()) + } + + # failed to download or load renv; warn the user + msg <- c( + "Failed to find an renv installation: the project will not be loaded.", + "Use `renv::activate()` to re-initialize the project." + ) + + warning(paste(msg, collapse = "\n"), call. = FALSE) + +}) diff --git a/src/CausalGrid.cpp b/src/CausalGrid.cpp new file mode 100644 index 0000000..d4d6331 --- /dev/null +++ b/src/CausalGrid.cpp @@ -0,0 +1,12 @@ +#include +using namespace Rcpp; + +// [[Rcpp::export]] +bool const_vect(NumericVector var){ + for (int i = 0, size = var.size(); i < size; ++i) { + if (var[i] - var[0] > 0 || var[0] - var[i] > 0) + return false; + } + + return true; +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp new file mode 100644 index 0000000..2a6351f --- /dev/null +++ b/src/RcppExports.cpp @@ -0,0 +1,28 @@ +// Generated by using Rcpp::compileAttributes() -> do not edit by hand +// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 + +#include + +using namespace Rcpp; + +// const_vect +bool const_vect(NumericVector var); +RcppExport SEXP _CausalGrid_const_vect(SEXP varSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< NumericVector >::type var(varSEXP); + rcpp_result_gen = Rcpp::wrap(const_vect(var)); + return rcpp_result_gen; +END_RCPP +} + +static const R_CallMethodDef CallEntries[] = { + {"_CausalGrid_const_vect", (DL_FUNC) &_CausalGrid_const_vect, 1}, + {NULL, NULL, 0} +}; + +RcppExport void R_init_CausalGrid(DllInfo *dll) { + R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); + R_useDynamicSymbols(dll, FALSE); +} diff --git a/tests/dgps.R b/tests/dgps.R new file mode 100644 index 0000000..db85774 --- /dev/null +++ b/tests/dgps.R @@ -0,0 +1,96 @@ + +exp_data <- function(n_4=25, dim_D=1, err_sd=0.01){ + #n_4 is n/4. We get this to make sure we have the same in each chunk + n = n_4*4 + stopifnot(dim_D %in% c(0,1,2)) + #dim_D in {0,1,2} + X1 = cbind(runif(n_4, 0, .5), runif(n_4, 0, .5)) + X2 = cbind(runif(n_4, 0, .5), runif(n_4, .5, 1)) + X3 = cbind(runif(n_4, .5, 1), runif(n_4, 0, .5)) + X4 = cbind(runif(n_4, .5, 1), runif(n_4, .5, 1)) + X = rbind(X1, X2, X3, X4) + + alpha = ifelse(X[,1]>.5, ifelse(X[,2]>.5,.5,.8), ifelse(X[,2]>.5, 2, -2)) + #alpha=0 + y = alpha + rnorm(n,0,err_sd) + if(dim_D) { + if(dim_D==1) { + beta = ifelse(X[,1]>.5, ifelse(X[,2]>.5,-1,2), ifelse(X[,2]>.5, 4, 6)) + #beta = ifelse(X[,1]>.5, -1,1) + d = matrix(rnorm(n), n, 1) + y = y + beta*d + colnames(d) <- "d" + } + else { + beta1 = ifelse(X[,1]>.5,-1, 4) + beta2 = ifelse(X[,2]>.5, 2, 6) + d = matrix(rnorm(2*n), n, 2) + y = y + beta1*d[,1] + beta2*d[,2] + colnames(d) = c("d1", "d2") + } + + } + else { + d = NULL + } + y = as.matrix(y, nrow=n, ncol=1) + colnames(y) = "y" + colnames(X) = c("X1", "X2") + return(list(y=y, X=X, d=d)) +} + +mix_data_y <- function(n=200) { + X = data.frame(X1=c(rep(0, n/4), rep(1, n/4), rep(0, n/4), rep(1, n/4)), + X2=factor(c(rep("A", n/2), rep("B", n/2)))) + alpha = c(rep(0, n/4), rep(1, n/4), rep(2, n/4), rep(3, n/4)) + y = alpha + return(list(y=y, X=X)) +} + +mix_data_d <- function(n=200) { + X = data.frame(X1=c(rep(0, n/4), rep(1, n/4), rep(0, n/4), rep(1, n/4)), + X2=factor(c(rep("A", n/2), rep("B", n/4), rep("C", n/4)))) + tau = c(rep(0, n/4), rep(1, n/4), rep(2, n/4), rep(3, n/4)) + d = rep(0:1, n/2) + y = d*tau + return(list(y=y, X=X, d=d)) +} + + +two_groups_data <- function(){ + + X = matrix(factor(c(rep("M", 100), rep("F", 100))),nrow = 200 ,ncol = 1) + y = c(rep(5, 100), rep(50, 100)) + + return(list(y=y, X=X)) +} + +two_groups_data_int <- function(){ + + X = matrix(c(rep(1, 100), rep(2, 100), rep(0, 200)) ,nrow = 200 ,ncol = 2) + y = c(rep(5, 100), rep(50, 100)) + + return(list(y=y, X=X)) +} + +AI_sim <- function(n=500, design=1) { + w = rbinom(n, 1, 0.5) + K = c(2, 10, 20)[design] + X = matrix(rnorm(n*K), nrow=n, ncol=K) + X_I = X>0 + if(design==1) { + eta = X %*% matrix(c(0.5, 1), ncol=1) + kappa = X %*% matrix(c(0.5, 0), ncol=1) + } + if(design==2) { + eta = X %*% matrix(c(rep(0.5, 2), rep(1, 4), rep(0, 4)), ncol=1) + kappa = (X*X_I) %*% matrix(c(rep(1,2), rep(0,8)), ncol=1) + } + if(design==3) { + eta = X %*% matrix(c(rep(0.5, 4), rep(1, 4), rep(0, 12)), ncol=1) + kappa = (X*X_I) %*% matrix(c(rep(1,4), rep(0,16)), ncol=1) + } + epsilon = rnorm(n, 0, 0.01) + Y = eta + 0.5*(2*w-1)*kappa + epsilon + return(list(Y, X, w, kappa)) +} diff --git a/tests/testthat.R b/tests/testthat.R new file mode 100644 index 0000000..bf489c4 --- /dev/null +++ b/tests/testthat.R @@ -0,0 +1,4 @@ +library(testthat) +library(CausalGrid) + +test_check("CausalGrid") diff --git a/tests/testthat/test_multi.R b/tests/testthat/test_multi.R new file mode 100644 index 0000000..69aaca9 --- /dev/null +++ b/tests/testthat/test_multi.R @@ -0,0 +1,82 @@ +library(testthat) + +if(F) { + devtools::load_all(".", export_all=FALSE, helpers=FALSE) +} +set.seed(1337) + +ys = rnorm(100) +ds = rnorm(100) +Xs = matrix(1:200, ncol=2) +ysm = matrix(ys) +dsm = matrix(ds) +tr_splits = seq(1, by=2, length.out=50) +cvsa = seq(1, by=2, length.out=25) +cvsb = seq(2, by=2, length.out=25) +cv_foldss = list(index=list(cvsa, cvsb), + indexOut=list(cvsb, cvsa)) + +yl = rnorm(120) +dl = rnorm(120) +Xl = matrix(1:240, ncol=2) +ylm = matrix(yl) +dlm = matrix(dl) +tr_splitl = seq(1, by=2, length.out=60) +cvla = seq(1, by=2, length.out=30) +cvlb = seq(2, by=2, length.out=30) +cv_foldsl = list(index=list(cvla, cvlb), + index=list(cvlb, cvla)) + +y0 = ys +d0 = ds +X0 = Xs + +y0m = ysm +d0m = dsm +X0m = Xs + +y1 = list(ys, yl, ys) +d1 = list(ds, dl, ds) +X1 = list(Xs, Xl, Xs) + +y1m = list(ysm, ylm, ysm) +d1m = list(dsm, dlm, dsm) +X1m = list(Xs, Xl, Xs) + +y2 = ys +d2 = cbind(d1=ds, d2=1:10, d3=1:5) +X2 = Xs + +y2m = ysm +d2m = cbind(d1=ds, d2=1:10, d3=1:5) +X2m = Xs + +y3 = cbind(y1=ys, y2=ys, y3=ys) +d3 = ds +X3 = Xs + +y3m = cbind(y1=ys, y2=ys, y3=ys) +d3m = dsm +X3m = Xs + +print(0) +#fit_estimate_partition(y0, X0, d0, partition_i=2, tr_split = tr_splits) +#fit_estimate_partition(y0, X0, d0, bump_B=2, cv_folds=cv_foldss) +#fit_estimate_partition(y0m, X0, d0m, bump_B=2) + +print(1) +fit_estimate_partition(y1, X1, d1, partition_i=2, tr_split = list(tr_splits,tr_splitl,tr_splits)) +fit_estimate_partition(y1, X1, d1, bump_B=2, cv_folds=list(cv_foldss,cv_foldsl,cv_foldss)) +fit_estimate_partition(y1m, X1, d1m, bump_B=2) + +print(2) +fit_estimate_partition(y2, X2, d2, partition_i=2, tr_split = tr_splits) +fit_estimate_partition(y2, X2, d2, bump_B=2, cv_folds=cv_foldss) +fit_estimate_partition(y2m, X2, d2m, bump_B=2) + +print(3) +fit_estimate_partition(y3, X3, d3, partition_i=2, tr_split = tr_splits) +fit_estimate_partition(y3, X3, d3, bump_B=2, cv_folds=cv_foldss) +fit_estimate_partition(y3m, X3, d3m, bump_B=2) + +expect_equal(1,1) \ No newline at end of file diff --git a/tests/testthat/testres.R b/tests/testthat/testres.R new file mode 100644 index 0000000..1ede36e --- /dev/null +++ b/tests/testthat/testres.R @@ -0,0 +1,64 @@ +# To run in the command-line with load_all: change do_load_all=T, then run the code in the first if(FALSE), subsequent runs just run that last line of the False block +# Undo for building project + +library(testthat) +library(rprojroot) +testthat_root_dir <- rprojroot::find_testthat_root_file() #R cmd check doesn't copy over git and RStudio proj file + +if(FALSE) { #Run manually to debug + library(rprojroot) + testthat_root_dir <- rprojroot::find_testthat_root_file() + debugSource(paste0(testthat_root_dir,"/testres.R")) +} + +do_load_all=F +if(!do_load_all){ + library(CausalGrid) +} else { + library(devtools) + devtools::load_all(".", export_all=FALSE, helpers=FALSE) +} + +set.seed(1337) + +context("Works OK") + +source(paste0(testthat_root_dir,"/../dgps.R")) + + +# Mean outcome +data <- exp_data(n_4=100, dim_D=0, err_sd = 0.00) +ret1 <- fit_estimate_partition(data$y, data$X, cv_folds=2, verbosity=0) +print(ret1$splits$s_by_dim) +test_that("We get OK results (OOS)", { + expect_lt(ret1$partition$s_by_dim[[1]][1], .6) + expect_gt(ret1$partition$s_by_dim[[1]][1], .4) + expect_lt(ret1$partition$s_by_dim[[2]][1], .6) + expect_gt(ret1$partition$s_by_dim[[2]][1], .4) +}) + +# +# # Treatment effect (1) +# set.seed(1337) +# data <- exp_data(n_4=100, dim_D=1, err_sd = 0.01) +# ret2 <- fit_partition(data$y, data$X, d = data$d, cv_folds=2, verbosity=0) +# #print(ret2$splits$s_by_dim) +# test_that("We get OK results (OOS)", { +# expect_lt(ret2$partition$s_by_dim[[1]][1], .6) +# expect_gt(ret2$partition$s_by_dim[[1]][1], .4) +# expect_lt(ret2$partition$s_by_dim[[2]][1], .6) +# expect_gt(ret2$partition$s_by_dim[[2]][1], .4) +# }) +# +# # Treatment effect (multiple) +# set.seed(1337) +# data <- exp_data(n_4=100, dim_D=2, err_sd = 0.01) +# ret3 <- fit_partition(data$y, data$X, d = data$d, cv_folds=2, verbosity=0) +# #print(ret2$partition$s_by_dim) +# test_that("We get OK results (OOS)", { +# expect_lt(ret3$partition$s_by_dim[[1]][1], .6) +# expect_gt(ret3$partition$s_by_dim[[1]][1], .4) +# expect_lt(ret3$partition$s_by_dim[[2]][1], .6) +# expect_gt(ret3$partition$s_by_dim[[2]][1], .4) +# }) + diff --git a/tests/testthat/testrun.R b/tests/testthat/testrun.R new file mode 100644 index 0000000..acb92a3 --- /dev/null +++ b/tests/testthat/testrun.R @@ -0,0 +1,146 @@ +# To run in the command-line with load_all: change do_load_all=T, then run the code in the first if(FALSE), subsequent runs just run that last line of the False block +# Undo for building project + +library(testthat) +library(rprojroot) +testthat_root_dir <- rprojroot::find_testthat_root_file() #R cmd check doesn't copy over git and RStudio proj file + +if(FALSE) { #Run manually to debug + library(rprojroot) + testthat_root_dir <- rprojroot::find_testthat_root_file() + debugSource(paste0(testthat_root_dir,"/testrun.R")) +} + +do_load_all=F +if(!do_load_all){ + library(CausalGrid) +} else { + library(devtools) + #devtools::load_all(".", export_all=FALSE, helpers=FALSE) +} + +set.seed(1337) + +source(paste0(testthat_root_dir,"/../dgps.R")) + +data <- mix_data_d(n=1000) +pot_break_points = list(c(0.5), c(0)) + +# Just y --------------- + +ret1 <- fit_estimate_partition(data$y, data$X, cv_folds=2, verbosity=0, pot_break_points=pot_break_points) +print(ret1$partition) +test_that("We get OK results (OOS)", { + expect_equal(ret1$partition$nsplits_by_dim, c(1,1)) +}) + + +# Include d --------------- + +ret1d <- fit_estimate_partition(data$y, data$X, data$d, cv_folds=2, verbosity=0, pot_break_points=pot_break_points) +print(ret1d$partition) +test_that("We get OK results (OOS)", { + expect_equal(ret1d$partition$nsplits_by_dim, c(1,1)) +}) +any_sign_effect(ret1d, check_negative=T, method="fdr") # +#any_sign_effect(ret1d, check_negative=T, method="sim_mom_ineq") #the sim produces treatment effect with 0 std err, so causes problems + +ret2d <- fit_estimate_partition(data$y, data$X, data$d, cv_folds=2, verbosity=0, pot_break_points=pot_break_points, ctrl_method="all") +print(ret2d$partition) +#TODO: Should I check this? +#test_that("We get OK results (OOS)", { +# expect_equal(ret2d$partition$nsplits_by_dim, c(1,1)) +#}) + +ret3d <- fit_estimate_partition(data$y, data$X, data$d, cv_folds=3, verbosity=0, pot_break_points=pot_break_points, ctrl_method="LassoCV") +print(ret3d$partition) +#TODO: Should I check this? +#test_that("We get OK results (OOS)", { +# expect_equal(ret3d$partition$nsplits_by_dim, c(1,1)) +#}) + +ret4d <- fit_estimate_partition(data$y, data$X, data$d, cv_folds=2, verbosity=0, pot_break_points=pot_break_points, ctrl_method="rf") +print(ret4d$partition) +#TODO: Should I check this? +#test_that("We get OK results (OOS)", { +# expect_equal(ret4d$partition$nsplits_by_dim, c(1,1)) +#}) + +ret1db <- fit_estimate_partition(data$y, data$X, data$d, cv_folds=2, verbosity=0, pot_break_points=pot_break_points, bump_B=2) + + +ret1dc <- fit_estimate_partition(data$y, data$X, data$d, cv_folds=2, verbosity=0, pot_break_points=pot_break_points, importance_type="single") + +X_3 = data$X +X_3$X3 = data$X$X2 +pot_break_points_3 = pot_break_points +pot_break_points_3[[3]] = pot_break_points[[2]] +print("---------------") +ret1dd <- fit_estimate_partition(data$y, X_3, data$d, cv_folds=2, verbosity=2, pot_break_points=pot_break_points_3, importance_type="interaction", bump_B=3) +print("---------------") +ret1dd <- fit_estimate_partition(data$y, X_3, data$d, cv_folds=2, verbosity=1, pot_break_points=pot_break_points_3, importance_type="interaction", bump_B=3) +print("---------------") +ret1dd <- fit_estimate_partition(data$y, X_3, data$d, cv_folds=2, verbosity=0, pot_break_points=pot_break_points_3, importance_type="interaction", bump_B=3) + +# Old test +if(FALSE) { + dim_D = 1 + n_4 = 100 + data <- exp_data(n_4=n_4, dim_D, err_sd = 1e-7) + K = ncol(data$X) + # Limit break points? + # pot_break_points = list() + # g=4 + # for(k in 1:ncol(data$X)) { + # pot_break_points[[k]] = quantile(data$X[,k], seq(0,1,length.out=g+1))[-c(g+1,1)] + # } + pot_break_points = NULL + tr_index = c(sample(n_4, n_4/2), sample(n_4, n_4/2)+n_4, sample(n_4, n_4/2) + 2*n_4, sample(n_4, n_4/2) + 3*n_4) + X = as.data.frame(data$X) + X[[2]] = factor(c("a", "a", "b", "c")) + ret2 <- fit_estimate_partition(data$y, X, tr_split = tr_index, cv_folds=5, max_splits=Inf, verbosity=1, pot_break_points=pot_break_points, d=data$d, ctrl_method="LassoCV") #bucket_min_d_var, bucket_min_n + print(ret2) + cat(paste("s_by_dim", paste(ret2$partition$s_by_dim, collapse=" "),"\n")) + cat(paste("lambda", paste(ret2$lambda, collapse=" "),"\n")) + cat(paste("param_ests", paste(ret2$cell_stats$param_ests, collapse=" "),"\n")) + cat(paste("var_ests", paste(ret2$cell_stats$var_ests, collapse=" "),"\n")) + cat(paste("cell_sizes", paste(ret2$cell_stats$cell_sizes, collapse=" "),"\n")) + + #View implied model + est_df = data.frame(y=data$y, f = get_factor_from_partition(ret2$partition, data$X)) + if (dim_D) est_df = cbind(est_df, data$d) + if(dim_D==0) { + ols_fit = lm(y~0+f, data=est_df) + } + if (dim_D==1) { + ols_fit = lm(y~0+f+d:f, data=est_df) + } + if (dim_D==2) { + ols_fit = lm(y~0+f+(d1+d2):f, data=est_df) + } + print(summary(ols_fit)) + #0.5003187, 0.5000464 + + + #Compare to manually-specified split + my_partition = add_split.grid_partition(add_split.grid_partition(grid_partition(get_X_range(data$X)),partition_split(1, .5)), partition_split(2, .5)) + y_tr = data$y[ret2$index_tr] + X_tr = data$X[ret2$index_tr, , drop=FALSE] + d_tr = data$d[ret2$index_tr, , drop=FALSE] + X_es = data$X[-ret2$index_tr, , drop=FALSE] + N_est = nrow(X_es) + my_part_mse = mse_hat_obj(y_tr,X_tr,d_tr, N_est=N_est, partition=my_partition) + print(paste("emse:",my_part_mse,". +pen", my_part_mse + ret2$lambda*(num_cells(my_partition)-1))) + est_df[['f2']] = interaction(get_factors_from_partition(my_partition, data$X)) + if(dim_D==0) { + ols_fit = lm(y~0+f2, data=est_df) + } + if(dim_D==1) { + ols_fit = lm(y~0+f2+d:f2, data=est_df) + } + if(dim_D==2) { + ols_fit = lm(y~0+f2+(d1+d2):f2, data=est_df) + } + print(summary(ols_fit)) + +} diff --git a/vignettes/.gitignore b/vignettes/.gitignore new file mode 100644 index 0000000..097b241 --- /dev/null +++ b/vignettes/.gitignore @@ -0,0 +1,2 @@ +*.html +*.R diff --git a/vignettes/vignette.Rmd b/vignettes/vignette.Rmd new file mode 100644 index 0000000..2d20104 --- /dev/null +++ b/vignettes/vignette.Rmd @@ -0,0 +1,33 @@ +--- +title: "vignette" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{vignette} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +```{r setup} +library(CausalGrid) +``` + +Let's get some fake data +```{r} +N = 1000 +X = matrix(rnorm(N), ncol=1) +d = rbinom(N, 1, 0.5) +tau = as.integer(X[,1]>0) +y = d*tau + rnorm(1000, 0, 0.0001) +``` + +Here's how you use the package +```{r} +est_part = fit_estimate_partition(y, X, d) +``` diff --git a/writeups/oneline algo.lyx b/writeups/oneline algo.lyx new file mode 100644 index 0000000..e56036e --- /dev/null +++ b/writeups/oneline algo.lyx @@ -0,0 +1,236 @@ +#LyX 2.3 created this file. For more info see http://www.lyx.org/ +\lyxformat 544 +\begin_document +\begin_header +\save_transient_properties true +\origin unavailable +\textclass article +\use_default_options true +\maintain_unincluded_children false +\language english +\language_package default +\inputencoding auto +\fontencoding global +\font_roman "default" "default" +\font_sans "default" "default" +\font_typewriter "default" "default" +\font_math "auto" "auto" +\font_default_family default +\use_non_tex_fonts false +\font_sc false +\font_osf false +\font_sf_scale 100 100 +\font_tt_scale 100 100 +\use_microtype false +\use_dash_ligatures true +\graphics default +\default_output_format default +\output_sync 0 +\bibtex_command default +\index_command default +\paperfontsize default +\use_hyperref false +\papersize default +\use_geometry false +\use_package amsmath 1 +\use_package amssymb 1 +\use_package cancel 1 +\use_package esint 1 +\use_package mathdots 1 +\use_package mathtools 1 +\use_package mhchem 1 +\use_package stackrel 1 +\use_package stmaryrd 1 +\use_package undertilde 1 +\cite_engine basic +\cite_engine_type default +\use_bibtopic false +\use_indices false +\paperorientation portrait +\suppress_date false +\justification true +\use_refstyle 1 +\use_minted 0 +\index Index +\shortcut idx +\color #008000 +\end_index +\secnumdepth 3 +\tocdepth 3 +\paragraph_separation indent +\paragraph_indentation default +\is_math_indent 0 +\math_numbering_side default +\quotes_style english +\dynamic_quotes 0 +\papercolumns 1 +\papersides 1 +\paperpagestyle default +\tracking_changes false +\output_changes false +\html_math_output 0 +\html_css_as_file 0 +\html_be_strict false +\end_header + +\begin_body + +\begin_layout Title +Online Cross-Fit Error Algo +\end_layout + +\begin_layout Standard + +\emph on +Background +\emph default +: For variance (MSE) there is the Welford algorithm for computing variance + online. + +\begin_inset Formula +\begin{align*} +\bar{x}_{k} & =\bar{x}_{k-1}+\frac{x_{k}-\bar{x}_{k-1}}{k}\\ +\bar{x}_{k} & =\bar{x}_{k-\tau}+\frac{(\sum_{t\in\tau}x_{t})-\tau\bar{x}_{k-\tau}}{k}\\ +\\ +S_{k} & =S_{k-1}+(x_{k}-\bar{x}_{k-1})(x_{k}-\bar{x}_{k})\\ +S_{k} & =S_{k-1}+\frac{(x_{k}-\bar{x}_{k-1})}{k}k(x_{k}-\bar{x}_{k})\\ +S_{k} & =S_{k-1}+(\bar{x}_{k}-\bar{x}_{k-1})k(x_{k}-\bar{x}_{k})\\ +\\ +\\ +S_{k}-S_{k-1}\\ +\sum_{i=1}^{k-1}[(x_{i}-\bar{x}_{k})^{2}-(x_{i}-\bar{x}_{k-1})^{2}]+(x_{k}-\bar{x}_{k})^{2}\\ +\sum_{i=1}^{k-1}[-2x_{i}(\bar{x}_{k}-\bar{x}_{k-1})+(\bar{x}_{k}-\bar{x}_{k-1})(\bar{x}_{k}+\bar{x}_{k-1})]+(x_{k}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k}-\bar{x}_{k-1})\sum_{i=1}^{k-1}[-x_{i}+\bar{x}_{k}-x_{i}+\bar{x}_{k-1}]+(x_{k}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k-1}-\bar{x}_{k})\sum_{i=1}^{k-1}(x_{i}-\bar{x}_{k})+(x_{k}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k-1}-\bar{x}_{k})[\sum_{i=1}^{k}(x_{i}-\bar{x}_{k})-(x_{k}-\bar{x}_{k})]+(x_{k}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k-1}-\bar{x}_{k})(\bar{x}_{k}-x_{k})+(x_{k}-\bar{x}_{k})^{2}\\ +\\ +S_{k}-S_{k-\tau}\\ +\sum_{i=1}^{k-\tau}[(x_{i}-\bar{x}_{k})^{2}-(x_{i}-\bar{x}_{k-\tau})^{2}]+\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})^{2}\\ +\sum_{i=1}^{k-\tau}[-2x_{i}(\bar{x}_{k}-\bar{x}_{k-\tau})+(\bar{x}_{k}-\bar{x}_{k-\tau})(\bar{x}_{k}+\bar{x}_{k-\tau})]+\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k}-\bar{x}_{k-\tau})\sum_{i=1}^{k-\tau}[-x_{i}+\bar{x}_{k}-x_{i}+\bar{x}_{k-\tau}]+\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k}-\bar{x}_{k-\tau})\sum_{i=1}^{k-\tau}[-x_{i}+\bar{x}_{k}]+\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k}-\bar{x}_{k-\tau})[\sum_{i=1}^{k}(\bar{x}_{k}-x_{i})-\sum_{i=k-\tau+1}^{k}(\bar{x}_{k}-x_{i})]+\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})^{2}\\ +(\bar{x}_{k}-\bar{x}_{k-\tau})[\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})]+\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})^{2}\\ +\sum_{i=k-\tau+1}^{k}(x_{i}-\bar{x}_{k})(x_{i}-\bar{x}_{k-\tau})\\ +\\ +MSE\ \sigma_{n}^{2} & =S_{n}/n +\end{align*} + +\end_inset + + +\end_layout + +\begin_layout Standard +For cross-fitting we need to keep track of two running means ( +\begin_inset Formula $\bar{x}^{a}$ +\end_inset + +, +\begin_inset Formula $\bar{x}^{b}$ +\end_inset + +) and now +\begin_inset Formula $S_{k}^{b}=\sum_{i=1}^{k}(x_{i}-\bar{x}^{a})^{2}$ +\end_inset + +. + If we add a new data point to +\begin_inset Formula $a$ +\end_inset + +, then we don't update +\begin_inset Formula $S^{a}$ +\end_inset + + or +\begin_inset Formula $\bar{x}^{b}$ +\end_inset + +, but we do update +\begin_inset Formula $\bar{x}^{a}$ +\end_inset + + as normal andd then this affects +\begin_inset Formula $S^{b}$ +\end_inset + +. + Suppose that +\begin_inset Formula $\Delta=\bar{x}_{k}^{a}-\bar{x}_{k-1}^{a}$ +\end_inset + +. + We update +\begin_inset Formula $S^{b}$ +\end_inset + + +\begin_inset Formula +\begin{align*} +S_{(b=k',a=k)}^{b} & =\sum_{i\in b}(x_{i}-\bar{x}_{k}^{a})^{2}\\ + & =\sum_{i\in b}(x_{i}-(\Delta+\bar{x}_{k-1}^{a}))^{2}\\ + & =\sum_{i\in b}((x_{i}-\bar{x}_{k-1}^{a})-\Delta)^{2}\\ + & =S_{(b=k',a=k-1)}^{b}+\Delta^{2}+2\Delta\bar{x}_{k-1}^{a}-2\Delta\sum_{i\in b}x_{i}\\ + & =S_{(b=k',a=k-1)}^{b}+\Delta^{2}+2\Delta\bar{x}_{k-1}^{a}-2\Delta k'\bar{x}_{k'}^{b} +\end{align*} + +\end_inset + + +\end_layout + +\begin_layout Standard +References: +\end_layout + +\begin_layout Standard +- +\begin_inset Flex URL +status open + +\begin_layout Plain Layout + +https://stats.stackexchange.com/questions/332951/online-algorithm-for-the-mean-squ +are-error +\end_layout + +\end_inset + + +\end_layout + +\begin_layout Standard +- +\begin_inset Flex URL +status open + +\begin_layout Plain Layout + +https://stats.stackexchange.com/questions/235129/online-estimation-of-variance-wit +h-limited-memory/235151#235151 +\end_layout + +\end_inset + + +\end_layout + +\begin_layout Standard +- +\begin_inset Flex URL +status open + +\begin_layout Plain Layout + +https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm +\end_layout + +\end_inset + + +\end_layout + +\end_body +\end_document