diff --git a/pydeequ/scala_utils.py b/pydeequ/scala_utils.py index b6d3e83..e001ccf 100644 --- a/pydeequ/scala_utils.py +++ b/pydeequ/scala_utils.py @@ -37,6 +37,9 @@ def apply(self, arg): """Implements the apply function""" return self.lambda_function(arg) + def hashCode(self): + return self.gateway.jvm.java.lang.Integer.hashCode(hash(self.lambda_function)) + class Java: """scala.Function1: a function that takes one argument""" @@ -55,6 +58,9 @@ def apply(self, t1, t2): """Implements the apply function""" return self.lambda_function(t1, t2) + def hashCode(self): + return self.gateway.jvm.java.lang.Integer.hashCode(hash(self.lambda_function)) + class Java: """scala.Function2: a function that takes two arguments""" diff --git a/tests/test_checks.py b/tests/test_checks.py index ae2402d..873aaf1 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -1366,4 +1366,25 @@ def test_fail_list_of_constraints(self): .run() df = VerificationResult.checkResultsAsDataFrame(self.spark, result) - self.assertEqual(df.select("constraint_status").collect(), [Row(constraint_status="Success"), Row(constraint_status="Success")]) \ No newline at end of file + self.assertEqual(df.select("constraint_status").collect(), [Row(constraint_status="Success"), Row(constraint_status="Success")]) + + def test_hash_code(self): + """ + Lack of Exception is passing. Previously this test would fail with: + AttributeError: 'ScalaFunction1' object has no attribute 'hashCode' + """ + vrb = VerificationSuite(self.spark) \ + .onData(self.df) + check = Check(self.spark, CheckLevel.Error, "Enough checks to trigger a hashCode not an attribute of ScalaFunction1") + check.isComplete('b') + vrb.addCheck(check) + check.containsEmail('email') + vrb.addCheck(check) + check.isGreaterThanOrEqualTo("d", "b") + vrb.addCheck(check) + check.isLessThanOrEqualTo("b", "d") + vrb.addCheck(check) + check.hasDataType("d", ConstrainableDataTypes.String, lambda x: x >= 1) + vrb.addCheck(check) + + result = vrb.run() diff --git a/tests/test_scala_utils.py b/tests/test_scala_utils.py index 0dfaf35..a3815aa 100644 --- a/tests/test_scala_utils.py +++ b/tests/test_scala_utils.py @@ -25,13 +25,24 @@ def test_scala_function1(self): self.assertFalse(notNoneTest.apply(None)) self.assertTrue(notNoneTest.apply("foo")) + # Test hashCode() + self.assertNotEqual(greaterThan10.hashCode(), notNoneTest.hashCode()) + self.assertTrue(isinstance(greaterThan10.hashCode(), int)) + appendTest = ScalaFunction1(self.sc._gateway, "{}test".format) self.assertEqual("xtest", appendTest.apply("x")) def test_scala_function2(self): - concatFunction = ScalaFunction2(self.sc._gateway, lambda x, y: x + y) + lambda_func = lambda x, y: x + y + concatFunction = ScalaFunction2(self.sc._gateway, lambda_func) self.assertEqual("ab", concatFunction.apply("a", "b")) + anotherConcatFunction = ScalaFunction2(self.sc._gateway, lambda_func) + + # Test hashCode() + self.assertEqual(concatFunction.hashCode(), anotherConcatFunction.hashCode()) + self.assertTrue(isinstance(concatFunction.hashCode(), int)) + if __name__ == "__main__": unittest.main()