Skip to content

refactor: Generic CompliantSelector #2064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 69 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
028098e
refactor: replace `lambda df: df.columns`
dangotbanned Feb 20, 2025
c597ba7
refactor: rename `get_columns` -> `get_column_names`
dangotbanned Feb 20, 2025
d26fd81
chore(typing): add missing context for `CompliantDataFrame`
dangotbanned Feb 21, 2025
230daf9
fix(typing): temp widen `.collect` from `CompliantDataFrame`
dangotbanned Feb 21, 2025
2ab3305
feat(typing): provide `PandasLikeSeries` to `PandasLikeDataFrame`
dangotbanned Feb 21, 2025
47816b9
feat: Adds generic `CompliantSelector` & `Namespace`
dangotbanned Feb 21, 2025
c2e56d1
feat: reimplement `_pandas_like.selectors`
dangotbanned Feb 21, 2025
706732a
feat: reimplement `_arrow.selectors`
dangotbanned Feb 21, 2025
7bb8cb2
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Feb 21, 2025
8bf4a49
refactor: utilize updated `CompliantExpr`
dangotbanned Feb 21, 2025
1e2e151
fix(typing): update `_dask` collect as well
dangotbanned Feb 21, 2025
5a36c81
fix: maybe resolve `<3.11` protocol bug
dangotbanned Feb 21, 2025
5e45350
ignore coverage on init
dangotbanned Feb 21, 2025
8fe1ce4
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Feb 21, 2025
42c2f41
refactor: Also drop `kwargs` in `_to_expr`
dangotbanned Feb 21, 2025
4e89a2e
add `CompliantLazyFrame.columns`
dangotbanned Feb 21, 2025
9210a17
chore(typing): add `Compliant(Lazy|Data)Frame.schema`
dangotbanned Feb 21, 2025
5382778
feat: add some default iteration methods
dangotbanned Feb 21, 2025
91c86a0
feat: build out lazy-support
dangotbanned Feb 21, 2025
a8c9d13
feat: reimplement `_dask.selectors`
dangotbanned Feb 21, 2025
d6a2e41
refactor: rename `DataFrameT` -> `FrameT`
dangotbanned Feb 21, 2025
f32c348
feat: reimplement `_duckdb.selectors`
dangotbanned Feb 21, 2025
d785eb3
fix: guard against back-compat `duckdb`
dangotbanned Feb 21, 2025
90229c5
feat: reimplement `_spark_like.selectors`
dangotbanned Feb 21, 2025
180b9a3
chore: remove duplicate import
dangotbanned Feb 21, 2025
bfed23f
chore(typing): ignore valid `_dask.selectors` warnings
dangotbanned Feb 22, 2025
f060f22
refactor: define `CompliantSelector.__repr__`
dangotbanned Feb 22, 2025
3c5ea87
refactor(typing): always use concrete class in `_selector() -> ...`
dangotbanned Feb 22, 2025
b30f020
chore: move comments to discussion threads
dangotbanned Feb 22, 2025
eaa5b1c
style: order w/ unimplemented methods first
dangotbanned Feb 22, 2025
92a96cd
ignore `DAR201`
dangotbanned Feb 22, 2025
29eb2d0
degrade the DX to satisfy pre-commit :disappointed:
dangotbanned Feb 22, 2025
547d62a
feat: reuse & define `CompliantNamespace.selectors`
dangotbanned Feb 22, 2025
e6f2d07
refactor: use `self` not `context`
dangotbanned Feb 22, 2025
9325639
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Feb 23, 2025
13882a5
refactor: remove `CompliantSeriesWithDType`
dangotbanned Feb 23, 2025
7f65b4b
fix(typing): resolve some variance issues
dangotbanned Feb 23, 2025
c11dc95
fix(typing): resolve more `mypy` errors
dangotbanned Feb 23, 2025
ef85eda
fix(typing): resolve `when()` related issues
dangotbanned Feb 23, 2025
0301a4b
chore(typing): Ignore all `_dask` warnings
dangotbanned Feb 23, 2025
253e2ea
Merge branch 'main' into more-dedup-1
dangotbanned Feb 23, 2025
6d19ba0
Merge branch 'main' into more-dedup-1
dangotbanned Feb 23, 2025
0d6d2fc
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Feb 24, 2025
7201454
chore(typing): resolve `mypy` strict errors
dangotbanned Feb 24, 2025
a4edcb7
fix(DRAFT): try giving a `super(...)` hints for `3.8`
dangotbanned Feb 24, 2025
cc00fdc
don't forget `self`
dangotbanned Feb 24, 2025
b650c78
fix(DRAFT): try removing `CompliantSelectorNamespace.__init__`
dangotbanned Feb 24, 2025
acd8cbf
revert (72014548b6f49b864164167c2372fa43778556ad)
dangotbanned Feb 24, 2025
3be58e6
refactor: try reusing `Protocol`'s `Generic
dangotbanned Feb 24, 2025
d57c877
refactor: try the same for `CompliantExpr`
dangotbanned Feb 24, 2025
b192deb
fix(DRAFT): try removing `Protocol` altogether?
dangotbanned Feb 24, 2025
75c5a81
fix: only replace `Protocol` for `CompliantExpr`
dangotbanned Feb 24, 2025
372f1eb
ignore coverage, remove comments
dangotbanned Feb 24, 2025
7979909
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Feb 25, 2025
0242421
Merge branch 'main' into more-dedup-1
dangotbanned Feb 26, 2025
1abc05a
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Feb 26, 2025
819bf3b
Merge branch 'main' into more-dedup-1
dangotbanned Feb 27, 2025
3da91c4
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Mar 1, 2025
c26d7ff
chore(typing): ignore spark issues
dangotbanned Mar 1, 2025
b4f6a51
:unamused:
dangotbanned Mar 1, 2025
cbc1ee1
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Mar 1, 2025
ec5dd0c
refactor: Reuse eager-backend `iter_columns`
dangotbanned Mar 1, 2025
d03cc22
Merge branch 'main' into more-dedup-1
dangotbanned Mar 2, 2025
fb05225
Merge remote-tracking branch 'upstream/main' into more-dedup-1
dangotbanned Mar 2, 2025
c180c7e
refactor: define `CompliantLazyFrame._iter_columns`
dangotbanned Mar 2, 2025
e85c934
docs: move note back into code
dangotbanned Mar 2, 2025
a43d5e5
refactor: remove duplicate import
dangotbanned Mar 2, 2025
f5765ff
typo
dangotbanned Mar 2, 2025
cd31cbe
docs: add notes on `3.8` compat code
dangotbanned Mar 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 24 additions & 188 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,45 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Sequence

from narwhals._arrow.expr import ArrowExpr
from narwhals.utils import _parse_time_unit_and_time_zone
from narwhals.utils import dtype_matches_time_unit_and_time_zone
from narwhals.utils import import_dtypes_module
from narwhals._selectors import CompliantSelector
from narwhals._selectors import EagerSelectorNamespace

if TYPE_CHECKING:
from datetime import timezone

from typing_extensions import Self

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
from narwhals.utils import _LimitedContext


class ArrowSelectorNamespace:
def __init__(self: Self, context: _LimitedContext, /) -> None:
self._backend_version = context._backend_version
self._version = context._version

def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [df[col] for col in df.columns if df.schema[col] in dtypes]

def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [col for col in df.columns if df.schema[col] in dtypes]

return selector(self, func, evaluate_output_names)

def matches(self: Self, pattern: str) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [df[col] for col in df.columns if re.search(pattern, col)]

def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [col for col in df.columns if re.search(pattern, col)]
from narwhals._selectors import EvalNames
from narwhals._selectors import EvalSeries
from narwhals.utils import _FullContext

return selector(self, func, evaluate_output_names)

def numeric(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype(
[
dtypes.Int128,
dtypes.Int64,
dtypes.Int32,
dtypes.Int16,
dtypes.Int8,
dtypes.UInt128,
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
dtypes.UInt8,
dtypes.Float64,
dtypes.Float32,
],
)

def categorical(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype([dtypes.Categorical])

def string(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype([dtypes.String])

def boolean(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype([dtypes.Boolean])

def all(self: Self) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [df[col] for col in df.columns]

return selector(self, func, lambda df: df.columns)

def datetime(
self: Self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
def _selector(
self,
call: EvalSeries[ArrowDataFrame, ArrowSeries],
evaluate_output_names: EvalNames[ArrowDataFrame],
/,
) -> ArrowSelector:
dtypes = import_dtypes_module(version=self._version)
time_units, time_zones = _parse_time_unit_and_time_zone(
time_unit=time_unit, time_zone=time_zone
return ArrowSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [
df[col]
for col in df.columns
if dtype_matches_time_unit_and_time_zone(
dtype=df.schema[col],
dtypes=dtypes,
time_units=time_units,
time_zones=time_zones,
)
]

def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [
col
for col in df.columns
if dtype_matches_time_unit_and_time_zone(
dtype=df.schema[col],
dtypes=dtypes,
time_units=time_units,
time_zones=time_zones,
)
]

return selector(self, func, evaluate_output_names)

def __init__(self: Self, context: _FullContext, /) -> None:
self._implementation = context._implementation
self._backend_version = context._backend_version
self._version = context._version

class ArrowSelector(ArrowExpr):
def __repr__(self: Self) -> str: # pragma: no cover
return f"ArrowSelector(depth={self._depth}, function_name={self._function_name})"

class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
def _to_expr(self: Self) -> ArrowExpr:
return ArrowExpr(
self._call,
Expand All @@ -135,82 +50,3 @@ def _to_expr(self: Self) -> ArrowExpr:
backend_version=self._backend_version,
version=self._version,
)

def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any:
if isinstance(other, ArrowSelector):

def call(df: ArrowDataFrame) -> list[ArrowSeries]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
lhs = self._call(df)
return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names]

def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x not in rhs_names]

return selector(self, call, evaluate_output_names)
else:
return self._to_expr() - other

def __or__(self: Self, other: Self | Any) -> ArrowSelector | Any:
if isinstance(other, ArrowSelector):

def call(df: ArrowDataFrame) -> list[ArrowSeries]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
lhs = self._call(df)
rhs = other._call(df)
return [
*(x for x, name in zip(lhs, lhs_names) if name not in rhs_names),
*rhs,
]

def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]

return selector(self, call, evaluate_output_names)
else:
return self._to_expr() | other

def __and__(self: Self, other: Self | Any) -> ArrowSelector | Any:
if isinstance(other, ArrowSelector):

def call(df: ArrowDataFrame) -> list[ArrowSeries]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
lhs = self._call(df)
return [x for x, name in zip(lhs, lhs_names) if name in rhs_names]

def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x in rhs_names]

return selector(self, call, evaluate_output_names)

else:
return self._to_expr() & other

def __invert__(self: Self) -> ArrowSelector:
return ArrowSelectorNamespace(self).all() - self


def selector(
context: _LimitedContext,
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
/,
) -> ArrowSelector:
return ArrowSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
)
3 changes: 3 additions & 0 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.typing import CompliantNamespace
from narwhals.utils import Implementation
from narwhals.utils import get_column_names

if TYPE_CHECKING:
Expand All @@ -38,6 +39,8 @@


class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] (#2044)
_implementation: Implementation = Implementation.DASK

@property
def selectors(self: Self) -> DaskSelectorNamespace:
return DaskSelectorNamespace(self)
Expand Down
Loading
Loading