Skip to content

Commit 55464fb

Browse files
committed
Fix: SQLite cosine function + tests
1 parent 231e5f4 commit 55464fb

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

.vscode/settings.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@
104104
"ranges": "cpp",
105105
"span": "cpp",
106106
"format": "cpp",
107-
"charconv": "cpp"
107+
"charconv": "cpp",
108+
"strstream": "cpp"
108109
},
109110
"cSpell.words": [
110111
"allclose",
@@ -153,7 +154,9 @@
153154
"SLOC",
154155
"Sonatype",
155156
"sorensen",
157+
"sqeuclidean",
156158
"Struct",
159+
"swar",
157160
"tanimoto",
158161
"tqdm",
159162
"uninitialize",

python/lib_sqlite.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ static void sqlite_dense(sqlite3_context* context, int argc, sqlite3_value** arg
7979

8080
// Parse the floating-point numbers
8181
std::from_chars_result result1 = std::from_chars(vec1, vec1 + bytes1, parsed1[i]);
82-
std::from_chars_result result2 = std::from_chars(vec2, vec2 + bytes2, parsed1[i]);
82+
std::from_chars_result result2 = std::from_chars(vec2, vec2 + bytes2, parsed2[i]);
8383
if (result1.ec != std::errc() || result2.ec != std::errc()) {
8484
sqlite3_result_error(context, "Number can't be parsed", -1);
8585
return;

python/scripts/test_sqlite.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import sqlite3
2+
import json
3+
import math
4+
import pytest
5+
import numpy as np
6+
7+
import usearch
8+
from usearch.io import load_matrix, save_matrix
9+
from usearch.index import search
10+
from usearch.eval import random_vectors
11+
12+
from usearch.index import Match, Matches, BatchMatches, Index, Indexes
13+
14+
15+
dimensions = [3, 97, 256]
16+
batch_sizes = [1, 77, 100]
17+
18+
19+
def test_sqlite_distances():
20+
conn = sqlite3.connect(":memory:")
21+
conn.enable_load_extension(True)
22+
conn.load_extension(usearch.sqlite)
23+
24+
cursor = conn.cursor()
25+
26+
# Create a table with additional columns for f32 and f16 BLOBs
27+
cursor.execute(
28+
"""
29+
CREATE TABLE IF NOT EXISTS vector_table (
30+
id INTEGER PRIMARY KEY,
31+
vector_json JSON,
32+
vector_f32 BLOB,
33+
vector_f16 BLOB
34+
)
35+
"""
36+
)
37+
38+
# Generate and insert random vectors
39+
num_vectors = 3 # Number of vectors to generate
40+
dim = 4 # Dimension of each vector
41+
vectors = []
42+
43+
for i in range(num_vectors):
44+
# Generate a random 256-dimensional vector
45+
vector = np.random.rand(dim)
46+
vectors.append(vector)
47+
48+
# Convert the vector to f32 and f16
49+
vector_f32 = np.float32(vector)
50+
vector_f16 = np.float16(vector)
51+
52+
# Insert the vector into the database as JSON and as BLOBs
53+
cursor.execute(
54+
"""
55+
INSERT INTO vector_table (vector_json, vector_f32, vector_f16) VALUES (?, ?, ?)
56+
""",
57+
(json.dumps(vector.tolist()), vector_f32.tobytes(), vector_f16.tobytes()),
58+
)
59+
60+
# Commit changes
61+
conn.commit()
62+
63+
similarities = """
64+
SELECT
65+
a.id AS id1,
66+
b.id AS id2,
67+
distance_cosine_f32(a.vector_json, b.vector_json) AS cosine_similarity_json,
68+
distance_cosine_f32(a.vector_f32, b.vector_f32) AS cosine_similarity_f32,
69+
distance_cosine_f16(a.vector_f16, b.vector_f16) AS cosine_similarity_f16
70+
FROM
71+
vector_table AS a,
72+
vector_table AS b
73+
WHERE
74+
a.id < b.id;
75+
"""
76+
cursor.execute(similarities)
77+
78+
for a, b, similarity_json, similarity_f32, similarity_f16 in cursor.fetchall():
79+
assert math.isclose(similarity_json, similarity_f32, abs_tol=0.1)
80+
assert math.isclose(similarity_json, similarity_f16, abs_tol=0.1)

python/usearch/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import importlib
2+
13
from usearch.compiled import (
24
VERSION_MAJOR,
35
VERSION_MINOR,
46
VERSION_PATCH,
57
)
68

79
__version__ = f"{VERSION_MAJOR}.{VERSION_MINOR}.{VERSION_PATCH}"
10+
11+
# The same binary file (.so, .dll, or .dylib) that contains the pre-compiled
12+
# USearch code also contains the SQLite3 binding
13+
sqlite = importlib.util.find_spec("usearch.compiled").origin

0 commit comments

Comments
 (0)