This commit is contained in:
soukayna 2024-01-30 14:04:11 -05:00
Родитель 65a5354c44
Коммит 3fd9ad8f73
1 изменённых файлов: 221 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,221 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Prophet\n",
"\n",
"### Uses prophet model as prediction of future."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"import os, sys\n",
"from tqdm import tqdm\n",
"from subseasonal_toolkit.utils.notebook_util import isnotebook\n",
"if isnotebook():\n",
" # Autoreload packages that are modified\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
"else:\n",
" from argparse import ArgumentParser\n",
"import pandas as pd\n",
"import numpy as np\n",
"from scipy.spatial.distance import cdist, euclidean\n",
"from datetime import datetime, timedelta\n",
"from ttictoc import tic, toc\n",
"from subseasonal_data.utils import get_measurement_variable\n",
"from subseasonal_toolkit.utils.general_util import printf\n",
"from subseasonal_toolkit.utils.experiments_util import get_id_name, get_th_name, get_first_year, get_start_delta\n",
"from subseasonal_toolkit.utils.models_util import (get_submodel_name, start_logger, log_params, get_forecast_filename,\n",
" save_forecasts)\n",
"from subseasonal_toolkit.utils.eval_util import get_target_dates, mean_rmse_to_score, save_metric\n",
"from sklearn.linear_model import *\n",
"\n",
"from subseasonal_data import data_loaders"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"#\n",
"# Specify model parameters\n",
"#\n",
"if not isnotebook():\n",
" # If notebook run as a script, parse command-line arguments\n",
" parser = ArgumentParser()\n",
" parser.add_argument(\"pos_vars\",nargs=\"*\") # gt_id and horizon \n",
" parser.add_argument('--target_dates', '-t', default=\"std_test\")\n",
" args, opt = parser.parse_known_args()\n",
" \n",
" # Assign variables \n",
" gt_id = get_id_name(args.pos_vars[0]) # \"contest_precip\" or \"contest_tmp2m\" \n",
" horizon = get_th_name(args.pos_vars[1]) # \"12w\", \"34w\", or \"56w\" \n",
" target_dates = args.target_dates\n",
"else:\n",
" # Otherwise, specify arguments interactively \n",
" gt_id = \"contest_tmp2m\"\n",
" horizon = \"34w\"\n",
" target_dates = \"std_contest\"\n",
"\n",
"#\n",
"# Process model parameters\n",
"#\n",
"# One can subtract this number from a target date to find the last viable training date.\n",
"start_delta = timedelta(days=get_start_delta(horizon, gt_id))\n",
"\n",
"# Record model and submodel name\n",
"model_name = \"prophet\"\n",
"submodel_name = get_submodel_name(model_name)\n",
"\n",
"FIRST_SAVE_YEAR = 2007 # Don't save forecasts from years prior to FIRST_SAVE_YEAR\n",
"\n",
"if not isnotebook():\n",
" # Save output to log file\n",
" logger = start_logger(model=model_name,submodel=submodel_name,gt_id=gt_id,\n",
" horizon=horizon,target_dates=target_dates)\n",
" # Store parameter values in log \n",
" params_names = ['gt_id', 'horizon', 'target_dates']\n",
" params_values = [eval(param) for param in params_names]\n",
" log_params(params_names, params_values)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"printf('Loading target variable and dropping extraneous columns')\n",
"tic()\n",
"var = get_measurement_variable(gt_id)\n",
"gt = data_loaders.get_ground_truth(gt_id).loc[:,[\"start_date\",\"lat\",\"lon\",var]]\n",
"toc()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"printf('Pivoting dataframe to have one column per lat-lon pair and one row per start_date')\n",
"tic()\n",
"gt = gt.set_index(['lat','lon','start_date']).squeeze().unstack(['lat','lon'])\n",
"toc()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"#\n",
"# Make predictions for each target date\n",
"#\n",
"from fbprophet import Prophet\n",
"from pandas.tseries.offsets import DateOffset\n",
"def forecast(df, pred_date, num_periods, start_delta, m):\n",
" df = df[df.ds < pred_date - start_delta]\n",
" m.fit(df)\n",
" future = m.make_future_dataframe(periods=num_periods + start_delta.days)\n",
" forecast = m.predict(future)\n",
" return forecast.tail(num_periods)[[\"ds\", \"yhat\"]], m\n",
"\n",
"def get_first_fourth_month(date):\n",
" targets = {(4, 30), (8, 31), (12, 31)}\n",
" while (date.month, date.day) not in targets:\n",
" date = date - DateOffset(days=1)\n",
" return date\n",
"\n",
"def get_predictions(df: pd.DataFrame, date, start_delta):\n",
" # take the first (12/31, 8/31, 4/30) right before the date. \n",
" true_date = get_first_fourth_month(date)\n",
" # build that dataframe (also, maybe cache it locally?)\n",
" df = df.loc[:true_date]\n",
" # figure out the predictions on date\n",
" all_coords_arr = []\n",
" all_preds = []\n",
" for column in df:\n",
" new_df = df[column]\n",
" print(new_df.values)\n",
" cur_date = true_date\n",
" df_single = pd.DataFrame(\n",
" {\"ds\": new_df.reset_index()[\"start_date\"], \"y\": new_df.values})\n",
" print(df_single)\n",
" grid_df = pd.DataFrame()\n",
" all_preds = []\n",
" m = Prophet(yearly_seasonality=True, weekly_seasonality=False)\n",
" num_periods = ((cur_date + DateOffset(months=4)) - cur_date).days\n",
" preds, _ = forecast(df_single, date, num_periods, start_delta, m)\n",
" all_coords_arr.append(column)\n",
" all_preds.append(preds)\n",
" res = [z[z.ds == date][\"yhat\"].item() for z in all_preds]\n",
" return res\n",
"\n",
"\n",
"tic()\n",
"target_date_objs = pd.Series(get_target_dates(date_str=target_dates,horizon=horizon))\n",
"rmses = pd.Series(index=target_date_objs, dtype=np.float64)\n",
"preds = pd.DataFrame(index = target_date_objs, columns = gt.columns, \n",
" dtype=np.float64)\n",
"preds.index.name = \"start_date\"\n",
"# Sort target_date_objs by day of week\n",
"target_date_objs = target_date_objs[target_date_objs.dt.weekday.argsort(kind='stable')]\n",
"toc()\n",
"for target_date_obj in target_date_objs:\n",
" tic()\n",
" target_date_str = datetime.strftime(target_date_obj, '%Y%m%d')\n",
" # Find the last observable training date for this target\n",
" last_train_date = target_date_obj - start_delta\n",
" if not last_train_date in gt.index:\n",
" printf(f'-Warning: no persistence prediction for {target_date_str}; skipping')\n",
" continue\n",
" printf(f'Forming persistence prediction for {target_date_obj}')\n",
"\n",
" # key logic here:\n",
" \n",
" preds.loc[target_date_obj,:] = get_predictions(gt.loc[:last_train_date,:], target_date_obj, start_delta)\n",
" \n",
" # Save prediction to file in standard format\n",
" if target_date_obj.year >= FIRST_SAVE_YEAR:\n",
" save_forecasts(\n",
" preds.loc[[target_date_obj],:].unstack().rename(\"pred\").reset_index(),\n",
" model=model_name, submodel=submodel_name, \n",
" gt_id=gt_id, horizon=horizon, \n",
" target_date_str=target_date_str)\n",
" # Evaluate and store error if we have ground truth data\n",
" if target_date_obj in gt.index:\n",
" rmse = np.sqrt(np.square(preds.loc[target_date_obj,:] - gt.loc[target_date_obj,:]).mean())\n",
" rmses.loc[target_date_obj] = rmse\n",
" print(\"-rmse: {}, score: {}\".format(rmse, mean_rmse_to_score(rmse)))\n",
" mean_rmse = rmses.mean()\n",
" print(\"-mean rmse: {}, running score: {}\".format(mean_rmse, mean_rmse_to_score(mean_rmse)))\n",
" toc()\n",
"\n",
"printf(\"Save rmses in standard format\")\n",
"rmses = rmses.sort_index().reset_index()\n",
"rmses.columns = ['start_date','rmse']\n",
"save_metric(rmses, model=model_name, submodel=submodel_name, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric=\"rmse\")"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"language_info": {
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}