diff --git a/examples/01_prepare_data/data_prep_retail.ipynb b/examples/01_prepare_data/data_prep_retail.ipynb deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/01_prepare_data/data_prepare_retail.ipynb b/examples/01_prepare_data/data_prepare_retail.ipynb new file mode 100644 index 00000000..dbc33a1a --- /dev/null +++ b/examples/01_prepare_data/data_prepare_retail.ipynb @@ -0,0 +1,1344 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation.\n", + "\n", + "Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Preparation for Retail Sales Forecasting\n", + "\n", + "This notebook introduces how to split the Orange Juice dataset into training sets and test sets for training and evaluating different retail sales forecasting methods.\n", + "\n", + "We use backtesting a method that tests a predictive model on historical data to evaluate the forecasting methods. Other than standard [K-fold cross validation](https://en.wikipedia.org/wiki/Cross-validation_%28statistics%29) which randomly splits data into K folds, we split the data so that any of the time stamps in the training set is no later than any of the time stamps in the test set to ensure that no future information is used (expect certain information that we can know beforehand, e.g., price of the product in the next few weeks as we can set the price manually).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Global Settings and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "System version: 3.6.7 | packaged by conda-forge | (default, Nov 6 2019, 16:19:42) \n", + "[GCC 7.3.0]\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "import forecasting_lib.common.forecast_settings as fs\n", + "from forecasting_lib.common.utils import git_repo_path\n", + "from forecasting_lib.dataset.ojdata import download_ojdata, split_train_test\n", + "\n", + "print(\"System version: {}\".format(sys.version))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Use False if you've already downloaded and split the data\n", + "DOWNLOAD_DATA = True\n", + "\n", + "# Data directory\n", + "DATA_DIR = os.path.join(git_repo_path(), \"ojdata\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Data\n", + "\n", + "We need to download the Orange Juice data before splitting it into training and test sets. By default, the following cell will download the data. If you've already done so, you may skip this part by switching `DOWNLOAD_DATA` to `False`.\n", + "\n", + "The dataset is from R package [bayesm](https://cran.r-project.org/web/packages/bayesm/index.html) and is part of the [Dominick's dataset](https://www.chicagobooth.edu/research/kilts/datasets/dominicks). It contains the following two csv files:\n", + "\n", + "1. `yx.csv` includes weekly sales of refrigerated orange juice at 83 stores. This files has 106139 rows and 19 columns. It contains weekly sales and prices of 11 orange juice brands as well as information about profit, deal, and advertisement for each brand. Note that the weekly sales is captured by a column named `logmove` which corresponds to the natural logarithm of the number of units sold. To get the number of units sold, you need to apply an exponential transform to this column.\n", + "\n", + "2. `storedemo.csv` includes demographic information on those stores. This table has 83 rows and 13 columns. For every store, the table describes demographic information of its consumers, distance to the nearest warehouse store, average distance to the nearest 5 supermarkets, ratio of its sales to the nearest warehouse store, and ratio of its sales\n", + "to the average of the nearest 5 stores.\n", + "\n", + "Note that the week number starts from 40 in this dataset, while the full Dominick's dataset has week number from 1 to 400. According to [Dominick's Data Manual](https://www.chicagobooth.edu/-/media/enterprise/centers/kilts/datasets/dominicks-dataset/dominicks-manual-and-codebook_kiltscenter.aspx), week 1 starts on 09/14/1989.\n", + "Please see pages 40 and 41 of the [bayesm reference manual](https://cran.r-project.org/web/packages/bayesm/bayesm.pdf) and the [Dominick's Data Manual](https://www.chicagobooth.edu/-/media/enterprise/centers/kilts/datasets/dominicks-dataset/dominicks-manual-and-codebook_kiltscenter.aspx) for more details about the data." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting data download ...\n", + "Data download completed. Data saved to /data/home/chenhui/work/forecasting/ojdata\n" + ] + } + ], + "source": [ + "if DOWNLOAD_DATA:\n", + " download_ojdata(DATA_DIR)\n", + " print(\"Data download completed. Data saved to \" + DATA_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split Data for Single-Round Forecasting\n", + "\n", + "Next, we can use `split_train_test()` utility function to split the data in `yx.csv` into training and test sets. If we want to do a one-time model training and evaluation, we can split the data using the default settings provided in `forecasting_lib.common.forecast_settings`.\n", + "\n", + "The data split function will return training data and test data as dataframes. The training data includes `train_df` and `aux_df` with `train_df` containing the historical sales up to week 135 (the time we make forecasts) and `aux_df` containing price/promotion information up until week 138. Here we assume that future price and promotion information up to a certain number of weeks ahead is predetermined and known. The test data is stored in `test_df` which contains the sales of each product in week 137 and 138. Assuming the current week is week 135, our goal is to forecast the sales in week 137 and 138 using the training data. There is a one-week gap between the current week and the first target week of forecasting as we want to leave time for planning inventory in practice." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "data_generator = split_train_test(DATA_DIR, fs)\n", + "[train_df, test_df, aux_df] = next(data_generator)\n", + "train_df.reset_index(inplace=True)\n", + "test_df.reset_index(inplace=True)\n", + "aux_df.reset_index(inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | store | \n", + "brand | \n", + "week | \n", + "logmove | \n", + "constant | \n", + "price1 | \n", + "price2 | \n", + "price3 | \n", + "price4 | \n", + "price5 | \n", + "price6 | \n", + "price7 | \n", + "price8 | \n", + "price9 | \n", + "price10 | \n", + "price11 | \n", + "deal | \n", + "feat | \n", + "profit | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "2 | \n", + "1 | \n", + "40 | \n", + "9.018695 | \n", + "1 | \n", + "0.060469 | \n", + "0.060497 | \n", + "0.042031 | \n", + "0.029531 | \n", + "0.049531 | \n", + "0.053021 | \n", + "0.038906 | \n", + "0.041406 | \n", + "0.028906 | \n", + "0.024844 | \n", + "0.038984 | \n", + "1 | \n", + "0.0 | \n", + "37.992326 | \n", + "
1 | \n", + "2 | \n", + "1 | \n", + "46 | \n", + "8.723231 | \n", + "1 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.045156 | \n", + "0.046719 | \n", + "0.049531 | \n", + "0.047813 | \n", + "0.045781 | \n", + "0.027969 | \n", + "0.042969 | \n", + "0.042031 | \n", + "0.038984 | \n", + "0 | \n", + "0.0 | \n", + "30.126667 | \n", + "
2 | \n", + "2 | \n", + "1 | \n", + "47 | \n", + "8.253228 | \n", + "1 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.045156 | \n", + "0.046719 | \n", + "0.037344 | \n", + "0.053021 | \n", + "0.045781 | \n", + "0.041406 | \n", + "0.048125 | \n", + "0.032656 | \n", + "0.038984 | \n", + "0 | \n", + "0.0 | \n", + "30.000000 | \n", + "
3 | \n", + "2 | \n", + "1 | \n", + "48 | \n", + "8.987197 | \n", + "1 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.049844 | \n", + "0.037344 | \n", + "0.049531 | \n", + "0.053021 | \n", + "0.045781 | \n", + "0.041406 | \n", + "0.042344 | \n", + "0.032656 | \n", + "0.038984 | \n", + "0 | \n", + "0.0 | \n", + "29.950000 | \n", + "
4 | \n", + "2 | \n", + "1 | \n", + "50 | \n", + "9.093357 | \n", + "1 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.043594 | \n", + "0.031094 | \n", + "0.049531 | \n", + "0.053021 | \n", + "0.046648 | \n", + "0.041406 | \n", + "0.042344 | \n", + "0.032656 | \n", + "0.038203 | \n", + "0 | \n", + "0.0 | \n", + "29.920000 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
84178 | \n", + "137 | \n", + "11 | \n", + "131 | \n", + "9.631154 | \n", + "1 | \n", + "0.027969 | \n", + "0.051979 | \n", + "0.049080 | \n", + "0.039820 | \n", + "0.031094 | \n", + "0.048395 | \n", + "0.037500 | \n", + "0.038906 | \n", + "0.023281 | \n", + "0.022187 | \n", + "0.025703 | \n", + "1 | \n", + "0.0 | \n", + "17.170000 | \n", + "
84179 | \n", + "137 | \n", + "11 | \n", + "132 | \n", + "9.704061 | \n", + "1 | \n", + "0.030504 | \n", + "0.051979 | \n", + "0.043594 | \n", + "0.033927 | \n", + "0.033167 | \n", + "0.045729 | \n", + "0.031094 | \n", + "0.038906 | \n", + "0.025313 | \n", + "0.024844 | \n", + "0.026328 | \n", + "1 | \n", + "1.0 | \n", + "18.630000 | \n", + "
84180 | \n", + "137 | \n", + "11 | \n", + "133 | \n", + "8.995165 | \n", + "1 | \n", + "0.043056 | \n", + "0.051979 | \n", + "0.045542 | \n", + "0.031094 | \n", + "0.037205 | \n", + "0.046579 | \n", + "0.033470 | \n", + "0.037969 | \n", + "0.020156 | \n", + "0.025625 | \n", + "0.029609 | \n", + "1 | \n", + "0.0 | \n", + "25.350000 | \n", + "
84181 | \n", + "137 | \n", + "11 | \n", + "134 | \n", + "8.912473 | \n", + "1 | \n", + "0.039062 | \n", + "0.049301 | \n", + "0.049588 | \n", + "0.032300 | \n", + "0.031094 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.035781 | \n", + "0.022031 | \n", + "0.031094 | \n", + "0.029609 | \n", + "1 | \n", + "0.0 | \n", + "25.320000 | \n", + "
84182 | \n", + "137 | \n", + "11 | \n", + "135 | \n", + "9.901886 | \n", + "1 | \n", + "0.040473 | \n", + "0.045729 | \n", + "0.046957 | \n", + "0.045223 | \n", + "0.033493 | \n", + "0.050937 | \n", + "0.033941 | \n", + "0.035781 | \n", + "0.026406 | \n", + "0.022969 | \n", + "0.023359 | \n", + "1 | \n", + "1.0 | \n", + "5.350000 | \n", + "
84183 rows × 19 columns
\n", + "\n", + " | store | \n", + "brand | \n", + "week | \n", + "logmove | \n", + "constant | \n", + "price1 | \n", + "price2 | \n", + "price3 | \n", + "price4 | \n", + "price5 | \n", + "price6 | \n", + "price7 | \n", + "price8 | \n", + "price9 | \n", + "price10 | \n", + "price11 | \n", + "deal | \n", + "feat | \n", + "profit | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "2 | \n", + "1 | \n", + "137 | \n", + "9.189321 | \n", + "1 | \n", + "0.041645 | \n", + "0.051979 | \n", + "0.047656 | \n", + "0.038801 | \n", + "0.032656 | \n", + "0.038125 | \n", + "0.032861 | \n", + "0.036094 | \n", + "0.037344 | \n", + "0.022187 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "20.425098 | \n", + "
1 | \n", + "2 | \n", + "1 | \n", + "138 | \n", + "9.738613 | \n", + "1 | \n", + "0.037344 | \n", + "0.038958 | \n", + "0.047656 | \n", + "0.035781 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.038906 | \n", + "0.037344 | \n", + "0.031094 | \n", + "0.032422 | \n", + "1 | \n", + "1.0 | \n", + "11.290000 | \n", + "
2 | \n", + "2 | \n", + "2 | \n", + "137 | \n", + "8.738735 | \n", + "1 | \n", + "0.041645 | \n", + "0.051979 | \n", + "0.047656 | \n", + "0.038801 | \n", + "0.032656 | \n", + "0.038125 | \n", + "0.032861 | \n", + "0.036094 | \n", + "0.037344 | \n", + "0.022187 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "33.300308 | \n", + "
3 | \n", + "2 | \n", + "2 | \n", + "138 | \n", + "9.601301 | \n", + "1 | \n", + "0.037344 | \n", + "0.038958 | \n", + "0.047656 | \n", + "0.035781 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.038906 | \n", + "0.037344 | \n", + "0.031094 | \n", + "0.032422 | \n", + "1 | \n", + "1.0 | \n", + "9.430000 | \n", + "
4 | \n", + "2 | \n", + "3 | \n", + "137 | \n", + "7.560080 | \n", + "1 | \n", + "0.041645 | \n", + "0.051979 | \n", + "0.047656 | \n", + "0.038801 | \n", + "0.032656 | \n", + "0.038125 | \n", + "0.032861 | \n", + "0.036094 | \n", + "0.037344 | \n", + "0.022187 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "30.506667 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
1821 | \n", + "137 | \n", + "9 | \n", + "138 | \n", + "5.950643 | \n", + "1 | \n", + "0.037344 | \n", + "0.038958 | \n", + "0.047656 | \n", + "0.035781 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.038906 | \n", + "0.037344 | \n", + "0.031094 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "29.490000 | \n", + "
1822 | \n", + "137 | \n", + "10 | \n", + "137 | \n", + "10.606189 | \n", + "1 | \n", + "0.042785 | \n", + "0.051979 | \n", + "0.047656 | \n", + "0.040621 | \n", + "0.032656 | \n", + "0.038125 | \n", + "0.033353 | \n", + "0.036875 | \n", + "0.037344 | \n", + "0.021094 | \n", + "0.032109 | \n", + "1 | \n", + "0.0 | \n", + "5.110000 | \n", + "
1823 | \n", + "137 | \n", + "10 | \n", + "138 | \n", + "8.886271 | \n", + "1 | \n", + "0.037344 | \n", + "0.038958 | \n", + "0.047656 | \n", + "0.035781 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.038906 | \n", + "0.037344 | \n", + "0.031094 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "34.120000 | \n", + "
1824 | \n", + "137 | \n", + "11 | \n", + "137 | \n", + "8.912473 | \n", + "1 | \n", + "0.042785 | \n", + "0.051979 | \n", + "0.047656 | \n", + "0.040621 | \n", + "0.032656 | \n", + "0.038125 | \n", + "0.033353 | \n", + "0.036875 | \n", + "0.037344 | \n", + "0.021094 | \n", + "0.032109 | \n", + "0 | \n", + "0.0 | \n", + "31.720000 | \n", + "
1825 | \n", + "137 | \n", + "11 | \n", + "138 | \n", + "8.723231 | \n", + "1 | \n", + "0.037344 | \n", + "0.038958 | \n", + "0.047656 | \n", + "0.035781 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.038906 | \n", + "0.037344 | \n", + "0.031094 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "33.590000 | \n", + "
1826 rows × 19 columns
\n", + "\n", + " | store | \n", + "brand | \n", + "week | \n", + "price1 | \n", + "price2 | \n", + "price3 | \n", + "price4 | \n", + "price5 | \n", + "price6 | \n", + "price7 | \n", + "price8 | \n", + "price9 | \n", + "price10 | \n", + "price11 | \n", + "deal | \n", + "feat | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "2 | \n", + "1 | \n", + "40 | \n", + "0.060469 | \n", + "0.060497 | \n", + "0.042031 | \n", + "0.029531 | \n", + "0.049531 | \n", + "0.053021 | \n", + "0.038906 | \n", + "0.041406 | \n", + "0.028906 | \n", + "0.024844 | \n", + "0.038984 | \n", + "1 | \n", + "0.0 | \n", + "
1 | \n", + "2 | \n", + "1 | \n", + "46 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.045156 | \n", + "0.046719 | \n", + "0.049531 | \n", + "0.047813 | \n", + "0.045781 | \n", + "0.027969 | \n", + "0.042969 | \n", + "0.042031 | \n", + "0.038984 | \n", + "0 | \n", + "0.0 | \n", + "
2 | \n", + "2 | \n", + "1 | \n", + "47 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.045156 | \n", + "0.046719 | \n", + "0.037344 | \n", + "0.053021 | \n", + "0.045781 | \n", + "0.041406 | \n", + "0.048125 | \n", + "0.032656 | \n", + "0.038984 | \n", + "0 | \n", + "0.0 | \n", + "
3 | \n", + "2 | \n", + "1 | \n", + "48 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.049844 | \n", + "0.037344 | \n", + "0.049531 | \n", + "0.053021 | \n", + "0.045781 | \n", + "0.041406 | \n", + "0.042344 | \n", + "0.032656 | \n", + "0.038984 | \n", + "0 | \n", + "0.0 | \n", + "
4 | \n", + "2 | \n", + "1 | \n", + "50 | \n", + "0.060469 | \n", + "0.060312 | \n", + "0.043594 | \n", + "0.031094 | \n", + "0.049531 | \n", + "0.053021 | \n", + "0.046648 | \n", + "0.041406 | \n", + "0.042344 | \n", + "0.032656 | \n", + "0.038203 | \n", + "0 | \n", + "0.0 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
86906 | \n", + "137 | \n", + "11 | \n", + "134 | \n", + "0.039062 | \n", + "0.049301 | \n", + "0.049588 | \n", + "0.032300 | \n", + "0.031094 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.035781 | \n", + "0.022031 | \n", + "0.031094 | \n", + "0.029609 | \n", + "1 | \n", + "0.0 | \n", + "
86907 | \n", + "137 | \n", + "11 | \n", + "135 | \n", + "0.040473 | \n", + "0.045729 | \n", + "0.046957 | \n", + "0.045223 | \n", + "0.033493 | \n", + "0.050937 | \n", + "0.033941 | \n", + "0.035781 | \n", + "0.026406 | \n", + "0.022969 | \n", + "0.023359 | \n", + "1 | \n", + "1.0 | \n", + "
86908 | \n", + "137 | \n", + "11 | \n", + "136 | \n", + "0.049844 | \n", + "0.047412 | \n", + "0.047656 | \n", + "0.046554 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.031094 | \n", + "0.035781 | \n", + "0.026875 | \n", + "0.020156 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "
86909 | \n", + "137 | \n", + "11 | \n", + "137 | \n", + "0.042785 | \n", + "0.051979 | \n", + "0.047656 | \n", + "0.040621 | \n", + "0.032656 | \n", + "0.038125 | \n", + "0.033353 | \n", + "0.036875 | \n", + "0.037344 | \n", + "0.021094 | \n", + "0.032109 | \n", + "0 | \n", + "0.0 | \n", + "
86910 | \n", + "137 | \n", + "11 | \n", + "138 | \n", + "0.037344 | \n", + "0.038958 | \n", + "0.047656 | \n", + "0.035781 | \n", + "0.043594 | \n", + "0.050937 | \n", + "0.042031 | \n", + "0.038906 | \n", + "0.037344 | \n", + "0.031094 | \n", + "0.032422 | \n", + "0 | \n", + "0.0 | \n", + "
86911 rows × 16 columns
\n", + "