{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Time Series Cross Validation \n", "\n", "Time series cross validation is a statistical technique used to evaluate the performance \n", "of a forecasting model on a time series dataset by splitting the data into multiple \n", "folds or partitions. Unlike traditional cross validation, where data is randomly \n", "partitioned into training and testing sets, time series cross validation ensures that \n", "the temporal ordering of the data is maintained." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import FeatureExtractor\n", "from forecastflowml import ForecastFlowML\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from pyspark.sql import SparkSession\n", "import pyspark.sql.functions as F\n", "from lightgbm import LGBMRegressor\n", "import pandas as pd\n", "import plotly.express as px\n", "import plotly.io as pio\n", "\n", "pd.set_option(\"display.max_columns\", 100)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"8g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark).localCheckpoint()\n", "df.show(5)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Extract Features" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null| null| null| null| null| null| null| null| null| 31| 2| 5| 5| 0| 1| 1|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 1| 3| 5| 1| 0| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 2| 4| 5| 1| 0| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 3| 5| 5| 1| 0| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 4| 6| 5| 1| 0| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 5| 7| 5| 1| 1| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| 6| 1| 5| 1| 1| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| 2.0| 2.0| null| null| null| null| null| null| null| null| null| 7| 2| 6| 1| 0| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 1.0| 1.0| 1.0| null| null| null| null| null| null| null| null| null| 8| 3| 6| 2| 0| 1| 2|2011|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.6666666666666666| 0.6666666666666666| 0.6666666666666666| null| null| null| null| null| null| null| null| null| 9| 4| 6| 2| 0| 1| 2|2011|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n", " },\n", " date_features=[\n", " \"day_of_month\",\n", " \"day_of_week\",\n", " \"week_of_year\",\n", " \"week_of_month\",\n", " \"weekend\",\n", " \"quarter\",\n", " \"month\",\n", " \"year\",\n", " ],\n", ")\n", "df_train = feature_extractor.transform(df).localCheckpoint()\n", "df_train.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize ForecastFlowML" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(),\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def plot_cv_forecast(df_train, cv_forecast):\n", " pio.renderers.default = \"notebook\"\n", "\n", " cv_state = (\n", " df_train.select(\"id\", \"store_id\", \"date\", \"sales\")\n", " .join(\n", " cv_forecast.select(\"id\", \"date\", \"cv\", \"prediction\"),\n", " on=[\"id\", \"date\"],\n", " how=\"left\",\n", " )\n", " .groupBy(\"id\", \"store_id\", \"date\", \"sales\")\n", " .pivot(\"cv\")\n", " .sum(\"prediction\")\n", " .groupBy(\"store_id\", \"date\")\n", " .agg(\n", " F.sum(\"sales\").alias(\"sales\"),\n", " *[F.sum(f\"{i}\").alias(f\"cv_{i}\") for i in range(3)],\n", " )\n", " .orderBy(\"store_id\", \"date\")\n", " ).toPandas()\n", "\n", " fig = px.line(\n", " cv_state,\n", " x=\"date\",\n", " y=[\"sales\", *[f\"cv_{i}\" for i in range(3)]],\n", " facet_row_spacing=0.04,\n", " facet_col=\"store_id\",\n", " facet_col_wrap=2,\n", " height=1000,\n", " width=720,\n", " )\n", " fig = fig.update_layout(\n", " legend=dict(orientation=\"h\", yanchor=\"top\", y=1.07, xanchor=\"center\", x=0.5),\n", " margin=dict(l=0, r=10, t=5, b=5),\n", " legend_title=\"\",\n", " )\n", " fig = fig.update_traces(line=dict(width=1.7))\n", " fig = fig.update_yaxes(matches=None, title=\"\")\n", " fig = fig.update_xaxes(type=\"date\", range=[\"2015-11-01\", \"2016-05-22\"])\n", " return fig" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Increasing Training Size\n", "\n", "![image info](../_static/cross_validation-default.svg)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+----------+---+------+----------+\n", "|group| id| date| cv|target|prediction|\n", "+-----+--------------------+----------+---+------+----------+\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-25| 0| 1.0|0.40366215|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 0| 0.0|0.40702468|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-27| 0| 0.0|0.35053134|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-28| 0| 0.0|0.35053134|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-29| 0| 0.0|0.39713493|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-30| 0| 0.0|0.53590035|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-05-01| 0| 0.0|0.43878192|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-25| 0| 0.0|0.15198386|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-26| 0| 0.0|0.14098388|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-27| 0| 0.0|0.09492111|\n", "+-----+--------------------+----------+---+------+----------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()\n", "cv_forecast.show(10)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## No Refit\n", "\n", "![image info](../_static/cross_validation-no_refit.svg)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "cv_forecast = forecast_flow.cross_validate(df_train, refit=False).localCheckpoint()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Step Length Between Folds\n", "\n", "![image info](../_static/cross_validation-step_length.svg)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "cv_forecast = forecast_flow.cross_validate(\n", " df_train, cv_step_length=14\n", ").localCheckpoint()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Fixed Training Size\n", "\n", "![image info](../_static/cross_validation-max_train_size.svg)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "cv_forecast = forecast_flow.cross_validate(\n", " df_train, max_train_size=365\n", ").localCheckpoint()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] } ], "metadata": { "kernelspec": { "display_name": "sspark37", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.16" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }