aboutsummaryrefslogtreecommitdiff
path: root/tests/test_numpy_array.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_numpy_array.py')
-rw-r--r--tests/test_numpy_array.py169
1 files changed, 148 insertions, 21 deletions
diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py
index 02f3ecfc..12e7d17d 100644
--- a/tests/test_numpy_array.py
+++ b/tests/test_numpy_array.py
@@ -1,8 +1,6 @@
-# -*- coding: utf-8 -*-
import pytest
import env # noqa: F401
-
from pybind11_tests import numpy_array as m
np = pytest.importorskip("numpy")
@@ -20,13 +18,11 @@ def test_dtypes():
assert check.numpy == check.pybind11, check
if check.numpy.num != check.pybind11.num:
print(
- "NOTE: typenum mismatch for {}: {} != {}".format(
- check, check.numpy.num, check.pybind11.num
- )
+ f"NOTE: typenum mismatch for {check}: {check.numpy.num} != {check.pybind11.num}"
)
-@pytest.fixture(scope="function")
+@pytest.fixture()
def arr():
return np.array([[1, 2, 3], [4, 5, 6]], "=u2")
@@ -71,7 +67,7 @@ def test_array_attributes():
@pytest.mark.parametrize(
- "args, ret", [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
+ ("args", "ret"), [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
)
def test_index_offset(arr, args, ret):
assert m.index_at(arr, *args) == ret
@@ -97,7 +93,7 @@ def test_dim_check_fail(arr):
@pytest.mark.parametrize(
- "args, ret",
+ ("args", "ret"),
[
([], [1, 2, 3, 4, 5, 6]),
([1], [4, 5, 6]),
@@ -118,9 +114,7 @@ def test_at_fail(arr, dim):
for func in m.at_t, m.mutate_at_t:
with pytest.raises(IndexError) as excinfo:
func(arr, *([0] * dim))
- assert str(excinfo.value) == "index dimension mismatch: {} (ndim = 2)".format(
- dim
- )
+ assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
def test_at(arr):
@@ -194,8 +188,6 @@ def test_make_empty_shaped_array():
def test_wrap():
def assert_references(a, b, base=None):
- from distutils.version import LooseVersion
-
if base is None:
base = a
assert a is not b
@@ -206,7 +198,8 @@ def test_wrap():
assert a.flags.f_contiguous == b.flags.f_contiguous
assert a.flags.writeable == b.flags.writeable
assert a.flags.aligned == b.flags.aligned
- if LooseVersion(np.__version__) >= LooseVersion("1.14.0"):
+ # 1.13 supported Python 3.6
+ if tuple(int(x) for x in np.__version__.split(".")[:2]) >= (1, 14):
assert a.flags.writebackifcopy == b.flags.writebackifcopy
else:
assert a.flags.updateifcopy == b.flags.updateifcopy
@@ -218,12 +211,14 @@ def test_wrap():
assert b[0, 0] == 1234
a1 = np.array([1, 2], dtype=np.int16)
- assert a1.flags.owndata and a1.base is None
+ assert a1.flags.owndata
+ assert a1.base is None
a2 = m.wrap(a1)
assert_references(a1, a2)
a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order="F")
- assert a1.flags.owndata and a1.base is None
+ assert a1.flags.owndata
+ assert a1.base is None
a2 = m.wrap(a1)
assert_references(a1, a2)
@@ -412,7 +407,7 @@ def test_array_unchecked_fixed_dims(msg):
assert m.proxy_auxiliaries2_const_ref(z1)
-def test_array_unchecked_dyn_dims(msg):
+def test_array_unchecked_dyn_dims():
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
m.proxy_add2_dyn(z1, 10)
assert np.all(z1 == [[11, 12], [13, 14]])
@@ -445,7 +440,7 @@ def test_initializer_list():
assert m.array_initializer_list4().shape == (1, 2, 3, 4)
-def test_array_resize(msg):
+def test_array_resize():
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
m.array_reshape2(a)
assert a.size == 9
@@ -458,30 +453,85 @@ def test_array_resize(msg):
try:
m.array_resize3(a, 3, True)
except ValueError as e:
- assert str(e).startswith("cannot resize an array")
+ assert str(e).startswith("cannot resize an array") # noqa: PT017
# transposed array doesn't own data
b = a.transpose()
try:
m.array_resize3(b, 3, False)
except ValueError as e:
- assert str(e).startswith("cannot resize this array: it does not own its data")
+ assert str(e).startswith( # noqa: PT017
+ "cannot resize this array: it does not own its data"
+ )
# ... but reshape should be fine
m.array_reshape2(b)
assert b.shape == (8, 8)
@pytest.mark.xfail("env.PYPY")
-def test_array_create_and_resize(msg):
+def test_array_create_and_resize():
a = m.create_and_resize(2)
assert a.size == 4
assert np.all(a == 42.0)
+def test_array_view():
+ a = np.ones(100 * 4).astype("uint8")
+ a_float_view = m.array_view(a, "float32")
+ assert a_float_view.shape == (100 * 1,) # 1 / 4 bytes = 8 / 32
+
+ a_int16_view = m.array_view(a, "int16") # 1 / 2 bytes = 16 / 32
+ assert a_int16_view.shape == (100 * 2,)
+
+
+def test_array_view_invalid():
+ a = np.ones(100 * 4).astype("uint8")
+ with pytest.raises(TypeError):
+ m.array_view(a, "deadly_dtype")
+
+
+def test_reshape_initializer_list():
+ a = np.arange(2 * 7 * 3) + 1
+ x = m.reshape_initializer_list(a, 2, 7, 3)
+ assert x.shape == (2, 7, 3)
+ assert list(x[1][4]) == [34, 35, 36]
+ with pytest.raises(ValueError) as excinfo:
+ m.reshape_initializer_list(a, 1, 7, 3)
+ assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)"
+
+
+def test_reshape_tuple():
+ a = np.arange(3 * 7 * 2) + 1
+ x = m.reshape_tuple(a, (3, 7, 2))
+ assert x.shape == (3, 7, 2)
+ assert list(x[1][4]) == [23, 24]
+ y = m.reshape_tuple(x, (x.size,))
+ assert y.shape == (42,)
+ with pytest.raises(ValueError) as excinfo:
+ m.reshape_tuple(a, (3, 7, 1))
+ assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)"
+ with pytest.raises(ValueError) as excinfo:
+ m.reshape_tuple(a, ())
+ assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()"
+
+
def test_index_using_ellipsis():
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
assert a.shape == (6,)
+@pytest.mark.parametrize(
+ "test_func",
+ [
+ m.test_fmt_desc_float,
+ m.test_fmt_desc_double,
+ m.test_fmt_desc_const_float,
+ m.test_fmt_desc_const_double,
+ ],
+)
+def test_format_descriptors_for_floating_point_types(test_func):
+ assert "numpy.ndarray[numpy.float" in test_func.__doc__
+
+
@pytest.mark.parametrize("forcecast", [False, True])
@pytest.mark.parametrize("contiguity", [None, "C", "F"])
@pytest.mark.parametrize("noconvert", [False, True])
@@ -539,3 +589,80 @@ def test_dtype_refcount_leak():
m.ndim(a)
after = getrefcount(dtype)
assert after == before
+
+
+def test_round_trip_float():
+ arr = np.zeros((), np.float64)
+ arr[()] = 37.2
+ assert m.round_trip_float(arr) == 37.2
+
+
+# HINT: An easy and robust way (although only manual unfortunately) to check for
+# ref-count leaks in the test_.*pyobject_ptr.* functions below is to
+# * temporarily insert `while True:` (one-by-one),
+# * run this test, and
+# * run the Linux `top` command in another shell to visually monitor
+# `RES` for a minute or two.
+# If there is a leak, it is usually evident in seconds because the `RES`
+# value increases without bounds. (Don't forget to Ctrl-C the test!)
+
+
+# For use as a temporary user-defined object, to maximize sensitivity of the tests below:
+# * Ref-count leaks will be immediately evident.
+# * Sanitizers are much more likely to detect heap-use-after-free due to
+# other ref-count bugs.
+class PyValueHolder:
+ def __init__(self, value):
+ self.value = value
+
+
+def WrapWithPyValueHolder(*values):
+ return [PyValueHolder(v) for v in values]
+
+
+def UnwrapPyValueHolder(vhs):
+ return [vh.value for vh in vhs]
+
+
+def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray():
+ # Intentionally all temporaries, do not change.
+ assert (
+ m.pass_array_pyobject_ptr_return_sum_str_values(
+ np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)
+ )
+ == "-3four5.0"
+ )
+
+
+def test_pass_array_pyobject_ptr_return_sum_str_values_list():
+ # Intentionally all temporaries, do not change.
+ assert (
+ m.pass_array_pyobject_ptr_return_sum_str_values(
+ WrapWithPyValueHolder(2, "three", -4.0)
+ )
+ == "2three-4.0"
+ )
+
+
+def test_pass_array_pyobject_ptr_return_as_list():
+ # Intentionally all temporaries, do not change.
+ assert UnwrapPyValueHolder(
+ m.pass_array_pyobject_ptr_return_as_list(
+ np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)
+ )
+ ) == [-1, "two", 3.0]
+
+
+@pytest.mark.parametrize(
+ ("return_array_pyobject_ptr", "unwrap"),
+ [
+ (m.return_array_pyobject_ptr_cpp_loop, list),
+ (m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder),
+ ],
+)
+def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap):
+ # Intentionally all temporaries, do not change.
+ arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0))
+ assert isinstance(arr_from_list, np.ndarray)
+ assert arr_from_list.dtype == np.dtype("O")
+ assert unwrap(arr_from_list) == [6, "seven", -8.0]