Skip to content

Commit 0ba6961

Browse files
authored
Add support for flat uint8 arrow arrays for multi channel images (#8908)
2 parents 22d6265 + 5066917 commit 0ba6961

File tree

3 files changed

+162
-13
lines changed

3 files changed

+162
-13
lines changed

.ci/requirements-mypy.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ IceSpringPySideStubs-PySide6
44
ipython
55
numpy
66
packaging
7+
pyarrow-stubs
78
pytest
89
sphinx
910
types-atheris

Tests/test_pyarrow.py

Lines changed: 147 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any # undone
3+
from typing import Any, NamedTuple
44

55
import pytest
66

@@ -10,30 +10,73 @@
1010
assert_deep_equal,
1111
assert_image_equal,
1212
hopper,
13+
is_big_endian,
1314
)
1415

15-
pyarrow = pytest.importorskip("pyarrow", reason="PyArrow not installed")
16+
TYPE_CHECKING = False
17+
if TYPE_CHECKING:
18+
import pyarrow
19+
else:
20+
pyarrow = pytest.importorskip("pyarrow", reason="PyArrow not installed")
1621

1722
TEST_IMAGE_SIZE = (10, 10)
1823

1924

2025
def _test_img_equals_pyarray(
21-
img: Image.Image, arr: Any, mask: list[int] | None
26+
img: Image.Image, arr: Any, mask: list[int] | None, elts_per_pixel: int = 1
2227
) -> None:
23-
assert img.height * img.width == len(arr)
28+
assert img.height * img.width * elts_per_pixel == len(arr)
2429
px = img.load()
2530
assert px is not None
31+
if elts_per_pixel > 1 and mask is None:
32+
# have to do element-wise comparison when we're comparing
33+
# flattened r,g,b,a to a pixel.
34+
mask = list(range(elts_per_pixel))
2635
for x in range(0, img.size[0], int(img.size[0] / 10)):
2736
for y in range(0, img.size[1], int(img.size[1] / 10)):
2837
if mask:
38+
pixel = px[x, y]
39+
assert isinstance(pixel, tuple)
2940
for ix, elt in enumerate(mask):
30-
pixel = px[x, y]
31-
assert isinstance(pixel, tuple)
32-
assert pixel[ix] == arr[y * img.width + x].as_py()[elt]
41+
if elts_per_pixel == 1:
42+
assert pixel[ix] == arr[y * img.width + x].as_py()[elt]
43+
else:
44+
assert (
45+
pixel[ix]
46+
== arr[(y * img.width + x) * elts_per_pixel + elt].as_py()
47+
)
3348
else:
3449
assert_deep_equal(px[x, y], arr[y * img.width + x].as_py())
3550

3651

52+
def _test_img_equals_int32_pyarray(
53+
img: Image.Image, arr: Any, mask: list[int] | None, elts_per_pixel: int = 1
54+
) -> None:
55+
assert img.height * img.width * elts_per_pixel == len(arr)
56+
px = img.load()
57+
assert px is not None
58+
if mask is None:
59+
# have to do element-wise comparison when we're comparing
60+
# flattened rgba in an uint32 to a pixel.
61+
mask = list(range(elts_per_pixel))
62+
for x in range(0, img.size[0], int(img.size[0] / 10)):
63+
for y in range(0, img.size[1], int(img.size[1] / 10)):
64+
pixel = px[x, y]
65+
assert isinstance(pixel, tuple)
66+
arr_pixel_int = arr[y * img.width + x].as_py()
67+
arr_pixel_tuple = (
68+
arr_pixel_int % 256,
69+
(arr_pixel_int // 256) % 256,
70+
(arr_pixel_int // 256**2) % 256,
71+
(arr_pixel_int // 256**3),
72+
)
73+
if is_big_endian():
74+
arr_pixel_tuple = arr_pixel_tuple[::-1]
75+
76+
for ix, elt in enumerate(mask):
77+
assert pixel[ix] == arr_pixel_tuple[elt]
78+
79+
3780
# really hard to get a non-nullable list type
3881
fl_uint8_4_type = pyarrow.field(
3982
"_", pyarrow.list_(pyarrow.field("_", pyarrow.uint8()).with_nullable(False), 4)
@@ -55,14 +98,14 @@ def _test_img_equals_pyarray(
5598
("HSV", fl_uint8_4_type, [0, 1, 2]),
5699
),
57100
)
58-
def test_to_array(mode: str, dtype: Any, mask: list[int] | None) -> None:
101+
def test_to_array(mode: str, dtype: pyarrow.DataType, mask: list[int] | None) -> None:
59102
img = hopper(mode)
60103

61104
# Resize to non-square
62105
img = img.crop((3, 0, 124, 127))
63106
assert img.size == (121, 127)
64107

65-
arr = pyarrow.array(img)
108+
arr = pyarrow.array(img) # type: ignore[call-overload]
66109
_test_img_equals_pyarray(img, arr, mask)
67110
assert arr.type == dtype
68111

@@ -79,8 +122,8 @@ def test_lifetime() -> None:
79122

80123
img = hopper("L")
81124

82-
arr_1 = pyarrow.array(img)
83-
arr_2 = pyarrow.array(img)
125+
arr_1 = pyarrow.array(img) # type: ignore[call-overload]
126+
arr_2 = pyarrow.array(img) # type: ignore[call-overload]
84127

85128
del img
86129

@@ -97,8 +140,8 @@ def test_lifetime2() -> None:
97140

98141
img = hopper("L")
99142

100-
arr_1 = pyarrow.array(img)
101-
arr_2 = pyarrow.array(img)
143+
arr_1 = pyarrow.array(img) # type: ignore[call-overload]
144+
arr_2 = pyarrow.array(img) # type: ignore[call-overload]
102145

103146
assert arr_1.sum().as_py() > 0
104147
del arr_1
@@ -110,3 +153,94 @@ def test_lifetime2() -> None:
110153
px = img2.load()
111154
assert px # make mypy happy
112155
assert isinstance(px[0, 0], int)
156+
157+
158+
class DataShape(NamedTuple):
159+
dtype: pyarrow.DataType
160+
# Strictly speaking, elt should be a pixel or pixel component, so
161+
# list[uint8][4], float, int, uint32, uint8, etc. But more
162+
# correctly, it should be exactly the dtype from the line above.
163+
elt: Any
164+
elts_per_pixel: int
165+
166+
167+
UINT_ARR = DataShape(
168+
dtype=fl_uint8_4_type,
169+
elt=[1, 2, 3, 4], # array of 4 uint8 per pixel
170+
elts_per_pixel=1, # only one array per pixel
171+
)
172+
173+
UINT = DataShape(
174+
dtype=pyarrow.uint8(),
175+
elt=3, # one uint8,
176+
elts_per_pixel=4, # but repeated 4x per pixel
177+
)
178+
179+
UINT32 = DataShape(
180+
dtype=pyarrow.uint32(),
181+
elt=0xABCDEF45, # one packed int, doesn't fit in a int32 > 0x80000000
182+
elts_per_pixel=1, # one per pixel
183+
)
184+
185+
INT32 = DataShape(
186+
dtype=pyarrow.uint32(),
187+
elt=0x12CDEF45, # one packed int
188+
elts_per_pixel=1, # one per pixel
189+
)
190+
191+
192+
@pytest.mark.parametrize(
193+
"mode, data_tp, mask",
194+
(
195+
("L", DataShape(pyarrow.uint8(), 3, 1), None),
196+
("I", DataShape(pyarrow.int32(), 1 << 24, 1), None),
197+
("F", DataShape(pyarrow.float32(), 3.14159, 1), None),
198+
("LA", UINT_ARR, [0, 3]),
199+
("LA", UINT, [0, 3]),
200+
("RGB", UINT_ARR, [0, 1, 2]),
201+
("RGBA", UINT_ARR, None),
202+
("CMYK", UINT_ARR, None),
203+
("YCbCr", UINT_ARR, [0, 1, 2]),
204+
("HSV", UINT_ARR, [0, 1, 2]),
205+
("RGB", UINT, [0, 1, 2]),
206+
("RGBA", UINT, None),
207+
("CMYK", UINT, None),
208+
("YCbCr", UINT, [0, 1, 2]),
209+
("HSV", UINT, [0, 1, 2]),
210+
),
211+
)
212+
def test_fromarray(mode: str, data_tp: DataShape, mask: list[int] | None) -> None:
213+
(dtype, elt, elts_per_pixel) = data_tp
214+
215+
ct_pixels = TEST_IMAGE_SIZE[0] * TEST_IMAGE_SIZE[1]
216+
arr = pyarrow.array([elt] * (ct_pixels * elts_per_pixel), type=dtype)
217+
img = Image.fromarrow(arr, mode, TEST_IMAGE_SIZE)
218+
219+
_test_img_equals_pyarray(img, arr, mask, elts_per_pixel)
220+
221+
222+
@pytest.mark.parametrize(
223+
"mode, data_tp, mask",
224+
(
225+
("LA", UINT32, [0, 3]),
226+
("RGB", UINT32, [0, 1, 2]),
227+
("RGBA", UINT32, None),
228+
("CMYK", UINT32, None),
229+
("YCbCr", UINT32, [0, 1, 2]),
230+
("HSV", UINT32, [0, 1, 2]),
231+
("LA", INT32, [0, 3]),
232+
("RGB", INT32, [0, 1, 2]),
233+
("RGBA", INT32, None),
234+
("CMYK", INT32, None),
235+
("YCbCr", INT32, [0, 1, 2]),
236+
("HSV", INT32, [0, 1, 2]),
237+
),
238+
)
239+
def test_from_int32array(mode: str, data_tp: DataShape, mask: list[int] | None) -> None:
240+
(dtype, elt, elts_per_pixel) = data_tp
241+
242+
ct_pixels = TEST_IMAGE_SIZE[0] * TEST_IMAGE_SIZE[1]
243+
arr = pyarrow.array([elt] * (ct_pixels * elts_per_pixel), type=dtype)
244+
img = Image.fromarrow(arr, mode, TEST_IMAGE_SIZE)
245+
246+
_test_img_equals_int32_pyarray(img, arr, mask, elts_per_pixel)

src/libImaging/Storage.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,8 @@ ImagingNewArrow(
723723
int64_t pixels = (int64_t)xsize * (int64_t)ysize;
724724

725725
// fmt:off // don't reformat this
726+
// stored as a single array, one element per pixel, either single band
727+
// or multiband, where each pixel is an I32.
726728
if (((strcmp(schema->format, "I") == 0 // int32
727729
&& im->pixelsize == 4 // 4xchar* storage
728730
&& im->bands >= 2) // INT32 into any INT32 Storage mode
@@ -735,6 +737,7 @@ ImagingNewArrow(
735737
return im;
736738
}
737739
}
740+
// Stored as [[r,g,b,a],...]
738741
if (strcmp(schema->format, "+w:4") == 0 // 4 up array
739742
&& im->pixelsize == 4 // storage as 32 bpc
740743
&& schema->n_children > 0 // make sure schema is well formed.
@@ -750,6 +753,17 @@ ImagingNewArrow(
750753
return im;
751754
}
752755
}
756+
// Stored as [r,g,b,a,r,g,b,a,...]
757+
if (strcmp(schema->format, "C") == 0 // uint8
758+
&& im->pixelsize == 4 // storage as 32 bpc
759+
&& schema->n_children == 0 // make sure schema is well formed.
760+
&& strcmp(im->arrow_band_format, "C") == 0 // expected format
761+
&& 4 * pixels == external_array->length) { // expected length
762+
// single flat array, interleaved storage.
763+
if (ImagingBorrowArrow(im, external_array, 1, array_capsule)) {
764+
return im;
765+
}
766+
}
753767
// fmt: on
754768
ImagingDelete(im);
755769
return NULL;

0 commit comments

Comments
 (0)