Skip to content

Commit 469d297

Browse files
authored
Merge pull request #125 from tursodatabase/with_schemadb
Make it possible to signal to the driver that we're connecting to schema db and use nontransactional batches in such case
2 parents 518abad + f20faf4 commit 469d297

File tree

4 files changed

+49
-21
lines changed

4 files changed

+49
-21
lines changed

libsql/internal/hrana/stream_request.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func ExecuteStoredStream(sqlId int32, params shared.Params, wantRows bool) (*Str
3838
return &StreamRequest{Type: "execute", Stmt: stmt}, nil
3939
}
4040

41-
func BatchStream(sqls []string, params []shared.Params, wantRows bool) (*StreamRequest, error) {
41+
func BatchStream(sqls []string, params []shared.Params, wantRows bool, transactional bool) (*StreamRequest, error) {
4242
batch := &Batch{}
4343
for idx, sql := range sqls {
4444
s := sql
@@ -50,19 +50,23 @@ func BatchStream(sqls []string, params []shared.Params, wantRows bool) (*StreamR
5050
return nil, err
5151
}
5252
var condition *BatchCondition
53-
if idx > 0 {
54-
prev_idx := int32(idx - 1)
55-
condition = &BatchCondition{
56-
Type: "ok",
57-
Step: &prev_idx,
53+
if transactional {
54+
if idx > 0 {
55+
prev_idx := int32(idx - 1)
56+
condition = &BatchCondition{
57+
Type: "ok",
58+
Step: &prev_idx,
59+
}
5860
}
5961
}
6062
batch.Add(*stmt, condition)
6163
}
62-
rollback := "ROLLBACK"
63-
last_idx := int32(len(sqls) - 1)
64-
batch.Add(Stmt{Sql: &rollback, WantRows: false},
65-
&BatchCondition{Type: "not", Cond: &BatchCondition{Type: "ok", Step: &last_idx}})
64+
if transactional {
65+
rollback := "ROLLBACK"
66+
last_idx := int32(len(sqls) - 1)
67+
batch.Add(Stmt{Sql: &rollback, WantRows: false},
68+
&BatchCondition{Type: "not", Cond: &BatchCondition{Type: "ok", Step: &last_idx}})
69+
}
6670
return &StreamRequest{Type: "batch", Batch: batch}, nil
6771
}
6872

libsql/internal/http/driver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ import (
66
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/hranaV2"
77
)
88

9-
func Connect(url, jwt, host string) driver.Conn {
10-
return hranaV2.Connect(url, jwt, host)
9+
func Connect(url, jwt, host string, schemaDb bool) driver.Conn {
10+
return hranaV2.Connect(url, jwt, host, schemaDb)
1111
}

libsql/internal/http/hranaV2/hranaV2.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ func init() {
3636
commitHash = "unknown"
3737
}
3838

39-
func Connect(url, jwt, host string) driver.Conn {
40-
return &hranaV2Conn{url, jwt, host, "", false, 0}
39+
func Connect(url, jwt, host string, schemaDb bool) driver.Conn {
40+
return &hranaV2Conn{url, jwt, host, schemaDb, "", false, 0}
4141
}
4242

4343
type hranaV2Stmt struct {
@@ -85,6 +85,7 @@ type hranaV2Conn struct {
8585
url string
8686
jwt string
8787
host string
88+
schemaDb bool
8889
baton string
8990
streamClosed bool
9091
replicationIndex uint64
@@ -298,7 +299,7 @@ func (h *hranaV2Conn) executeStmt(ctx context.Context, query string, args []driv
298299
}
299300
msg.Add(*executeStream)
300301
} else {
301-
batchStream, err := hrana.BatchStream(stmts, params, wantRows)
302+
batchStream, err := hrana.BatchStream(stmts, params, wantRows, !h.schemaDb)
302303
if err != nil {
303304
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", query, err)
304305
}
@@ -338,7 +339,11 @@ func (h *hranaV2Conn) ExecContext(ctx context.Context, query string, args []driv
338339
}
339340
lastInsertRowId := int64(0)
340341
affectedRowCount := int64(0)
341-
for idx := 0; idx < len(res.StepResults)-1; idx++ {
342+
upperBound := len(res.StepResults)
343+
if !h.schemaDb {
344+
upperBound -= 1
345+
}
346+
for idx := 0; idx < upperBound; idx++ {
342347
r := res.StepResults[idx]
343348
rowId := r.GetLastInsertRowId()
344349
if rowId > 0 {
@@ -458,8 +463,10 @@ func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []dri
458463
if err != nil {
459464
return nil, err
460465
}
461-
res.StepResults = res.StepResults[:len(res.StepResults)-1]
462-
res.StepErrors = res.StepErrors[:len(res.StepErrors)-1]
466+
if !h.schemaDb {
467+
res.StepResults = res.StepResults[:len(res.StepResults)-1]
468+
res.StepErrors = res.StepErrors[:len(res.StepErrors)-1]
469+
}
463470
return shared.NewRows(&BatchResultRowsProvider{res}), nil
464471
default:
465472
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, "unknown response type")

libsql/sql.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type config struct {
1717
authToken *string
1818
tls *bool
1919
proxy *string
20+
schemaDb *bool
2021
}
2122

2223
type Option interface {
@@ -65,6 +66,16 @@ func WithProxy(proxy string) Option {
6566
})
6667
}
6768

69+
func WithSchemaDb(schemaDb bool) Option {
70+
return option(func(o *config) error {
71+
if o.tls != nil {
72+
return fmt.Errorf("schemaDb already set")
73+
}
74+
o.schemaDb = &schemaDb
75+
return nil
76+
})
77+
}
78+
6879
func (c config) connector(dbPath string) (driver.Connector, error) {
6980
u, err := url.Parse(dbPath)
7081
if err != nil {
@@ -144,11 +155,16 @@ func (c config) connector(dbPath string) (driver.Connector, error) {
144155
}
145156
}
146157

158+
schemaDb := false
159+
if c.schemaDb != nil {
160+
schemaDb = *c.schemaDb
161+
}
162+
147163
if u.Scheme == "wss" || u.Scheme == "ws" {
148164
return wsConnector{url: u.String(), authToken: authToken}, nil
149165
}
150166
if u.Scheme == "https" || u.Scheme == "http" {
151-
return httpConnector{url: u.String(), authToken: authToken, host: host}, nil
167+
return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb}, nil
152168
}
153169

154170
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)
@@ -172,10 +188,11 @@ type httpConnector struct {
172188
url string
173189
authToken string
174190
host string
191+
schemaDb bool
175192
}
176193

177194
func (c httpConnector) Connect(_ctx context.Context) (driver.Conn, error) {
178-
return http.Connect(c.url, c.authToken, c.host), nil
195+
return http.Connect(c.url, c.authToken, c.host, c.schemaDb), nil
179196
}
180197

181198
func (c httpConnector) Driver() driver.Driver {
@@ -324,7 +341,7 @@ func (d Driver) Open(dbUrl string) (driver.Conn, error) {
324341
return ws.Connect(u.String(), jwt)
325342
}
326343
if u.Scheme == "https" || u.Scheme == "http" {
327-
return http.Connect(u.String(), jwt, u.Host), nil
344+
return http.Connect(u.String(), jwt, u.Host, false), nil
328345
}
329346

330347
return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)

0 commit comments

Comments
 (0)