Skip to content

Commit 936539c

Browse files
authored
[data] Move core arrow util dependencies to arrow_utils.py (#51306)
- Moves the `parse_version` logic inside of the `get_pyarrow_version` function to avoid redundancy. - Moves the two depended-on utils to `_private/arrow_utils.py` to reduce dependencies from data onto the giant `_private/utils.py` file. Once "Ray Storage" is fully deprecated, we can move these utils into data fully. --------- Signed-off-by: Edward Oakes <[email protected]>
1 parent 81755fb commit 936539c

File tree

21 files changed

+163
-176
lines changed

21 files changed

+163
-176
lines changed

python/ray/_private/arrow_utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import json
2+
from typing import Dict, Optional
3+
from urllib.parse import urlencode, unquote, urlparse, parse_qsl, urlunparse
4+
5+
from packaging.version import Version
6+
from packaging.version import parse as parse_version
7+
8+
_PYARROW_INSTALLED: Optional[bool] = None
9+
_PYARROW_VERSION: Optional[Version] = None
10+
11+
12+
def get_pyarrow_version() -> Optional[Version]:
13+
"""Get the version of the pyarrow package or None if not installed."""
14+
global _PYARROW_INSTALLED, _PYARROW_VERSION
15+
if _PYARROW_INSTALLED is False:
16+
return None
17+
18+
if _PYARROW_INSTALLED is None:
19+
try:
20+
import pyarrow
21+
22+
_PYARROW_INSTALLED = True
23+
if hasattr(pyarrow, "__version__"):
24+
_PYARROW_VERSION = parse_version(pyarrow.__version__)
25+
except ModuleNotFoundError:
26+
_PYARROW_INSTALLED = False
27+
28+
return _PYARROW_VERSION
29+
30+
31+
def _add_url_query_params(url: str, params: Dict[str, str]) -> str:
32+
"""Add params to the provided url as query parameters.
33+
34+
If url already contains query parameters, they will be merged with params, with the
35+
existing query parameters overriding any in params with the same parameter name.
36+
37+
Args:
38+
url: The URL to add query parameters to.
39+
params: The query parameters to add.
40+
41+
Returns:
42+
URL with params added as query parameters.
43+
"""
44+
# Unquote URL first so we don't lose existing args.
45+
url = unquote(url)
46+
# Parse URL.
47+
parsed_url = urlparse(url)
48+
# Merge URL query string arguments dict with new params.
49+
base_params = params
50+
params = dict(parse_qsl(parsed_url.query))
51+
base_params.update(params)
52+
# bool and dict values should be converted to json-friendly values.
53+
base_params.update(
54+
{
55+
k: json.dumps(v)
56+
for k, v in base_params.items()
57+
if isinstance(v, (bool, dict))
58+
}
59+
)
60+
61+
# Convert URL arguments to proper query string.
62+
encoded_params = urlencode(base_params, doseq=True)
63+
# Replace query string in parsed URL with updated query string.
64+
parsed_url = parsed_url._replace(query=encoded_params)
65+
# Convert back to URL.
66+
return urlunparse(parsed_url)
67+
68+
69+
def add_creatable_buckets_param_if_s3_uri(uri: str) -> str:
70+
"""If the provided URI is an S3 URL, add allow_bucket_creation=true as a query
71+
parameter. For pyarrow >= 9.0.0, this is required in order to allow
72+
``S3FileSystem.create_dir()`` to create S3 buckets.
73+
74+
If the provided URI is not an S3 URL or if pyarrow < 9.0.0 is installed, we return
75+
the URI unchanged.
76+
77+
Args:
78+
uri: The URI that we'll add the query parameter to, if it's an S3 URL.
79+
80+
Returns:
81+
A URI with the added allow_bucket_creation=true query parameter, if the provided
82+
URI is an S3 URL; uri will be returned unchanged otherwise.
83+
"""
84+
85+
pyarrow_version = get_pyarrow_version()
86+
if pyarrow_version is not None and pyarrow_version < parse_version("9.0.0"):
87+
# This bucket creation query parameter is not required for pyarrow < 9.0.0.
88+
return uri
89+
parsed_uri = urlparse(uri)
90+
if parsed_uri.scheme == "s3":
91+
uri = _add_url_query_params(uri, {"allow_bucket_creation": True})
92+
return uri

python/ray/_private/storage.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING, List, Optional
66

7-
from ray._private.client_mode_hook import client_mode_hook
8-
from ray._private.utils import _add_creatable_buckets_param_if_s3_uri, load_class
97
from ray._private.auto_init_hook import wrap_auto_init
8+
from ray._private.client_mode_hook import client_mode_hook
9+
from ray._private.arrow_utils import add_creatable_buckets_param_if_s3_uri
10+
from ray._private.utils import load_class
1011

1112
if TYPE_CHECKING:
1213
import pyarrow.fs
@@ -452,7 +453,7 @@ def _init_filesystem(create_valid_file: bool = False, check_valid_file: bool = T
452453
else:
453454
# Arrow's S3FileSystem doesn't allow creating buckets by default, so we add a
454455
# query arg enabling bucket creation if an S3 URI is provided.
455-
_storage_uri = _add_creatable_buckets_param_if_s3_uri(_storage_uri)
456+
_storage_uri = add_creatable_buckets_param_if_s3_uri(_storage_uri)
456457
_filesystem, _storage_prefix = pyarrow.fs.FileSystem.from_uri(_storage_uri)
457458

458459
if os.name == "nt":

python/ray/_private/utils.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import tempfile
2121
import threading
2222
import time
23-
from urllib.parse import urlencode, unquote, urlparse, parse_qsl, urlunparse
2423
import warnings
2524
from inspect import signature
2625
from pathlib import Path
@@ -71,7 +70,6 @@
7170
win32_AssignProcessToJobObject = None
7271

7372
ENV_DISABLE_DOCKER_CPU_WARNING = "RAY_DISABLE_DOCKER_CPU_WARNING" in os.environ
74-
_PYARROW_VERSION = None
7573

7674
# This global variable is used for testing only
7775
_CALLED_FREQ = defaultdict(lambda: 0)
@@ -1735,90 +1733,6 @@ def get_entrypoint_name():
17351733
return "unknown"
17361734

17371735

1738-
def _add_url_query_params(url: str, params: Dict[str, str]) -> str:
1739-
"""Add params to the provided url as query parameters.
1740-
1741-
If url already contains query parameters, they will be merged with params, with the
1742-
existing query parameters overriding any in params with the same parameter name.
1743-
1744-
Args:
1745-
url: The URL to add query parameters to.
1746-
params: The query parameters to add.
1747-
1748-
Returns:
1749-
URL with params added as query parameters.
1750-
"""
1751-
# Unquote URL first so we don't lose existing args.
1752-
url = unquote(url)
1753-
# Parse URL.
1754-
parsed_url = urlparse(url)
1755-
# Merge URL query string arguments dict with new params.
1756-
base_params = params
1757-
params = dict(parse_qsl(parsed_url.query))
1758-
base_params.update(params)
1759-
# bool and dict values should be converted to json-friendly values.
1760-
base_params.update(
1761-
{
1762-
k: json.dumps(v)
1763-
for k, v in base_params.items()
1764-
if isinstance(v, (bool, dict))
1765-
}
1766-
)
1767-
1768-
# Convert URL arguments to proper query string.
1769-
encoded_params = urlencode(base_params, doseq=True)
1770-
# Replace query string in parsed URL with updated query string.
1771-
parsed_url = parsed_url._replace(query=encoded_params)
1772-
# Convert back to URL.
1773-
return urlunparse(parsed_url)
1774-
1775-
1776-
def _add_creatable_buckets_param_if_s3_uri(uri: str) -> str:
1777-
"""If the provided URI is an S3 URL, add allow_bucket_creation=true as a query
1778-
parameter. For pyarrow >= 9.0.0, this is required in order to allow
1779-
``S3FileSystem.create_dir()`` to create S3 buckets.
1780-
1781-
If the provided URI is not an S3 URL or if pyarrow < 9.0.0 is installed, we return
1782-
the URI unchanged.
1783-
1784-
Args:
1785-
uri: The URI that we'll add the query parameter to, if it's an S3 URL.
1786-
1787-
Returns:
1788-
A URI with the added allow_bucket_creation=true query parameter, if the provided
1789-
URI is an S3 URL; uri will be returned unchanged otherwise.
1790-
"""
1791-
from packaging.version import parse as parse_version
1792-
1793-
pyarrow_version = _get_pyarrow_version()
1794-
if pyarrow_version is not None:
1795-
pyarrow_version = parse_version(pyarrow_version)
1796-
if pyarrow_version is not None and pyarrow_version < parse_version("9.0.0"):
1797-
# This bucket creation query parameter is not required for pyarrow < 9.0.0.
1798-
return uri
1799-
parsed_uri = urlparse(uri)
1800-
if parsed_uri.scheme == "s3":
1801-
uri = _add_url_query_params(uri, {"allow_bucket_creation": True})
1802-
return uri
1803-
1804-
1805-
def _get_pyarrow_version() -> Optional[str]:
1806-
"""Get the version of the installed pyarrow package, returned as a tuple of ints.
1807-
Returns None if the package is not found.
1808-
"""
1809-
global _PYARROW_VERSION
1810-
if _PYARROW_VERSION is None:
1811-
try:
1812-
import pyarrow
1813-
except ModuleNotFoundError:
1814-
# pyarrow not installed, short-circuit.
1815-
pass
1816-
else:
1817-
if hasattr(pyarrow, "__version__"):
1818-
_PYARROW_VERSION = pyarrow.__version__
1819-
return _PYARROW_VERSION
1820-
1821-
18221736
class DeferSigint(contextlib.AbstractContextManager):
18231737
"""Context manager that defers SIGINT signals until the context is left."""
18241738

python/ray/air/tests/test_arrow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from packaging.version import parse as parse_version
88

9-
from ray._private.utils import _get_pyarrow_version
9+
from ray._private.arrow_utils import get_pyarrow_version
1010
from ray.air.util.tensor_extensions.arrow import (
1111
ArrowConversionError,
1212
_convert_to_pyarrow_native_array,
@@ -51,9 +51,9 @@ def test_arrow_native_list_conversion(input, disable_fallback_to_object_extensio
5151
upon serialization into Arrow format (and are NOT converted to numpy
5252
tensor using extension)"""
5353

54-
if isinstance(input[0], pa.Scalar) and parse_version(
55-
_get_pyarrow_version()
56-
) <= parse_version("13.0.0"):
54+
if isinstance(input[0], pa.Scalar) and get_pyarrow_version() <= parse_version(
55+
"13.0.0"
56+
):
5757
pytest.skip(
5858
"Pyarrow < 13.0 not able to properly infer native types from its own Scalars"
5959
)

python/ray/air/tests/test_tensor_extension.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from packaging.version import parse as parse_version
88

9-
from ray._private.utils import _get_pyarrow_version
9+
from ray._private.arrow_utils import get_pyarrow_version
1010
from ray.air.util.tensor_extensions.arrow import (
1111
ArrowConversionError,
1212
ArrowTensorArray,
@@ -517,7 +517,7 @@ def test_arrow_tensor_array_getitem(chunked, restore_data_context, tensor_format
517517
if chunked:
518518
t_arr = pa.chunked_array(t_arr)
519519

520-
pyarrow_version = parse_version(_get_pyarrow_version())
520+
pyarrow_version = get_pyarrow_version()
521521
if (
522522
chunked
523523
and pyarrow_version >= parse_version("8.0.0")
@@ -589,7 +589,7 @@ def test_arrow_variable_shaped_tensor_array_getitem(
589589
if chunked:
590590
t_arr = pa.chunked_array(t_arr)
591591

592-
pyarrow_version = parse_version(_get_pyarrow_version())
592+
pyarrow_version = get_pyarrow_version()
593593
if (
594594
chunked
595595
and pyarrow_version >= parse_version("8.0.0")

python/ray/air/util/object_extensions/arrow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77

88
import ray.air.util.object_extensions.pandas
99
from ray._private.serialization import pickle_dumps
10-
from ray._private.utils import _get_pyarrow_version
10+
from ray._private.arrow_utils import get_pyarrow_version
1111
from ray.util.annotations import PublicAPI
1212

1313
MIN_PYARROW_VERSION_SCALAR_SUBCLASS = parse_version("9.0.0")
1414

15-
_VER = _get_pyarrow_version()
16-
PYARROW_VERSION = None if _VER is None else parse_version(_VER)
15+
PYARROW_VERSION = get_pyarrow_version()
1716

1817

1918
def _object_extension_type_allowed() -> bool:

python/ray/air/util/tensor_extensions/arrow.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pyarrow as pa
1212
from packaging.version import parse as parse_version
1313

14-
from ray._private.utils import _get_pyarrow_version
14+
from ray._private.arrow_utils import get_pyarrow_version
1515
from ray.air.util.tensor_extensions.utils import (
1616
_is_ndarray_variable_shaped_tensor,
1717
create_ragged_ndarray,
@@ -26,11 +26,7 @@
2626
from ray.util import log_once
2727
from ray.util.annotations import DeveloperAPI, PublicAPI
2828

29-
30-
PYARROW_VERSION = _get_pyarrow_version()
31-
if PYARROW_VERSION is not None:
32-
PYARROW_VERSION = parse_version(PYARROW_VERSION)
33-
29+
PYARROW_VERSION = get_pyarrow_version()
3430
# Minimum version of Arrow that supports ExtensionScalars.
3531
# TODO(Clark): Remove conditional definition once we only support Arrow 8.0.0+.
3632
MIN_PYARROW_VERSION_SCALAR = parse_version("8.0.0")

python/ray/data/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas # noqa
44
from packaging.version import parse as parse_version
55

6-
from ray._private.utils import _get_pyarrow_version
6+
from ray._private.arrow_utils import get_pyarrow_version
77
from ray.data._internal.compute import ActorPoolStrategy
88
from ray.data._internal.datasource.tfrecords_datasource import TFXReadOptions
99
from ray.data._internal.execution.interfaces import (
@@ -80,8 +80,8 @@
8080
# disabled it's deserialization by default. To ensure that users can load data
8181
# written with earlier version of Ray Data, we enable auto-loading of serialized
8282
# tensor extensions.
83-
pyarrow_version = _get_pyarrow_version()
84-
if not isinstance(pyarrow_version, str):
83+
pyarrow_version = get_pyarrow_version()
84+
if pyarrow_version is None:
8585
# PyArrow is mocked in documentation builds. In this case, we don't need to do
8686
# anything.
8787
pass
@@ -93,7 +93,7 @@
9393
)
9494

9595
if (
96-
parse_version(pyarrow_version) >= parse_version("14.0.1")
96+
pyarrow_version >= parse_version("14.0.1")
9797
and RAY_DATA_AUTOLOAD_PYEXTENSIONTYPE
9898
):
9999
pa.PyExtensionType.set_auto_load(True)

python/ray/data/_internal/arrow_block.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import numpy as np
1717

18-
from ray._private.utils import _get_pyarrow_version
18+
from ray._private.arrow_utils import get_pyarrow_version
1919
from ray.air.constants import TENSOR_COLUMN_NAME
2020
from ray.air.util.tensor_extensions.arrow import (
2121
convert_to_pyarrow_array,
@@ -189,9 +189,7 @@ def _build_tensor_row(
189189
element = row[col_name][0]
190190
# TODO(Clark): Reduce this to np.asarray(element) once we only support Arrow
191191
# 9.0.0+.
192-
pyarrow_version = _get_pyarrow_version()
193-
if pyarrow_version is not None:
194-
pyarrow_version = parse_version(pyarrow_version)
192+
pyarrow_version = get_pyarrow_version()
195193
if pyarrow_version is None or pyarrow_version >= parse_version("8.0.0"):
196194
assert isinstance(element, pyarrow.ExtensionScalar)
197195
if pyarrow_version is None or pyarrow_version >= parse_version("9.0.0"):

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from packaging.version import parse as parse_version
66

77
from ray._private.ray_constants import env_integer
8-
from ray._private.utils import _get_pyarrow_version
8+
from ray._private.arrow_utils import get_pyarrow_version
99
from ray.air.util.tensor_extensions.arrow import (
1010
INT32_OVERFLOW_THRESHOLD,
1111
MIN_PYARROW_VERSION_CHUNKED_ARRAY_TO_NUMPY_ZERO_COPY_ONLY,
@@ -206,7 +206,7 @@ def unify_schemas(
206206
schemas_to_unify = schemas
207207

208208
try:
209-
if parse_version(_get_pyarrow_version()) < MIN_PYARROW_VERSION_TYPE_PROMOTION:
209+
if get_pyarrow_version() < MIN_PYARROW_VERSION_TYPE_PROMOTION:
210210
return pyarrow.unify_schemas(schemas_to_unify)
211211

212212
# NOTE: By default type promotion (from "smaller" to "larger" types) is disabled,
@@ -554,7 +554,7 @@ def concat(
554554
# to vary b/w blocks
555555
#
556556
# NOTE: Type promotions aren't available in Arrow < 14.0
557-
if parse_version(_get_pyarrow_version()) < parse_version("14.0.0"):
557+
if get_pyarrow_version() < parse_version("14.0.0"):
558558
table = pyarrow.concat_tables(blocks, promote=True)
559559
else:
560560
arrow_promote_types_mode = "permissive" if promote_types else "default"

0 commit comments

Comments
 (0)