@@ -14,6 +14,7 @@ import (
14
14
"os"
15
15
"strconv"
16
16
"strings"
17
+ "sync/atomic"
17
18
"testing"
18
19
"time"
19
20
@@ -3819,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
3819
3820
}
3820
3821
}
3821
3822
3823
+ func TestConnectWithDirectSSLNegotiation (t * testing.T ) {
3824
+ t .Parallel ()
3825
+
3826
+ tests := []struct {
3827
+ name string
3828
+ connString string
3829
+ expectDirectNego bool
3830
+ }{
3831
+ {
3832
+ name : "Default negotiation (postgres)" ,
3833
+ connString : "sslmode=require" ,
3834
+ expectDirectNego : false ,
3835
+ },
3836
+ {
3837
+ name : "Direct negotiation" ,
3838
+ connString : "sslmode=require sslnegotiation=direct" ,
3839
+ expectDirectNego : true ,
3840
+ },
3841
+ {
3842
+ name : "Explicit postgres negotiation" ,
3843
+ connString : "sslmode=require sslnegotiation=postgres" ,
3844
+ expectDirectNego : false ,
3845
+ },
3846
+ }
3847
+
3848
+ for _ , tt := range tests {
3849
+ tt := tt
3850
+ t .Run (tt .name , func (t * testing.T ) {
3851
+ t .Parallel ()
3852
+
3853
+ script := & pgmock.Script {
3854
+ Steps : pgmock .AcceptUnauthenticatedConnRequestSteps (),
3855
+ }
3856
+
3857
+ ln , err := net .Listen ("tcp" , "127.0.0.1:" )
3858
+ require .NoError (t , err )
3859
+ defer ln .Close ()
3860
+
3861
+ _ , port , err := net .SplitHostPort (ln .Addr ().String ())
3862
+ require .NoError (t , err )
3863
+
3864
+ var directNegoObserved atomic.Bool
3865
+
3866
+ serverErrCh := make (chan error , 1 )
3867
+ go func () {
3868
+ defer close (serverErrCh )
3869
+
3870
+ conn , err := ln .Accept ()
3871
+ if err != nil {
3872
+ serverErrCh <- fmt .Errorf ("accept error: %w" , err )
3873
+ return
3874
+ }
3875
+ defer conn .Close ()
3876
+
3877
+ conn .SetDeadline (time .Now ().Add (5 * time .Second ))
3878
+
3879
+ firstByte := make ([]byte , 1 )
3880
+ _ , err = conn .Read (firstByte )
3881
+ if err != nil {
3882
+ serverErrCh <- fmt .Errorf ("read first byte error: %w" , err )
3883
+ return
3884
+ }
3885
+
3886
+ // Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
3887
+ isDirect := firstByte [0 ] >= 20 && firstByte [0 ] <= 23
3888
+ directNegoObserved .Store (isDirect )
3889
+
3890
+ var tlsConn * tls.Conn
3891
+
3892
+ if ! isDirect {
3893
+ // Handle standard PostgreSQL SSL negotiation
3894
+ // Read the rest of the SSL request message
3895
+ sslRequestRemainder := make ([]byte , 7 )
3896
+ _ , err = io .ReadFull (conn , sslRequestRemainder )
3897
+ if err != nil {
3898
+ serverErrCh <- fmt .Errorf ("read ssl request remainder error: %w" , err )
3899
+ return
3900
+ }
3901
+
3902
+ // Send SSL acceptance response
3903
+ _ , err = conn .Write ([]byte ("S" ))
3904
+ if err != nil {
3905
+ serverErrCh <- fmt .Errorf ("write ssl acceptance error: %w" , err )
3906
+ return
3907
+ }
3908
+
3909
+ // Setup TLS server without needing to reuse the first byte
3910
+ cert , err := tls .X509KeyPair ([]byte (rsaCertPEM ), []byte (rsaKeyPEM ))
3911
+ if err != nil {
3912
+ serverErrCh <- fmt .Errorf ("cert error: %w" , err )
3913
+ return
3914
+ }
3915
+
3916
+ tlsConn = tls .Server (conn , & tls.Config {
3917
+ Certificates : []tls.Certificate {cert },
3918
+ })
3919
+ } else {
3920
+ // Handle direct TLS negotiation
3921
+ // Setup TLS server with the first byte already read
3922
+ cert , err := tls .X509KeyPair ([]byte (rsaCertPEM ), []byte (rsaKeyPEM ))
3923
+ if err != nil {
3924
+ serverErrCh <- fmt .Errorf ("cert error: %w" , err )
3925
+ return
3926
+ }
3927
+
3928
+ // Use a wrapper to inject the first byte back into the TLS handshake
3929
+ bufConn := & prefixConn {
3930
+ Conn : conn ,
3931
+ prefixData : firstByte ,
3932
+ }
3933
+
3934
+ tlsConn = tls .Server (bufConn , & tls.Config {
3935
+ Certificates : []tls.Certificate {cert },
3936
+ })
3937
+ }
3938
+
3939
+ // Complete TLS handshake
3940
+ if err := tlsConn .Handshake (); err != nil {
3941
+ serverErrCh <- fmt .Errorf ("TLS handshake error: %w" , err )
3942
+ return
3943
+ }
3944
+ defer tlsConn .Close ()
3945
+
3946
+ err = script .Run (pgproto3 .NewBackend (tlsConn , tlsConn ))
3947
+ if err != nil {
3948
+ serverErrCh <- fmt .Errorf ("pgmock run error: %w" , err )
3949
+ return
3950
+ }
3951
+ }()
3952
+
3953
+ connStr := fmt .Sprintf ("%s host=localhost port=%s sslmode=require sslinsecure=1" ,
3954
+ tt .connString , port )
3955
+
3956
+ ctx , cancel := context .WithTimeout (context .Background (), 3 * time .Second )
3957
+ defer cancel ()
3958
+
3959
+ conn , err := pgconn .Connect (ctx , connStr )
3960
+
3961
+ require .NoError (t , err )
3962
+
3963
+ defer conn .Close (ctx )
3964
+
3965
+ err = <- serverErrCh
3966
+ require .NoError (t , err )
3967
+
3968
+ require .Equal (t , tt .expectDirectNego , directNegoObserved .Load ())
3969
+ })
3970
+ }
3971
+ }
3972
+
3973
+ // prefixConn implements a net.Conn that prepends some data to the first Read
3974
+ type prefixConn struct {
3975
+ net.Conn
3976
+ prefixData []byte
3977
+ prefixConsumed bool
3978
+ }
3979
+
3980
+ func (c * prefixConn ) Read (b []byte ) (n int , err error ) {
3981
+ if ! c .prefixConsumed && len (c .prefixData ) > 0 {
3982
+ n = copy (b , c .prefixData )
3983
+ c .prefixData = c .prefixData [n :]
3984
+ c .prefixConsumed = len (c .prefixData ) == 0
3985
+ return n , nil
3986
+ }
3987
+ return c .Conn .Read (b )
3988
+ }
3989
+
3822
3990
// https://github.com/jackc/pgx/issues/1920
3823
3991
func TestFatalErrorReceivedInPipelineMode (t * testing.T ) {
3824
3992
t .Parallel ()
0 commit comments