Skip to content

Commit 0d406e4

Browse files
authored
chore: remove kwargs from selector function (#2058)
1 parent c18c145 commit 0d406e4

File tree

5 files changed

+30
-59
lines changed

5 files changed

+30
-59
lines changed

narwhals/_arrow/expr.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
4444
backend_version: tuple[int, ...],
4545
version: Version,
46-
kwargs: dict[str, Any],
46+
kwargs: dict[str, Any] | None = None,
4747
) -> None:
4848
self._call = call
4949
self._depth = depth
@@ -53,7 +53,7 @@ def __init__(
5353
self._alias_output_names = alias_output_names
5454
self._backend_version = backend_version
5555
self._version = version
56-
self._kwargs = kwargs
56+
self._kwargs = kwargs or {}
5757

5858
def __repr__(self: Self) -> str: # pragma: no cover
5959
return f"ArrowExpr(depth={self._depth}, function_name={self._function_name}, "
@@ -117,7 +117,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
117117
alias_output_names=None,
118118
backend_version=backend_version,
119119
version=version,
120-
kwargs={},
121120
)
122121

123122
@classmethod
@@ -148,7 +147,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
148147
alias_output_names=None,
149148
backend_version=backend_version,
150149
version=version,
151-
kwargs={},
152150
)
153151

154152
def __narwhals_namespace__(self: Self) -> ArrowNamespace:
@@ -315,7 +313,7 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:
315313
alias_output_names=alias_output_names,
316314
backend_version=self._backend_version,
317315
version=self._version,
318-
kwargs={**self._kwargs, "name": name},
316+
kwargs=self._kwargs,
319317
)
320318

321319
def null_count(self: Self) -> Self:

narwhals/_arrow/namespace.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _create_expr_from_callable(
4949
function_name: str,
5050
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
5151
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
52-
kwargs: dict[str, Any],
52+
kwargs: dict[str, Any] | None = None,
5353
) -> ArrowExpr:
5454
from narwhals._arrow.expr import ArrowExpr
5555

@@ -75,7 +75,6 @@ def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr:
7575
alias_output_names=None,
7676
backend_version=self._backend_version,
7777
version=self._version,
78-
kwargs={},
7978
)
8079

8180
def _create_series_from_scalar(
@@ -142,7 +141,6 @@ def len(self: Self) -> ArrowExpr:
142141
alias_output_names=None,
143142
backend_version=self._backend_version,
144143
version=self._version,
145-
kwargs={},
146144
)
147145

148146
def all(self: Self) -> ArrowExpr:
@@ -165,7 +163,6 @@ def all(self: Self) -> ArrowExpr:
165163
alias_output_names=None,
166164
backend_version=self._backend_version,
167165
version=self._version,
168-
kwargs={},
169166
)
170167

171168
def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
@@ -188,7 +185,6 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
188185
alias_output_names=None,
189186
backend_version=self._backend_version,
190187
version=self._version,
191-
kwargs={},
192188
)
193189

194190
def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
@@ -202,7 +198,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
202198
function_name="all_horizontal",
203199
evaluate_output_names=combine_evaluate_output_names(*exprs),
204200
alias_output_names=combine_alias_output_names(*exprs),
205-
kwargs={"exprs": exprs},
206201
)
207202

208203
def any_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
@@ -387,11 +382,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
387382
function_name="concat_str",
388383
evaluate_output_names=combine_evaluate_output_names(*exprs),
389384
alias_output_names=combine_alias_output_names(*exprs),
390-
kwargs={
391-
"exprs": exprs,
392-
"separator": separator,
393-
"ignore_nulls": ignore_nulls,
394-
},
395385
)
396386

397387

@@ -466,7 +456,6 @@ def then(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen:
466456
alias_output_names=getattr(value, "_alias_output_names", None),
467457
backend_version=self._backend_version,
468458
version=self._version,
469-
kwargs={"value": value},
470459
)
471460

472461

@@ -481,7 +470,7 @@ def __init__(
481470
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
482471
backend_version: tuple[int, ...],
483472
version: Version,
484-
kwargs: dict[str, Any],
473+
kwargs: dict[str, Any] | None = None,
485474
) -> None:
486475
self._backend_version = backend_version
487476
self._version = version
@@ -490,7 +479,7 @@ def __init__(
490479
self._function_name = function_name
491480
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
492481
self._alias_output_names = alias_output_names
493-
self._kwargs = kwargs
482+
self._kwargs = kwargs or {}
494483

495484
def otherwise(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr:
496485
# type ignore because we are setting the `_call` attribute to a

narwhals/_arrow/selectors.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
3636
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
3737
return [col for col in df.columns if df.schema[col] in dtypes]
3838

39-
return selector(self, func, evaluate_output_names, {"dtypes": dtypes})
39+
return selector(self, func, evaluate_output_names)
4040

4141
def matches(self: Self, pattern: str) -> ArrowSelector:
4242
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
@@ -45,7 +45,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
4545
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
4646
return [col for col in df.columns if re.search(pattern, col)]
4747

48-
return selector(self, func, evaluate_output_names, {"pattern": pattern})
48+
return selector(self, func, evaluate_output_names)
4949

5050
def numeric(self: Self) -> ArrowSelector:
5151
dtypes = import_dtypes_module(self._version)
@@ -82,7 +82,7 @@ def all(self: Self) -> ArrowSelector:
8282
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
8383
return [df[col] for col in df.columns]
8484

85-
return selector(self, func, lambda df: df.columns, {})
85+
return selector(self, func, lambda df: df.columns)
8686

8787
def datetime(
8888
self: Self,
@@ -118,7 +118,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
118118
)
119119
]
120120

121-
return selector(self, func, evaluate_output_names, {})
121+
return selector(self, func, evaluate_output_names)
122122

123123

124124
class ArrowSelector(ArrowExpr):
@@ -134,7 +134,6 @@ def _to_expr(self: Self) -> ArrowExpr:
134134
alias_output_names=self._alias_output_names,
135135
backend_version=self._backend_version,
136136
version=self._version,
137-
kwargs=self._kwargs,
138137
)
139138

140139
def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any:
@@ -151,9 +150,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
151150
rhs_names = other._evaluate_output_names(df)
152151
return [x for x in lhs_names if x not in rhs_names]
153152

154-
return selector(
155-
self, call, evaluate_output_names, {**self._kwargs, "other": other}
156-
)
153+
return selector(self, call, evaluate_output_names)
157154
else:
158155
return self._to_expr() - other
159156

@@ -175,9 +172,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
175172
rhs_names = other._evaluate_output_names(df)
176173
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
177174

178-
return selector(
179-
self, call, evaluate_output_names, {**self._kwargs, "other": other}
180-
)
175+
return selector(self, call, evaluate_output_names)
181176
else:
182177
return self._to_expr() | other
183178

@@ -195,9 +190,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
195190
rhs_names = other._evaluate_output_names(df)
196191
return [x for x in lhs_names if x in rhs_names]
197192

198-
return selector(
199-
self, call, evaluate_output_names, {**self._kwargs, "other": other}
200-
)
193+
return selector(self, call, evaluate_output_names)
201194

202195
else:
203196
return self._to_expr() & other
@@ -210,7 +203,6 @@ def selector(
210203
context: _LimitedContext,
211204
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
212205
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
213-
kwargs: dict[str, Any],
214206
/,
215207
) -> ArrowSelector:
216208
return ArrowSelector(
@@ -221,5 +213,4 @@ def selector(
221213
alias_output_names=None,
222214
backend_version=context._backend_version,
223215
version=context._version,
224-
kwargs=kwargs,
225216
)

narwhals/_pandas_like/expr.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,29 +431,29 @@ def over(self: Self, keys: list[str]) -> Self:
431431
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
432432
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
433433

434-
reverse = self._kwargs.get("reverse", False)
435-
if reverse:
436-
msg = (
437-
"Cumulative operation with `reverse=True` is not supported in "
438-
"over context for pandas-like backend."
439-
)
440-
raise NotImplementedError(msg)
441-
434+
unsupported_reverse_msg = (
435+
"Cumulative operation with `reverse=True` is not supported in "
436+
"over context for pandas-like backend."
437+
)
442438
if function_name == "cum_count":
439+
if self._kwargs["reverse"]:
440+
raise NotImplementedError(unsupported_reverse_msg)
443441
plx = self.__narwhals_namespace__()
444442
df = df.with_columns(~plx.col(*output_names).is_null())
445443

446444
if function_name == "shift":
447445
kwargs = {"periods": self._kwargs["n"]}
448446
elif function_name == "rank":
449-
_method = self._kwargs.get("method", "average")
447+
_method = self._kwargs["method"]
450448
kwargs = {
451449
"method": "first" if _method == "ordinal" else _method,
452450
"ascending": not self._kwargs["descending"],
453451
"na_option": "keep",
454452
"pct": False,
455453
}
456454
else: # Cumulative operation
455+
if self._kwargs["reverse"]:
456+
raise NotImplementedError(unsupported_reverse_msg)
457457
kwargs = {"skipna": True}
458458

459459
res_native = getattr(

narwhals/_pandas_like/selectors.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
3737
def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]:
3838
return [col for col in df.columns if df.schema[col] in dtypes]
3939

40-
return selector(self, func, evaluate_output_names, {"dtypes": dtypes})
40+
return selector(self, func, evaluate_output_names)
4141

4242
def matches(self: Self, pattern: str) -> PandasSelector:
4343
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
@@ -46,7 +46,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
4646
def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]:
4747
return [col for col in df.columns if re.search(pattern, col)]
4848

49-
return selector(self, func, evaluate_output_names, {"pattern": pattern})
49+
return selector(self, func, evaluate_output_names)
5050

5151
def numeric(self: Self) -> PandasSelector:
5252
dtypes = import_dtypes_module(self._version)
@@ -83,7 +83,7 @@ def all(self: Self) -> PandasSelector:
8383
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
8484
return [df[col] for col in df.columns]
8585

86-
return selector(self, func, lambda df: df.columns, {})
86+
return selector(self, func, lambda df: df.columns)
8787

8888
def datetime(
8989
self: Self,
@@ -119,7 +119,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]:
119119
)
120120
]
121121

122-
return selector(self, func, evaluate_output_names, {})
122+
return selector(self, func, evaluate_output_names)
123123

124124

125125
class PandasSelector(PandasLikeExpr):
@@ -155,9 +155,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]:
155155
rhs_names = other._evaluate_output_names(df)
156156
return [x for x in lhs_names if x not in rhs_names]
157157

158-
return selector(
159-
self, call, evaluate_output_names, {**self._kwargs, "other": other}
160-
)
158+
return selector(self, call, evaluate_output_names)
161159
else:
162160
return self._to_expr() - other
163161

@@ -179,9 +177,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]:
179177
rhs_names = other._evaluate_output_names(df)
180178
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
181179

182-
return selector(
183-
self, call, evaluate_output_names, {**self._kwargs, "other": other}
184-
)
180+
return selector(self, call, evaluate_output_names)
185181
else:
186182
return self._to_expr() | other
187183

@@ -199,9 +195,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]:
199195
rhs_names = other._evaluate_output_names(df)
200196
return [x for x in lhs_names if x in rhs_names]
201197

202-
return selector(
203-
self, call, evaluate_output_names, {**self._kwargs, "other": other}
204-
)
198+
return selector(self, call, evaluate_output_names)
205199
else:
206200
return self._to_expr() & other
207201

@@ -213,7 +207,6 @@ def selector(
213207
context: _FullContext,
214208
call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]],
215209
evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]],
216-
kwargs: dict[str, Any],
217210
/,
218211
) -> PandasSelector:
219212
return PandasSelector(
@@ -225,5 +218,5 @@ def selector(
225218
implementation=context._implementation,
226219
backend_version=context._backend_version,
227220
version=context._version,
228-
kwargs=kwargs,
221+
kwargs={},
229222
)

0 commit comments

Comments
 (0)