Save/Load ForecastFlowML#

This guide shows how the ForecastFlowML can be saved and loaded to be used afterwards.

Import packages#

from forecastflowml import ForecastFlowML
from forecastflowml import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pickle

Initialize Spark#

spark = (
    SparkSession.builder.master("local[4]")
    .config("spark.driver.memory", "8g")
    .config("spark.sql.shuffle.partitions", "4")
    .config("spark.sql.execution.arrow.enabled", "true")
    .getOrCreate()
)

Sample Dataset#

df = load_walmart_m5(spark)
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-29|  2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-30|  5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-01-31|  3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-01|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-02|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-03|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-04|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-05|  1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-06|  0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS|    TX_2|      TX|2011-02-07|  3.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows

Feature Engineering#

feature_extractor = FeatureExtractor(
    id_col="id",
    date_col="date",
    target_col="sales",
    lag_window_features={
        "lag": [7 * (i + 1) for i in range(4)],
    },
)
df_features = feature_extractor.transform(df).localCheckpoint()
df_features.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
|                  id|    item_id|dept_id|cat_id|store_id|state_id|      date|sales|lag_7|lag_14|lag_21|lag_28|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-01-31|  2.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-01|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-02|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-03|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-04|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-05|  0.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-06|  1.0| null|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-07|  0.0|  2.0|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-08|  0.0|  0.0|  null|  null|  null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS|    WI_2|      WI|2011-02-09|  0.0|  0.0|  null|  null|  null|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
only showing top 10 rows

Train/Test Dataset#

df_train = df_features.filter(F.col("date") < "2016-04-25")
df_test = df_features.filter(F.col("date") >= "2016-04-25")

Initialize Model#

forecast_flow = ForecastFlowML(
    group_col="store_id",
    id_col="id",
    date_col="date",
    target_col="sales",
    date_frequency="days",
    model_horizon=7,
    max_forecast_horizon=28,
    model=LGBMRegressor(),
)

PySpark DataFrame with Distributed Results#

Save#

forecast_flow.train(df_train).write.parquet("trained_models.parquet")
with open("forecast_flow.pickle", "wb") as f:
    pickle.dump(forecast_flow, f)

Load#

trained_models = spark.read.parquet("trained_models.parquet")
with open("forecast_flow.pickle", "rb") as f:
    forecast_flow = pickle.load(f)
forecast_flow.predict(df_test, trained_models).show(10)
+-----+--------------------+----------+----------+
|group|                  id|      date|prediction|
+-----+--------------------+----------+----------+
| CA_2|FOODS_1_179_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 1.0937389|
| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-30|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-05-01|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-26|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-27|0.57157516|
+-----+--------------------+----------+----------+
only showing top 10 rows

PySpark DataFrame with Local Results#

Save#

forecast_flow.train(df_train, local_result=True)
with open("forecast_flow.pickle", "wb") as f:
    pickle.dump(forecast_flow, f)

Load#

with open("forecast_flow.pickle", "rb") as f:
    forecast_flow = pickle.load(f)
forecast_flow.predict(df_test, spark=spark).show(10)
+-----+--------------------+----------+----------+
|group|                  id|      date|prediction|
+-----+--------------------+----------+----------+
| CA_2|FOODS_1_179_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 1.0937389|
| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-30|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-05-01|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-26|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-27|0.57157516|
+-----+--------------------+----------+----------+
only showing top 10 rows

Pandas DataFrame#

Save#

forecast_flow.train(df_train.toPandas(), spark=spark)
with open("forecast_flow.pickle", "wb") as f:
    pickle.dump(forecast_flow, f)

Load#

with open("forecast_flow.pickle", "rb") as f:
    forecast_flow = pickle.load(f)
forecast_flow.predict(df_test.toPandas(), spark=spark)
group id date prediction
0 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-25 0.571575
1 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-26 1.093739
2 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-27 0.571575
3 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-28 0.571575
4 CA_2 FOODS_1_179_CA_2_evaluation 2016-04-29 0.571575
... ... ... ... ...
26427 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-18 0.665920
26428 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-19 0.665920
26429 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-20 0.665920
26430 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-21 1.017469
26431 TX_2 HOUSEHOLD_2_481_TX_2_evaluation 2016-05-22 0.665920

26432 rows × 4 columns