Skip to content

Commit 02f9d57

Browse files
committed
make sshContext thread safe and fix the data race bug
1 parent cf1ec7e commit 02f9d57

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

context.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,14 @@ type Context interface {
9494
type sshContext struct {
9595
context.Context
9696
*sync.Mutex
97+
98+
values map[interface{}]interface{}
99+
valuesMu sync.Mutex
97100
}
98101

99102
func newContext(srv *Server) (*sshContext, context.CancelFunc) {
100103
innerCtx, cancel := context.WithCancel(context.Background())
101-
ctx := &sshContext{innerCtx, &sync.Mutex{}}
104+
ctx := &sshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})}
102105
ctx.SetValue(ContextKeyServer, srv)
103106
perms := &Permissions{&gossh.Permissions{}}
104107
ctx.SetValue(ContextKeyPermissions, perms)
@@ -119,8 +122,19 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
119122
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
120123
}
121124

125+
func (ctx *sshContext) Value(key interface{}) interface{} {
126+
ctx.valuesMu.Lock()
127+
defer ctx.valuesMu.Unlock()
128+
if v, ok := ctx.values[key]; ok {
129+
return v
130+
}
131+
return ctx.Context.Value(key)
132+
}
133+
122134
func (ctx *sshContext) SetValue(key, value interface{}) {
123-
ctx.Context = context.WithValue(ctx.Context, key, value)
135+
ctx.valuesMu.Lock()
136+
defer ctx.valuesMu.Unlock()
137+
ctx.values[key] = value
124138
}
125139

126140
func (ctx *sshContext) User() string {

context_test.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package ssh
22

3-
import "testing"
3+
import (
4+
"testing"
5+
"time"
6+
)
47

58
func TestSetPermissions(t *testing.T) {
69
t.Parallel()
@@ -45,3 +48,38 @@ func TestSetValue(t *testing.T) {
4548
t.Fatal(err)
4649
}
4750
}
51+
52+
func TestSetValueConcurrency(t *testing.T) {
53+
ctx, cancel := newContext(nil)
54+
defer cancel()
55+
56+
go func() {
57+
for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue
58+
_, _ = ctx.Deadline()
59+
_ = ctx.Err()
60+
_ = ctx.Value("foo")
61+
select {
62+
case <-ctx.Done():
63+
break
64+
default:
65+
}
66+
}
67+
}()
68+
ctx.SetValue("bar", -1) // a context value which never changes
69+
now := time.Now()
70+
var cnt int64
71+
go func() {
72+
for time.Since(now) < 100*time.Millisecond {
73+
cnt++
74+
ctx.SetValue("foo", cnt) // a context value which changes a lot
75+
}
76+
cancel()
77+
}()
78+
<-ctx.Done()
79+
if ctx.Value("foo") != cnt {
80+
t.Fatal("context.Value(foo) doesn't match latest SetValue")
81+
}
82+
if ctx.Value("bar") != -1 {
83+
t.Fatal("context.Value(bar) doesn't match latest SetValue")
84+
}
85+
}

0 commit comments

Comments
 (0)