{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": {}, "inputWidgets": {}, "nuid": "a02bb3a2-e5d7-4d14-b966-827457675b75", "showTitle": false, "title": "" } }, "source": [ "# Save/Load ForecastFlowML\n", "\n", "This guide shows how the ``ForecastFlowML`` can be saved and loaded to be used afterwards." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import ForecastFlowML\n", "from forecastflowml import FeatureExtractor\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from lightgbm import LGBMRegressor\n", "from pyspark.sql import SparkSession\n", "import pyspark.sql.functions as F\n", "import pickle" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"8g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.arrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-29| 2.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-30| 5.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-01-31| 3.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-01| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-02| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-03| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-04| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-05| 1.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-06| 0.0|\n", "|FOODS_1_013_TX_2_...|FOODS_1_013|FOODS_1| FOODS| TX_2| TX|2011-02-07| 3.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark)\n", "df.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Engineering" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-01-31| 2.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-01| 0.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-02| 0.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-03| 0.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-04| 0.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-05| 0.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-06| 1.0| null| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-07| 0.0| 2.0| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-08| 0.0| 0.0| null| null| null|\n", "|FOODS_1_011_WI_2_...|FOODS_1_011|FOODS_1| FOODS| WI_2| WI|2011-02-09| 0.0| 0.0| null| null| null|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"lag\": [7 * (i + 1) for i in range(4)],\n", " },\n", ")\n", "df_features = feature_extractor.transform(df).localCheckpoint()\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Train/Test Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df_train = df_features.filter(F.col(\"date\") < \"2016-04-25\")\n", "df_test = df_features.filter(F.col(\"date\") >= \"2016-04-25\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark DataFrame with Distributed Results" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Save" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "forecast_flow.train(df_train).write.parquet(\"trained_models.parquet\")\n", "with open(\"forecast_flow.pickle\", \"wb\") as f:\n", " pickle.dump(forecast_flow, f)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Load" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+----------+----------+\n", "|group| id| date|prediction|\n", "+-----+--------------------+----------+----------+\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-25|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 1.0937389|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-30|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-05-01|0.57157516|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-25|0.57157516|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-26|0.57157516|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-27|0.57157516|\n", "+-----+--------------------+----------+----------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "trained_models = spark.read.parquet(\"trained_models.parquet\")\n", "with open(\"forecast_flow.pickle\", \"rb\") as f:\n", " forecast_flow = pickle.load(f)\n", "forecast_flow.predict(df_test, trained_models).show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark DataFrame with Local Results" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Save" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "forecast_flow.train(df_train, local_result=True)\n", "with open(\"forecast_flow.pickle\", \"wb\") as f:\n", " pickle.dump(forecast_flow, f)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Load" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+----------+----------+\n", "|group| id| date|prediction|\n", "+-----+--------------------+----------+----------+\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-25|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-26| 1.0937389|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-27|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-28|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-29|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-04-30|0.57157516|\n", "| CA_2|FOODS_1_179_CA_2_...|2016-05-01|0.57157516|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-25|0.57157516|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-26|0.57157516|\n", "| CA_2|FOODS_1_192_CA_2_...|2016-04-27|0.57157516|\n", "+-----+--------------------+----------+----------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "with open(\"forecast_flow.pickle\", \"rb\") as f:\n", " forecast_flow = pickle.load(f)\n", "forecast_flow.predict(df_test, spark=spark).show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Pandas DataFrame" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Save" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "forecast_flow.train(df_train.toPandas(), spark=spark)\n", "with open(\"forecast_flow.pickle\", \"wb\") as f:\n", " pickle.dump(forecast_flow, f)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Load" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | group | \n", "id | \n", "date | \n", "prediction | \n", "
|---|---|---|---|---|
| 0 | \n", "CA_2 | \n", "FOODS_1_179_CA_2_evaluation | \n", "2016-04-25 | \n", "0.571575 | \n", "
| 1 | \n", "CA_2 | \n", "FOODS_1_179_CA_2_evaluation | \n", "2016-04-26 | \n", "1.093739 | \n", "
| 2 | \n", "CA_2 | \n", "FOODS_1_179_CA_2_evaluation | \n", "2016-04-27 | \n", "0.571575 | \n", "
| 3 | \n", "CA_2 | \n", "FOODS_1_179_CA_2_evaluation | \n", "2016-04-28 | \n", "0.571575 | \n", "
| 4 | \n", "CA_2 | \n", "FOODS_1_179_CA_2_evaluation | \n", "2016-04-29 | \n", "0.571575 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 26427 | \n", "TX_2 | \n", "HOUSEHOLD_2_481_TX_2_evaluation | \n", "2016-05-18 | \n", "0.665920 | \n", "
| 26428 | \n", "TX_2 | \n", "HOUSEHOLD_2_481_TX_2_evaluation | \n", "2016-05-19 | \n", "0.665920 | \n", "
| 26429 | \n", "TX_2 | \n", "HOUSEHOLD_2_481_TX_2_evaluation | \n", "2016-05-20 | \n", "0.665920 | \n", "
| 26430 | \n", "TX_2 | \n", "HOUSEHOLD_2_481_TX_2_evaluation | \n", "2016-05-21 | \n", "1.017469 | \n", "
| 26431 | \n", "TX_2 | \n", "HOUSEHOLD_2_481_TX_2_evaluation | \n", "2016-05-22 | \n", "0.665920 | \n", "
26432 rows × 4 columns
\n", "