Skip to content

Commit 26c79eb

Browse files
committed
Handle writes that could deadlock with reads from the server
This commit adds a background reader that can optionally buffer reads. It is used whenever a potentially blocking write is made to the server. The background reader is started on a slight delay so there should be no meaningful performance impact as it doesn't run for quick queries and its overhead is minimal relative to slower queries.
1 parent 85136a8 commit 26c79eb

File tree

5 files changed

+316
-15
lines changed

5 files changed

+316
-15
lines changed

pgconn/auth_scram.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
4242
Data: sc.clientFirstMessage(),
4343
}
4444
c.frontend.Send(saslInitialResponse)
45-
err = c.frontend.Flush()
45+
err = c.flushWithPotentialWriteReadDeadlock()
4646
if err != nil {
4747
return err
4848
}
@@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
6262
Data: []byte(sc.clientFinalMessage()),
6363
}
6464
c.frontend.Send(saslResponse)
65-
err = c.frontend.Flush()
65+
err = c.flushWithPotentialWriteReadDeadlock()
6666
if err != nil {
6767
return err
6868
}

pgconn/internal/bgreader/bgreader.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
2+
package bgreader
3+
4+
import (
5+
"io"
6+
"sync"
7+
8+
"github.com/jackc/pgx/v5/internal/iobufpool"
9+
)
10+
11+
const (
12+
bgReaderStatusStopped = iota
13+
bgReaderStatusRunning
14+
bgReaderStatusStopping
15+
)
16+
17+
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
18+
type BGReader struct {
19+
r io.Reader
20+
21+
cond *sync.Cond
22+
bgReaderStatus int32
23+
readResults []readResult
24+
}
25+
26+
type readResult struct {
27+
buf *[]byte
28+
err error
29+
}
30+
31+
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
32+
// reader will stop automatically when the underlying reader returns an error.
33+
func (r *BGReader) Start() {
34+
r.cond.L.Lock()
35+
defer r.cond.L.Unlock()
36+
37+
switch r.bgReaderStatus {
38+
case bgReaderStatusStopped:
39+
r.bgReaderStatus = bgReaderStatusRunning
40+
go r.bgRead()
41+
case bgReaderStatusRunning:
42+
// no-op
43+
case bgReaderStatusStopping:
44+
r.bgReaderStatus = bgReaderStatusRunning
45+
}
46+
}
47+
48+
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
49+
// background reader is not running.
50+
func (r *BGReader) Stop() {
51+
r.cond.L.Lock()
52+
defer r.cond.L.Unlock()
53+
54+
switch r.bgReaderStatus {
55+
case bgReaderStatusStopped:
56+
// no-op
57+
case bgReaderStatusRunning:
58+
r.bgReaderStatus = bgReaderStatusStopping
59+
case bgReaderStatusStopping:
60+
// no-op
61+
}
62+
}
63+
64+
func (r *BGReader) bgRead() {
65+
keepReading := true
66+
for keepReading {
67+
buf := iobufpool.Get(8192)
68+
n, err := r.r.Read(*buf)
69+
*buf = (*buf)[:n]
70+
71+
r.cond.L.Lock()
72+
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
73+
if r.bgReaderStatus == bgReaderStatusStopping || err != nil {
74+
r.bgReaderStatus = bgReaderStatusStopped
75+
keepReading = false
76+
}
77+
r.cond.L.Unlock()
78+
r.cond.Broadcast()
79+
}
80+
}
81+
82+
// Read implements the io.Reader interface.
83+
func (r *BGReader) Read(p []byte) (int, error) {
84+
r.cond.L.Lock()
85+
defer r.cond.L.Unlock()
86+
87+
if len(r.readResults) > 0 {
88+
return r.readFromReadResults(p)
89+
}
90+
91+
// There are no unread background read results and the background reader is stopped.
92+
if r.bgReaderStatus == bgReaderStatusStopped {
93+
return r.r.Read(p)
94+
}
95+
96+
// Wait for results from the background reader
97+
for len(r.readResults) == 0 {
98+
r.cond.Wait()
99+
}
100+
return r.readFromReadResults(p)
101+
}
102+
103+
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
104+
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
105+
buf := r.readResults[0].buf
106+
var err error
107+
108+
n := copy(p, *buf)
109+
if n == len(*buf) {
110+
err = r.readResults[0].err
111+
iobufpool.Put(buf)
112+
if len(r.readResults) == 1 {
113+
r.readResults = nil
114+
} else {
115+
r.readResults = r.readResults[1:]
116+
}
117+
} else {
118+
*buf = (*buf)[n:]
119+
r.readResults[0].buf = buf
120+
}
121+
122+
return n, err
123+
}
124+
125+
func New(r io.Reader) *BGReader {
126+
return &BGReader{
127+
r: r,
128+
cond: &sync.Cond{
129+
L: &sync.Mutex{},
130+
},
131+
}
132+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package bgreader_test
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"io"
7+
"math/rand"
8+
"testing"
9+
"time"
10+
11+
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestBGReaderReadWhenStopped(t *testing.T) {
16+
r := bytes.NewReader([]byte("foo bar baz"))
17+
bgr := bgreader.New(r)
18+
buf, err := io.ReadAll(bgr)
19+
require.NoError(t, err)
20+
require.Equal(t, []byte("foo bar baz"), buf)
21+
}
22+
23+
func TestBGReaderReadWhenStarted(t *testing.T) {
24+
r := bytes.NewReader([]byte("foo bar baz"))
25+
bgr := bgreader.New(r)
26+
bgr.Start()
27+
buf, err := io.ReadAll(bgr)
28+
require.NoError(t, err)
29+
require.Equal(t, []byte("foo bar baz"), buf)
30+
}
31+
32+
type mockReadFunc func(p []byte) (int, error)
33+
34+
type mockReader struct {
35+
readFuncs []mockReadFunc
36+
}
37+
38+
func (r *mockReader) Read(p []byte) (int, error) {
39+
if len(r.readFuncs) == 0 {
40+
return 0, io.EOF
41+
}
42+
43+
fn := r.readFuncs[0]
44+
r.readFuncs = r.readFuncs[1:]
45+
46+
return fn(p)
47+
}
48+
49+
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
50+
rr := &mockReader{
51+
readFuncs: []mockReadFunc{
52+
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
53+
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
54+
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
55+
},
56+
}
57+
bgr := bgreader.New(rr)
58+
bgr.Start()
59+
buf := make([]byte, 3)
60+
n, err := bgr.Read(buf)
61+
require.NoError(t, err)
62+
require.EqualValues(t, 3, n)
63+
require.Equal(t, []byte("foo"), buf)
64+
}
65+
66+
func TestBGReaderErrorWhenStarted(t *testing.T) {
67+
rr := &mockReader{
68+
readFuncs: []mockReadFunc{
69+
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
70+
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
71+
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
72+
},
73+
}
74+
75+
bgr := bgreader.New(rr)
76+
bgr.Start()
77+
buf, err := io.ReadAll(bgr)
78+
require.Equal(t, []byte("foobarbaz"), buf)
79+
require.EqualError(t, err, "oops")
80+
}
81+
82+
func TestBGReaderErrorWhenStopped(t *testing.T) {
83+
rr := &mockReader{
84+
readFuncs: []mockReadFunc{
85+
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
86+
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
87+
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
88+
},
89+
}
90+
91+
bgr := bgreader.New(rr)
92+
buf, err := io.ReadAll(bgr)
93+
require.Equal(t, []byte("foobarbaz"), buf)
94+
require.EqualError(t, err, "oops")
95+
}
96+
97+
type numberReader struct {
98+
v uint8
99+
rng *rand.Rand
100+
}
101+
102+
func (nr *numberReader) Read(p []byte) (int, error) {
103+
n := nr.rng.Intn(len(p))
104+
for i := 0; i < n; i++ {
105+
p[i] = nr.v
106+
nr.v++
107+
}
108+
109+
return n, nil
110+
}
111+
112+
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
113+
// stopping the background worker from other goroutines.
114+
func TestBGReaderStress(t *testing.T) {
115+
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
116+
bgr := bgreader.New(nr)
117+
118+
bytesRead := 0
119+
var expected uint8
120+
buf := make([]byte, 10_000)
121+
rng := rand.New(rand.NewSource(0))
122+
123+
for bytesRead < 1_000_000 {
124+
randomNumber := rng.Intn(100)
125+
switch {
126+
case randomNumber < 10:
127+
go bgr.Start()
128+
case randomNumber < 20:
129+
go bgr.Stop()
130+
default:
131+
n, err := bgr.Read(buf)
132+
require.NoError(t, err)
133+
for i := 0; i < n; i++ {
134+
require.Equal(t, expected, buf[i])
135+
expected++
136+
}
137+
bytesRead += n
138+
}
139+
}
140+
}

pgconn/krb5.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
6363
Data: nextData,
6464
}
6565
c.frontend.Send(gssResponse)
66-
err = c.frontend.Flush()
66+
err = c.flushWithPotentialWriteReadDeadlock()
6767
if err != nil {
6868
return err
6969
}

0 commit comments

Comments
 (0)