Time Series Cross Validation#
Time series cross validation is a statistical technique used to evaluate the performance of a forecasting model on a time series dataset by splitting the data into multiple folds or partitions. Unlike traditional cross validation, where data is randomly partitioned into training and testing sets, time series cross validation ensures that the temporal ordering of the data is maintained.
Imports#
from forecastflowml import FeatureExtractor
from forecastflowml import ForecastFlowML
from forecastflowml.data.loader import load_walmart_m5
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from lightgbm import LGBMRegressor
import pandas as pd
import plotly.express as px
import plotly.io as pio
import sys
import os
os.environ["PYSPARK_PYTHON"] = sys.executable
pd.set_option("display.max_columns", 100)
Initialize Spark#
spark = (
SparkSession.builder.master("local[4]")
.config("spark.driver.memory", "4g")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.sql.execution.pyarrow.enabled", "true")
.getOrCreate()
)
Sample Dataset#
df = load_walmart_m5(spark).localCheckpoint()
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows
Extract Features#
feature_extractor = FeatureExtractor(
id_col="id",
date_col="date",
target_col="sales",
lag_window_features={
"mean": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],
},
date_features=[
"day_of_month",
"day_of_week",
"week_of_year",
"week_of_month",
"weekend",
"quarter",
"month",
"year",
],
)
df_train = feature_extractor.transform(df).localCheckpoint()
df_train.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null| null| null| null| null| null| null| null| null| 15| 4| 3| 3| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 16| 5| 3| 3| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| 17| 6| 3| 3| 1| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 18| 7| 3| 3| 1| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 19| 1| 4| 3| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 20| 2| 4| 3| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| 21| 3| 4| 3| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 3.0| 3.0| 3.0| null| null| null| null| null| null| null| null| null| 22| 4| 4| 4| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 1.5| 1.5| 1.5| null| null| null| null| null| null| null| null| null| 23| 5| 4| 4| 0| 1| 1|2015|
|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 1.3333333333333333| 1.3333333333333333| 1.3333333333333333| null| null| null| null| null| null| null| null| null| 24| 6| 4| 4| 1| 1| 1|2015|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+------------+-----------+------------+-------------+-------+-------+-----+----+
only showing top 10 rows
Initialize ForecastFlowML#
forecast_flow = ForecastFlowML(
group_col="store_id",
id_col="id",
date_col="date",
target_col="sales",
date_frequency="days",
model_horizon=7,
max_forecast_horizon=28,
model=LGBMRegressor(),
)
def plot_cv_forecast(df_train, cv_forecast):
pio.renderers.default = "notebook"
cv_state = (
df_train.select("id", "store_id", "date", "sales")
.join(
cv_forecast.select("id", "date", "cv", "prediction"),
on=["id", "date"],
how="left",
)
.groupBy("id", "store_id", "date", "sales")
.pivot("cv")
.sum("prediction")
.groupBy("store_id", "date")
.agg(
F.sum("sales").alias("sales"),
*[F.sum(f"{i}").alias(f"cv_{i}") for i in range(3)],
)
.orderBy("store_id", "date")
).toPandas()
fig = px.line(
cv_state,
x="date",
y=["sales", *[f"cv_{i}" for i in range(3)]],
facet_row_spacing=0.04,
facet_col="store_id",
facet_col_wrap=2,
height=700,
width=720,
)
fig = fig.update_layout(
legend=dict(orientation="h", yanchor="top", y=1.09, xanchor="center", x=0.5),
margin=dict(l=0, r=10, t=5, b=5),
legend_title="",
)
fig = fig.update_traces(line=dict(width=1.7))
fig = fig.update_yaxes(matches=None, title="")
fig = fig.update_xaxes(type="date", range=["2015-11-01", "2016-05-22"])
return fig
Increasing Training Size#
cv_forecast = forecast_flow.cross_validate(df_train).localCheckpoint()
cv_forecast.show(10)
+--------+--------------------+-------------------+---+-----+----------+
|store_id| id| date| cv|sales|prediction|
+--------+--------------------+-------------------+---+-----+----------+
| CA_1|FOODS_1_064_CA_1_...|2016-04-25 00:00:00| 0| 2.0|0.94086176|
| CA_1|FOODS_1_064_CA_1_...|2016-04-26 00:00:00| 0| 0.0| 0.9023968|
| CA_1|FOODS_1_064_CA_1_...|2016-04-27 00:00:00| 0| 2.0| 1.0340574|
| CA_1|FOODS_1_064_CA_1_...|2016-04-28 00:00:00| 0| 4.0|0.98158336|
| CA_1|FOODS_1_064_CA_1_...|2016-04-29 00:00:00| 0| 0.0| 0.9397872|
| CA_1|FOODS_1_064_CA_1_...|2016-04-30 00:00:00| 0| 0.0| 1.3279248|
| CA_1|FOODS_1_064_CA_1_...|2016-05-01 00:00:00| 0| 0.0| 1.3603985|
| CA_1|FOODS_1_121_CA_1_...|2016-04-25 00:00:00| 0| 0.0|0.59270364|
| CA_1|FOODS_1_121_CA_1_...|2016-04-26 00:00:00| 0| 1.0|0.60231286|
| CA_1|FOODS_1_121_CA_1_...|2016-04-27 00:00:00| 0| 1.0|0.61724263|
+--------+--------------------+-------------------+---+-----+----------+
only showing top 10 rows
plot_cv_forecast(df_train, cv_forecast)
No Refit#
cv_forecast = forecast_flow.cross_validate(df_train, refit=False).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)
Step Length Between Folds#
cv_forecast = forecast_flow.cross_validate(
df_train, cv_step_length=14
).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)
Fixed Training Size#
cv_forecast = forecast_flow.cross_validate(
df_train, max_train_size=365
).localCheckpoint()
plot_cv_forecast(df_train, cv_forecast)