diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 26c85cf6c..3578b538d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 diff --git a/Dockerfile b/Dockerfile index 32a69050d..5c9365c86 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,30 +4,12 @@ FROM apache/spark-py:v3.2.1 # Change to root user for installation steps USER 0 -# Uninstall existing python and replace it with miniconda. -# This is to get the right version of Python in Debian, since Prophet doesn't play nice with Python 3.9+. -# FIXME: maybe optimize the size? this image is currently 3.2GB. -RUN apt-get update && \ - apt-get remove -y python3 python3-pip && \ - apt-get install -y --no-install-recommends curl && \ - apt-get autoremove -yqq --purge && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* -RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh && \ - # Install prophet while we're at it, since this is easier to conda install than pip install - /opt/conda/bin/conda install -y prophet && \ - /opt/conda/bin/conda clean -ya -ENV PATH="/opt/conda/bin:${SPARK_HOME}/bin:${PATH}" - -# Install (for spark-sql) and Merlion; get pyspark & py4j from the PYTHONPATH +# Install pyarrow (for spark-sql) and Merlion; get pyspark & py4j from the PYTHONPATH ENV PYTHONPATH="${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-0.10.9.3-src.zip:${PYTHONPATH}" COPY *.md ./ COPY setup.py ./ COPY merlion merlion -RUN pip install pyarrow "./[prophet]" && pip uninstall -y py4j +RUN pip install pyarrow "./" && pip uninstall -y py4j # Copy Merlion pyspark apps COPY spark /opt/spark/apps diff --git a/merlion/models/automl/autosarima.py b/merlion/models/automl/autosarima.py index 98a624b0e..73ef51484 100644 --- a/merlion/models/automl/autosarima.py +++ b/merlion/models/automl/autosarima.py @@ -7,10 +7,9 @@ """ Automatic hyperparameter selection for SARIMA. """ -from collections import Iterator from copy import copy, deepcopy import logging -from typing import Any, Optional, Tuple, Union +from typing import Any, Iterator, Optional, Tuple, Union import numpy as np diff --git a/merlion/models/forecast/prophet.py b/merlion/models/forecast/prophet.py index 826e9b766..bad4d20c0 100644 --- a/merlion/models/forecast/prophet.py +++ b/merlion/models/forecast/prophet.py @@ -7,21 +7,15 @@ """ Wrapper around Facebook's popular Prophet model for time series forecasting. """ +import copy import logging import os from typing import Iterable, List, Tuple, Union -try: - import prophet -except ImportError as e: - err_msg = ( - "Try installing Merlion with optional dependencies using `pip install salesforce-merlion[prophet]` or " - "`pip install `salesforce-merlion[all]`" - ) - raise ImportError(str(e) + ". " + err_msg) - import numpy as np import pandas as pd +import prophet +import prophet.serialize from merlion.models.automl.seasonality import SeasonalityModel from merlion.models.forecast.base import ForecasterBase, ForecasterConfig @@ -144,14 +138,19 @@ def require_even_sampling(self) -> bool: return False def __getstate__(self): - stan_backend = self.model.stan_backend - if hasattr(stan_backend, "logger"): - model_logger = self.model.stan_backend.logger - self.model.stan_backend.logger = None - state_dict = super().__getstate__() - if hasattr(stan_backend, "logger"): - self.model.stan_backend.logger = model_logger - return state_dict + try: + model = prophet.serialize.model_to_json(self.model) + except ValueError: # prophet.serialize only works for fitted models, so deepcopy as a backup + model = copy.deepcopy(self.model) + return {k: model if k == "model" else copy.deepcopy(v) for k, v in self.__dict__.items()} + + def __setstate__(self, state): + if "model" in state: + model = state["model"] + if isinstance(model, str): + state = copy.copy(state) + state["model"] = prophet.serialize.model_from_json(model) + super().__setstate__(state) @property def yearly_seasonality(self): diff --git a/setup.py b/setup.py index dcd66e19f..e2e44a43a 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,7 @@ ] # optional dependencies -extra_require = { - "plot": ["plotly>=4.13"], - "prophet": ["prophet", "pystan<3.0"], # pystan >= 3.0 doesn't work with prophet - "deep-learning": ["torch>=1.1.0"], - "spark": ["pyspark[sql]>=3"], -} +extra_require = {"plot": ["plotly>=4.13"], "deep-learning": ["torch>=1.1.0"], "spark": ["pyspark[sql]>=3"]} extra_require["all"] = sum(extra_require.values(), []) @@ -29,7 +24,7 @@ def read_file(fname): setup( name="salesforce-merlion", - version="1.2.2", + version="1.2.3", author=", ".join(read_file("AUTHORS.md").split("\n")), author_email="abhatnagar@salesforce.com", description="Merlion: A Machine Learning Framework for Time Series Intelligence", @@ -52,6 +47,8 @@ def read_file(fname): "numpy>=1.19; python_version < '3.7'", # however, numpy 1.20+ requires python 3.7+ "packaging", "pandas>=1.1.0", # >=1.1.0 for origin kwarg to df.resample() + "prophet>=1.1; python_version >= '3.7'", # 1.1 removes dependency on pystan + "prophet==1.0.1; python_version < '3.7'", # however, prophet 1.1 requires python 3.7+ "scikit-learn>=0.22", # >=0.22 for changes to isolation forest algorithm "scipy>=1.6.0; python_version >= '3.7'", # 1.6.0 adds multivariate_t density to scipy.stats "scipy>=1.5.0; python_version < '3.7'", # however, scipy 1.6.0 requires python 3.7+