Time Series Cross Validation#

Time series cross validation is a statistical technique used to evaluate the performance of a forecasting model on a time series dataset by splitting the data into multiple folds or partitions. Unlike traditional cross validation, where data is randomly partitioned into training and testing sets, time series cross validation ensures that the temporal ordering of the data is maintained.

Imports#

from forecastflowml import FeatureExtractor
from forecastflowml import ForecastFlowML
from forecastflowml.data.loader import load_walmart_m5
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from lightgbm import LGBMRegressor
import pandas as pd
import plotly.express as px
import plotly.io as pio
import sys
import os

os.environ["PYSPARK_PYTHON"] = sys.executable
pd.set_option("display.max_columns", 100)

Initialize Spark#

spark = (
    SparkSession.builder.master("local[4]")
    .config("spark.driver.memory", "4g")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.sql.execution.pyarrow.enabled", "true")
    .getOrCreate()
)

Sample Dataset#

df = load_walmart_m5(spark).localCheckpoint()
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-15|  3.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-16|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-17|  1.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-18|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-19|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-20|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-21|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-22|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-23|  0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-24|  0.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows

Extract Features#

feature_extractor = FeatureExtractor(
    id_col="id",
    date_col="date",
    target_col="sales",
    lag_window_features={
        "mean": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],
    },
    date_features=[
        "day_of_month",
        "day_of_week",
        "week_of_year",
        "week_of_month",
        "weekend",
        "quarter",
        "month",
        "year",
    ],
)
df_train = feature_extractor.transform(df).localCheckpoint()
df_train.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-15|  3.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          15|          4|           3|            3|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-16|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          16|          5|           3|            3|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-17|  1.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          17|          6|           3|            3|      1|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-18|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          18|          7|           3|            3|      1|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-19|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          19|          1|           4|            3|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-20|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          20|          2|           4|            3|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-21|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          21|          3|           4|            3|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-22|  0.0|                3.0|                 3.0|                 3.0|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          22|          4|           4|            4|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-23|  0.0|                1.5|                 1.5|                 1.5|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          23|          5|           4|            4|      0|      1|    1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS|    TX_1|      TX|2015-01-24|  0.0| 1.3333333333333333|  1.3333333333333333|  1.3333333333333333|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          24|          6|           4|            4|      1|      1|    1|2015|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
only showing top 10 rows

Initialize ForecastFlowML#

forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=LGBMRegressor(),
)
def plot_cv_forecast(df_train, cv_forecast):
    pio.renderers.default = "notebook"

    cv_state = (
        df_train.select("id", "store_id", "date", "sales")
        .join(
            cv_forecast.select("id", "date", "cv", "prediction"),
            on=["id", "date"],
            how="left",
        )
        .groupBy("id", "store_id", "date", "sales")
        .pivot("cv")
        .sum("prediction")
        .groupBy("store_id", "date")
        .agg(
            F.sum("sales").alias("sales"),
            *[F.sum(f"{i}").alias(f"cv_{i}") for i in range(3)],
        )
        .orderBy("store_id", "date")
    ).toPandas()

    fig = px.line(
        cv_state,
        x="date",
        y=["sales", *[f"cv_{i}" for i in range(3)]],
        facet_row_spacing=0.04,
        facet_col="store_id",
        facet_col_wrap=2,
        height=700,
        width=720,
    )
    fig = fig.update_layout(
        legend=dict(orientation="h", yanchor="top", y=1.09, xanchor="center", x=0.5),
        margin=dict(l=0, r=10, t=5, b=5),
        legend_title="",
    )
    fig = fig.update_traces(line=dict(width=1.7))
    fig = fig.update_yaxes(matches=None, title="")
    fig = fig.update_xaxes(type="date", range=["2015-11-01", "2016-05-22"])
    return fig

Increasing Training Size#

image info

cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()
cv_forecast.show(10)
+--------+--------------------+-------------------+---+-----+----------+
|store_id|                  id|               date| cv|sales|prediction|
+--------+--------------------+-------------------+---+-----+----------+
|    CA_1|FOODS_1_064_CA_1_...|2016-04-25 00:00:00|  0|  2.0|0.94086176|
|    CA_1|FOODS_1_064_CA_1_...|2016-04-26 00:00:00|  0|  0.0| 0.9023968|
|    CA_1|FOODS_1_064_CA_1_...|2016-04-27 00:00:00|  0|  2.0| 1.0340574|
|    CA_1|FOODS_1_064_CA_1_...|2016-04-28 00:00:00|  0|  4.0|0.98158336|
|    CA_1|FOODS_1_064_CA_1_...|2016-04-29 00:00:00|  0|  0.0| 0.9397872|
|    CA_1|FOODS_1_064_CA_1_...|2016-04-30 00:00:00|  0|  0.0| 1.3279248|
|    CA_1|FOODS_1_064_CA_1_...|2016-05-01 00:00:00|  0|  0.0| 1.3603985|
|    CA_1|FOODS_1_121_CA_1_...|2016-04-25 00:00:00|  0|  0.0|0.59270364|
|    CA_1|FOODS_1_121_CA_1_...|2016-04-26 00:00:00|  0|  1.0|0.60231286|
|    CA_1|FOODS_1_121_CA_1_...|2016-04-27 00:00:00|  0|  1.0|0.61724263|
+--------+--------------------+-------------------+---+-----+----------+
only showing top 10 rows
plot_cv_forecast(df_train, cv_forecast)

No Refit#

image info

cv_forecast = forecast_flow.cross_validate(df_train, refit=False).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)

Step Length Between Folds#

image info

cv_forecast = forecast_flow.cross_validate(
    df_train, cv_step_length=14
).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)

Fixed Training Size#

image info

cv_forecast = forecast_flow.cross_validate(
    df_train, max_train_size=365
).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)