Skip to content

Commit fc21a98

Browse files
author
KJ Tsanaktsidis
committed
Remove registry for CredentialProvider
Instead, it can be used by passing the config object directly to a Connector.
1 parent 5042da9 commit fc21a98

File tree

6 files changed

+93
-106
lines changed

6 files changed

+93
-106
lines changed

README.md

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,6 @@ SELECT u.id FROM users as u
209209

210210
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
211211

212-
#### `credentialProvider`
213-
214-
```
215-
Type: string
216-
Valid Values: <name>
217-
Default: ""
218-
```
219-
220-
If set, this must refer to a credential provider name registered with `RegisterCredentialProvider`. When this is set, the username and password in the DSN will be ignored; instead, each time a conneciton is to be opened, the named credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire.
221-
222212
##### `interpolateParams`
223213

224214
```
@@ -377,6 +367,31 @@ Examples:
377367
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
378368
* [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'`
379369

370+
#### Non-DSN parameters
371+
372+
Some parameters (those that have types too complex to fit into a string) are not supported as part of a DSN string, but can only be specified by using the Connector interface. To use these parameters, set your database client up like so:
373+
374+
```go
375+
dbConfig := mysql.Config {
376+
Addr: "localhost:3306",
377+
// ... other parameters ...
378+
}
379+
connector, err := mysql.NewConnector(dbConfig)
380+
if err != nil {
381+
panic(error)
382+
}
383+
db := sql.OpenDB(connector)
384+
```
385+
386+
##### `CredentialProvider`
387+
388+
```
389+
Type: CredentialProviderFunc
390+
Default: nil
391+
```
392+
393+
If set, this must refer to a credential provider function of type `CredentialProviderFunc`. When this is set, the `User` and `Passwd` fields in the config will be ignored; instead, each time a connection is to be opened, the credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire.
394+
380395

381396
#### Examples
382397
```

auth.go

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@ import (
1515
"crypto/sha256"
1616
"crypto/x509"
1717
"encoding/pem"
18-
"fmt"
1918
"sync"
2019
)
2120

2221
// server pub keys registry
2322
var (
24-
serverPubKeyLock sync.RWMutex
25-
serverPubKeyRegistry map[string]*rsa.PublicKey
26-
credentialProviderLock sync.RWMutex
27-
credentialProviderRegistry map[string]CredentialProviderFunc
23+
serverPubKeyLock sync.RWMutex
24+
serverPubKeyRegistry map[string]*rsa.PublicKey
2825
)
2926

3027
// RegisterServerPubKey registers a server RSA public key which can be used to
@@ -89,39 +86,6 @@ func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
8986
// and the second the password.
9087
type CredentialProviderFunc func() (user string, password string, error error)
9188

92-
// RegisterCredentialProvider registers a function to be called on every connection open to
93-
// get the username and password to call
94-
func RegisterCredentialProvider(name string, providerFunc CredentialProviderFunc) {
95-
credentialProviderLock.Lock()
96-
if credentialProviderRegistry == nil {
97-
credentialProviderRegistry = make(map[string]CredentialProviderFunc)
98-
}
99-
credentialProviderRegistry[name] = providerFunc
100-
credentialProviderLock.Unlock()
101-
}
102-
103-
// DeregisterCredentialProvider removes a function registered with RegisterCredentialProvider
104-
func DeregisterCredentialProvider(name string) {
105-
credentialProviderLock.Lock()
106-
if credentialProviderRegistry != nil {
107-
delete(credentialProviderRegistry, name)
108-
}
109-
credentialProviderLock.Unlock()
110-
}
111-
112-
func getCredentialsFromConfig(cfg *Config) (user string, password string, error error) {
113-
if cfg.CredentialProvider != "" {
114-
credentialProviderLock.RLock()
115-
defer credentialProviderLock.RUnlock()
116-
cpFunc, ok := credentialProviderRegistry[cfg.CredentialProvider]
117-
if !ok {
118-
return "", "", fmt.Errorf("credential provider %s not registered", cfg.CredentialProvider)
119-
}
120-
return cpFunc()
121-
}
122-
return cfg.User, cfg.Passwd, nil
123-
}
124-
12589
// Hash password using pre 4.1 (old password) method
12690
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
12791
type myRnd struct {

connector.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
8888
plugin = defaultAuthPlugin
8989
}
9090

91-
user, password, err := getCredentialsFromConfig(c.cfg)
91+
user, password, err := c.cfg.getCredentials()
9292
if err != nil {
9393
mc.cleanup()
9494
return nil, err

driver_test.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -125,36 +125,47 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT
125125
}
126126

127127
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
128+
cfg, err := ParseDSN(dsn)
129+
if err != nil {
130+
t.Fatalf("error formatting DSN")
131+
}
132+
runTestsWithConfig(t, cfg, tests...)
133+
}
134+
135+
func runTestsWithConfig(t *testing.T, cfg *Config, tests ...func(dbt *DBTest)) {
128136
if !available {
129137
t.Skipf("MySQL server not running on %s", netAddr)
130138
}
131139

132-
db, err := sql.Open("mysql", dsn)
140+
connector, err := NewConnector(cfg)
133141
if err != nil {
134142
t.Fatalf("error connecting: %s", err.Error())
135143
}
144+
db := sql.OpenDB(connector)
136145
defer db.Close()
137146

138147
db.Exec("DROP TABLE IF EXISTS test")
139148

140-
dsn2 := dsn + "&interpolateParams=true"
149+
cfg2 := cfg.Clone()
150+
cfg2.InterpolateParams = true
141151
var db2 *sql.DB
142-
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
143-
db2, err = sql.Open("mysql", dsn2)
144-
if err != nil {
145-
t.Fatalf("error connecting: %s", err.Error())
146-
}
152+
connector2, err := NewConnector(cfg2)
153+
if err != errInvalidDSNUnsafeCollation {
154+
db2 = sql.OpenDB(connector2)
147155
defer db2.Close()
156+
} else if err != nil {
157+
t.Fatalf("error connecting: %s", err.Error())
148158
}
149159

150-
dsn3 := dsn + "&multiStatements=true"
160+
cfg3 := cfg.Clone()
161+
cfg3.MultiStatements = true
151162
var db3 *sql.DB
152-
if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
153-
db3, err = sql.Open("mysql", dsn3)
154-
if err != nil {
155-
t.Fatalf("error connecting: %s", err.Error())
156-
}
163+
connector3, err := NewConnector(cfg2)
164+
if err != errInvalidDSNUnsafeCollation {
165+
db3 = sql.OpenDB(connector3)
157166
defer db3.Close()
167+
} else if err != nil {
168+
t.Fatalf("error connecting: %s", err.Error())
158169
}
159170

160171
dbt := &DBTest{t, db}
@@ -3169,18 +3180,23 @@ func TestCredentialProviderFunc(t *testing.T) {
31693180
// to test that it really is having an effect.
31703181
shouldFailCreds := false
31713182
shouldFailError := false
3172-
RegisterCredentialProvider("TestCredentialProviderFunc", func() (string, string, error) {
3173-
if shouldFailCreds {
3174-
return "fail", "fail", nil
3175-
}
3176-
if shouldFailError {
3177-
return "", "", fmt.Errorf("credential_error")
3178-
}
3179-
return user, pass, nil
3180-
})
3181-
defer DeregisterCredentialProvider("TestCredentialProviderFunc")
3182-
dsn := fmt.Sprintf("%s/%s?timeout=30s&credentialProvider=TestCredentialProviderFunc", netAddr, dbname)
3183-
runTests(t, dsn, func(dbt *DBTest) {
3183+
cfg := &Config{
3184+
Addr: addr,
3185+
Net: prot,
3186+
DBName: dbname,
3187+
Collation: defaultCollation,
3188+
AllowNativePasswords: true,
3189+
CredentialProvider: func() (string, string, error) {
3190+
if shouldFailCreds {
3191+
return "fail", "fail", nil
3192+
}
3193+
if shouldFailError {
3194+
return "", "", fmt.Errorf("credential_error")
3195+
}
3196+
return user, pass, nil
3197+
},
3198+
}
3199+
runTestsWithConfig(t, cfg, func(dbt *DBTest) {
31843200
ctx := context.Background()
31853201
c1, err := dbt.db.Conn(ctx)
31863202
if err != nil {

dsn.go

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,23 @@ var (
3434
// If a new Config is created instead of being parsed from a DSN string,
3535
// the NewConfig function should be used, which sets default values.
3636
type Config struct {
37-
User string // Username
38-
Passwd string // Password (requires User)
39-
CredentialProvider string // Credential provider name registered with RegisterCredentialProvider
40-
Net string // Network type
41-
Addr string // Network address (requires Net)
42-
DBName string // Database name
43-
Params map[string]string // Connection parameters
44-
Collation string // Connection collation
45-
Loc *time.Location // Location for time.Time values
46-
MaxAllowedPacket int // Max packet size allowed
47-
ServerPubKey string // Server public key name
48-
pubKey *rsa.PublicKey // Server public key
49-
TLSConfig string // TLS configuration name
50-
tls *tls.Config // TLS configuration
51-
Timeout time.Duration // Dial timeout
52-
ReadTimeout time.Duration // I/O read timeout
53-
WriteTimeout time.Duration // I/O write timeout
37+
User string // Username
38+
Passwd string // Password (requires User)
39+
CredentialProvider CredentialProviderFunc // Credential provider function
40+
Net string // Network type
41+
Addr string // Network address (requires Net)
42+
DBName string // Database name
43+
Params map[string]string // Connection parameters
44+
Collation string // Connection collation
45+
Loc *time.Location // Location for time.Time values
46+
MaxAllowedPacket int // Max packet size allowed
47+
ServerPubKey string // Server public key name
48+
pubKey *rsa.PublicKey // Server public key
49+
TLSConfig string // TLS configuration name
50+
tls *tls.Config // TLS configuration
51+
Timeout time.Duration // Dial timeout
52+
ReadTimeout time.Duration // I/O read timeout
53+
WriteTimeout time.Duration // I/O write timeout
5454

5555
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
5656
AllowCleartextPasswords bool // Allows the cleartext client side plugin
@@ -348,16 +348,6 @@ func (cfg *Config) FormatDSN() string {
348348

349349
}
350350

351-
if cfg.CredentialProvider != "" {
352-
if hasParam {
353-
buf.WriteString("&credentialProvider=")
354-
} else {
355-
hasParam = true
356-
buf.WriteString("?credentialProvider=")
357-
}
358-
buf.WriteString(cfg.CredentialProvider)
359-
}
360-
361351
// other params
362352
if cfg.Params != nil {
363353
var params []string
@@ -624,8 +614,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
624614
if err != nil {
625615
return
626616
}
627-
case "credentialProvider":
628-
cfg.CredentialProvider = value
629617
default:
630618
// lazy init
631619
if cfg.Params == nil {
@@ -641,6 +629,13 @@ func parseDSNParams(cfg *Config, params string) (err error) {
641629
return
642630
}
643631

632+
func (cfg *Config) getCredentials() (user string, password string, err error) {
633+
if cfg.CredentialProvider != nil {
634+
return cfg.CredentialProvider()
635+
}
636+
return cfg.User, cfg.Passwd, nil
637+
}
638+
644639
func ensureHavePort(addr string) string {
645640
if _, _, err := net.SplitHostPort(addr); err != nil {
646641
return net.JoinHostPort(addr, "3306")

dsn_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ var testDSNs = []struct {
7171
}, {
7272
"tcp(de:ad:be:ef::ca:fe)/dbname",
7373
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
74-
}, {
75-
"tcp(localhost)/dbname?credentialProvider=foobar",
76-
&Config{Net: "tcp", Addr: "localhost:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CredentialProvider: "foobar"},
7774
}}
7875

7976
func TestDSNParser(t *testing.T) {

0 commit comments

Comments
 (0)