Skip to content

Add SSPI context flag validation functions #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package sspi

import (
"strconv"
"testing"
)

func Test_verifySelectiveFlags(t *testing.T) {
type args struct {
flags uint32
establishedFlags uint32
}
tests := []struct {
name string
args args
wantValid bool
wantMissing uint32
wantExtra uint32
}{
{
name: "all zeros",
args: args{
flags: binary("00000"),
establishedFlags: binary("00000"),
},
wantValid: true,
wantMissing: binary("00000"),
wantExtra: binary("00000"),
},
{
name: "all ones",
args: args{
flags: binary("11111"),
establishedFlags: binary("11111"),
},
wantValid: true,
wantMissing: binary("00000"),
wantExtra: binary("00000"),
},
{
name: "missing one bit",
args: args{
flags: binary("11111"),
establishedFlags: binary("11011"),
},
wantValid: false,
wantMissing: binary("00100"),
wantExtra: binary("00000"),
},
{
name: "missing two bits",
args: args{
flags: binary("11111"),
establishedFlags: binary("01011"),
},
wantValid: false,
wantMissing: binary("10100"),
wantExtra: binary("00000"),
},
{
name: "missing all bits",
args: args{
flags: binary("11101"),
establishedFlags: binary("00000"),
},
wantValid: false,
wantMissing: binary("11101"),
wantExtra: binary("00000"),
},
{
name: "one extra bit",
args: args{
flags: binary("00111"),
establishedFlags: binary("01111"),
},
wantValid: true,
wantMissing: binary("00000"),
wantExtra: binary("01000"),
},
{
name: "two extra bits",
args: args{
flags: binary("01000"),
establishedFlags: binary("11001"),
},
wantValid: true,
wantMissing: binary("00000"),
wantExtra: binary("10001"),
},
{
name: "all extra bits",
args: args{
flags: binary("00000"),
establishedFlags: binary("11111"),
},
wantValid: true,
wantMissing: binary("00000"),
wantExtra: binary("11111"),
},
{
name: "missing and extra bits",
args: args{
flags: binary("00101"),
establishedFlags: binary("11001"),
},
wantValid: false,
wantMissing: binary("00100"),
wantExtra: binary("11000"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotValid, gotMissing, gotExtra := verifySelectiveFlags(tt.args.flags, tt.args.establishedFlags)
if gotValid != tt.wantValid {
t.Errorf("verifySelectiveFlags() gotValid = %v, want %v", gotValid, tt.wantValid)
}
if gotMissing != tt.wantMissing {
t.Errorf("verifySelectiveFlags() gotMissing = %v, want %v", gotMissing, tt.wantMissing)
}
if gotExtra != tt.wantExtra {
t.Errorf("verifySelectiveFlags() gotExtra = %v, want %v", gotExtra, tt.wantExtra)
}
})
}
}

func binary(b string) uint32 {
n, err := strconv.ParseUint(b, 2, 32)
if err != nil {
panic(err) // programmer error due to invalid test data
}
return uint32(n)
}
12 changes: 12 additions & 0 deletions negotiate/negotiate.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,18 @@ func (c *ClientContext) DecryptMessage(msg []byte, seqno uint32) (uint32, []byte
return decryptMessage(c.sctxt, msg, seqno)
}

// VerifyFlags determines if all flags used to construct the client context
// were honored (see NewClientContextWithFlags). It should be called after c.Update.
func (c *ClientContext) VerifyFlags() error {
return c.sctxt.VerifyFlags()
}

// VerifySelectiveFlags determines if the given flags were honored (see NewClientContextWithFlags).
// It should be called after c.Update.
func (c *ClientContext) VerifySelectiveFlags(flags uint32) error {
return c.sctxt.VerifySelectiveFlags(flags)
}

// ServerContext is used by the server to manage all steps of Negotiate
// negotiation. Once authentication is completed the context can be
// used to impersonate client.
Expand Down
80 changes: 80 additions & 0 deletions negotiate/negotiate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,86 @@ func TestSignatureEncryption(t *testing.T) {
t.Logf("client verified server signature")
}

func TestFlagVerification(t *testing.T) {
clientCred, err := negotiate.AcquireCurrentUserCredentials()
if err != nil {
t.Fatal(err)
}
defer clientCred.Release()

serverCred, err := negotiate.AcquireServerCredentials("")
if err != nil {
t.Fatal(err)
}
defer serverCred.Release()

const desiredFlags = sspi.ISC_REQ_CONFIDENTIALITY |
sspi.ISC_REQ_INTEGRITY |
sspi.ISC_REQ_MUTUAL_AUTH |
sspi.ISC_REQ_REPLAY_DETECT |
sspi.ISC_REQ_SEQUENCE_DETECT

client, toServerToken, err := negotiate.NewClientContextWithFlags(clientCred, "", desiredFlags)
if err != nil {
t.Fatal(err)
}
defer client.Release()

if len(toServerToken) == 0 {
t.Fatal("token for server cannot be empty")
}

server, serverDone, toClientToken, err := negotiate.NewServerContext(serverCred, toServerToken)
if err != nil {
t.Fatal(err)
}
defer server.Release()

if len(toClientToken) == 0 {
t.Fatal("token for client cannot be empty")
}

errMsg := "sspi: invalid flags check: desired=100000000 requested=10000000000011110 missing=100000000 extra=10000000000011110"

var clientDone bool
for {
if len(toClientToken) == 0 {
break
}
clientDone, toServerToken, err = client.Update(toClientToken)
if err != nil {
t.Fatal(err)
}

// verify all flags
if err := client.VerifyFlags(); err != nil {
t.Fatal(err)
}
// verify a subset of flags
if err := client.VerifySelectiveFlags(sspi.ISC_REQ_MUTUAL_AUTH); err != nil {
t.Fatal(err)
}
// try to verify a flag that was not initially requested
if err := client.VerifySelectiveFlags(sspi.ISC_REQ_ALLOCATE_MEMORY); err == nil || err.Error() != errMsg {
t.Fatalf("wrong error found: %v", err)
}

if len(toServerToken) == 0 {
break
}
serverDone, toClientToken, err = server.Update(toServerToken)
if err != nil {
t.Fatal(err)
}
}
if !clientDone {
t.Fatal("client authentication should be completed now")
}
if !serverDone {
t.Fatal("server authentication should be completed now")
}
}

func copyArray(a []byte) []byte {
b := make([]byte, len(a))
copy(b, a)
Expand Down
30 changes: 30 additions & 0 deletions sspi.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package sspi

import (
"fmt"
"syscall"
"time"
"unsafe"
Expand Down Expand Up @@ -185,6 +186,35 @@ func (c *Context) Sizes() (uint32, uint32, uint32, uint32, error) {
return s.MaxToken, s.MaxSignature, s.BlockSize, s.SecurityTrailer, nil
}

// VerifyFlags determines if all flags used to construct the context
// were honored (see NewClientContext). It should be called after c.Update.
func (c *Context) VerifyFlags() error {
return c.VerifySelectiveFlags(c.RequestedFlags)
}

// VerifySelectiveFlags determines if the given flags were honored (see NewClientContext).
// It should be called after c.Update.
func (c *Context) VerifySelectiveFlags(flags uint32) error {
if valid, missing, extra := verifySelectiveFlags(flags, c.RequestedFlags); !valid {
return fmt.Errorf("sspi: invalid flags check: desired=%b requested=%b missing=%b extra=%b", flags, c.RequestedFlags, missing, extra)
}
if valid, missing, extra := verifySelectiveFlags(flags, c.EstablishedFlags); !valid {
return fmt.Errorf("sspi: invalid flags: desired=%b established=%b missing=%b extra=%b", flags, c.EstablishedFlags, missing, extra)
}
return nil
}

// verifySelectiveFlags determines if all bits requested in flags are set in establishedFlags.
// missing represents the bits set in flags that are not set in establishedFlags.
// extra represents the bits set in establishedFlags that are not set in flags.
// valid is true and missing is zero when establishedFlags has all of the requested flags.
func verifySelectiveFlags(flags, establishedFlags uint32) (valid bool, missing, extra uint32) {
missing = flags&establishedFlags ^ flags
extra = flags | establishedFlags ^ flags
valid = missing == 0
return valid, missing, extra
}

// NewSecBufferDesc returns an initialized SecBufferDesc describing the
// provided SecBuffer.
func NewSecBufferDesc(b []SecBuffer) *SecBufferDesc {
Expand Down