Skip to content

Commit 4015a0c

Browse files
authored
Merge pull request #2329 from flimzy/prepareConn
Add PrepareConn hook, which extends BeforeAcquire's behavior to allow…
2 parents fc334e4 + 48d27a9 commit 4015a0c

File tree

3 files changed

+106
-11
lines changed

3 files changed

+106
-11
lines changed

pgxpool/common_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName
141141
// Can't test function equality, so just test that they are set or not.
142142
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
143143
assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
144+
assert.Equalf(t, expected.PrepareConn == nil, actual.PrepareConn == nil, "%s - PrepareConn", testName)
144145
assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
145146

146147
assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)

pgxpool/pool.go

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ type Pool struct {
8383
config *Config
8484
beforeConnect func(context.Context, *pgx.ConnConfig) error
8585
afterConnect func(context.Context, *pgx.Conn) error
86-
beforeAcquire func(context.Context, *pgx.Conn) bool
86+
prepareConn func(context.Context, *pgx.Conn) (bool, error)
8787
afterRelease func(*pgx.Conn) bool
8888
beforeClose func(*pgx.Conn)
8989
minConns int32
@@ -118,8 +118,22 @@ type Config struct {
118118
// BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the
119119
// acquisition or false to indicate that the connection should be destroyed and a different connection should be
120120
// acquired.
121+
//
122+
// Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take
123+
// precedence, ignoring BeforeAcquire.
121124
BeforeAcquire func(context.Context, *pgx.Conn) bool
122125

126+
// PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection
127+
// is considered valid. If the function returns a non-nil error, the instigating query will fail with the returned error.
128+
//
129+
// Specifically, this means that:
130+
//
131+
// - It must return true and a nil error to allow acquisition and the query to proceed.
132+
// - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error.
133+
// - If it returns false, and an error, the query will fail with the returned error, and the connection will be destroyed.
134+
// - If it returns false and a nil error, the connection will be returned to the pool, and the instigating query will be retried on a new connection.
135+
PrepareConn func(context.Context, *pgx.Conn) (bool, error)
136+
123137
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
124138
// return the connection to the pool or false to destroy the connection.
125139
AfterRelease func(*pgx.Conn) bool
@@ -189,11 +203,18 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
189203
panic("config must be created by ParseConfig")
190204
}
191205

206+
prepareConn := config.PrepareConn
207+
if prepareConn == nil && config.BeforeAcquire != nil {
208+
prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) {
209+
return config.BeforeAcquire(ctx, conn), nil
210+
}
211+
}
212+
192213
p := &Pool{
193214
config: config,
194215
beforeConnect: config.BeforeConnect,
195216
afterConnect: config.AfterConnect,
196-
beforeAcquire: config.BeforeAcquire,
217+
prepareConn: prepareConn,
197218
afterRelease: config.AfterRelease,
198219
beforeClose: config.BeforeClose,
199220
minConns: config.MinConns,
@@ -560,11 +581,23 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
560581
}
561582
}
562583

563-
if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
564-
return cr.getConn(p, res), nil
584+
if p.prepareConn != nil {
585+
ok, err := p.prepareConn(ctx, cr.conn)
586+
if !ok {
587+
res.Destroy()
588+
}
589+
if err != nil {
590+
if ok {
591+
res.Release()
592+
}
593+
return nil, err
594+
}
595+
if !ok {
596+
continue
597+
}
565598
}
566599

567-
res.Destroy()
600+
return cr.getConn(p, res), nil
568601
}
569602
}
570603

@@ -588,11 +621,14 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
588621
conns := make([]*Conn, 0, len(resources))
589622
for _, res := range resources {
590623
cr := res.Value()
591-
if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
592-
conns = append(conns, cr.getConn(p, res))
593-
} else {
594-
res.Destroy()
624+
if p.prepareConn != nil {
625+
ok, err := p.prepareConn(ctx, cr.conn)
626+
if !ok || err != nil {
627+
res.Destroy()
628+
continue
629+
}
595630
}
631+
conns = append(conns, cr.getConn(p, res))
596632
}
597633

598634
return conns

pgxpool/pool_test.go

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,64 @@ func TestPoolBeforeAcquire(t *testing.T) {
330330
assert.EqualValues(t, 12, acquireAttempts)
331331
}
332332

333+
func TestPoolPrepareConn(t *testing.T) {
334+
t.Parallel()
335+
336+
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
337+
defer cancel()
338+
339+
config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
340+
require.NoError(t, err)
341+
342+
acquireAttempts := 0
343+
344+
config.PrepareConn = func(context.Context, *pgx.Conn) (bool, error) {
345+
acquireAttempts++
346+
var err error
347+
if acquireAttempts%3 == 0 {
348+
err = errors.New("PrepareConn error")
349+
}
350+
return acquireAttempts%2 == 0, err
351+
}
352+
353+
db, err := pgxpool.NewWithConfig(ctx, config)
354+
require.NoError(t, err)
355+
t.Cleanup(db.Close)
356+
357+
var errorCount int
358+
conns := make([]*pgxpool.Conn, 0, 4)
359+
for {
360+
conn, err := db.Acquire(ctx)
361+
if err != nil {
362+
errorCount++
363+
continue
364+
}
365+
conns = append(conns, conn)
366+
if len(conns) == 4 {
367+
break
368+
}
369+
}
370+
const wantErrorCount = 3
371+
assert.Equal(t, wantErrorCount, errorCount, "Acquire() should have failed %d times", wantErrorCount)
372+
373+
for _, c := range conns {
374+
c.Release()
375+
}
376+
waitForReleaseToComplete()
377+
378+
assert.EqualValues(t, len(conns)*2+wantErrorCount-1, acquireAttempts)
379+
380+
conns = db.AcquireAllIdle(ctx)
381+
assert.Len(t, conns, 1)
382+
383+
for _, c := range conns {
384+
c.Release()
385+
}
386+
waitForReleaseToComplete()
387+
388+
assert.EqualValues(t, 14, acquireAttempts)
389+
}
390+
333391
func TestPoolAfterRelease(t *testing.T) {
334392
t.Parallel()
335393

@@ -1082,9 +1140,9 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) {
10821140
acquireAttempts := int64(0)
10831141
connectAttempts := int64(0)
10841142

1085-
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
1143+
config.PrepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) {
10861144
atomic.AddInt64(&acquireAttempts, 1)
1087-
return true
1145+
return true, nil
10881146
}
10891147
config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error {
10901148
atomic.AddInt64(&connectAttempts, 1)

0 commit comments

Comments
 (0)