Skip to content

Commit 0f77a2d

Browse files
authored
Merge pull request #2293 from divyam234/master
feat: add support for direct sslnegotiation
2 parents 04bcc02 + ddd966f commit 0f77a2d

File tree

3 files changed

+190
-1
lines changed

3 files changed

+190
-1
lines changed

pgconn/config.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ type Config struct {
5151
KerberosSpn string
5252
Fallbacks []*FallbackConfig
5353

54+
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
55+
5456
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
5557
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
5658
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
@@ -318,6 +320,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
318320
"sslkey": {},
319321
"sslcert": {},
320322
"sslrootcert": {},
323+
"sslnegotiation": {},
321324
"sslpassword": {},
322325
"sslsni": {},
323326
"krbspn": {},
@@ -386,6 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
386389
config.Port = fallbacks[0].Port
387390
config.TLSConfig = fallbacks[0].TLSConfig
388391
config.Fallbacks = fallbacks[1:]
392+
config.SSLNegotiation = settings["sslnegotiation"]
389393

390394
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
391395
if err == nil {
@@ -449,6 +453,7 @@ func parseEnvSettings() map[string]string {
449453
"PGSSLSNI": "sslsni",
450454
"PGSSLROOTCERT": "sslrootcert",
451455
"PGSSLPASSWORD": "sslpassword",
456+
"PGSSLNEGOTIATION": "sslnegotiation",
452457
"PGTARGETSESSIONATTRS": "target_session_attrs",
453458
"PGSERVICE": "service",
454459
"PGSERVICEFILE": "servicefile",
@@ -646,6 +651,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
646651
sslkey := settings["sslkey"]
647652
sslpassword := settings["sslpassword"]
648653
sslsni := settings["sslsni"]
654+
sslnegotiation := settings["sslnegotiation"]
649655

650656
// Match libpq default behavior
651657
if sslmode == "" {
@@ -657,6 +663,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
657663

658664
tlsConfig := &tls.Config{}
659665

666+
if sslnegotiation == "direct" {
667+
tlsConfig.NextProtos = []string{"postgresql"}
668+
if sslmode == "prefer" {
669+
sslmode = "require"
670+
}
671+
}
672+
660673
if sslrootcert != "" {
661674
var caCertPool *x509.CertPool
662675

pgconn/pgconn.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
325325
if connectConfig.tlsConfig != nil {
326326
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
327327
pgConn.contextWatcher.Watch(ctx)
328-
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
328+
var (
329+
tlsConn net.Conn
330+
err error
331+
)
332+
if config.SSLNegotiation == "direct" {
333+
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
334+
} else {
335+
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
336+
}
329337
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
330338
if err != nil {
331339
pgConn.conn.Close()

pgconn/pgconn_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"os"
1515
"strconv"
1616
"strings"
17+
"sync/atomic"
1718
"testing"
1819
"time"
1920

@@ -3819,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
38193820
}
38203821
}
38213822

3823+
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
3824+
t.Parallel()
3825+
3826+
tests := []struct {
3827+
name string
3828+
connString string
3829+
expectDirectNego bool
3830+
}{
3831+
{
3832+
name: "Default negotiation (postgres)",
3833+
connString: "sslmode=require",
3834+
expectDirectNego: false,
3835+
},
3836+
{
3837+
name: "Direct negotiation",
3838+
connString: "sslmode=require sslnegotiation=direct",
3839+
expectDirectNego: true,
3840+
},
3841+
{
3842+
name: "Explicit postgres negotiation",
3843+
connString: "sslmode=require sslnegotiation=postgres",
3844+
expectDirectNego: false,
3845+
},
3846+
}
3847+
3848+
for _, tt := range tests {
3849+
tt := tt
3850+
t.Run(tt.name, func(t *testing.T) {
3851+
t.Parallel()
3852+
3853+
script := &pgmock.Script{
3854+
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
3855+
}
3856+
3857+
ln, err := net.Listen("tcp", "127.0.0.1:")
3858+
require.NoError(t, err)
3859+
defer ln.Close()
3860+
3861+
_, port, err := net.SplitHostPort(ln.Addr().String())
3862+
require.NoError(t, err)
3863+
3864+
var directNegoObserved atomic.Bool
3865+
3866+
serverErrCh := make(chan error, 1)
3867+
go func() {
3868+
defer close(serverErrCh)
3869+
3870+
conn, err := ln.Accept()
3871+
if err != nil {
3872+
serverErrCh <- fmt.Errorf("accept error: %w", err)
3873+
return
3874+
}
3875+
defer conn.Close()
3876+
3877+
conn.SetDeadline(time.Now().Add(5 * time.Second))
3878+
3879+
firstByte := make([]byte, 1)
3880+
_, err = conn.Read(firstByte)
3881+
if err != nil {
3882+
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
3883+
return
3884+
}
3885+
3886+
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
3887+
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
3888+
directNegoObserved.Store(isDirect)
3889+
3890+
var tlsConn *tls.Conn
3891+
3892+
if !isDirect {
3893+
// Handle standard PostgreSQL SSL negotiation
3894+
// Read the rest of the SSL request message
3895+
sslRequestRemainder := make([]byte, 7)
3896+
_, err = io.ReadFull(conn, sslRequestRemainder)
3897+
if err != nil {
3898+
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
3899+
return
3900+
}
3901+
3902+
// Send SSL acceptance response
3903+
_, err = conn.Write([]byte("S"))
3904+
if err != nil {
3905+
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
3906+
return
3907+
}
3908+
3909+
// Setup TLS server without needing to reuse the first byte
3910+
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
3911+
if err != nil {
3912+
serverErrCh <- fmt.Errorf("cert error: %w", err)
3913+
return
3914+
}
3915+
3916+
tlsConn = tls.Server(conn, &tls.Config{
3917+
Certificates: []tls.Certificate{cert},
3918+
})
3919+
} else {
3920+
// Handle direct TLS negotiation
3921+
// Setup TLS server with the first byte already read
3922+
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
3923+
if err != nil {
3924+
serverErrCh <- fmt.Errorf("cert error: %w", err)
3925+
return
3926+
}
3927+
3928+
// Use a wrapper to inject the first byte back into the TLS handshake
3929+
bufConn := &prefixConn{
3930+
Conn: conn,
3931+
prefixData: firstByte,
3932+
}
3933+
3934+
tlsConn = tls.Server(bufConn, &tls.Config{
3935+
Certificates: []tls.Certificate{cert},
3936+
})
3937+
}
3938+
3939+
// Complete TLS handshake
3940+
if err := tlsConn.Handshake(); err != nil {
3941+
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
3942+
return
3943+
}
3944+
defer tlsConn.Close()
3945+
3946+
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
3947+
if err != nil {
3948+
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
3949+
return
3950+
}
3951+
}()
3952+
3953+
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
3954+
tt.connString, port)
3955+
3956+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
3957+
defer cancel()
3958+
3959+
conn, err := pgconn.Connect(ctx, connStr)
3960+
3961+
require.NoError(t, err)
3962+
3963+
defer conn.Close(ctx)
3964+
3965+
err = <-serverErrCh
3966+
require.NoError(t, err)
3967+
3968+
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
3969+
})
3970+
}
3971+
}
3972+
3973+
// prefixConn implements a net.Conn that prepends some data to the first Read
3974+
type prefixConn struct {
3975+
net.Conn
3976+
prefixData []byte
3977+
prefixConsumed bool
3978+
}
3979+
3980+
func (c *prefixConn) Read(b []byte) (n int, err error) {
3981+
if !c.prefixConsumed && len(c.prefixData) > 0 {
3982+
n = copy(b, c.prefixData)
3983+
c.prefixData = c.prefixData[n:]
3984+
c.prefixConsumed = len(c.prefixData) == 0
3985+
return n, nil
3986+
}
3987+
return c.Conn.Read(b)
3988+
}
3989+
38223990
// https://github.com/jackc/pgx/issues/1920
38233991
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
38243992
t.Parallel()

0 commit comments

Comments
 (0)