{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": {}, "inputWidgets": {}, "nuid": "a02bb3a2-e5d7-4d14-b966-827457675b75", "showTitle": false, "title": "" } }, "source": [ "# Feature Importance\n", "\n", "This quick guide shows how the feature importances can be reached from ``ForecastFlowML``." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import ForecastFlowML\n", "from forecastflowml import FeatureExtractor\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from lightgbm import LGBMRegressor\n", "from pyspark.sql import SparkSession\n", "import sys\n", "import os\n", "\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"4g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.pyarrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark)\n", "df.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Engineering" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null| 15| 4| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null| 16| 5| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null| 17| 6| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null| 18| 7| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null| 19| 1| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null| 20| 2| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null| 21| 3| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 3.0| null| null| null| 22| 4| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 0.0| null| null| null| 23| 5| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 1.0| null| null| null| 24| 6| 4| 4| 1| 1| 1|2015|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"lag\": [7 * (i + 1) for i in range(4)],\n", " },\n", " date_features=[\n", " \"day_of_month\",\n", " \"day_of_week\",\n", " \"week_of_year\",\n", " \"week_of_month\",\n", " \"weekend\",\n", " \"quarter\",\n", " \"month\",\n", " \"year\",\n", " ],\n", ")\n", "df_train = feature_extractor.transform(df).localCheckpoint()\n", "df_train.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Importance" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Distributed Results" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | store_id | \n", "forecast_horizon | \n", "feature | \n", "importance | \n", "
|---|---|---|---|---|
| 0 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "day_of_week | \n", "354.0 | \n", "
| 1 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "day_of_month | \n", "687.0 | \n", "
| 2 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "week_of_year | \n", "704.0 | \n", "
| 3 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "month | \n", "50.0 | \n", "
| 4 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "week_of_month | \n", "0.0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 139 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "week_of_month | \n", "0.0 | \n", "
| 140 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "weekend | \n", "0.0 | \n", "
| 141 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "year | \n", "85.0 | \n", "
| 142 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "quarter | \n", "0.0 | \n", "
| 143 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "lag_28 | \n", "1060.0 | \n", "
144 rows × 4 columns
\n", "| \n", " | store_id | \n", "forecast_horizon | \n", "feature | \n", "importance | \n", "
|---|---|---|---|---|
| 0 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "day_of_week | \n", "354 | \n", "
| 1 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "day_of_month | \n", "687 | \n", "
| 2 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "week_of_year | \n", "704 | \n", "
| 3 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "month | \n", "50 | \n", "
| 4 | \n", "CA_1 | \n", "[1, 2, 3, 4, 5, 6, 7] | \n", "week_of_month | \n", "0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 139 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "week_of_month | \n", "0 | \n", "
| 140 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "weekend | \n", "0 | \n", "
| 141 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "year | \n", "85 | \n", "
| 142 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "quarter | \n", "0 | \n", "
| 143 | \n", "WI_1 | \n", "[22, 23, 24, 25, 26, 27, 28] | \n", "lag_28 | \n", "1060 | \n", "
144 rows × 4 columns
\n", "