From 0d89182f1fe310051d49287480ea54d6aadcb680 Mon Sep 17 00:00:00 2001 From: wwitkowski Date: Mon, 19 May 2025 17:22:41 +0200 Subject: [PATCH 1/4] add customsql analyzer --- pydeequ/analyzers.py | 24 ++++++++++++++++++++++++ tests/test_analyzers.py | 19 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pydeequ/analyzers.py b/pydeequ/analyzers.py index 3952c93..0615833 100644 --- a/pydeequ/analyzers.py +++ b/pydeequ/analyzers.py @@ -360,6 +360,30 @@ def _analyzer_jvm(self): return self._deequAnalyzers.CountDistinct(to_scala_seq(self._jvm, self.columns)) +class CustomSql(_AnalyzerObject): + """ + A custom SQL-based analyzer executing provided SQL expression. + The expression must return a single value. + + :param str expression: A SQL expression to execute. + :param str where: A label used to distinguish this metric + when running multiple custom SQL analyzers. Defaults to "*". + """ + + def __init__(self, expression: str, disambiguator: str = "*"): + self.expression = expression + self.disambiguator = disambiguator + + @property + def _analyzer_jvm(self): + """ + Returns the result of SQL expression execution. + + :return self + """ + return self._deequAnalyzers.CustomSql(self.expression, self.disambiguator) + + class DataType(_AnalyzerObject): """ Data Type Analyzer. Returns the datatypes of column diff --git a/tests/test_analyzers.py b/tests/test_analyzers.py index 175e8ae..885350d 100644 --- a/tests/test_analyzers.py +++ b/tests/test_analyzers.py @@ -14,6 +14,7 @@ Compliance, Correlation, CountDistinct, + CustomSql, DataType, Distinctness, Entropy, @@ -111,6 +112,14 @@ def CountDistinct(self, columns): df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) return result_df.select("value").collect() + + def CustomSql(self, expression, disambiguator="*"): + result = self.AnalysisRunner.onData(self.df).addAnalyzer(CustomSql(expression, disambiguator)).run() + result_df = AnalyzerContext.successMetricsAsDataFrame(self.spark, result) + result_json = AnalyzerContext.successMetricsAsJson(self.spark, result) + df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) + self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) + return result_df.select("value").collect() def Datatype(self, column, where=None): result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run() @@ -298,6 +307,16 @@ def test_CountDistinct(self): def test_fail_CountDistinct(self): self.assertEqual(self.CountDistinct("b"), [Row(value=1.0)]) + def test_CustomSql(self): + self.df.createOrReplaceTempView("input_table") + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0)]) + self.assertEqual(self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table"), [Row(value=3.0)]) + self.assertEqual(self.CustomSql("SELECT MAX(c) FROM input_table"), [Row(value=6.0)]) + + @pytest.mark.xfail(reason="@unittest.expectedFailure") + def test_fail_CustomSql(self): + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)]) + def test_DataType(self): self.assertEqual( self.Datatype("b"), From cd740d01746007b0298fe30de92f4f10575d078a Mon Sep 17 00:00:00 2001 From: wwitkowski Date: Sun, 25 May 2025 09:33:26 +0200 Subject: [PATCH 2/4] add unit tests for disambiguator and incorrect query --- pydeequ/analyzers.py | 2 +- tests/test_analyzers.py | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pydeequ/analyzers.py b/pydeequ/analyzers.py index 0615833..cc22196 100644 --- a/pydeequ/analyzers.py +++ b/pydeequ/analyzers.py @@ -366,7 +366,7 @@ class CustomSql(_AnalyzerObject): The expression must return a single value. :param str expression: A SQL expression to execute. - :param str where: A label used to distinguish this metric + :param str disambiguator: A label used to distinguish this metric when running multiple custom SQL analyzers. Defaults to "*". """ diff --git a/tests/test_analyzers.py b/tests/test_analyzers.py index 885350d..0f180b7 100644 --- a/tests/test_analyzers.py +++ b/tests/test_analyzers.py @@ -3,6 +3,7 @@ import pytest from pyspark.sql import Row +from pyspark.errors import AnalysisException from pydeequ import PyDeequSession from pydeequ.analyzers import ( @@ -119,7 +120,7 @@ def CustomSql(self, expression, disambiguator="*"): result_json = AnalyzerContext.successMetricsAsJson(self.spark, result) df_from_json = self.spark.read.json(self.sc.parallelize([result_json])) self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect()) - return result_df.select("value").collect() + return result_df.select("value", "instance").collect() def Datatype(self, column, where=None): result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run() @@ -309,14 +310,25 @@ def test_fail_CountDistinct(self): def test_CustomSql(self): self.df.createOrReplaceTempView("input_table") - self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0)]) - self.assertEqual(self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table"), [Row(value=3.0)]) - self.assertEqual(self.CustomSql("SELECT MAX(c) FROM input_table"), [Row(value=6.0)]) + self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0, instance="*")]) + self.assertEqual( + self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table", disambiguator="foo"), + [Row(value=3.0, instance="foo")] + ) + self.assertEqual( + self.CustomSql("SELECT MAX(c) FROM input_table", disambiguator="bar"), + [Row(value=6.0, instance="bar")] + ) @pytest.mark.xfail(reason="@unittest.expectedFailure") def test_fail_CustomSql(self): self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)]) + @pytest.mark.xfail(reason="@unittest.expectedFailure") + def test_fail_CustomSql_incorrect_query(self): + with self.assertRaises(AnalysisException): + self.CustomSql("SELECT SUM(b)") + def test_DataType(self): self.assertEqual( self.Datatype("b"), From bca4cf1119bdc60d1d2742a19b904f481d9b1e0b Mon Sep 17 00:00:00 2001 From: wwitkowski Date: Sun, 25 May 2025 09:42:41 +0200 Subject: [PATCH 3/4] update incorrect query unit test --- tests/test_analyzers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_analyzers.py b/tests/test_analyzers.py index 0f180b7..67b225f 100644 --- a/tests/test_analyzers.py +++ b/tests/test_analyzers.py @@ -3,7 +3,6 @@ import pytest from pyspark.sql import Row -from pyspark.errors import AnalysisException from pydeequ import PyDeequSession from pydeequ.analyzers import ( @@ -326,8 +325,7 @@ def test_fail_CustomSql(self): @pytest.mark.xfail(reason="@unittest.expectedFailure") def test_fail_CustomSql_incorrect_query(self): - with self.assertRaises(AnalysisException): - self.CustomSql("SELECT SUM(b)") + self.CustomSql("SELECT SUM(b)") def test_DataType(self): self.assertEqual( From 636ff93cc96f902cce52cf6966f1bda150288171 Mon Sep 17 00:00:00 2001 From: wwitkowski Date: Fri, 30 May 2025 12:45:53 +0200 Subject: [PATCH 4/4] update docs --- docs/analyzers.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/analyzers.md b/docs/analyzers.md index 6d8244e..1220822 100644 --- a/docs/analyzers.md +++ b/docs/analyzers.md @@ -16,6 +16,7 @@ Here are the current supported functionalities of Analyzers. | Compliance | Compliance(instance, predicate) | Done| | Correlation | Correlation(column1, column2) | Done| | CountDistinct | CountDistinct(columns) | Done| +| CustomSql | CustomSql(expression, disambiguator) | Done| | Datatype | Datatype(column) | Done| | Distinctness | Distinctness(columns) | Done| | Entropy | Entropy(column) | Done|