Skip to content

Tokenizers v0.20.2 fails on batches as tuples #1672

Closed
@OyvindTafjord

Description

@OyvindTafjord

Certain fast tokenizers now fail on batches given as tuples, e.g. (on a MacBook M2 with transformers 4.46.1):

>>> from transformers import AutoTokenizer
>>> tok = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
>>> tok.batch_encode_plus(("hello there", "bye bye bye"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/oyvindt/miniconda3/envs/oe-eval/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3311, in batch_encode_plus
    return self._batch_encode_plus(
  File "/Users/oyvindt/miniconda3/envs/oe-eval/lib/python3.10/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py", line 127, in _batch_encode_plus
    return super()._batch_encode_plus(*args, **kwargs)
  File "/Users/oyvindt/miniconda3/envs/oe-eval/lib/python3.10/site-packages/transformers/tokenization_utils_fast.py", line 529, in _batch_encode_plus
    encodings = self._tokenizer.encode_batch(
TypeError: argument 'input': 'tuple' object cannot be converted to 'PyList'

This works in v0.20.1. Presumably related to this PR: #1665

The code for batch_encode_plus in transformers claims to be working for both tuples and lists:

        if not isinstance(batch_text_or_text_pairs, (tuple, list)):
            raise TypeError(
                f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})"
            )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions