{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Feature Engineering\n", "\n", "ForecastFlowML includes a preprocessing module to create features based on the time \n", "series dataset. This user guide shows how the features can be created in a scaleable way\n", "before the modelling phase." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from forecastflowml import FeatureExtractor\n", "from forecastflowml import ForecastFlowML\n", "from forecastflowml.data.loader import load_walmart_m5\n", "from pyspark.sql import SparkSession\n", "from lightgbm import LGBMRegressor\n", "import pandas as pd\n", "import sys\n", "import os\n", "\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", "pd.set_option(\"display.max_columns\", 100)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Spark" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "spark = (\n", " SparkSession.builder.master(\"local[4]\")\n", " .config(\"spark.driver.memory\", \"4g\")\n", " .config(\"spark.sql.shuffle.partitions\", \"4\")\n", " .config(\"spark.sql.execution.pyarrow.enabled\", \"true\")\n", " .getOrCreate()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Sample Dataset" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df = load_walmart_m5(spark).localCheckpoint()\n", "df.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Overview\n", "\n", "\n", "With ``FeatureExtractor``, we can extract:\n", "- Lag features\n", "- Rolling statistics (mean, standard deviation etc.) with spesified lags\n", "- Count of consecutive spesific values that may be used to count number of out-of-stock periods\n", "- History length that refers to the number of periods from the beginning of the time series\n", "- Date features\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Lags\n", "\n", "When extracting the features, we should be careful about the lags we are creating.\n", "In this example, we are going to prepare features for 4 weekly models.\n", "\n", "- Model 1 will predict days 1–7, not using the the 6 most recent lag features.\n", "- Model 2 will predict days 8–14, not using the the 13 most recent lag features.\n", "- Model 3 will predict dayts 15–21, not using the the 20 most recent lag features.\n", "- Model 4 will predict days 22–28, not using the the 27 most recent lag features.\n", "\n", "For lag features, we are going to extract the sales on the same week day over the past 4 weeks. \n", "\n", "![image info](../_static/lag.svg)\n", "\n", "Since each model has different horizon, they will be allowed to use different lags in the modelling phase. In summary, we need to extract ``lag_7``, ``lag_14``, ``lag_21``, ``lag_28``, ``lag_35``, ``lag_42`` and ``lag_49`` as features." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 3.0| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 0.0| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 1.0| null| null| null| null| null| null| null|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"lag\": [7 * (i + 1) for i in range(8)],\n", " },\n", ")\n", "df_features = feature_extractor.transform(df)\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Rolling Statistics\n", "\n", "For rolling statistics, we are going to calculate the mean over the **window** of 7, 14 and 30 days, with the **most recent lags** that models can use which are 7 days for model 1, 14 days for model 2, 21 days for model 3 and 28 days for model 4.\n", "\n", "![image info](../_static/lag_window.svg)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 3.0| 3.0| 3.0| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 1.5| 1.5| 1.5| null| null| null| null| null| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 1.3333333333333333| 1.3333333333333333| 1.3333333333333333| null| null| null| null| null| null| null| null| null|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n", " },\n", ")\n", "df_features = feature_extractor.transform(df)\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Out-of-stock Periods\n", "\n", "Sometimes a product might be out-of-stock for a certain period. We are now going to\n", "count the consecutive periods where sales did not occur with the **most recent lags** \n", "that models can use." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----------------------------+------------------------------+------------------------------+------------------------------+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----------------------------+------------------------------+------------------------------+------------------------------+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 0| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 1| null| null| null|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 0| null| null| null|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----------------------------+------------------------------+------------------------------+------------------------------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " count_consecutive_values={\n", " \"value\": 0,\n", " \"lags\": [7, 14, 21, 28],\n", " },\n", ")\n", "df_features = feature_extractor.transform(df)\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## History Length\n", "\n", "We can also count the total number periods past after the introduction of the time series." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+--------------+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|history_length|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+--------------+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| 1|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| 2|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| 3|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| 4|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| 5|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| 6|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| 7|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 8|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 9|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 10|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+--------------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " history_length=True,\n", ")\n", "df_features = feature_extractor.transform(df)\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Date Features\n", "\n", "Finally, we can also include the date derived features." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| 15| 5| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| 16| 6| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| 17| 7| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| 18| 1| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| 19| 2| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| 20| 3| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| 21| 4| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 22| 5| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 23| 6| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 24| 7| 4| 4| 1| 1| 1|2015|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_features=[\n", " \"day_of_month\",\n", " \"day_of_week\",\n", " \"week_of_year\",\n", " \"week_of_month\",\n", " \"weekend\",\n", " \"quarter\",\n", " \"month\",\n", " \"year\",\n", " ],\n", ")\n", "df_features = feature_extractor.transform(df)\n", "df_features.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Combine Features\n", "\n", "Let's combine all of the features extraction steps together." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "feature_extractor = FeatureExtractor(\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " lag_window_features={\n", " \"lag\": [7 * (i + 1) for i in range(8)],\n", " \"mean\": [[window, lag] for lag in [7, 14, 21, 28] for window in [7, 14, 30]],\n", " },\n", " date_features=[\n", " \"day_of_month\",\n", " \"day_of_week\",\n", " \"week_of_year\",\n", " \"week_of_month\",\n", " \"weekend\",\n", " \"quarter\",\n", " \"month\",\n", " \"year\",\n", " ],\n", " count_consecutive_values={\n", " \"value\": 0,\n", " \"lags\": [7, 14, 21, 28],\n", " },\n", " history_length=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "| id| item_id|dept_id|cat_id|store_id|state_id| date|sales|lag_7|lag_14|lag_21|lag_28|lag_35|lag_42|lag_49|lag_56|window_7_lag_7_mean|window_14_lag_7_mean|window_30_lag_7_mean|window_7_lag_14_mean|window_14_lag_14_mean|window_30_lag_14_mean|window_7_lag_21_mean|window_14_lag_21_mean|window_30_lag_21_mean|window_7_lag_28_mean|window_14_lag_28_mean|window_30_lag_28_mean|count_consecutive_value_lag_7|count_consecutive_value_lag_14|count_consecutive_value_lag_21|count_consecutive_value_lag_28|history_length|day_of_month|day_of_week|week_of_year|week_of_month|weekend|quarter|month|year|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-15| 3.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 1| 15| 5| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-16| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 2| 16| 6| 3| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-17| 1.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 3| 17| 7| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-18| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 4| 18| 1| 3| 3| 1| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-19| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 5| 19| 2| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-20| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 6| 20| 3| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-21| 0.0| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| null| 7| 21| 4| 4| 3| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-22| 0.0| 3.0| null| null| null| null| null| null| null| 3.0| 3.0| 3.0| null| null| null| null| null| null| null| null| null| 0| null| null| null| 8| 22| 5| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-23| 0.0| 0.0| null| null| null| null| null| null| null| 1.5| 1.5| 1.5| null| null| null| null| null| null| null| null| null| 1| null| null| null| 9| 23| 6| 4| 4| 0| 1| 1|2015|\n", "|FOODS_1_002_TX_1_...|FOODS_1_002|FOODS_1| FOODS| TX_1| TX|2015-01-24| 0.0| 1.0| null| null| null| null| null| null| null| 1.3333333333333333| 1.3333333333333333| 1.3333333333333333| null| null| null| null| null| null| null| null| null| 0| null| null| null| 10| 24| 7| 4| 4| 1| 1| 1|2015|\n", "+--------------------+-----------+-------+------+--------+--------+----------+-----+-----+------+------+------+------+------+------+------+-------------------+--------------------+--------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+-----------------------------+------------------------------+------------------------------+------------------------------+--------------+------------+-----------+------------+-------------+-------+-------+-----+----+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "df_train = feature_extractor.transform(df).localCheckpoint()\n", "df_train.show(10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "We can not pass the features created by ``FeatureExtractor`` to ``ForecastFlowML`` for training. As mentioned in the lag feature creation step, we are going to set ``use_lag_range=28`` to use lags which are 28 days after from the most recent lag features. Also, as we know that the models that will be built are small enough to not cause memory problems, we are going to keep them as a class attribute by ``local_result=True`` argument." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
store_idforecast_horizonmodelstart_timeend_timeelapsed_seconds
0CA_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ...19-May-2023 (03:34:43)19-May-2023 (03:34:44)0.7
1TX_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ...19-May-2023 (03:34:44)19-May-2023 (03:34:45)0.7
2WI_1[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ...19-May-2023 (03:34:45)19-May-2023 (03:34:46)0.9
3TX_2[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,...[[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ...19-May-2023 (03:34:42)19-May-2023 (03:34:43)0.8
\n", "
" ], "text/plain": [ " store_id forecast_horizon \\\n", "0 CA_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "1 TX_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "2 WI_1 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "3 TX_2 [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13,... \n", "\n", " model start_time \\\n", "0 [[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ... 19-May-2023 (03:34:43) \n", "1 [[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ... 19-May-2023 (03:34:44) \n", "2 [[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ... 19-May-2023 (03:34:45) \n", "3 [[128, 4, 149, 236, 1, 0, 0, 0, 0, 0, 0, 140, ... 19-May-2023 (03:34:42) \n", "\n", " end_time elapsed_seconds \n", "0 19-May-2023 (03:34:44) 0.7 \n", "1 19-May-2023 (03:34:45) 0.7 \n", "2 19-May-2023 (03:34:46) 0.9 \n", "3 19-May-2023 (03:34:43) 0.8 " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forecast_flow = ForecastFlowML(\n", " group_col=\"store_id\",\n", " id_col=\"id\",\n", " date_col=\"date\",\n", " target_col=\"sales\",\n", " date_frequency=\"days\",\n", " model_horizon=7,\n", " max_forecast_horizon=28,\n", " model=LGBMRegressor(),\n", " use_lag_range=28,\n", ")\n", "forecast_flow.train(df_train, local_result=True)\n", "forecast_flow.model_" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Examine Features\n", "\n", "Let's examine which features are used for each model." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
store_idforecast_horizonfeatureimportance
0CA_1[1, 2, 3, 4, 5, 6, 7]day_of_week103
1CA_1[1, 2, 3, 4, 5, 6, 7]week_of_year126
2CA_1[1, 2, 3, 4, 5, 6, 7]week_of_month23
3CA_1[1, 2, 3, 4, 5, 6, 7]month32
4CA_1[1, 2, 3, 4, 5, 6, 7]quarter0
...............
283WI_1[22, 23, 24, 25, 26, 27, 28]lag_56184
284WI_1[22, 23, 24, 25, 26, 27, 28]window_7_lag_28_mean278
285WI_1[22, 23, 24, 25, 26, 27, 28]window_14_lag_28_mean325
286WI_1[22, 23, 24, 25, 26, 27, 28]window_30_lag_28_mean390
287WI_1[22, 23, 24, 25, 26, 27, 28]count_consecutive_value_lag_2874
\n", "

288 rows × 4 columns

\n", "
" ], "text/plain": [ " store_id forecast_horizon feature \\\n", "0 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_week \n", "1 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_year \n", "2 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_month \n", "3 CA_1 [1, 2, 3, 4, 5, 6, 7] month \n", "4 CA_1 [1, 2, 3, 4, 5, 6, 7] quarter \n", ".. ... ... ... \n", "283 WI_1 [22, 23, 24, 25, 26, 27, 28] lag_56 \n", "284 WI_1 [22, 23, 24, 25, 26, 27, 28] window_7_lag_28_mean \n", "285 WI_1 [22, 23, 24, 25, 26, 27, 28] window_14_lag_28_mean \n", "286 WI_1 [22, 23, 24, 25, 26, 27, 28] window_30_lag_28_mean \n", "287 WI_1 [22, 23, 24, 25, 26, 27, 28] count_consecutive_value_lag_28 \n", "\n", " importance \n", "0 103 \n", "1 126 \n", "2 23 \n", "3 32 \n", "4 0 \n", ".. ... \n", "283 184 \n", "284 278 \n", "285 325 \n", "286 390 \n", "287 74 \n", "\n", "[288 rows x 4 columns]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "importance = forecast_flow.get_feature_importance()\n", "importance" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Here we can see that the minimum lag used for first week model is lag_7, and for second week model is lag_14." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
store_idforecast_horizonfeatureimportance
0CA_1[1, 2, 3, 4, 5, 6, 7]day_of_week103
1CA_1[1, 2, 3, 4, 5, 6, 7]week_of_year126
2CA_1[1, 2, 3, 4, 5, 6, 7]week_of_month23
3CA_1[1, 2, 3, 4, 5, 6, 7]month32
4CA_1[1, 2, 3, 4, 5, 6, 7]quarter0
5CA_1[1, 2, 3, 4, 5, 6, 7]history_length199
6CA_1[1, 2, 3, 4, 5, 6, 7]weekend69
7CA_1[1, 2, 3, 4, 5, 6, 7]year0
8CA_1[1, 2, 3, 4, 5, 6, 7]day_of_month182
9CA_1[1, 2, 3, 4, 5, 6, 7]lag_7234
10CA_1[1, 2, 3, 4, 5, 6, 7]lag_14255
11CA_1[1, 2, 3, 4, 5, 6, 7]lag_21221
12CA_1[1, 2, 3, 4, 5, 6, 7]lag_28290
13CA_1[1, 2, 3, 4, 5, 6, 7]lag_35290
14CA_1[1, 2, 3, 4, 5, 6, 7]window_7_lag_7_mean345
15CA_1[1, 2, 3, 4, 5, 6, 7]window_14_lag_7_mean246
16CA_1[1, 2, 3, 4, 5, 6, 7]window_30_lag_7_mean363
17CA_1[1, 2, 3, 4, 5, 6, 7]count_consecutive_value_lag_722
18CA_1[8, 9, 10, 11, 12, 13, 14]day_of_week88
19CA_1[8, 9, 10, 11, 12, 13, 14]week_of_year183
20CA_1[8, 9, 10, 11, 12, 13, 14]week_of_month26
21CA_1[8, 9, 10, 11, 12, 13, 14]month18
22CA_1[8, 9, 10, 11, 12, 13, 14]quarter0
23CA_1[8, 9, 10, 11, 12, 13, 14]history_length257
24CA_1[8, 9, 10, 11, 12, 13, 14]weekend61
25CA_1[8, 9, 10, 11, 12, 13, 14]year0
26CA_1[8, 9, 10, 11, 12, 13, 14]day_of_month155
27CA_1[8, 9, 10, 11, 12, 13, 14]lag_14280
28CA_1[8, 9, 10, 11, 12, 13, 14]lag_21224
29CA_1[8, 9, 10, 11, 12, 13, 14]lag_28248
30CA_1[8, 9, 10, 11, 12, 13, 14]lag_35229
31CA_1[8, 9, 10, 11, 12, 13, 14]lag_42298
32CA_1[8, 9, 10, 11, 12, 13, 14]window_7_lag_14_mean288
33CA_1[8, 9, 10, 11, 12, 13, 14]window_14_lag_14_mean240
34CA_1[8, 9, 10, 11, 12, 13, 14]window_30_lag_14_mean384
35CA_1[8, 9, 10, 11, 12, 13, 14]count_consecutive_value_lag_1421
\n", "
" ], "text/plain": [ " store_id forecast_horizon feature \\\n", "0 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_week \n", "1 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_year \n", "2 CA_1 [1, 2, 3, 4, 5, 6, 7] week_of_month \n", "3 CA_1 [1, 2, 3, 4, 5, 6, 7] month \n", "4 CA_1 [1, 2, 3, 4, 5, 6, 7] quarter \n", "5 CA_1 [1, 2, 3, 4, 5, 6, 7] history_length \n", "6 CA_1 [1, 2, 3, 4, 5, 6, 7] weekend \n", "7 CA_1 [1, 2, 3, 4, 5, 6, 7] year \n", "8 CA_1 [1, 2, 3, 4, 5, 6, 7] day_of_month \n", "9 CA_1 [1, 2, 3, 4, 5, 6, 7] lag_7 \n", "10 CA_1 [1, 2, 3, 4, 5, 6, 7] lag_14 \n", "11 CA_1 [1, 2, 3, 4, 5, 6, 7] lag_21 \n", "12 CA_1 [1, 2, 3, 4, 5, 6, 7] lag_28 \n", "13 CA_1 [1, 2, 3, 4, 5, 6, 7] lag_35 \n", "14 CA_1 [1, 2, 3, 4, 5, 6, 7] window_7_lag_7_mean \n", "15 CA_1 [1, 2, 3, 4, 5, 6, 7] window_14_lag_7_mean \n", "16 CA_1 [1, 2, 3, 4, 5, 6, 7] window_30_lag_7_mean \n", "17 CA_1 [1, 2, 3, 4, 5, 6, 7] count_consecutive_value_lag_7 \n", "18 CA_1 [8, 9, 10, 11, 12, 13, 14] day_of_week \n", "19 CA_1 [8, 9, 10, 11, 12, 13, 14] week_of_year \n", "20 CA_1 [8, 9, 10, 11, 12, 13, 14] week_of_month \n", "21 CA_1 [8, 9, 10, 11, 12, 13, 14] month \n", "22 CA_1 [8, 9, 10, 11, 12, 13, 14] quarter \n", "23 CA_1 [8, 9, 10, 11, 12, 13, 14] history_length \n", "24 CA_1 [8, 9, 10, 11, 12, 13, 14] weekend \n", "25 CA_1 [8, 9, 10, 11, 12, 13, 14] year \n", "26 CA_1 [8, 9, 10, 11, 12, 13, 14] day_of_month \n", "27 CA_1 [8, 9, 10, 11, 12, 13, 14] lag_14 \n", "28 CA_1 [8, 9, 10, 11, 12, 13, 14] lag_21 \n", "29 CA_1 [8, 9, 10, 11, 12, 13, 14] lag_28 \n", "30 CA_1 [8, 9, 10, 11, 12, 13, 14] lag_35 \n", "31 CA_1 [8, 9, 10, 11, 12, 13, 14] lag_42 \n", "32 CA_1 [8, 9, 10, 11, 12, 13, 14] window_7_lag_14_mean \n", "33 CA_1 [8, 9, 10, 11, 12, 13, 14] window_14_lag_14_mean \n", "34 CA_1 [8, 9, 10, 11, 12, 13, 14] window_30_lag_14_mean \n", "35 CA_1 [8, 9, 10, 11, 12, 13, 14] count_consecutive_value_lag_14 \n", "\n", " importance \n", "0 103 \n", "1 126 \n", "2 23 \n", "3 32 \n", "4 0 \n", "5 199 \n", "6 69 \n", "7 0 \n", "8 182 \n", "9 234 \n", "10 255 \n", "11 221 \n", "12 290 \n", "13 290 \n", "14 345 \n", "15 246 \n", "16 363 \n", "17 22 \n", "18 88 \n", "19 183 \n", "20 26 \n", "21 18 \n", "22 0 \n", "23 257 \n", "24 61 \n", "25 0 \n", "26 155 \n", "27 280 \n", "28 224 \n", "29 248 \n", "30 229 \n", "31 298 \n", "32 288 \n", "33 240 \n", "34 384 \n", "35 21 " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "importance[importance[\"store_id\"] == \"CA_1\"].head(36)" ] } ], "metadata": { "kernelspec": { "display_name": "sspark37", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }