remove any references to torch/tabnet as we are not currently running those models. Replace meta packages like tidyverse and tidymodels with their underlying packages. Change how weighted MAPE is calculated to help improve negative forecasts.

This commit is contained in:
Mike Tokic 2021-12-17 11:40:40 -08:00
Родитель fe87c4254f
Коммит 9f26dc1da5
9 изменённых файлов: 12 добавлений и 94 удалений

5
.github/workflows/R-CMD-check.yaml поставляемый
Просмотреть файл

@ -53,9 +53,8 @@ jobs:
run: |
install.packages(c('rcmdcheck', 'knitr', 'rmarkdown', 'lubridate', 'foreach', 'modeltime',
'Cubist', 'doParallel', 'devtools', 'dplyr', 'earth', 'glmnet',
'gtools', 'hts', 'kernlab',
'modeltime.gluonts', 'modeltime.resample', 'plyr', 'purrr', 'rules',
'tabnet', 'tibble', 'tidyr', 'timetk', 'torch'))
'gtools', 'hts', 'kernlab', 'modeltime.gluonts', 'modeltime.resample', 'plyr', 'purrr',
'rules', 'tibble', 'tidyr', 'timetk'))
devtools::install_github(c("Azure/rAzureBatch", "Azure/doAzureParallel"))

Просмотреть файл

@ -51,7 +51,6 @@ Imports:
rsample,
rules,
stringr,
tabnet,
tibble,
tidyr,
tidyselect,
@ -66,6 +65,5 @@ Suggests:
Config/testthat/edition: 3
Depends:
R (>= 3.6.0),
modeltime,
torch
modeltime
VignetteBuilder: knitr

Просмотреть файл

@ -21,12 +21,10 @@ export(stlm_arima)
export(stlm_ets)
export(svm_poly)
export(svm_rbf)
export(tabnet)
export(tbats)
export(theta)
export(xgboost)
import(modeltime)
import(torch)
importFrom(Cubist,cubist)
importFrom(Cubist,cubistControl)
importFrom(earth,earth)

Просмотреть файл

@ -245,8 +245,8 @@ get_back_test_scenario_hist_periods<- function(full_data_tbl,
get_export_packages <- function(){
c('modeltime', 'modeltime.gluonts', 'modeltime.resample',
'timetk', 'rules', 'Cubist', 'glmnet', 'earth', 'kernlab', 'xgboost',
'tidyverse', 'lubridate', 'prophet', 'torch', 'tabnet',
"doParallel", "parallel")
'dplyr', 'tibble', 'tidyr', 'purrr', 'stringr', 'lubridate', 'prophet',
'doParallel', 'parallel')
}
#' Fetches a list of parallel transfer functions
@ -256,7 +256,7 @@ get_export_packages <- function(){
get_transfer_functions <- function(){
c("arima", "arima_boost", "croston", "cubist", "deepar", "ets", "glmnet", "mars",
"meanf", "nbeats", "nnetar", "nnetar_xregs", "prophet", "prophet_boost", "prophet_xregs",
"snaive", "stlm_arima", "stlm_ets", "svm_poly", "svm_rbf", "tbats", "tabnet", "theta", "xgboost",
"snaive", "stlm_arima", "stlm_ets", "svm_poly", "svm_rbf", "tbats", "theta", "xgboost",
"multivariate_prep_recipe_1", "multivariate_prep_recipe_2","combo_specific_filter",
"construct_forecast_models", "get_model_functions", "get_not_all_data_models",
"get_r1_data_models", "get_r2_data_models", "get_deep_learning_models", "get_frequency_adjustment_models",

Просмотреть файл

@ -377,8 +377,9 @@ forecast_time_series <- function(input_data,
}
combinations_tbl <- foreach::foreach(i = model_combinations[[1]], .combine = 'rbind',
.packages = c('tidyverse', 'lubridate',
"doParallel", "parallel", "gtools"),
.packages = c('dplyr', 'tibble', 'tidyr', 'purrr',
'stringr', 'lubridate',
'doParallel', 'parallel', "gtools"),
.export = c("fcst_prep")) %dopar% {
fcst_combination_temp <- fcst_prep %>%
@ -468,8 +469,8 @@ forecast_time_series <- function(input_data,
dplyr::mutate(Target = ifelse(Target == 0, 0.1, Target)) %>%
dplyr::mutate(MAPE = round(abs((FCST - Target) / Target), digits = 4)) %>%
dplyr::group_by(Model, Combo) %>%
dplyr::mutate(Combo_Total = sum(Target, na.rm = TRUE),
weighted_MAPE = (Target/Combo_Total)*MAPE) %>%
dplyr::mutate(Combo_Total = sum(abs(Target), na.rm = TRUE),
weighted_MAPE = (abs(Target)/Combo_Total)*MAPE) %>%
dplyr::summarise(Rolling_MAPE = sum(weighted_MAPE, na.rm=TRUE)) %>%
dplyr::arrange(Rolling_MAPE) %>%
dplyr::ungroup()

Просмотреть файл

@ -1399,58 +1399,6 @@ svm_rbf <- function(train_data,
}
#' TabNet
#'
#' @param train_data Training Data
#' @param parallel Parallel
#' @param date_rm_regex Date RM Regex
#' @param fiscal_year_start Fiscal Year Start
#'
#' @return Get Tab Net
#' @keywords internal
#' @export
tabnet <- function(train_data,
parallel,
fiscal_year_start,
date_rm_regex,
pca) {
date_rm_regex_final <- "(.xts$)|(.iso$)|(hour)|(minute)|(second)|(am.pm)|(day)|(week)"
#create model recipe
recipe_spec_tabnet <- train_data %>%
get_recipie_configurable(fiscal_year_start,
date_rm_regex_final,
mutate_adj_half = FALSE,
step_nzv = "none",
one_hot = TRUE,
pca = pca)
model_spec_tabnet <- tabnet::tabnet(
mode = "regression",
batch_size = tune::tune(),
virtual_batch_size = tune::tune(),
epochs = tune::tune()
) %>%
parsnip::set_engine("torch")
wflw_spec_tune_tabnet <- get_workflow_simple(model_spec_tabnet,
recipe_spec_tabnet)
tune_results_tabnet <- train_data %>%
get_kfold_tune_grid(wkflw = wflw_spec_tune_tabnet,
parallel = parallel)
wflw_fit_tabnet <- train_data %>%
get_fit_wkflw_best(tune_results_tabnet,
wflw_spec_tune_tabnet)
cli::cli_alert_success("tabnet")
return(wflw_fit_tabnet)
}
#' Tbats
#'
#' @param train_data Training Data

Просмотреть файл

@ -27,8 +27,6 @@ utils::globalVariables(c(".id", ".key", ".model_desc", ".pred", ".resample_id",
#' @importFrom glmnet glmnet
#' @import torch
# * cbind.fill custom function ----
#create function to cbind dataframes that contain different amounts of rows
#https://github.com/cvarrichio/rowr/blob/master/R/rowr.R

Просмотреть файл

@ -1,24 +0,0 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/models.R
\name{tabnet}
\alias{tabnet}
\title{TabNet}
\usage{
tabnet(train_data, parallel, fiscal_year_start, date_rm_regex, pca)
}
\arguments{
\item{train_data}{Training Data}
\item{parallel}{Parallel}
\item{fiscal_year_start}{Fiscal Year Start}
\item{date_rm_regex}{Date RM Regex}
}
\value{
Get Tab Net
}
\description{
TabNet
}
\keyword{internal}

Просмотреть файл

@ -80,7 +80,7 @@ azure_batch_cluster_config <- list(
),
"containerImage" = "mftokic/finn-azure-batch-dev", # docker image you can use that automatically downloads software needed for Finn to run in cloud
"rPackages" = list(
"cran" = c('Rcpp', 'modeltime', 'modeltime.resample', 'tidymodels', 'lubridate', 'rules', 'Cubist', 'earth', 'kernlab', 'doParallel', 'tidyverse', 'torch', 'tabnet', 'prophet', 'glmnet', 'gtools'), # finnts package dependencies
"cran" = c('Rcpp', 'modeltime', 'modeltime.resample', 'parsnip', 'tune', 'recipes', 'rsample', 'workflows', 'dials', 'lubridate', 'rules', 'Cubist', 'earth', 'kernlab', 'doParallel', 'dplyr', 'tibble', 'tidyr', 'purrr', 'stringr', 'prophet', 'glmnet', 'gtools'), # finnts package dependencies
"github" = list(),
"bioconductor" = list()
),