91 строка
3.6 KiB
R
91 строка
3.6 KiB
R
|
|
|
|
#' Create 2D plots of parameter estimates
|
|
#'
|
|
#' Creates a 2D plot of parameter estimates or a series of such slices if partition is across >2 features.
|
|
#'
|
|
#' @param x grid_fit
|
|
#' @param X_names_2D X_names_2D
|
|
#' @param ... Additional arguments. Unused.
|
|
#'
|
|
#' @return ggplot2 object or list of such objects
|
|
#' @export
|
|
plot.estimated_partition <- function(x, X_names_2D=NULL, ...) {
|
|
if (!requireNamespace("ggplot2", quietly = TRUE)) {
|
|
stop("Package \"ggplot2\" needed for this function to work. Please install it.",
|
|
call. = FALSE)
|
|
}
|
|
|
|
split_dims = (x$partition$nsplits_by_dim > 0)
|
|
n_split_dims = sum(split_dims)
|
|
if(n_split_dims==0) {
|
|
print("Nothing to graph as no heterogeneity")
|
|
return(NULL)
|
|
}
|
|
desc_range_df = get_desc_df(x$partition, drop_unsplit=TRUE, cont_bounds_inf=FALSE)
|
|
if(n_split_dims==1) {
|
|
desc_range_df = do.call(cbind, lapply(desc_range_df, function(c) as.data.frame(t(matrix(unlist(c), nrow=2)))))
|
|
desc_range_df['ymin'] = 0
|
|
desc_range_df['ymax'] = 1
|
|
colnames(desc_range_df)<-c("xmin", "xmax", "ymin", "ymax")
|
|
desc_range_df["estimate"] = x$cell_stats$param_ests
|
|
xname = if(!is.null(X_names_2D)) X_names_2D[1] else x$partition$varnames[split_dims]
|
|
plt = ggplot2::ggplot() +
|
|
ggplot2::scale_x_continuous(name=xname) +
|
|
ggplot2::theme(axis.title.y=ggplot2::element_blank(),
|
|
axis.text.y=ggplot2::element_blank(),
|
|
axis.ticks.y=ggplot2::element_blank()) +
|
|
ggplot2::xlab(xname) +
|
|
ggplot2::geom_rect(data=desc_range_df, mapping=ggplot2::aes(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, fill=estimate), color="black")
|
|
return(plt)
|
|
}
|
|
if(n_split_dims==2){
|
|
if(is.null(X_names_2D)) X_names_2D = x$partition$varnames[split_dims]
|
|
return(gen_one_plt(desc_range_df, x$cell_stats$param_ests, X_names_2D))
|
|
}
|
|
|
|
desc_range_df_fact = data.frame(lapply(get_desc_df(x$partition, drop_unsplit=TRUE, do_str=TRUE), unclass))
|
|
if(is.null(X_names_2D)){
|
|
if(is.null(x$importance_weights)) {
|
|
X_names_2D = x$partition$varnames[split_dims][1:2]
|
|
}
|
|
else {
|
|
X_names_2D = x$partition$varnames[order(imp_weights, decreasing=FALSE)]
|
|
}
|
|
}
|
|
other_idx = !(names(desc_range_df) %in% X_names_2D)
|
|
n_segs_other = (x$partition$nsplits_by_dim+1)[other_idx]
|
|
names_other = names(desc_range_df)[other_idx]
|
|
size_other = cumprod(n_segs_other)
|
|
test_row_equals_vec <- function(M, v) {
|
|
rowSums(M == rep(v, each = nrow(M))) == ncol(M)
|
|
}
|
|
plts = list()
|
|
for(slice_i in 1:size_other) {
|
|
segment_indexes = segment_indexes_from_cell_i(slice_i, n_segs_other)
|
|
row_idx = test_row_equals_vec(desc_range_df_fact[,other_idx,drop=FALSE], segment_indexes)
|
|
#levels_desc = segment_indexes
|
|
levels_desc = c()
|
|
for(k in 1:length(segment_indexes)){
|
|
levels_desc[k] = levels(desc_range_df_fact[,which(other_idx)[k]])[segment_indexes[k]]
|
|
}
|
|
plts[[slice_i]] = gen_one_plt(desc_range_df[row_idx,X_names_2D], x$cell_stats$param_ests[row_idx], X_names_2D) +
|
|
ggplot2::ggtitle(paste(paste(names_other, levels_desc), collapse=", "))
|
|
}
|
|
|
|
return(plts)
|
|
}
|
|
|
|
|
|
gen_one_plt <- function(desc_range_df, param_ests, X_names_2D) {
|
|
desc_range_df = do.call(cbind, lapply(desc_range_df, function(c) as.data.frame(t(matrix(unlist(c), nrow=2)))))
|
|
|
|
colnames(desc_range_df)<-c("xmin", "xmax", "ymin", "ymax")
|
|
desc_range_df["estimate"] = param_ests
|
|
|
|
plt = ggplot2::ggplot() +
|
|
ggplot2::scale_x_continuous(name=X_names_2D[1]) +ggplot2::scale_y_continuous(name=X_names_2D[2]) +
|
|
ggplot2::geom_rect(data=desc_range_df, mapping=ggplot2::aes(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, fill=estimate), color="black")
|
|
return(plt)
|
|
}
|