-
-
Notifications
You must be signed in to change notification settings - Fork 334
refactor v3 data types #2874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor v3 data types #2874
Changes from 11 commits
f5e3f78
b4e71e2
3c50f54
d74e7a4
5000dcb
9cd5c51
042fac1
556e390
b588f70
4ed41c6
1b2c773
24930b3
703e0e1
3c232a4
b7fe986
d9b44b4
bf24d69
c1a8566
2868994
9ab0b1e
e9f5e26
6df84a9
e14279d
381a264
6a7857b
e8fd72c
b22f324
b7a231e
7dfcd0f
706e6b6
8fbf673
e9aff64
44e78f5
60cac04
120df57
0d9922b
2075952
44369d6
4f3381f
c8d7680
2a7b5a8
e855e54
a2da99a
5ea3fa4
cbb159d
c506d09
bb11867
7a619e0
ea2d0bf
042c9e5
def5eb2
1b7273b
60b2e9d
83f508c
4ceb6ed
5b9cff0
65f0453
cb0a7d4
40f0063
9989c64
a276c84
6285739
e9241b9
2bffe1a
aa32271
617d3f0
2b5fd8f
1831f20
a427a16
41d7e58
c08ffd9
778d740
269215e
8af0ce4
df60d05
7f54bbf
be83f03
3979746
a210f9f
8fbd29a
afc9872
e1bf901
45f0c88
890077e
a3f05f0
4788f05
d3f9204
fdf17e3
4afa42a
4990803
1458aad
9673997
aa11df4
f706b46
52518c2
4ab1c58
e4c89f3
e386c2b
703192c
0fab5e5
2f945bf
63a6af4
56e7c84
eee0d7b
1dc8e72
13ca230
2a42205
3f775c8
5320a77
b525b8e
ec94878
3af98aa
6388203
6ef7924
1329c69
d8c3672
3f4d87a
d8a382a
9aa751b
e4a0372
8a976d6
be0d2df
8c90d2c
0fc653f
7c58f7a
3a21845
ce0afe3
e67d4dc
4e2a157
a1deda6
528a942
c9c8181
1cb7734
d80d565
7806563
39219fa
4a7a550
807c585
5150d60
9ddbe97
d6535d6
42e14ef
3991406
d7da3d9
c3c3288
d1feaee
3ef138a
1f767e4
cf55041
24b6b35
7f099a2
bf7e2c5
cbb0b0d
8f3aa68
e885869
63de7c4
b069d36
ae36dbf
a1f2c94
b2e56c8
d26b695
49f0062
70da4da
16b4ac6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -355,9 +355,10 @@ def __init__( | |
object.__setattr__(self, "index_location", index_location_parsed) | ||
|
||
# Use instance-local lru_cache to avoid memory leaks | ||
object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) | ||
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec)) | ||
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) | ||
# TODO: fix these when we don't get hashability errors for certain numpy dtypes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something that needs fixing before this PR is merged? |
||
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec)) | ||
# object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec)) | ||
# object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard)) | ||
|
||
# todo: typedict return type | ||
def __getstate__(self) -> dict[str, Any]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,19 +98,21 @@ | |
ArrayV3MetadataDict, | ||
T_ArrayMetadata, | ||
) | ||
from zarr.core.metadata.dtype import DTypeWrapper | ||
from zarr.core.metadata.v2 import ( | ||
_default_compressor, | ||
_default_filters, | ||
parse_compressor, | ||
parse_filters, | ||
) | ||
from zarr.core.metadata.v3 import DataType, parse_node_type_array | ||
from zarr.core.metadata.v3 import parse_node_type_array | ||
from zarr.core.sync import sync | ||
from zarr.errors import MetadataValidationError | ||
from zarr.registry import ( | ||
_parse_array_array_codec, | ||
_parse_array_bytes_codec, | ||
_parse_bytes_bytes_codec, | ||
get_data_type_from_numpy, | ||
get_pipeline_class, | ||
) | ||
from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path | ||
|
@@ -578,7 +580,7 @@ | |
""" | ||
store_path = await make_store_path(store) | ||
|
||
dtype_parsed = parse_dtype(dtype, zarr_format) | ||
dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format) | ||
shape = parse_shapelike(shape) | ||
|
||
if chunks is not None and chunk_shape is not None: | ||
|
@@ -677,7 +679,7 @@ | |
""" | ||
|
||
shape = parse_shapelike(shape) | ||
codecs = list(codecs) if codecs is not None else _get_default_codecs(np.dtype(dtype)) | ||
codecs = list(codecs) if codecs is not None else _get_default_codecs(dtype) | ||
chunk_key_encoding_parsed: ChunkKeyEncodingLike | ||
if chunk_key_encoding is None: | ||
chunk_key_encoding_parsed = {"name": "default", "separator": "/"} | ||
|
@@ -691,13 +693,23 @@ | |
category=UserWarning, | ||
stacklevel=2, | ||
) | ||
|
||
# resolve the numpy dtype into zarr v3 datatype | ||
zarr_data_type = get_data_type_from_numpy(dtype) | ||
|
||
if fill_value is None: | ||
# v3 spec will not allow a null fill value | ||
fill_value_parsed = zarr_data_type.default_value | ||
else: | ||
fill_value_parsed = fill_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should (or are) we be casting this scalar somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, we can use |
||
|
||
chunk_grid_parsed = RegularChunkGrid(chunk_shape=chunk_shape) | ||
return ArrayV3Metadata( | ||
shape=shape, | ||
data_type=dtype, | ||
data_type=zarr_data_type, | ||
chunk_grid=chunk_grid_parsed, | ||
chunk_key_encoding=chunk_key_encoding_parsed, | ||
fill_value=fill_value, | ||
fill_value=fill_value_parsed, | ||
codecs=codecs, | ||
dimension_names=tuple(dimension_names) if dimension_names else None, | ||
attributes=attributes or {}, | ||
|
@@ -1682,7 +1694,7 @@ | |
def _info( | ||
self, count_chunks_initialized: int | None = None, count_bytes_stored: int | None = None | ||
) -> Any: | ||
_data_type: np.dtype[Any] | DataType | ||
_data_type: np.dtype[Any] | DTypeWrapper | ||
if isinstance(self.metadata, ArrayV2Metadata): | ||
_data_type = self.metadata.dtype | ||
else: | ||
|
@@ -4203,17 +4215,11 @@ | |
""" | ||
Get the default ArrayArrayCodecs, ArrayBytesCodec, and BytesBytesCodec for a given dtype. | ||
""" | ||
dtype = DataType.from_numpy(np_dtype) | ||
if dtype == DataType.string: | ||
dtype_key = "string" | ||
elif dtype == DataType.bytes: | ||
dtype_key = "bytes" | ||
else: | ||
dtype_key = "numeric" | ||
dtype = get_data_type_from_numpy(np_dtype) | ||
|
||
default_filters = zarr_config.get("array.v3_default_filters").get(dtype_key) | ||
default_serializer = zarr_config.get("array.v3_default_serializer").get(dtype_key) | ||
default_compressors = zarr_config.get("array.v3_default_compressors").get(dtype_key) | ||
default_filters = zarr_config.get("array.v3_default_filters").get(dtype.kind) | ||
default_serializer = zarr_config.get("array.v3_default_serializer").get(dtype.kind) | ||
default_compressors = zarr_config.get("array.v3_default_compressors").get(dtype.kind) | ||
|
||
filters = tuple(_parse_array_array_codec(codec_dict) for codec_dict in default_filters) | ||
serializer = _parse_array_bytes_codec(default_serializer) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from zarr.core.dtype.core import ZarrDType | ||
|
||
__all__ = ["ZarrDType"] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
""" | ||
# Overview | ||
|
||
This module provides a proof-of-concept standalone interface for managing dtypes in the zarr-python codebase. | ||
|
||
The `ZarrDType` class introduced in this module effectively acts as a replacement for `np.dtype` throughout the | ||
zarr-python codebase. It attempts to encapsulate all relevant runtime information necessary for working with | ||
dtypes in the context of the Zarr V3 specification (e.g. is this a core dtype or not, how many bytes and what | ||
endianness is the dtype etc). By providing this abstraction, the module aims to: | ||
|
||
- Simplify dtype management within zarr-python | ||
- Support runtime flexibility and custom extensions | ||
- Remove unnecessary dependencies on the numpy API | ||
|
||
## Extensibility | ||
|
||
The module attempts to support user-driven extensions, allowing developers to introduce custom dtypes | ||
without requiring immediate changes to zarr-python. Extensions can leverage the current entrypoint mechanism, | ||
enabling integration of experimental features. Over time, widely adopted extensions may be formalized through | ||
inclusion in zarr-python or standardized via a Zarr Enhancement Proposal (ZEP), but this is not essential. | ||
|
||
## Examples | ||
|
||
### Core `dtype` Registration | ||
|
||
The following example demonstrates how to register a built-in `dtype` in the core codebase: | ||
|
||
```python | ||
from zarr.core.dtype import ZarrDType | ||
from zarr.registry import register_v3dtype | ||
|
||
class Float16(ZarrDType): | ||
zarr_spec_format = "3" | ||
experimental = False | ||
endianness = "little" | ||
byte_count = 2 | ||
to_numpy = np.dtype('float16') | ||
|
||
register_v3dtype(Float16) | ||
``` | ||
|
||
### Entrypoint Extension | ||
|
||
The following example demonstrates how users can register a new `bfloat16` dtype for Zarr. | ||
This approach adheres to the existing Zarr entrypoint pattern as much as possible, ensuring | ||
consistency with other extensions. The code below would typically be part of a Python package | ||
that specifies the entrypoints for the extension: | ||
|
||
```python | ||
import ml_dtypes | ||
from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype | ||
|
||
class Bfloat16(ZarrDType): | ||
zarr_spec_format = "3" | ||
experimental = True | ||
endianness = "little" | ||
byte_count = 2 | ||
to_numpy = np.dtype('bfloat16') # Enabled by importing ml_dtypes | ||
configuration_v3 = { | ||
"version": "example_value", | ||
"author": "example_value", | ||
"ml_dtypes_version": "example_value" | ||
} | ||
``` | ||
|
||
### dtype lookup | ||
|
||
The following examples demonstrate how to perform a lookup for the relevant ZarrDType, given | ||
a string that matches the dtype Zarr specification ID, or a numpy dtype object: | ||
|
||
``` | ||
from zarr.registry import get_v3dtype_class, get_v3dtype_class_from_numpy | ||
|
||
get_v3dtype_class('complex64') # returns little-endian Complex64 ZarrDType | ||
get_v3dtype_class('not_registered_dtype') # ValueError | ||
|
||
get_v3dtype_class_from_numpy('>i2') # returns big-endian Int16 ZarrDType | ||
get_v3dtype_class_from_numpy(np.dtype('float32')) # returns little-endian Float32 ZarrDType | ||
get_v3dtype_class_from_numpy('i10') # ValueError | ||
``` | ||
|
||
### String dtypes | ||
|
||
The following indicates one possibility for supporting variable-length strings. It is via the | ||
entrypoint mechanism as in a previous example. The Apache Arrow specification does not currently | ||
include a dtype for fixed-length strings (only for fixed-length bytes) and so I am using string | ||
here to implicitly refer to a variable-length string data (there may be some subtleties with codecs | ||
that means this needs to be refined further): | ||
|
||
```python | ||
import numpy as np | ||
from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype | ||
|
||
try: | ||
to_numpy = np.dtypes.StringDType() | ||
except AttributeError: | ||
to_numpy = np.dtypes.ObjectDType() | ||
|
||
class String(ZarrDType): | ||
zarr_spec_format = "3" | ||
experimental = True | ||
endianness = 'little' | ||
byte_count = None # None is defined to mean variable | ||
to_numpy = to_numpy | ||
``` | ||
|
||
### int4 dtype | ||
|
||
There is currently considerable interest in the AI community in 'quantising' models - storing | ||
models at reduced precision, while minimising loss of information content. There are a number | ||
of sub-byte dtypes that the community are using e.g. int4. Unfortunately numpy does not | ||
currently have support for handling such sub-byte dtypes in an easy way. However, they can | ||
still be held in a numpy array and then passed (in a zero-copy way) to something like pytorch | ||
which can handle appropriately: | ||
|
||
```python | ||
import numpy as np | ||
from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype | ||
|
||
class Int4(ZarrDType): | ||
zarr_spec_format = "3" | ||
experimental = True | ||
endianness = 'little' | ||
byte_count = 1 # this is ugly, but I could change this from byte_count to bit_count if there was consensus | ||
to_numpy = np.dtype('B') # could also be np.dtype('V1'), but this would prevent bit-twiddling | ||
configuration_v3 = { | ||
"version": "example_value", | ||
"author": "example_value", | ||
} | ||
``` | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Literal | ||
|
||
import numpy as np | ||
|
||
|
||
class FrozenClassVariables(type): | ||
def __setattr__(cls, attr: str, value: object) -> None: | ||
if hasattr(cls, attr): | ||
raise ValueError(f"Attribute {attr} on ZarrDType class can not be changed once set.") | ||
else: | ||
raise AttributeError(f"'{cls}' object has no attribute '{attr}'") | ||
|
||
|
||
class ZarrDType(metaclass=FrozenClassVariables): | ||
zarr_spec_format: Literal["2", "3"] # the version of the zarr spec used | ||
experimental: bool # is this in the core spec or not | ||
endianness: Literal[ | ||
"big", "little", None | ||
] # None indicates not defined i.e. single byte or byte strings | ||
byte_count: int | None # None indicates variable count | ||
to_numpy: np.dtype[Any] # may involve installing a a numpy extension e.g. ml_dtypes; | ||
|
||
configuration_v3: dict | None # TODO: understand better how this is recommended by the spec | ||
|
||
_zarr_spec_identifier: str # implementation detail used to map to core spec | ||
|
||
def __init_subclass__( # enforces all required fields are set and basic sanity checks | ||
cls, | ||
**kwargs, | ||
) -> None: | ||
required_attrs = [ | ||
"zarr_spec_format", | ||
"experimental", | ||
"endianness", | ||
"byte_count", | ||
"to_numpy", | ||
] | ||
for attr in required_attrs: | ||
if not hasattr(cls, attr): | ||
raise ValueError(f"{attr} is a required attribute for a Zarr dtype.") | ||
|
||
if not hasattr(cls, "configuration_v3"): | ||
cls.configuration_v3 = None | ||
|
||
cls._zarr_spec_identifier = ( | ||
"big_" + cls.__qualname__.lower() | ||
if cls.endianness == "big" | ||
else cls.__qualname__.lower() | ||
) # how this dtype is identified in core spec; convention is prefix with big_ for big-endian | ||
|
||
cls._validate() # sanity check on basic requirements | ||
|
||
super().__init_subclass__(**kwargs) | ||
|
||
# TODO: add further checks | ||
@classmethod | ||
def _validate(cls): | ||
if cls.byte_count is not None and cls.byte_count <= 0: | ||
raise ValueError("byte_count must be a positive integer.") | ||
|
||
if cls.byte_count == 1 and cls.endianness is not None: | ||
raise ValueError("Endianness must be None for single-byte types.") | ||
Uh oh!
There was an error while loading. Please reload this page.