You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gentoo-overlay/dev-python/numpy/files/numpy-1.20.3-float-hashing-...

130 lines
4.6 KiB

From ad2a73c18dcff95d844c382c94ab7f73b5571cf3 Mon Sep 17 00:00:00 2001
From: Sebastian Berg <sebastian@sipsolutions.net>
Date: Tue, 4 May 2021 17:43:26 -0500
Subject: [PATCH] MAINT: Adjust NumPy float hashing to Python's slightly
changed hash
This is necessary, since we use the Python double hash and the
semi-private function to calculate it in Python has a new signature
to return the identity-hash when the value is NaN.
closes gh-18833, gh-18907
---
numpy/core/src/common/npy_pycompat.h | 16 ++++++++++
numpy/core/src/multiarray/scalartypes.c.src | 13 ++++----
numpy/core/tests/test_scalarmath.py | 34 +++++++++++++++++++++
3 files changed, 57 insertions(+), 6 deletions(-)
diff --git a/numpy/core/src/common/npy_pycompat.h b/numpy/core/src/common/npy_pycompat.h
index aa0b5c1224d..9e94a971090 100644
--- a/numpy/core/src/common/npy_pycompat.h
+++ b/numpy/core/src/common/npy_pycompat.h
@@ -3,4 +3,20 @@
#include "numpy/npy_3kcompat.h"
+
+/*
+ * In Python 3.10a7 (or b1), python started using the identity for the hash
+ * when a value is NaN. See https://bugs.python.org/issue43475
+ */
+#if PY_VERSION_HEX > 0x030a00a6
+#define Npy_HashDouble _Py_HashDouble
+#else
+static NPY_INLINE Py_hash_t
+Npy_HashDouble(PyObject *NPY_UNUSED(identity), double val)
+{
+ return _Py_HashDouble(val);
+}
+#endif
+
+
#endif /* _NPY_COMPAT_H_ */
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index a001500b0a9..9930f7791d6 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -3172,7 +3172,7 @@ static npy_hash_t
static npy_hash_t
@lname@_arrtype_hash(PyObject *obj)
{
- return _Py_HashDouble((double) PyArrayScalar_VAL(obj, @name@));
+ return Npy_HashDouble(obj, (double)PyArrayScalar_VAL(obj, @name@));
}
/* borrowed from complex_hash */
@@ -3180,14 +3180,14 @@ static npy_hash_t
c@lname@_arrtype_hash(PyObject *obj)
{
npy_hash_t hashreal, hashimag, combined;
- hashreal = _Py_HashDouble((double)
- PyArrayScalar_VAL(obj, C@name@).real);
+ hashreal = Npy_HashDouble(
+ obj, (double)PyArrayScalar_VAL(obj, C@name@).real);
if (hashreal == -1) {
return -1;
}
- hashimag = _Py_HashDouble((double)
- PyArrayScalar_VAL(obj, C@name@).imag);
+ hashimag = Npy_HashDouble(
+ obj, (double)PyArrayScalar_VAL(obj, C@name@).imag);
if (hashimag == -1) {
return -1;
}
@@ -3202,7 +3202,8 @@ c@lname@_arrtype_hash(PyObject *obj)
static npy_hash_t
half_arrtype_hash(PyObject *obj)
{
- return _Py_HashDouble(npy_half_to_double(PyArrayScalar_VAL(obj, Half)));
+ return Npy_HashDouble(
+ obj, npy_half_to_double(PyArrayScalar_VAL(obj, Half)));
}
static npy_hash_t
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index d91b4a39146..09a734284a7 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -712,6 +712,40 @@ def test_shift_all_bits(self, type_code, op):
assert_equal(res_arr, res_scl)
+class TestHash:
+ @pytest.mark.parametrize("type_code", np.typecodes['AllInteger'])
+ def test_integer_hashes(self, type_code):
+ scalar = np.dtype(type_code).type
+ for i in range(128):
+ assert hash(i) == hash(scalar(i))
+
+ @pytest.mark.parametrize("type_code", np.typecodes['AllFloat'])
+ def test_float_and_complex_hashes(self, type_code):
+ scalar = np.dtype(type_code).type
+ for val in [np.pi, np.inf, 3, 6.]:
+ numpy_val = scalar(val)
+ # Cast back to Python, in case the NumPy scalar has less precision
+ if numpy_val.dtype.kind == 'c':
+ val = complex(numpy_val)
+ else:
+ val = float(numpy_val)
+ assert val == numpy_val
+ print(repr(numpy_val), repr(val))
+ assert hash(val) == hash(numpy_val)
+
+ if hash(float(np.nan)) != hash(float(np.nan)):
+ # If Python distinguises different NaNs we do so too (gh-18833)
+ assert hash(scalar(np.nan)) != hash(scalar(np.nan))
+
+ @pytest.mark.parametrize("type_code", np.typecodes['Complex'])
+ def test_complex_hashes(self, type_code):
+ # Test some complex valued hashes specifically:
+ scalar = np.dtype(type_code).type
+ for val in [np.pi+1j, np.inf-3j, 3j, 6.+1j]:
+ numpy_val = scalar(val)
+ assert hash(complex(numpy_val)) == hash(numpy_val)
+
+
@contextlib.contextmanager
def recursionlimit(n):
o = sys.getrecursionlimit()