Save/Load ForecastFlowML#
This guide shows how the ForecastFlowML can be saved and loaded to be used afterwards.
Import packages#
from forecastflowml import ForecastFlowML
from forecastflowml import FeatureExtractor
from forecastflowml.data.loader import load_walmart_m5
from lightgbm import LGBMRegressor
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pickle
Initialize Spark#
spark = (
SparkSession.builder.master("local[4]")
.config("spark.driver.memory", "8g")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.sql.execution.arrow.enabled", "true")
.getOrCreate()
)
Sample Dataset#
df = load_walmart_m5(spark)
df.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|
|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|
+--------------------+-----------+-------+------+--------+--------+----------+-----+
only showing top 10 rows
Feature Engineering#
feature_extractor = FeatureExtractor(
id_col="id",
date_col="date",
target_col="sales",
lag_window_features={
"lag": [7 * (i + 1) for i in range(4)],
},
)
df_features = feature_extractor.transform(df).localCheckpoint()
df_features.show(10)
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null|
|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null|
+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+
only showing top 10 rows
Train/Test Dataset#
df_train = df_features.filter(F.col("date") < "2016-04-25")
df_test = df_features.filter(F.col("date") >= "2016-04-25")
Initialize Model#
forecast_flow = ForecastFlowML(
group_col="store_id",
id_col="id",
date_col="date",
target_col="sales",
date_frequency="days",
model_horizon=7,
max_forecast_horizon=28,
model=LGBMRegressor(),
)
PySpark DataFrame with Distributed Results#
Save#
forecast_flow.train(df_train).write.parquet("trained_models.parquet")
with open("forecast_flow.pickle", "wb") as f:
pickle.dump(forecast_flow, f)
Load#
trained_models = spark.read.parquet("trained_models.parquet")
with open("forecast_flow.pickle", "rb") as f:
forecast_flow = pickle.load(f)
forecast_flow.predict(df_test, trained_models).show(10)
+-----+--------------------+----------+----------+
|group| id| date|prediction|
+-----+--------------------+----------+----------+
| CA_2|FOODS_1_179_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 1.0937389|
| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-30|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-05-01|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-26|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-27|0.57157516|
+-----+--------------------+----------+----------+
only showing top 10 rows
PySpark DataFrame with Local Results#
Save#
forecast_flow.train(df_train, local_result=True)
with open("forecast_flow.pickle", "wb") as f:
pickle.dump(forecast_flow, f)
Load#
with open("forecast_flow.pickle", "rb") as f:
forecast_flow = pickle.load(f)
forecast_flow.predict(df_test, spark=spark).show(10)
+-----+--------------------+----------+----------+
|group| id| date|prediction|
+-----+--------------------+----------+----------+
| CA_2|FOODS_1_179_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 1.0937389|
| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-04-30|0.57157516|
| CA_2|FOODS_1_179_CA_2_...|2016-05-01|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-25|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-26|0.57157516|
| CA_2|FOODS_1_192_CA_2_...|2016-04-27|0.57157516|
+-----+--------------------+----------+----------+
only showing top 10 rows
Pandas DataFrame#
Save#
forecast_flow.train(df_train.toPandas(), spark=spark)
with open("forecast_flow.pickle", "wb") as f:
pickle.dump(forecast_flow, f)
Load#
with open("forecast_flow.pickle", "rb") as f:
forecast_flow = pickle.load(f)
forecast_flow.predict(df_test.toPandas(), spark=spark)
| group | id | date | prediction | |
|---|---|---|---|---|
| 0 | CA_2 | FOODS_1_179_CA_2_evaluation | 2016-04-25 | 0.571575 |
| 1 | CA_2 | FOODS_1_179_CA_2_evaluation | 2016-04-26 | 1.093739 |
| 2 | CA_2 | FOODS_1_179_CA_2_evaluation | 2016-04-27 | 0.571575 |
| 3 | CA_2 | FOODS_1_179_CA_2_evaluation | 2016-04-28 | 0.571575 |
| 4 | CA_2 | FOODS_1_179_CA_2_evaluation | 2016-04-29 | 0.571575 |
| ... | ... | ... | ... | ... |
| 26427 | TX_2 | HOUSEHOLD_2_481_TX_2_evaluation | 2016-05-18 | 0.665920 |
| 26428 | TX_2 | HOUSEHOLD_2_481_TX_2_evaluation | 2016-05-19 | 0.665920 |
| 26429 | TX_2 | HOUSEHOLD_2_481_TX_2_evaluation | 2016-05-20 | 0.665920 |
| 26430 | TX_2 | HOUSEHOLD_2_481_TX_2_evaluation | 2016-05-21 | 1.017469 |
| 26431 | TX_2 | HOUSEHOLD_2_481_TX_2_evaluation | 2016-05-22 | 0.665920 |
26432 rows × 4 columns