Skip to content

Commit 5ab0123

Browse files
committed
set_dtype
1 parent 5d5173b commit 5ab0123

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

pandas/core/categorical.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,27 @@ def _codes_for_groupby(self, sort):
616616
return self.reorder_categories(cat.categories)
617617

618618
def _set_dtype(self, dtype):
619-
# TODO: this could go fast from codes maybe?
620-
return type(self)(self.codes, dtype=dtype, fastpath=True)
619+
"""Internal method for directly updating the CategoricalDtype
620+
621+
Parameters
622+
----------
623+
dtype : CategoricalDtype
624+
625+
Notes
626+
-----
627+
We don't do any validation here. It's assumed that the dtype is
628+
a (valid) instance of `CategoricalDtype`.
629+
"""
630+
# We want to convert old codes -> new codes *without* going to values
631+
# [b, a, c, a, b, f] | original dtype: [a, b, c, d]
632+
# [0, 1, 2, 0, 1, .] | original codes
633+
# --------------- | ----------
634+
# [b, a, ., a, b, .] | new dtype: [b, a, e]
635+
# [0, 1, ., 1, 0, .] |
636+
mapping = dtype.categories.get_indexer_for(self.categories)
637+
codes = mapping[self.codes]
638+
codes[self.codes == -1] = -1
639+
return type(self)(codes, dtype=dtype, fastpath=True)
621640

622641
def set_ordered(self, value, inplace=False):
623642
"""

pandas/tests/test_categorical.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,60 @@ def test_ordered_api(self):
879879
tm.assert_index_equal(cat4.categories, Index(['b', 'c', 'a']))
880880
assert cat4.ordered
881881

882+
def test_set_dtype_same(self):
883+
c = Categorical(['a', 'b', 'c'])
884+
result = c._set_dtype(CategoricalDtype(['a', 'b', 'c']))
885+
tm.assert_categorical_equal(result, c)
886+
887+
def test_set_dtype_new_categories(self):
888+
c = Categorical(['a', 'b', 'c'])
889+
result = c._set_dtype(CategoricalDtype(['a', 'b', 'c', 'd']))
890+
tm.assert_numpy_array_equal(result.codes, c.codes)
891+
tm.assert_index_equal(result.dtype.categories,
892+
pd.Index(['a', 'b', 'c', 'd']))
893+
894+
def test_set_dtype_nans(self):
895+
c = Categorical(['a', 'b', np.nan])
896+
result = c._set_dtype(CategoricalDtype(['a', 'c']))
897+
tm.assert_numpy_array_equal(result.codes, np.array([0, -1, -1],
898+
dtype='int8'))
899+
900+
@pytest.mark.parametrize('values, categories, new_categories', [
901+
# No NaNs, same cats, same order
902+
(['a', 'b', 'a'], ['a', 'b'], ['a', 'b'],),
903+
# No NaNs, same cats, different order
904+
(['a', 'b', 'a'], ['a', 'b'], ['b', 'a'],),
905+
# Same, unsorted
906+
(['b', 'a', 'a'], ['a', 'b'], ['a', 'b'],),
907+
# No NaNs, same cats, different order
908+
(['b', 'a', 'a'], ['a', 'b'], ['b', 'a'],),
909+
# NaNs
910+
(['a', 'b', 'c'], ['a', 'b'], ['a', 'b']),
911+
(['a', 'b', 'c'], ['a', 'b'], ['b', 'a']),
912+
(['b', 'a', 'c'], ['a', 'b'], ['a', 'b']),
913+
(['b', 'a', 'c'], ['a', 'b'], ['a', 'b']),
914+
# Introduce NaNs
915+
(['a', 'b', 'c'], ['a', 'b'], ['a']),
916+
(['a', 'b', 'c'], ['a', 'b'], ['b']),
917+
(['b', 'a', 'c'], ['a', 'b'], ['a']),
918+
(['b', 'a', 'c'], ['a', 'b'], ['a']),
919+
# No overlap
920+
(['a', 'b', 'c'], ['a', 'b'], ['d', 'e']),
921+
])
922+
@pytest.mark.parametrize('ordered', [True, False])
923+
def test_set_dtype_many(self, values, categories, new_categories,
924+
ordered):
925+
c = Categorical(values, categories)
926+
expected = Categorical(values, new_categories, ordered)
927+
result = c._set_dtype(expected.dtype)
928+
tm.assert_categorical_equal(result, expected)
929+
930+
def test_set_dtype_no_overlap(self):
931+
c = Categorical(['a', 'b', 'c'], ['d', 'e'])
932+
result = c._set_dtype(CategoricalDtype(['a', 'b']))
933+
expected = Categorical([None, None, None], categories=['a', 'b'])
934+
tm.assert_categorical_equal(result, expected)
935+
882936
def test_set_ordered(self):
883937

884938
cat = Categorical(["a", "b", "c", "a"], ordered=True)

0 commit comments

Comments
 (0)