diff --git a/janitor/functions.py b/janitor/functions.py index 2fe77556b..6d4789dcc 100644 --- a/janitor/functions.py +++ b/janitor/functions.py @@ -8,7 +8,38 @@ import re -def clean_names(df): +def _strip_underscores(df, strip_underscores=None): + """ + Strip underscores from the beginning, end or both of the + of the DataFrames column names. + + .. code-block:: python + + df = _strip_underscores(df, strip_underscores='left') + + :param df: The pandas DataFrame object. + :param strip_underscores: (optional) Removes the outer underscores from all + column names. Default None keeps outer underscores. Values can be + either 'left', 'right' or 'both' or the respective shorthand 'l', 'r' + and True. + :returns: A pandas DataFrame. + """ + underscore_options = [None, 'left', 'right', 'both', 'l', 'r', True] + if strip_underscores not in underscore_options: + raise JanitorError( + """strip_underscores must be one of: %s""" % underscore_options + ) + + if strip_underscores in ['left', 'l']: + df = df.rename(columns=lambda x: x.lstrip('_')) + elif strip_underscores in ['right', 'r']: + df = df.rename(columns=lambda x: x.rstrip('_')) + elif strip_underscores == 'both' or strip_underscores is True: + df = df.rename(columns=lambda x: x.strip('_')) + return df + + +def clean_names(df, strip_underscores=None): """ Clean column names. @@ -29,6 +60,10 @@ def clean_names(df): df = jn.DataFrame(df).clean_names() :param df: The pandas DataFrame object. + :param strip_underscores: (optional) Removes the outer underscores from all + column names. Default None keeps outer underscores. Values can be + either 'left', 'right' or 'both' or the respective shorthand 'l', 'r' + and True. :returns: A pandas DataFrame. """ df = df.rename( @@ -47,6 +82,7 @@ def clean_names(df): ) df = df.rename(columns=lambda x: re.sub('_+', '_', x)) + df = _strip_underscores(df, strip_underscores) return df @@ -190,7 +226,7 @@ def get_features_targets(df, target_columns, feature_columns=None): if isinstance(target_columns, str): xcols = [c for c in df.columns if target_columns != c] elif (isinstance(target_columns, list) - or isinstance(target_columns, tuple)): + or isinstance(target_columns, tuple)): xcols = [c for c in df.columns if c not in target_columns] X = df[xcols] return X, Y diff --git a/tests/test_functions.py b/tests/test_functions.py index 6af5b2498..7242ceab8 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -159,3 +159,91 @@ def test_multiindex_clean_names_pipe(multiindex_dataframe): expected_columns = pd.MultiIndex(levels=levels, labels=labels) assert set(df.columns) == set(expected_columns) + + +def test_clean_names_strip_underscores_both(multiindex_dataframe): + df = multiindex_dataframe.rename(columns=lambda x: '_' + x) + df = clean_names(multiindex_dataframe, strip_underscores='both') + + levels = [ + ['a', 'bell_chart', 'decorated_elephant'], + ['b', 'normal_distribution', 'r_i_p_rhino'] + ] + + labels = [[1, 0, 2], [1, 0, 2]] + + expected_columns = pd.MultiIndex(levels=levels, labels=labels) + assert set(df.columns) == set(expected_columns) + + +def test_clean_names_strip_underscores_true(multiindex_dataframe): + df = multiindex_dataframe.rename(columns=lambda x: '_' + x) + df = clean_names(multiindex_dataframe, strip_underscores=True) + + levels = [ + ['a', 'bell_chart', 'decorated_elephant'], + ['b', 'normal_distribution', 'r_i_p_rhino'] + ] + + labels = [[1, 0, 2], [1, 0, 2]] + + expected_columns = pd.MultiIndex(levels=levels, labels=labels) + assert set(df.columns) == set(expected_columns) + + +def test_clean_names_strip_underscores_right(multiindex_dataframe): + df = clean_names(multiindex_dataframe, strip_underscores='right') + + levels = [ + ['a', 'bell_chart', 'decorated_elephant'], + ['b', 'normal_distribution', 'r_i_p_rhino'] + ] + + labels = [[1, 0, 2], [1, 0, 2]] + + expected_columns = pd.MultiIndex(levels=levels, labels=labels) + assert set(df.columns) == set(expected_columns) + + +def test_clean_names_strip_underscores_r(multiindex_dataframe): + df = clean_names(multiindex_dataframe, strip_underscores='r') + + levels = [ + ['a', 'bell_chart', 'decorated_elephant'], + ['b', 'normal_distribution', 'r_i_p_rhino'] + ] + + labels = [[1, 0, 2], [1, 0, 2]] + + expected_columns = pd.MultiIndex(levels=levels, labels=labels) + assert set(df.columns) == set(expected_columns) + + +def test_clean_names_strip_underscores_left(multiindex_dataframe): + df = multiindex_dataframe.rename(columns=lambda x: '_' + x) + df = clean_names(multiindex_dataframe, strip_underscores='left') + + levels = [ + ['a', 'bell_chart', 'decorated_elephant'], + ['b', 'normal_distribution', 'r_i_p_rhino_'] + ] + + labels = [[1, 0, 2], [1, 0, 2]] + + expected_columns = pd.MultiIndex(levels=levels, labels=labels) + assert set(df.columns) == set(expected_columns) + + +def test_clean_names_strip_underscores_l(multiindex_dataframe): + df = multiindex_dataframe.rename(columns=lambda x: '_' + x) + df = clean_names(multiindex_dataframe, strip_underscores='l') + + levels = [ + ['a', 'bell_chart', 'decorated_elephant'], + ['b', 'normal_distribution', 'r_i_p_rhino_'] + ] + + labels = [[1, 0, 2], [1, 0, 2]] + + expected_columns = pd.MultiIndex(levels=levels, labels=labels) + assert set(df.columns) == set(expected_columns)