Skip to content

Commit 2ece31b

Browse files
MB-66395: Support batch processing of vector search requests
Requires: blevesearch/scorch_segment_api#62
1 parent 4e38ae4 commit 2ece31b

File tree

6 files changed

+372
-28
lines changed

6 files changed

+372
-28
lines changed

faiss_vector_batch_executor.go

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// Copyright (c) 2025 Couchbase, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build vectors
16+
// +build vectors
17+
18+
package zap
19+
20+
import (
21+
"encoding/json"
22+
"sync"
23+
"time"
24+
25+
"github.com/RoaringBitmap/roaring/v2/roaring64"
26+
faiss "github.com/blevesearch/go-faiss"
27+
segment "github.com/blevesearch/scorch_segment_api/v2"
28+
)
29+
30+
// batchKey represents a unique combination of k and params for batching
31+
type batchKey struct {
32+
k int64
33+
params string // string representation of params for comparison
34+
}
35+
36+
// batchRequest represents a single vector search request in a batch
37+
type batchRequest struct {
38+
qVector []float32
39+
result chan batchResult
40+
}
41+
42+
// batchGroup represents a group of requests with the same k and params
43+
type batchGroup struct {
44+
requests []batchRequest
45+
vecIndex *faiss.IndexImpl
46+
vecDocIDMap map[int64]uint32
47+
vectorIDsToExclude []int64
48+
}
49+
50+
// batchExecutor manages batched vector search requests
51+
type batchExecutor struct {
52+
batchDelay time.Duration
53+
54+
m sync.RWMutex
55+
groups map[batchKey]*batchGroup
56+
}
57+
58+
func newBatchExecutor(options segment.InterpretVectorIndexOptions) *batchExecutor {
59+
batchDelay := options.BatchExecutionDelay
60+
if batchDelay <= 0 {
61+
batchDelay = segment.DefaultBatchExecutionDelay
62+
}
63+
64+
return &batchExecutor{
65+
batchDelay: batchDelay,
66+
groups: make(map[batchKey]*batchGroup),
67+
}
68+
}
69+
70+
type batchResult struct {
71+
pl segment.VecPostingsList
72+
err error
73+
}
74+
75+
func (be *batchExecutor) close() {
76+
be.m.Lock()
77+
defer be.m.Unlock()
78+
79+
for key, group := range be.groups {
80+
for _, req := range group.requests {
81+
close(req.result)
82+
}
83+
delete(be.groups, key)
84+
}
85+
}
86+
87+
// queueRequest adds a vector search request to the appropriate batch group
88+
func (be *batchExecutor) queueRequest(qVector []float32, k int64, params json.RawMessage,
89+
vecIndex *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
90+
vectorIDsToExclude []int64) <-chan batchResult {
91+
92+
// Create a channel for the result
93+
resultCh := make(chan batchResult, 1)
94+
95+
// Create batch key
96+
key := batchKey{
97+
k: k,
98+
params: string(params),
99+
}
100+
101+
be.m.Lock()
102+
defer be.m.Unlock()
103+
104+
// Get or create batch group
105+
group, exists := be.groups[key]
106+
if !exists {
107+
group = &batchGroup{
108+
requests: make([]batchRequest, 0),
109+
vecIndex: vecIndex,
110+
vecDocIDMap: vecDocIDMap,
111+
vectorIDsToExclude: vectorIDsToExclude,
112+
}
113+
be.groups[key] = group
114+
}
115+
116+
// Add request to group
117+
group.requests = append(group.requests, batchRequest{
118+
qVector: qVector,
119+
result: resultCh,
120+
})
121+
122+
// If this is the first request in the group, start a timer to process the batch
123+
if len(group.requests) == 1 {
124+
go be.processBatchAfterDelay(key, be.batchDelay)
125+
}
126+
127+
return resultCh
128+
}
129+
130+
// processBatchAfterDelay waits for the specified delay and then processes the batch
131+
func (be *batchExecutor) processBatchAfterDelay(key batchKey, delay time.Duration) {
132+
time.Sleep(delay)
133+
134+
be.m.Lock()
135+
group, exists := be.groups[key]
136+
if !exists {
137+
be.m.Unlock()
138+
return
139+
}
140+
141+
// Remove the group from the map before processing
142+
delete(be.groups, key)
143+
be.m.Unlock()
144+
145+
// Process the batch
146+
be.processBatch(key, group)
147+
}
148+
149+
// processBatch executes a batch of vector search requests
150+
func (be *batchExecutor) processBatch(key batchKey, group *batchGroup) {
151+
if len(group.requests) == 0 {
152+
return
153+
}
154+
155+
// Prepare vectors for batch search
156+
vecs := make([]float32, 0, len(group.requests)*group.vecIndex.D())
157+
for _, req := range group.requests {
158+
vecs = append(vecs, req.qVector...)
159+
}
160+
161+
// Execute batch search
162+
scores, ids, err := group.vecIndex.SearchWithoutIDs(vecs, key.k, group.vectorIDsToExclude,
163+
json.RawMessage(key.params))
164+
if err != nil {
165+
// Send error to all channels
166+
for _, req := range group.requests {
167+
req.result <- batchResult{
168+
err: err,
169+
}
170+
close(req.result)
171+
}
172+
return
173+
}
174+
175+
// Calculate number of results per request
176+
resultsPerRequest := int(key.k)
177+
totalResults := len(scores)
178+
179+
// Process results and send to respective channels
180+
for i := range group.requests {
181+
pl := &VecPostingsList{
182+
postings: roaring64.New(),
183+
}
184+
185+
// Calculate start and end indices for this request's results
186+
startIdx := i * resultsPerRequest
187+
endIdx := startIdx + resultsPerRequest
188+
if endIdx > totalResults {
189+
endIdx = totalResults
190+
}
191+
192+
// Get this request's results
193+
currScores := scores[startIdx:endIdx]
194+
currIDs := ids[startIdx:endIdx]
195+
196+
// Add results to postings list
197+
for j := 0; j < len(currIDs); j++ {
198+
vecID := currIDs[j]
199+
if docID, ok := group.vecDocIDMap[vecID]; ok {
200+
code := getVectorCode(docID, currScores[j])
201+
pl.postings.Add(code)
202+
}
203+
}
204+
205+
// Send result to channel
206+
group.requests[i].result <- batchResult{
207+
pl: pl,
208+
}
209+
close(group.requests[i].result)
210+
}
211+
}

faiss_vector_cache.go

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/RoaringBitmap/roaring/v2"
2727
faiss "github.com/blevesearch/go-faiss"
28+
segment "github.com/blevesearch/scorch_segment_api/v2"
2829
)
2930

3031
func newVectorIndexCache() *vectorIndexCache {
@@ -56,17 +57,17 @@ func (vc *vectorIndexCache) Clear() {
5657
// present. It also returns the batch executor for the field if it's present in the
5758
// cache.
5859
func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte,
59-
loadDocVecIDMap bool, except *roaring.Bitmap) (
60+
loadDocVecIDMap bool, except *roaring.Bitmap, options segment.InterpretVectorIndexOptions) (
6061
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64,
61-
vecIDsToExclude []int64, err error) {
62+
vecIDsToExclude []int64, batchExec *batchExecutor, err error) {
6263
vc.m.RLock()
6364
entry, ok := vc.cache[fieldID]
6465
if ok {
65-
index, vecDocIDMap, docVecIDMap = entry.load()
66+
index, vecDocIDMap, docVecIDMap, batchExec = entry.load()
6667
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
6768
if !loadDocVecIDMap || len(entry.docVecIDMap) > 0 {
6869
vc.m.RUnlock()
69-
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
70+
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
7071
}
7172

7273
vc.m.RUnlock()
@@ -76,14 +77,14 @@ func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte,
7677
// typically seen for the first filtered query.
7778
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
7879
vc.m.Unlock()
79-
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
80+
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
8081
}
8182

8283
vc.m.RUnlock()
8384
// acquiring a lock since this is modifying the cache.
8485
vc.m.Lock()
8586
defer vc.m.Unlock()
86-
return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except)
87+
return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except, options)
8788
}
8889

8990
func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint32][]int64 {
@@ -104,21 +105,22 @@ func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint
104105

105106
// Rebuilding the cache on a miss.
106107
func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,
107-
loadDocVecIDMap bool, except *roaring.Bitmap) (
108+
loadDocVecIDMap bool, except *roaring.Bitmap, options segment.InterpretVectorIndexOptions) (
108109
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
109-
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) {
110+
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64,
111+
batchExec *batchExecutor, err error) {
110112

111113
// Handle concurrent accesses (to avoid unnecessary work) by adding a
112114
// check within the write lock here.
113115
entry := vc.cache[fieldID]
114116
if entry != nil {
115-
index, vecDocIDMap, docVecIDMap = entry.load()
117+
index, vecDocIDMap, docVecIDMap, batchExec = entry.load()
116118
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
117119
if !loadDocVecIDMap || len(entry.docVecIDMap) > 0 {
118-
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
120+
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
119121
}
120122
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
121-
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
123+
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
122124
}
123125

124126
// if the cache doesn't have the entry, construct the vector to doc id map and
@@ -154,16 +156,17 @@ func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,
154156

155157
index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags)
156158
if err != nil {
157-
return nil, nil, nil, nil, err
159+
return nil, nil, nil, nil, nil, err
158160
}
159161

160-
vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap)
161-
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
162+
batchExec = newBatchExecutor(options)
163+
vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap, batchExec)
164+
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
162165
}
163166

164167
func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
165168
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, loadDocVecIDMap bool,
166-
docVecIDMap map[uint32][]int64) {
169+
docVecIDMap map[uint32][]int64, batchExec *batchExecutor) {
167170
// the first time we've hit the cache, try to spawn a monitoring routine
168171
// which will reconcile the moving averages for all the fields being hit
169172
if len(vc.cache) == 0 {
@@ -178,7 +181,7 @@ func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
178181
// longer time and thereby the index to be resident in the cache
179182
// for longer time.
180183
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap,
181-
loadDocVecIDMap, docVecIDMap, 0.4)
184+
loadDocVecIDMap, docVecIDMap, 0.4, batchExec)
182185
}
183186
}
184187

@@ -272,15 +275,17 @@ func (e *ewma) add(val uint64) {
272275
// -----------------------------------------------------------------------------
273276

274277
func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
275-
loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64) *cacheEntry {
278+
loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64,
279+
batchExec *batchExecutor) *cacheEntry {
276280
ce := &cacheEntry{
277281
index: index,
278282
vecDocIDMap: vecDocIDMap,
279283
tracker: &ewma{
280284
alpha: alpha,
281285
sample: 1,
282286
},
283-
refs: 1,
287+
refs: 1,
288+
batchExec: batchExec,
284289
}
285290
if loadDocVecIDMap {
286291
ce.docVecIDMap = docVecIDMap
@@ -299,6 +304,8 @@ type cacheEntry struct {
299304
index *faiss.IndexImpl
300305
vecDocIDMap map[int64]uint32
301306
docVecIDMap map[uint32][]int64
307+
308+
batchExec *batchExecutor
302309
}
303310

304311
func (ce *cacheEntry) incHit() {
@@ -313,10 +320,14 @@ func (ce *cacheEntry) decRef() {
313320
atomic.AddInt64(&ce.refs, -1)
314321
}
315322

316-
func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32][]int64) {
323+
func (ce *cacheEntry) load() (
324+
*faiss.IndexImpl,
325+
map[int64]uint32,
326+
map[uint32][]int64,
327+
*batchExecutor) {
317328
ce.incHit()
318329
ce.addRef()
319-
return ce.index, ce.vecDocIDMap, ce.docVecIDMap
330+
return ce.index, ce.vecDocIDMap, ce.docVecIDMap, ce.batchExec
320331
}
321332

322333
func (ce *cacheEntry) close() {
@@ -325,6 +336,7 @@ func (ce *cacheEntry) close() {
325336
ce.index = nil
326337
ce.vecDocIDMap = nil
327338
ce.docVecIDMap = nil
339+
ce.batchExec.close()
328340
}()
329341
}
330342

0 commit comments

Comments
 (0)