Skip to content

Commit e580b90

Browse files
committed
Add SSPI context flag validation functions
This change adds two verification functions for SSPI context: VerifyFlags and VerifySelectiveFlags. The former is used to confirm that all flags requested for the context were honored. The latter is used to confirm if some subset of the requested flags were honored. Signed-off-by: Monis Khan <[email protected]>
1 parent b6e6a13 commit e580b90

File tree

4 files changed

+255
-0
lines changed

4 files changed

+255
-0
lines changed

flags_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package sspi
2+
3+
import (
4+
"strconv"
5+
"testing"
6+
)
7+
8+
func Test_verifySelectiveFlags(t *testing.T) {
9+
type args struct {
10+
flags uint32
11+
establishedFlags uint32
12+
}
13+
tests := []struct {
14+
name string
15+
args args
16+
wantValid bool
17+
wantMissing uint32
18+
wantExtra uint32
19+
}{
20+
{
21+
name: "all zeros",
22+
args: args{
23+
flags: binary("00000"),
24+
establishedFlags: binary("00000"),
25+
},
26+
wantValid: true,
27+
wantMissing: binary("00000"),
28+
wantExtra: binary("00000"),
29+
},
30+
{
31+
name: "all ones",
32+
args: args{
33+
flags: binary("11111"),
34+
establishedFlags: binary("11111"),
35+
},
36+
wantValid: true,
37+
wantMissing: binary("00000"),
38+
wantExtra: binary("00000"),
39+
},
40+
{
41+
name: "missing one bit",
42+
args: args{
43+
flags: binary("11111"),
44+
establishedFlags: binary("11011"),
45+
},
46+
wantValid: false,
47+
wantMissing: binary("00100"),
48+
wantExtra: binary("00000"),
49+
},
50+
{
51+
name: "missing two bits",
52+
args: args{
53+
flags: binary("11111"),
54+
establishedFlags: binary("01011"),
55+
},
56+
wantValid: false,
57+
wantMissing: binary("10100"),
58+
wantExtra: binary("00000"),
59+
},
60+
{
61+
name: "missing all bits",
62+
args: args{
63+
flags: binary("11101"),
64+
establishedFlags: binary("00000"),
65+
},
66+
wantValid: false,
67+
wantMissing: binary("11101"),
68+
wantExtra: binary("00000"),
69+
},
70+
{
71+
name: "one extra bit",
72+
args: args{
73+
flags: binary("00111"),
74+
establishedFlags: binary("01111"),
75+
},
76+
wantValid: true,
77+
wantMissing: binary("00000"),
78+
wantExtra: binary("01000"),
79+
},
80+
{
81+
name: "two extra bits",
82+
args: args{
83+
flags: binary("01000"),
84+
establishedFlags: binary("11001"),
85+
},
86+
wantValid: true,
87+
wantMissing: binary("00000"),
88+
wantExtra: binary("10001"),
89+
},
90+
{
91+
name: "all extra bits",
92+
args: args{
93+
flags: binary("00000"),
94+
establishedFlags: binary("11111"),
95+
},
96+
wantValid: true,
97+
wantMissing: binary("00000"),
98+
wantExtra: binary("11111"),
99+
},
100+
{
101+
name: "missing and extra bits",
102+
args: args{
103+
flags: binary("00101"),
104+
establishedFlags: binary("11001"),
105+
},
106+
wantValid: false,
107+
wantMissing: binary("00100"),
108+
wantExtra: binary("11000"),
109+
},
110+
}
111+
for _, tt := range tests {
112+
t.Run(tt.name, func(t *testing.T) {
113+
gotValid, gotMissing, gotExtra := verifySelectiveFlags(tt.args.flags, tt.args.establishedFlags)
114+
if gotValid != tt.wantValid {
115+
t.Errorf("verifySelectiveFlags() gotValid = %v, want %v", gotValid, tt.wantValid)
116+
}
117+
if gotMissing != tt.wantMissing {
118+
t.Errorf("verifySelectiveFlags() gotMissing = %v, want %v", gotMissing, tt.wantMissing)
119+
}
120+
if gotExtra != tt.wantExtra {
121+
t.Errorf("verifySelectiveFlags() gotExtra = %v, want %v", gotExtra, tt.wantExtra)
122+
}
123+
})
124+
}
125+
}
126+
127+
func binary(b string) uint32 {
128+
n, err := strconv.ParseUint(b, 2, 32)
129+
if err != nil {
130+
panic(err) // programmer error due to invalid test data
131+
}
132+
return uint32(n)
133+
}

negotiate/negotiate.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,18 @@ func (c *ClientContext) DecryptMessage(msg []byte, seqno uint32) (uint32, []byte
327327
return decryptMessage(c.sctxt, msg, seqno)
328328
}
329329

330+
// VerifyFlags determines if all flags used to construct the client context
331+
// were honored (see NewClientContextWithFlags). It should be called after c.Update.
332+
func (c *ClientContext) VerifyFlags() error {
333+
return c.sctxt.VerifyFlags()
334+
}
335+
336+
// VerifySelectiveFlags determines if the given flags were honored (see NewClientContextWithFlags).
337+
// It should be called after c.Update.
338+
func (c *ClientContext) VerifySelectiveFlags(flags uint32) error {
339+
return c.sctxt.VerifySelectiveFlags(flags)
340+
}
341+
330342
// ServerContext is used by the server to manage all steps of Negotiate
331343
// negotiation. Once authentication is completed the context can be
332344
// used to impersonate client.

negotiate/negotiate_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,86 @@ func TestSignatureEncryption(t *testing.T) {
331331
t.Logf("client verified server signature")
332332
}
333333

334+
func TestFlagVerification(t *testing.T) {
335+
clientCred, err := negotiate.AcquireCurrentUserCredentials()
336+
if err != nil {
337+
t.Fatal(err)
338+
}
339+
defer clientCred.Release()
340+
341+
serverCred, err := negotiate.AcquireServerCredentials("")
342+
if err != nil {
343+
t.Fatal(err)
344+
}
345+
defer serverCred.Release()
346+
347+
const desiredFlags = sspi.ISC_REQ_CONFIDENTIALITY |
348+
sspi.ISC_REQ_INTEGRITY |
349+
sspi.ISC_REQ_MUTUAL_AUTH |
350+
sspi.ISC_REQ_REPLAY_DETECT |
351+
sspi.ISC_REQ_SEQUENCE_DETECT
352+
353+
client, toServerToken, err := negotiate.NewClientContextWithFlags(clientCred, "", desiredFlags)
354+
if err != nil {
355+
t.Fatal(err)
356+
}
357+
defer client.Release()
358+
359+
if len(toServerToken) == 0 {
360+
t.Fatal("token for server cannot be empty")
361+
}
362+
363+
server, serverDone, toClientToken, err := negotiate.NewServerContext(serverCred, toServerToken)
364+
if err != nil {
365+
t.Fatal(err)
366+
}
367+
defer server.Release()
368+
369+
if len(toClientToken) == 0 {
370+
t.Fatal("token for client cannot be empty")
371+
}
372+
373+
errMsg := "sspi: invalid flags check: desired=100000000 requested=10000000000011110 missing=100000000 extra=10000000000011110"
374+
375+
var clientDone bool
376+
for {
377+
if len(toClientToken) == 0 {
378+
break
379+
}
380+
clientDone, toServerToken, err = client.Update(toClientToken)
381+
if err != nil {
382+
t.Fatal(err)
383+
}
384+
385+
// verify all flags
386+
if err := client.VerifyFlags(); err != nil {
387+
t.Fatal(err)
388+
}
389+
// verify a subset of flags
390+
if err := client.VerifySelectiveFlags(sspi.ISC_REQ_MUTUAL_AUTH); err != nil {
391+
t.Fatal(err)
392+
}
393+
// try to verify a flag that was not initially requested
394+
if err := client.VerifySelectiveFlags(sspi.ISC_REQ_ALLOCATE_MEMORY); err == nil || err.Error() != errMsg {
395+
t.Fatalf("wrong error found: %v", err)
396+
}
397+
398+
if len(toServerToken) == 0 {
399+
break
400+
}
401+
serverDone, toClientToken, err = server.Update(toServerToken)
402+
if err != nil {
403+
t.Fatal(err)
404+
}
405+
}
406+
if !clientDone {
407+
t.Fatal("client authentication should be completed now")
408+
}
409+
if !serverDone {
410+
t.Fatal("server authentication should be completed now")
411+
}
412+
}
413+
334414
func copyArray(a []byte) []byte {
335415
b := make([]byte, len(a))
336416
copy(b, a)

sspi.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package sspi
88

99
import (
10+
"fmt"
1011
"syscall"
1112
"time"
1213
"unsafe"
@@ -185,6 +186,35 @@ func (c *Context) Sizes() (uint32, uint32, uint32, uint32, error) {
185186
return s.MaxToken, s.MaxSignature, s.BlockSize, s.SecurityTrailer, nil
186187
}
187188

189+
// VerifyFlags determines if all flags used to construct the context
190+
// were honored (see NewClientContext). It should be called after c.Update.
191+
func (c *Context) VerifyFlags() error {
192+
return c.VerifySelectiveFlags(c.RequestedFlags)
193+
}
194+
195+
// VerifySelectiveFlags determines if the given flags were honored (see NewClientContext).
196+
// It should be called after c.Update.
197+
func (c *Context) VerifySelectiveFlags(flags uint32) error {
198+
if valid, missing, extra := verifySelectiveFlags(flags, c.RequestedFlags); !valid {
199+
return fmt.Errorf("sspi: invalid flags check: desired=%b requested=%b missing=%b extra=%b", flags, c.RequestedFlags, missing, extra)
200+
}
201+
if valid, missing, extra := verifySelectiveFlags(flags, c.EstablishedFlags); !valid {
202+
return fmt.Errorf("sspi: invalid flags: desired=%b established=%b missing=%b extra=%b", flags, c.EstablishedFlags, missing, extra)
203+
}
204+
return nil
205+
}
206+
207+
// verifySelectiveFlags determines if all bits requested in flags are set in establishedFlags.
208+
// missing represents the bits set in flags that are not set in establishedFlags.
209+
// extra represents the bits set in establishedFlags that are not set in flags.
210+
// valid is true and missing is zero when establishedFlags has all of the requested flags.
211+
func verifySelectiveFlags(flags, establishedFlags uint32) (valid bool, missing, extra uint32) {
212+
missing = flags&establishedFlags ^ flags
213+
extra = flags | establishedFlags ^ flags
214+
valid = missing == 0
215+
return valid, missing, extra
216+
}
217+
188218
// NewSecBufferDesc returns an initialized SecBufferDesc describing the
189219
// provided SecBuffer.
190220
func NewSecBufferDesc(b []SecBuffer) *SecBufferDesc {

0 commit comments

Comments
 (0)