|
15 | 15 |
|
16 | 16 | import unittest
|
17 | 17 |
|
| 18 | +import numpy as np |
| 19 | + |
18 | 20 | from transformers import MllamaProcessor
|
19 | 21 | from transformers.testing_utils import require_torch, require_vision
|
20 | 22 | from transformers.utils import is_vision_available
|
@@ -177,3 +179,119 @@ def test_apply_chat_template(self):
|
177 | 179 | rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
|
178 | 180 | rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
|
179 | 181 | self.assertEqual(rendered_list, rendered_str)
|
| 182 | + |
| 183 | + def test_process_interleaved_images_prompts_image_splitting(self): |
| 184 | + # Test that a single image is processed correctly |
| 185 | + inputs = self.processor(images=self.image2, size={"width": 224, "height": 224}) |
| 186 | + self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 224, 224)) |
| 187 | + |
| 188 | + # Test that text is processed correctly |
| 189 | + text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>" |
| 190 | + inputs = self.processor(text=text) |
| 191 | + expected_ids = [128000, 2028, 374, 264, 1296, 11914, 13, 128001] |
| 192 | + self.assertEqual(inputs["input_ids"][0], expected_ids) |
| 193 | + self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids)) |
| 194 | + self.assertEqual(inputs.get("cross_attention_mask"), None) |
| 195 | + |
| 196 | + # Test a single sample with image and text |
| 197 | + image_str = "<|image|>" |
| 198 | + text_str = "This is a test sentence." |
| 199 | + text = image_str + text_str |
| 200 | + inputs = self.processor( |
| 201 | + text=text, |
| 202 | + images=self.image1, |
| 203 | + size={"width": 128, "height": 128}, |
| 204 | + ) |
| 205 | + expected_ids = [self.image_token_id, self.bos_token_id] + [2028, 374, 264, 1296, 11914, 13] |
| 206 | + |
| 207 | + self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 128, 128)) |
| 208 | + self.assertEqual(inputs["input_ids"][0], expected_ids) |
| 209 | + self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids)) |
| 210 | + cross_attention_mask = inputs["cross_attention_mask"] |
| 211 | + self.assertEqual(cross_attention_mask.shape, (1, 8, 1, 4)) |
| 212 | + self.assertTrue( |
| 213 | + np.all(cross_attention_mask == 1), f"Cross attention mask is not all ones: {cross_attention_mask}" |
| 214 | + ) |
| 215 | + |
| 216 | + # Test batch |
| 217 | + text = [ |
| 218 | + "<|image|>This is a test sentence.", |
| 219 | + "This is a test sentence.<|image|><|image|>This is a test sentence.", |
| 220 | + ] |
| 221 | + # fmt: off |
| 222 | + expected_ids = [ |
| 223 | + [self.image_token_id, self.bos_token_id, 2028, 374, 264, 1296, 11914, 13], |
| 224 | + [self.bos_token_id, 2028, 374, 264, 1296, 11914, 13, self.image_token_id, self.image_token_id, 2028, 374, 264, 1296, 11914, 13], |
| 225 | + ] |
| 226 | + # fmt: onn |
| 227 | + images = [[self.image1], [self.image1, self.image2]] |
| 228 | + inputs = self.processor(text=text, images=images, padding=True, size={"width": 256, "height": 256}) |
| 229 | + |
| 230 | + self.assertEqual(inputs["pixel_values"].shape, (2, 2, 4, 3, 256, 256)) |
| 231 | + for input_ids_i, attention_mask_i, expected_ids_i in zip(inputs["input_ids"], inputs["attention_mask"], expected_ids): |
| 232 | + pad_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 0] |
| 233 | + input_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 1] |
| 234 | + self.assertEqual(input_ids, expected_ids_i) |
| 235 | + self.assertEqual(pad_ids, [self.pad_token_id] * len(pad_ids)) |
| 236 | + |
| 237 | + cross_attention_mask = inputs["cross_attention_mask"] |
| 238 | + self.assertEqual(cross_attention_mask.shape, (2, 15, 2, 4)) |
| 239 | + |
| 240 | + # Check that only first tile of first sample is attended to all text tokens |
| 241 | + first_sample_mask = cross_attention_mask[0].copy() |
| 242 | + first_image_first_tile_attention = first_sample_mask[:, :1, :1] # text tokens, images, tiles |
| 243 | + self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}") |
| 244 | + |
| 245 | + # zero out first tile of first image |
| 246 | + first_image_first_tile_attention[:, :1, :1] = 0 |
| 247 | + self.assertTrue(np.all(first_image_first_tile_attention == 0), f"Cross attention mask is not all zeros: {first_image_first_tile_attention}") |
| 248 | + |
| 249 | + # second sample |
| 250 | + second_sample_mask = cross_attention_mask[1].copy() |
| 251 | + first_image_first_tile_attention = second_sample_mask[7:, :1, :1] # text tokens, images, tiles |
| 252 | + self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}") |
| 253 | + |
| 254 | + second_image_two_tiles_attention = second_sample_mask[8:, 1:2, :2] # text tokens, images, tiles |
| 255 | + self.assertTrue(np.all(second_image_two_tiles_attention == 1), f"Cross attention mask is not all ones: {second_image_two_tiles_attention}") |
| 256 | + |
| 257 | + # zero out both images masks |
| 258 | + second_sample_mask[7:, :1, :1] = 0 |
| 259 | + second_sample_mask[8:, 1:2, :2] = 0 |
| 260 | + self.assertTrue(np.all(second_sample_mask == 0), f"Cross attention mask is not all zeros: {second_sample_mask}") |
| 261 | + |
| 262 | + def test_process_interleaved_images_prompts_image_error(self): |
| 263 | + text = [ |
| 264 | + "This is a test sentence.", |
| 265 | + "In this other sentence we try some good things", |
| 266 | + ] |
| 267 | + inputs = self.processor(text=text, images=None, padding=True) |
| 268 | + self.assertIsNotNone(inputs["input_ids"]) |
| 269 | + |
| 270 | + text = [ |
| 271 | + "This is a test sentence.<|image|>", |
| 272 | + "In this other sentence we try some good things", |
| 273 | + ] |
| 274 | + with self.assertRaises(ValueError): |
| 275 | + self.processor(text=text, images=None, padding=True) |
| 276 | + |
| 277 | + images = [[self.image1], []] |
| 278 | + with self.assertRaises(ValueError): |
| 279 | + self.processor(text=text, images=images, padding=True) |
| 280 | + |
| 281 | + text = [ |
| 282 | + "This is a test sentence.<|image|>", |
| 283 | + "In this other sentence we try some good things<|image|>", |
| 284 | + ] |
| 285 | + with self.assertRaises(ValueError): |
| 286 | + self.processor(text=text, images=None, padding=True) |
| 287 | + |
| 288 | + text = [ |
| 289 | + "This is a test sentence.<|image|>", |
| 290 | + "In this other sentence we try some good things<|image|>", |
| 291 | + ] |
| 292 | + images = [[self.image1], [self.image2]] |
| 293 | + inputs = self.processor(text=text, images=images, padding=True) |
| 294 | + |
| 295 | + images = [[self.image1, self.image2], []] |
| 296 | + with self.assertRaises(ValueError): |
| 297 | + self.processor(text=text, images=None, padding=True) |
0 commit comments