Skip to content

Commit f0643a3

Browse files
authored
Improve protocol error messages
To aid protocol error debugging, report all errors found in the first two bytes of a message header.
1 parent 2d6ee4c commit f0643a3

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

conn.go

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"math/rand"
1414
"net"
1515
"strconv"
16+
"strings"
1617
"sync"
1718
"time"
1819
"unicode/utf8"
@@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) {
794795
}
795796

796797
// 2. Read and parse first two bytes of frame header.
798+
// To aid debugging, collect and report all errors in the first two bytes
799+
// of the header.
800+
801+
var errors []string
797802

798803
p, err := c.read(2)
799804
if err != nil {
800805
return noFrame, err
801806
}
802807

803-
final := p[0]&finalBit != 0
804808
frameType := int(p[0] & 0xf)
809+
final := p[0]&finalBit != 0
810+
rsv1 := p[0]&rsv1Bit != 0
811+
rsv2 := p[0]&rsv2Bit != 0
812+
rsv3 := p[0]&rsv3Bit != 0
805813
mask := p[1]&maskBit != 0
806814
c.setReadRemaining(int64(p[1] & 0x7f))
807815

808816
c.readDecompress = false
809-
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
810-
c.readDecompress = true
811-
p[0] &^= rsv1Bit
817+
if rsv1 {
818+
if c.newDecompressionReader != nil {
819+
c.readDecompress = true
820+
} else {
821+
errors = append(errors, "RSV1 set")
822+
}
812823
}
813824

814-
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
815-
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
825+
if rsv2 {
826+
errors = append(errors, "RSV2 set")
827+
}
828+
829+
if rsv3 {
830+
errors = append(errors, "RSV3 set")
816831
}
817832

818833
switch frameType {
819834
case CloseMessage, PingMessage, PongMessage:
820835
if c.readRemaining > maxControlFramePayloadSize {
821-
return noFrame, c.handleProtocolError("control frame length > 125")
836+
errors = append(errors, "len > 125 for control")
822837
}
823838
if !final {
824-
return noFrame, c.handleProtocolError("control frame not final")
839+
errors = append(errors, "FIN not set on control")
825840
}
826841
case TextMessage, BinaryMessage:
827842
if !c.readFinal {
828-
return noFrame, c.handleProtocolError("message start before final message frame")
843+
errors = append(errors, "data before FIN")
829844
}
830845
c.readFinal = final
831846
case continuationFrame:
832847
if c.readFinal {
833-
return noFrame, c.handleProtocolError("continuation after final message frame")
848+
errors = append(errors, "continuation after FIN")
834849
}
835850
c.readFinal = final
836851
default:
837-
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
852+
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
853+
}
854+
855+
if mask != c.isServer {
856+
errors = append(errors, "bad MASK")
857+
}
858+
859+
if len(errors) > 0 {
860+
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
838861
}
839862

840863
// 3. Read and parse frame length as per
@@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) {
872895

873896
// 4. Handle frame masking.
874897

875-
if mask != c.isServer {
876-
return noFrame, c.handleProtocolError("incorrect mask flag")
877-
}
878-
879898
if mask {
880899
c.readMaskPos = 0
881900
p, err := c.read(len(c.readMaskKey))
@@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) {
935954
if len(payload) >= 2 {
936955
closeCode = int(binary.BigEndian.Uint16(payload))
937956
if !isValidReceivedCloseCode(closeCode) {
938-
return noFrame, c.handleProtocolError("invalid close code")
957+
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
939958
}
940959
closeText = string(payload[2:])
941960
if !utf8.ValidString(closeText) {
@@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) {
952971
}
953972

954973
func (c *Conn) handleProtocolError(message string) error {
955-
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
974+
data := FormatCloseMessage(CloseProtocolError, message)
975+
if len(data) > maxControlFramePayloadSize {
976+
data = data[:maxControlFramePayloadSize]
977+
}
978+
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
956979
return errors.New("websocket: " + message)
957980
}
958981

0 commit comments

Comments
 (0)