{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Time Series Cross Validation \n", "\n", "Time series cross validation is a statistical technique used to evaluate the performance \n", "of a forecasting model on a time series dataset by splitting the data into multiple \n", "folds or partitions. Unlike traditional cross validation, where data is randomly \n", "partitioned into training and testing sets, time series cross validation ensures that \n", "the temporal ordering of the data is maintained." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import FeatureExtractor\n", "from forecastflowml import ForecastFlowML\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from pyspark.sql import SparkSession\n", "import pyspark.sql.functions as F\n", "from lightgbm import LGBMRegressor\n", "import pandas as pd\n", "import plotly.express as px\n", "import plotly.io as pio\n", "import sys\n", "import os\n", "\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", "pd.set_option(\"display.max_columns\", 100)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"4g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.pyarrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark).localCheckpoint()\n", "df.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Extract Features" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null| null| null| null| null| null| null| null| null| 15| 4| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 16| 5| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| 17| 6| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 18| 7| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 19| 1| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 20| 2| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 21| 3| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 3.0| 3.0| 3.0| null| null| null| null| null| null| null| null| null| 22| 4| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 1.5| 1.5| 1.5| null| null| null| null| null| null| null| null| null| 23| 5| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 1.3333333333333333| 1.3333333333333333| 1.3333333333333333| null| null| null| null| null| null| null| null| null| 24| 6| 4| 4| 1| 1| 1|2015|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n", " },\n", " date_features=[\n", " \"day_of_month\",\n", " \"day_of_week\",\n", " \"week_of_year\",\n", " \"week_of_month\",\n", " \"weekend\",\n", " \"quarter\",\n", " \"month\",\n", " \"year\",\n", " ],\n", ")\n", "df_train = feature_extractor.transform(df).localCheckpoint()\n", "df_train.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize ForecastFlowML" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(),\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def plot_cv_forecast(df_train, cv_forecast):\n", " pio.renderers.default = \"notebook\"\n", "\n", " cv_state = (\n", " df_train.select(\"id\", \"store_id\", \"date\", \"sales\")\n", " .join(\n", " cv_forecast.select(\"id\", \"date\", \"cv\", \"prediction\"),\n", " on=[\"id\", \"date\"],\n", " how=\"left\",\n", " )\n", " .groupBy(\"id\", \"store_id\", \"date\", \"sales\")\n", " .pivot(\"cv\")\n", " .sum(\"prediction\")\n", " .groupBy(\"store_id\", \"date\")\n", " .agg(\n", " F.sum(\"sales\").alias(\"sales\"),\n", " *[F.sum(f\"{i}\").alias(f\"cv_{i}\") for i in range(3)],\n", " )\n", " .orderBy(\"store_id\", \"date\")\n", " ).toPandas()\n", "\n", " fig = px.line(\n", " cv_state,\n", " x=\"date\",\n", " y=[\"sales\", *[f\"cv_{i}\" for i in range(3)]],\n", " facet_row_spacing=0.04,\n", " facet_col=\"store_id\",\n", " facet_col_wrap=2,\n", " height=700,\n", " width=720,\n", " )\n", " fig = fig.update_layout(\n", " legend=dict(orientation=\"h\", yanchor=\"top\", y=1.09, xanchor=\"center\", x=0.5),\n", " margin=dict(l=0, r=10, t=5, b=5),\n", " legend_title=\"\",\n", " )\n", " fig = fig.update_traces(line=dict(width=1.7))\n", " fig = fig.update_yaxes(matches=None, title=\"\")\n", " fig = fig.update_xaxes(type=\"date\", range=[\"2015-11-01\", \"2016-05-22\"])\n", " return fig" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Increasing Training Size\n", "\n", "![image info](../_static/cross_validation-default.svg)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------+--------------------+-------------------+---+-----+----------+\n", "|store_id| id| date| cv|sales|prediction|\n", "+--------+--------------------+-------------------+---+-----+----------+\n", "| CA_1|FOODS_1_064_CA_1_...|2016-04-25 00:00:00| 0| 2.0|0.94086176|\n", "| CA_1|FOODS_1_064_CA_1_...|2016-04-26 00:00:00| 0| 0.0| 0.9023968|\n", "| CA_1|FOODS_1_064_CA_1_...|2016-04-27 00:00:00| 0| 2.0| 1.0340574|\n", "| CA_1|FOODS_1_064_CA_1_...|2016-04-28 00:00:00| 0| 4.0|0.98158336|\n", "| CA_1|FOODS_1_064_CA_1_...|2016-04-29 00:00:00| 0| 0.0| 0.9397872|\n", "| CA_1|FOODS_1_064_CA_1_...|2016-04-30 00:00:00| 0| 0.0| 1.3279248|\n", "| CA_1|FOODS_1_064_CA_1_...|2016-05-01 00:00:00| 0| 0.0| 1.3603985|\n", "| CA_1|FOODS_1_121_CA_1_...|2016-04-25 00:00:00| 0| 0.0|0.59270364|\n", "| CA_1|FOODS_1_121_CA_1_...|2016-04-26 00:00:00| 0| 1.0|0.60231286|\n", "| CA_1|FOODS_1_121_CA_1_...|2016-04-27 00:00:00| 0| 1.0|0.61724263|\n", "+--------+--------------------+-------------------+---+-----+----------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()\n", "cv_forecast.show(10)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## No Refit\n", "\n", "![image info](../_static/cross_validation-no_refit.svg)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "cv_forecast = forecast_flow.cross_validate(df_train, refit=False).localCheckpoint()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Step Length Between Folds\n", "\n", "![image info](../_static/cross_validation-step_length.svg)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "cv_forecast = forecast_flow.cross_validate(\n", " df_train, cv_step_length=14\n", ").localCheckpoint()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Fixed Training Size\n", "\n", "![image info](../_static/cross_validation-max_train_size.svg)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "cv_forecast = forecast_flow.cross_validate(\n", " df_train, max_train_size=365\n", ").localCheckpoint()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cv_forecast(df_train, cv_forecast)" ] } ], "metadata": { "kernelspec": { "display_name": "sspark37", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }