diff --git a/db/sql/sql.go b/db/sql/sql.go index 9f153ac13..29cdf66e8 100644 --- a/db/sql/sql.go +++ b/db/sql/sql.go @@ -367,7 +367,7 @@ func (db *DB) BeginTx(options ...transaction.Option) (tx transaction.Transaction return nil, err } - rawTx, err := db.DB.BeginTxx(params.Context, sqlOptions) + rawTx, err := db.DB.BeginTxx(safeMysqlContext(params.Context), sqlOptions) if err != nil { db.updateCounter(1, "begin.failed") return nil, err @@ -583,7 +583,7 @@ func (tx *Transaction) Exec(ctx context.Context, sql string, args ...interface{} func (tx *Transaction) exec(ctx context.Context, sql string, args ...interface{}) error { tx.logQuery(sql, args...) - _, err := tx.transaction.ExecContext(ctx, sql, args...) + _, err := tx.transaction.ExecContext(safeMysqlContext(ctx), sql, args...) return err } @@ -856,7 +856,7 @@ func buildSelect(sc *selectContext) (string, []interface{}, error) { func (tx *Transaction) executeSelect(ctx context.Context, sc *selectContext, sql string, args []interface{}) (list []*schema.Resource, total uint64, err error) { tx.logQuery(sql, args...) - rows, err := tx.transaction.QueryxContext(ctx, sql, args...) + rows, err := tx.transaction.QueryxContext(safeMysqlContext(ctx), sql, args...) if err != nil { return } @@ -955,7 +955,7 @@ func (tx *Transaction) Query(ctx context.Context, s *schema.Schema, query string defer tx.measureTime(time.Now(), s.ID, "query") tx.logQuery(query, arguments...) - rows, err := tx.transaction.QueryxContext(ctx, query, arguments...) + rows, err := tx.transaction.QueryxContext(safeMysqlContext(ctx), query, arguments...) if err != nil { return nil, fmt.Errorf("Failed to run query: %s", query) } @@ -1022,7 +1022,7 @@ func (tx *Transaction) Count(ctx context.Context, s *schema.Schema, filter trans return } result := map[string]interface{}{} - err = tx.transaction.QueryRowxContext(ctx, sql, args...).MapScan(result) + err = tx.transaction.QueryRowxContext(safeMysqlContext(ctx), sql, args...).MapScan(result) if err != nil { return } @@ -1089,7 +1089,7 @@ func (tx *Transaction) StateFetch(ctx context.Context, s *schema.Schema, filter return } tx.logQuery(sql, args...) - rows, err := tx.transaction.QueryxContext(ctx, sql, args...) + rows, err := tx.transaction.QueryxContext(safeMysqlContext(ctx), sql, args...) if err != nil { return } @@ -1265,3 +1265,10 @@ func (db *DB) SetMaxOpenConns(maxIdleConns int) { // db.DB.SetMaxOpenConns(maxIdleConns) // db.DB.SetMaxIdleConns(maxIdleConns) } + +// Mysql driver does not support graceful cancellation via context.Cancel() +// This function (and its usages) should be removed when (if) this defect got fixed: +// https://github.com/go-sql-driver/mysql/issues/731 +func safeMysqlContext(_ context.Context) context.Context { + return context.Background() +}