Time Series Cross Validation#

Time series cross validation is a statistical technique used to evaluate the performance of a forecasting model on a time series dataset by splitting the data into multiple folds or partitions. Unlike traditional cross validation, where data is randomly partitioned into training and testing sets, time series cross validation ensures that the temporal ordering of the data is maintained.

Imports#

from forecastflowml import FeatureExtractor
from forecastflowml import ForecastFlowML
from forecastflowml.data.loader import load_walmart_m5
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from lightgbm import LGBMRegressor
import pandas as pd
import plotly.express as px
import plotly.io as pio

pd.set_option("display.max_columns", 100)

Initialize Spark#

spark = (
    SparkSession.builder.master("local[4]")
    .config("spark.driver.memory", "8g")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.sql.execution.arrow.enabled", "true")
    .getOrCreate()
)

Sample Dataset#

df = load_walmart_m5(spark).localCheckpoint()
df.show(5)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-29|  2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-30|  5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-31|  3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-01|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-02|  0.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 5 rows

Extract Features#

feature_extractor = FeatureExtractor(
    id_col="id",
    date_col="date",
    target_col="sales",
    lag_window_features={
        "mean": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],
    },
    date_features=[
        "day_of_month",
        "day_of_week",
        "week_of_year",
        "week_of_month",
        "weekend",
        "quarter",
        "month",
        "year",
    ],
)
df_train = feature_extractor.transform(df).localCheckpoint()
df_train.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-01-31|  2.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|          31|          2|           5|            5|      0|      1|    1|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-01|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           1|          3|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-02|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           2|          4|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-03|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           3|          5|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-04|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           4|          6|           5|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-05|  0.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           5|          7|           5|            1|      1|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-06|  1.0|               null|                null|                null|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           6|          1|           5|            1|      1|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-07|  0.0|                2.0|                 2.0|                 2.0|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           7|          2|           6|            1|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-08|  0.0|                1.0|                 1.0|                 1.0|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           8|          3|           6|            2|      0|      1|    2|2011|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-09|  0.0| 0.6666666666666666|  0.6666666666666666|  0.6666666666666666|                null|                 null|                 null|                null|                 null|                 null|                null|                 null|                 null|           9|          4|           6|            2|      0|      1|    2|2011|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
only showing top 10 rows

Initialize ForecastFlowML#

forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=LGBMRegressor(),
)
def plot_cv_forecast(df_train, cv_forecast):
    pio.renderers.default = "notebook"

    cv_state = (
        df_train.select("id", "store_id", "date", "sales")
        .join(
            cv_forecast.select("id", "date", "cv", "prediction"),
            on=["id", "date"],
            how="left",
        )
        .groupBy("id", "store_id", "date", "sales")
        .pivot("cv")
        .sum("prediction")
        .groupBy("store_id", "date")
        .agg(
            F.sum("sales").alias("sales"),
            *[F.sum(f"{i}").alias(f"cv_{i}") for i in range(3)],
        )
        .orderBy("store_id", "date")
    ).toPandas()

    fig = px.line(
        cv_state,
        x="date",
        y=["sales", *[f"cv_{i}" for i in range(3)]],
        facet_row_spacing=0.04,
        facet_col="store_id",
        facet_col_wrap=2,
        height=1000,
        width=720,
    )
    fig = fig.update_layout(
        legend=dict(orientation="h", yanchor="top", y=1.07, xanchor="center", x=0.5),
        margin=dict(l=0, r=10, t=5, b=5),
        legend_title="",
    )
    fig = fig.update_traces(line=dict(width=1.7))
    fig = fig.update_yaxes(matches=None, title="")
    fig = fig.update_xaxes(type="date", range=["2015-11-01", "2016-05-22"])
    return fig

Increasing Training Size#

image info

cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()
cv_forecast.show(10)
+-----+--------------------+----------+---+------+----------+
|group|                  id|      date| cv|target|prediction|
+-----+--------------------+----------+---+------+----------+
| CA_2|FOODS_1_179_CA_2_...|2016-04-25|  0|   1.0|0.40366215|
| CA_2|FOODS_1_179_CA_2_...|2016-04-26|  0|   0.0|0.40702468|
| CA_2|FOODS_1_179_CA_2_...|2016-04-27|  0|   0.0|0.35053134|
| CA_2|FOODS_1_179_CA_2_...|2016-04-28|  0|   0.0|0.35053134|
| CA_2|FOODS_1_179_CA_2_...|2016-04-29|  0|   0.0|0.39713493|
| CA_2|FOODS_1_179_CA_2_...|2016-04-30|  0|   0.0|0.53590035|
| CA_2|FOODS_1_179_CA_2_...|2016-05-01|  0|   0.0|0.43878192|
| CA_2|FOODS_1_192_CA_2_...|2016-04-25|  0|   0.0|0.15198386|
| CA_2|FOODS_1_192_CA_2_...|2016-04-26|  0|   0.0|0.14098388|
| CA_2|FOODS_1_192_CA_2_...|2016-04-27|  0|   0.0|0.09492111|
+-----+--------------------+----------+---+------+----------+
only showing top 10 rows
plot_cv_forecast(df_train, cv_forecast)

No Refit#

image info

cv_forecast = forecast_flow.cross_validate(df_train, refit=False).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)

Step Length Between Folds#

image info

cv_forecast = forecast_flow.cross_validate(
    df_train, cv_step_length=14
).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)

Fixed Training Size#

image info

cv_forecast = forecast_flow.cross_validate(
    df_train, max_train_size=365
).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)