Skip to content

Commit c64be31

Browse files
committed
[MllamaProcessor] Update errors and API with multiple image (#33715)
* update error * update and add a test * update * update
1 parent 2ef31de commit c64be31

File tree

2 files changed

+134
-16
lines changed

2 files changed

+134
-16
lines changed

src/transformers/models/mllama/processing_mllama.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
"""
16-
Processor class for Mllama.
17-
"""
1815

19-
from statistics import mean
16+
"""Processor class for Mllama."""
17+
2018
from typing import List, Optional, Union
2119

2220
import numpy as np
@@ -296,25 +294,27 @@ def __call__(
296294
encoding = self.tokenizer(text, **text_kwargs)
297295
data.update(encoding)
298296

297+
n_images_in_images = [0]
299298
if images is not None:
300299
images = make_list_of_images(images)
301300
n_images_in_images = [len(sample) for sample in images]
302301

303-
if text is not None:
304-
if (
305-
not all(batch_img_per_prompt == n_images_in_images for batch_img_per_prompt in n_images_in_text)
306-
and len(text) > 1
307-
):
308-
raise ValueError(
309-
f"The number of images in each batch {n_images_in_text} should be the same {n_images_in_images} should be the same. Yes, the model does not \
310-
support having a different number of images per batch."
311-
)
312-
if int(mean(n_images_in_text)) != int(mean(n_images_in_images)):
302+
if text is not None:
303+
if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
304+
batch_img == 0 for batch_img in n_images_in_text
305+
):
306+
raise ValueError(
307+
"If a batch of text is provided, there should be either no images or at least one image per sample"
308+
)
309+
if sum(n_images_in_images) != sum(n_images_in_text):
310+
if images is None:
311+
raise ValueError("No image were provided, but there are image tokens in the prompt")
312+
else:
313313
raise ValueError(
314-
f"The number of images in the text ({n_images_in_text}) should be the same as in the number of provided images ({n_images_in_images}) \
315-
should be the same."
314+
f"The number of image token ({sum(n_images_in_images)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
316315
)
317316

317+
if images is not None:
318318
image_features = self.image_processor(images, **images_kwargs)
319319
num_tiles = image_features.pop("num_tiles")
320320
data.update(image_features)

tests/models/mllama/test_processor_mllama.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import unittest
1717

18+
import numpy as np
19+
1820
from transformers import MllamaProcessor
1921
from transformers.testing_utils import require_torch, require_vision
2022
from transformers.utils import is_vision_available
@@ -177,3 +179,119 @@ def test_apply_chat_template(self):
177179
rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
178180
rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
179181
self.assertEqual(rendered_list, rendered_str)
182+
183+
def test_process_interleaved_images_prompts_image_splitting(self):
184+
# Test that a single image is processed correctly
185+
inputs = self.processor(images=self.image2, size={"width": 224, "height": 224})
186+
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 224, 224))
187+
188+
# Test that text is processed correctly
189+
text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>"
190+
inputs = self.processor(text=text)
191+
expected_ids = [128000, 2028, 374, 264, 1296, 11914, 13, 128001]
192+
self.assertEqual(inputs["input_ids"][0], expected_ids)
193+
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
194+
self.assertEqual(inputs.get("cross_attention_mask"), None)
195+
196+
# Test a single sample with image and text
197+
image_str = "<|image|>"
198+
text_str = "This is a test sentence."
199+
text = image_str + text_str
200+
inputs = self.processor(
201+
text=text,
202+
images=self.image1,
203+
size={"width": 128, "height": 128},
204+
)
205+
expected_ids = [self.image_token_id, self.bos_token_id] + [2028, 374, 264, 1296, 11914, 13]
206+
207+
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 128, 128))
208+
self.assertEqual(inputs["input_ids"][0], expected_ids)
209+
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
210+
cross_attention_mask = inputs["cross_attention_mask"]
211+
self.assertEqual(cross_attention_mask.shape, (1, 8, 1, 4))
212+
self.assertTrue(
213+
np.all(cross_attention_mask == 1), f"Cross attention mask is not all ones: {cross_attention_mask}"
214+
)
215+
216+
# Test batch
217+
text = [
218+
"<|image|>This is a test sentence.",
219+
"This is a test sentence.<|image|><|image|>This is a test sentence.",
220+
]
221+
# fmt: off
222+
expected_ids = [
223+
[self.image_token_id, self.bos_token_id, 2028, 374, 264, 1296, 11914, 13],
224+
[self.bos_token_id, 2028, 374, 264, 1296, 11914, 13, self.image_token_id, self.image_token_id, 2028, 374, 264, 1296, 11914, 13],
225+
]
226+
# fmt: onn
227+
images = [[self.image1], [self.image1, self.image2]]
228+
inputs = self.processor(text=text, images=images, padding=True, size={"width": 256, "height": 256})
229+
230+
self.assertEqual(inputs["pixel_values"].shape, (2, 2, 4, 3, 256, 256))
231+
for input_ids_i, attention_mask_i, expected_ids_i in zip(inputs["input_ids"], inputs["attention_mask"], expected_ids):
232+
pad_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 0]
233+
input_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 1]
234+
self.assertEqual(input_ids, expected_ids_i)
235+
self.assertEqual(pad_ids, [self.pad_token_id] * len(pad_ids))
236+
237+
cross_attention_mask = inputs["cross_attention_mask"]
238+
self.assertEqual(cross_attention_mask.shape, (2, 15, 2, 4))
239+
240+
# Check that only first tile of first sample is attended to all text tokens
241+
first_sample_mask = cross_attention_mask[0].copy()
242+
first_image_first_tile_attention = first_sample_mask[:, :1, :1] # text tokens, images, tiles
243+
self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}")
244+
245+
# zero out first tile of first image
246+
first_image_first_tile_attention[:, :1, :1] = 0
247+
self.assertTrue(np.all(first_image_first_tile_attention == 0), f"Cross attention mask is not all zeros: {first_image_first_tile_attention}")
248+
249+
# second sample
250+
second_sample_mask = cross_attention_mask[1].copy()
251+
first_image_first_tile_attention = second_sample_mask[7:, :1, :1] # text tokens, images, tiles
252+
self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}")
253+
254+
second_image_two_tiles_attention = second_sample_mask[8:, 1:2, :2] # text tokens, images, tiles
255+
self.assertTrue(np.all(second_image_two_tiles_attention == 1), f"Cross attention mask is not all ones: {second_image_two_tiles_attention}")
256+
257+
# zero out both images masks
258+
second_sample_mask[7:, :1, :1] = 0
259+
second_sample_mask[8:, 1:2, :2] = 0
260+
self.assertTrue(np.all(second_sample_mask == 0), f"Cross attention mask is not all zeros: {second_sample_mask}")
261+
262+
def test_process_interleaved_images_prompts_image_error(self):
263+
text = [
264+
"This is a test sentence.",
265+
"In this other sentence we try some good things",
266+
]
267+
inputs = self.processor(text=text, images=None, padding=True)
268+
self.assertIsNotNone(inputs["input_ids"])
269+
270+
text = [
271+
"This is a test sentence.<|image|>",
272+
"In this other sentence we try some good things",
273+
]
274+
with self.assertRaises(ValueError):
275+
self.processor(text=text, images=None, padding=True)
276+
277+
images = [[self.image1], []]
278+
with self.assertRaises(ValueError):
279+
self.processor(text=text, images=images, padding=True)
280+
281+
text = [
282+
"This is a test sentence.<|image|>",
283+
"In this other sentence we try some good things<|image|>",
284+
]
285+
with self.assertRaises(ValueError):
286+
self.processor(text=text, images=None, padding=True)
287+
288+
text = [
289+
"This is a test sentence.<|image|>",
290+
"In this other sentence we try some good things<|image|>",
291+
]
292+
images = [[self.image1], [self.image2]]
293+
inputs = self.processor(text=text, images=images, padding=True)
294+
295+
images = [[self.image1, self.image2], []]
296+
with self.assertRaises(ValueError):
297+
self.processor(text=text, images=None, padding=True)

0 commit comments

Comments
 (0)