diff --git a/docs/analyzers.md b/docs/analyzers.md index 6d8244e..1220822 100644 --- a/docs/analyzers.md +++ b/docs/analyzers.md @@ -16,6 +16,7 @@ Here are the current supported functionalities of Analyzers. | Compliance | Compliance(instance, predicate) | Done| | Correlation | Correlation(column1, column2) | Done| | CountDistinct | CountDistinct(columns) | Done| +| CustomSql | CustomSql(expression, disambiguator) | Done| | Datatype | Datatype(column) | Done| | Distinctness | Distinctness(columns) | Done| | Entropy | Entropy(column) | Done| diff --git a/pydeequ/analyzers.py b/pydeequ/analyzers.py index 3952c93..cc22196 100644 --- a/pydeequ/analyzers.py +++ b/pydeequ/analyzers.py @@ -360,6 +360,30 @@ def _analyzer_jvm(self): return self._deequAnalyzers.CountDistinct(to_scala_seq(self._jvm, self.columns)) +class CustomSql(_AnalyzerObject): + """ + A custom SQL-based analyzer executing provided SQL expression. + The expression must return a single value. + + :param str expression: A SQL expression to execute. + :param str disambiguator: A label used to distinguish this metric + when running multiple custom SQL analyzers. Defaults to "*". + """ + + def __init__(self, expression: str, disambiguator: str = "*"): + self.expression = expression + self.disambiguator = disambiguator + + @property + def _analyzer_jvm(self): + """ + Returns the result of SQL expression execution. + + :return self + """ + return self._deequAnalyzers.CustomSql(self.expression, self.disambiguator) + + class DataType(_AnalyzerObject): """ Data Type Analyzer. Returns the datatypes of column diff --git a/tests/test_analyzers.py b/tests/test_analyzers.py index 175e8ae..67b225f 100644 --- a/tests/test_analyzers.py +++ b/tests/test_analyzers.py @@ -14,6 +14,7 @@ Compliance, Correlation, CountDistinct, + CustomSql, DataType, Distinctness, Entropy, @@ -111,6 +112,14 @@ def CountDistinct(self, columns): df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) return result_df.select("value").collect() + + def CustomSql(self, expression, disambiguator="*"): + result = self.AnalysisRunner.onData(self.df).addAnalyzer(CustomSql(expression, disambiguator)).run() + result_df = AnalyzerContext.successMetricsAsDataFrame(self.spark, result) + result_json = AnalyzerContext.successMetricsAsJson(self.spark, result) + df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) + self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) + return result_df.select("value", "instance").collect() def Datatype(self, column, where=None): result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run() @@ -298,6 +307,26 @@ def test_CountDistinct(self): def test_fail_CountDistinct(self): self.assertEqual(self.CountDistinct("b"), [Row(value=1.0)]) + def test_CustomSql(self): + self.df.createOrReplaceTempView("input_table") + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0, instance="*")]) + self.assertEqual( + self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table", disambiguator="foo"), + [Row(value=3.0, instance="foo")] + ) + self.assertEqual( + self.CustomSql("SELECT MAX(c) FROM input_table", disambiguator="bar"), + [Row(value=6.0, instance="bar")] + ) + + @pytest.mark.xfail(reason="@unittest.expectedFailure") + def test_fail_CustomSql(self): + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)]) + + @pytest.mark.xfail(reason="@unittest.expectedFailure") + def test_fail_CustomSql_incorrect_query(self): + self.CustomSql("SELECT SUM(b)") + def test_DataType(self): self.assertEqual( self.Datatype("b"),