Skip to content

Commit 4b7bb05

Browse files
committed
fix(chunker): correctly determine chunk midpoint when empty chunks are present
Previously ["foo", '', "bar", 'baz'] would be token counted as 'foobarbaz' rather than 'foo bar baz' when getting the midpoint index
1 parent 41ad7f5 commit 4b7bb05

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

griptape/chunkers/base_chunker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _chunk_recursively(self, chunk: str, current_separator: Optional[ChunkSepara
6262

6363
if len(non_empty_subchunks) > 1:
6464
# Find what combination of subchunks results in the most balanced split of the chunk.
65-
midpoint_index = self.__find_midpoint_index(subchunks, half_token_count)
65+
midpoint_index = self.__find_midpoint_index(separator, subchunks, half_token_count)
6666

6767
# Create the two subchunks based on the best separator.
6868
first_subchunk, second_subchunk = self.__get_subchunks(separator, subchunks, midpoint_index)
@@ -98,12 +98,12 @@ def __get_subchunks(self, separator: ChunkSeparator, subchunks: list[str], balan
9898

9999
return first_subchunk, second_subchunk
100100

101-
def __find_midpoint_index(self, subchunks: list[str], half_token_count: int) -> int:
101+
def __find_midpoint_index(self, separator: ChunkSeparator, subchunks: list[str], half_token_count: int) -> int:
102102
midpoint_index = -1
103103
best_midpoint_distance = float("inf")
104104

105105
for index, _ in enumerate(subchunks):
106-
subchunk_tokens_count = self.tokenizer.count_tokens("".join(subchunks[: index + 1]))
106+
subchunk_tokens_count = self.tokenizer.count_tokens(separator.value.join(subchunks[: index + 1]))
107107

108108
midpoint_distance = abs(subchunk_tokens_count - half_token_count)
109109
if midpoint_distance < best_midpoint_distance:

tests/unit/chunkers/test_markdown_chunker.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_chunk(self, chunker):
2424
]
2525
chunks = chunker.chunk("".join(text))
2626

27-
assert len(chunks) == 6
27+
assert len(chunks) == 7
2828

2929
for chunk in chunks:
3030
assert chunker.tokenizer.count_tokens(chunk.value) <= MAX_TOKENS
@@ -33,12 +33,14 @@ def test_chunk(self, chunker):
3333
assert chunks[1].value.startswith("## Header 2\nfoo-0")
3434
assert chunks[2].value.startswith("foo-0.")
3535
assert chunks[3].value.startswith("## Header 3\nfoo-0")
36-
assert chunks[4].value.startswith("foo-10.")
37-
assert chunks[5].value.startswith("foo-16.")
36+
assert chunks[4].value.startswith("foo-5.")
37+
assert chunks[5].value.startswith("foo-12.")
38+
assert chunks[6].value.startswith("foo-19.")
3839

3940
assert chunks[0].value.endswith(". foo-5.")
4041
assert chunks[1].value.endswith(". foo-5.")
4142
assert chunks[2].value.endswith(". foo-5.")
42-
assert chunks[3].value.endswith(". foo-9.")
43-
assert chunks[4].value.endswith(". foo-15.")
44-
assert chunks[5].value.endswith(". foo-24.")
43+
assert chunks[3].value.endswith(". foo-4.")
44+
assert chunks[4].value.endswith(". foo-11.")
45+
assert chunks[5].value.endswith(". foo-18.")
46+
assert chunks[6].value.endswith(". foo-24.")

tests/unit/chunkers/test_text_chunker.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def test_large_chunks(self, chunker):
5656
assert chunker.tokenizer.count_tokens(chunk.value) <= MAX_TOKENS
5757

5858
assert chunks[0].value.startswith("foo-0!")
59-
assert chunks[1].value.startswith("foo-11!")
60-
assert chunks[2].value.startswith("foo-17!")
59+
assert chunks[1].value.startswith("foo-7!")
60+
assert chunks[2].value.startswith("foo-13!")
6161
assert chunks[3].value.startswith("foo-0.")
6262

63-
assert chunks[0].value.endswith("! foo-10!")
64-
assert chunks[1].value.endswith("! foo-16!")
63+
assert chunks[0].value.endswith("! foo-6!")
64+
assert chunks[1].value.endswith("! foo-12!")
6565
assert chunks[2].value.endswith("! foo-24!")
6666
assert chunks[3].value.endswith(". foo-11.")
6767

@@ -92,19 +92,19 @@ def test_separators(self, chunker):
9292
assert chunker.tokenizer.count_tokens(chunk.value) <= MAX_TOKENS
9393

9494
assert chunks[0].value.startswith("foo-0!")
95-
assert chunks[1].value.startswith("foo-11!")
96-
assert chunks[2].value.startswith("foo-17!")
95+
assert chunks[1].value.startswith("foo-7!")
96+
assert chunks[2].value.startswith("foo-13!")
9797
assert chunks[3].value.startswith("foo-0.")
9898
assert chunks[4].value.startswith("foo-0?")
99-
assert chunks[5].value.startswith("foo-9?")
99+
assert chunks[5].value.startswith("foo-7?")
100100
assert chunks[6].value.startswith("foo-0")
101101
assert chunks[7].value.startswith("foo-8")
102102

103-
assert chunks[0].value.endswith("! foo-10!")
104-
assert chunks[1].value.endswith("! foo-16!")
103+
assert chunks[0].value.endswith("! foo-6!")
104+
assert chunks[1].value.endswith("! foo-12!")
105105
assert chunks[2].value.endswith("! foo-24!")
106106
assert chunks[3].value.endswith(". foo-11.")
107-
assert chunks[4].value.endswith("? foo-8?")
107+
assert chunks[4].value.endswith("? foo-6?")
108108
assert chunks[5].value.endswith("? foo-12?")
109109
assert chunks[6].value.endswith(" foo-7")
110110
assert chunks[7].value.endswith(" foo-16")
@@ -138,3 +138,15 @@ def test_artifact_reference(self, chunker):
138138

139139
for chunk in chunks:
140140
assert chunk.reference is None
141+
142+
def test_midpoint_index_empty_subchunks(self, chunker):
143+
# This tests that a midpoint index is correctly found when there are some empty subchunks
144+
# Previously ["foo", '', "bar", 'baz'] would be token counted as 'foobarbaz' rather than 'foo bar baz'
145+
# when calculating the midpoint index.
146+
# https://github.com/griptape-ai/griptape/issues/1796
147+
chunker.max_tokens = 3
148+
149+
assert len(chunker.chunk("foo bar baz")) == 1
150+
assert len(chunker.chunk("foo bar baz ")) == 2
151+
152+
assert len(chunker.chunk("foo bar baz")) == 2

0 commit comments

Comments
 (0)