diff --git a/pydeequ/configs.py b/pydeequ/configs.py index d4d4b31..3f3d9ee 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +import logging from functools import lru_cache import os import re - +import pyspark SPARK_TO_DEEQU_COORD_MAPPING = { "3.5": "com.amazon.deequ:deequ:2.0.7-spark-3.5", @@ -22,7 +23,12 @@ def _extract_major_minor_versions(full_version: str): @lru_cache(maxsize=None) def _get_spark_version() -> str: try: - spark_version = os.environ["SPARK_VERSION"] + spark_version = os.getenv("SPARK_VERSION") + if not spark_version: + spark_version = str(pyspark.__version__) + logging.info( + f"SPARK_VERSION environment variable is not set, using Spark version from PySpark {spark_version} for Deequ Maven jars" + ) except KeyError: raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}") diff --git a/tests/test_config.py b/tests/test_config.py index c2956b3..d61ccfd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,17 @@ +import os +from unittest import mock + +import pyspark import pytest -from pydeequ.configs import _extract_major_minor_versions + +from pydeequ.configs import _extract_major_minor_versions, _get_spark_version + + +@pytest.fixture +def mock_env(monkeypatch): + with mock.patch.dict(os.environ, clear=True): + monkeypatch.delenv("SPARK_VERSION", raising=False) + yield @pytest.mark.parametrize( @@ -13,3 +25,24 @@ ) def test_extract_major_minor_versions(full_version, major_minor_version): assert _extract_major_minor_versions(full_version) == major_minor_version + + +@pytest.mark.parametrize( + "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.1"), ("3.10.3", "3.10"), ("3.10", "3.10")] +) +def test__get_spark_versione(spark_version, expected, mock_env): + try: + _get_spark_version.cache_clear() + with mock.patch.object(pyspark, "__version__", spark_version): + assert _get_spark_version() == expected + finally: + _get_spark_version.cache_clear() + + + +@pytest.mark.parametrize( + "spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.2"), ("3.10.3", "3.2"), ("3.10", "3.2")] +) +def test__get_spark_version_with_cache(spark_version, expected, mock_env): + with mock.patch.object(pyspark, "__version__", spark_version): + assert _get_spark_version() == expected