Skip to content

Commit 6ca6da9

Browse files
committed
fix(ttstream): metrics missing caused by server-side rpcinfo not set correctly
1 parent b3ca38d commit 6ca6da9

File tree

3 files changed

+123
-45
lines changed

3 files changed

+123
-45
lines changed

pkg/remote/trans/ttstream/server_handler.go

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ import (
2525
"runtime/debug"
2626
"sync"
2727

28+
"github.com/bytedance/gopkg/cloud/metainfo"
29+
2830
"github.com/cloudwego/kitex/pkg/endpoint"
2931
"github.com/cloudwego/kitex/pkg/gofunc"
3032
"github.com/cloudwego/kitex/pkg/kerrors"
3133
"github.com/cloudwego/kitex/pkg/klog"
3234
"github.com/cloudwego/kitex/pkg/remote"
35+
"github.com/cloudwego/kitex/pkg/remote/trans/ttstream/ktx"
3336
"github.com/cloudwego/kitex/pkg/rpcinfo"
3437
"github.com/cloudwego/kitex/pkg/streaming"
3538
ktransport "github.com/cloudwego/kitex/transport"
@@ -118,7 +121,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
118121
}()
119122
// connection level goroutine
120123
for {
121-
nctx, ss, nerr := t.provider.OnStream(ctx, conn)
124+
nctx, st, nerr := t.provider.OnStream(ctx, conn)
122125
if nerr != nil {
123126
if errors.Is(nerr, io.EOF) {
124127
return nil
@@ -130,7 +133,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
130133
// stream level goroutine
131134
gofunc.GoFunc(nctx, func() {
132135
defer wg.Done()
133-
err := t.OnStream(nctx, conn, ss)
136+
err := t.OnStream(nctx, conn, st)
134137
if err != nil && !errors.Is(err, io.EOF) {
135138
t.OnError(nctx, err, conn)
136139
}
@@ -142,7 +145,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
142145
// - create server stream
143146
// - process server stream
144147
// - close server stream
145-
func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss streaming.ServerStream) (err error) {
148+
func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, st *stream) (err error) {
146149
ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
147150
ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
148151
defer func() {
@@ -153,23 +156,37 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream
153156
}()
154157

155158
ink := ri.Invocation().(rpcinfo.InvocationSetter)
156-
if si, ok := ss.(StreamInfo); ok {
157-
sinfo := t.opt.SvcSearcher.SearchService(si.Service(), si.Method(), false)
158-
if sinfo == nil {
159-
return remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", si.Service()))
160-
}
161-
minfo := sinfo.MethodInfo(si.Method())
162-
if minfo == nil {
163-
return remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", si.Method()))
164-
}
165-
ink.SetServiceName(sinfo.ServiceName)
166-
ink.SetMethodName(si.Method())
167-
ink.SetStreamingMode(minfo.StreamingMode())
168-
if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil {
169-
_ = mutableTo.SetMethod(si.Method())
159+
sinfo := t.opt.SvcSearcher.SearchService(st.Service(), st.Method(), false)
160+
if sinfo == nil {
161+
return remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", st.Service()))
162+
}
163+
minfo := sinfo.MethodInfo(st.Method())
164+
if minfo == nil {
165+
return remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", st.Method()))
166+
}
167+
ink.SetServiceName(sinfo.ServiceName)
168+
ink.SetMethodName(st.Method())
169+
ink.SetStreamingMode(minfo.StreamingMode())
170+
if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil {
171+
_ = mutableTo.SetMethod(st.Method())
172+
}
173+
rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(st.TransportProtocol())
174+
175+
// headerHandler return a new stream level ctx
176+
// it contains rpcinfo modified by HeaderHandler
177+
if t.provider.headerHandler != nil {
178+
ctx, err = t.provider.headerHandler.OnReadStream(ctx, st.meta, st.header)
179+
if err != nil {
180+
return err
170181
}
171-
rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(si.TransportProtocol())
172182
}
183+
// register metainfo into ctx
184+
ctx = metainfo.SetMetaInfoFromMap(ctx, st.header)
185+
ss := newServerStream(st)
186+
187+
// cancel ctx when OnStreamFinish
188+
ctx, cancelFunc := ktx.WithCancel(ctx)
189+
ctx = context.WithValue(ctx, serverStreamCancelCtxKey{}, cancelFunc)
173190

174191
ctx = t.startTracer(ctx, ri)
175192
defer func() {

pkg/remote/trans/ttstream/server_handler_test.go

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ import (
2424
"fmt"
2525
"net"
2626
"runtime/debug"
27+
"strings"
2728
"testing"
2829
"time"
2930

31+
"github.com/cloudwego/gopkg/protocol/ttheader"
3032
"github.com/cloudwego/netpoll"
3133

3234
"github.com/cloudwego/kitex/internal/mocks"
@@ -87,17 +89,54 @@ func (m *mockNetpollConn) SetOnDisconnect(onDisconnect netpoll.OnDisconnect) err
8789
return nil
8890
}
8991

92+
type mockTracer struct {
93+
finishFunc func(ctx context.Context)
94+
}
95+
96+
func (m *mockTracer) Start(ctx context.Context) context.Context {
97+
return ctx
98+
}
99+
100+
func (m *mockTracer) Finish(ctx context.Context) {
101+
m.finishFunc(ctx)
102+
}
103+
104+
type mockHeaderFrameReadHandler struct {
105+
ripTag string
106+
}
107+
108+
func (m *mockHeaderFrameReadHandler) OnReadStream(ctx context.Context, ihd IntHeader, shd StrHeader) (context.Context, error) {
109+
ri := rpcinfo.GetRPCInfo(ctx)
110+
fi := rpcinfo.AsMutableEndpointInfo(ri.From())
111+
if rip, ok := shd[ttheader.HeaderTransRemoteAddr]; ok {
112+
fi.SetTag(m.ripTag, rip)
113+
}
114+
return ctx, nil
115+
}
116+
90117
func TestOnStream(t *testing.T) {
118+
ripTag := "rip"
119+
tracer := &mockTracer{}
120+
traceCtl := &rpcinfo.TraceController{}
121+
traceCtl.Append(tracer)
91122
factory := NewSvrTransHandlerFactory()
92123
rawTransHdl, err := factory.NewTransHandler(&remote.ServerOption{
93124
SvcSearcher: mock_remote.NewDefaultSvcSearcher(),
94125
InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
95-
return rpcinfo.NewRPCInfo(nil, nil,
126+
return rpcinfo.NewRPCInfo(rpcinfo.NewEndpointInfo(
127+
mocks.MockService2Name, mocks.Mock2Method, nil, make(map[string]string)), nil,
96128
rpcinfo.NewInvocation(mocks.MockServiceName, mocks.MockStreamingMethod),
97129
rpcinfo.NewRPCConfig(),
98130
rpcinfo.NewRPCStats())
99131
},
100-
TracerCtl: &rpcinfo.TraceController{},
132+
TracerCtl: traceCtl,
133+
TTHeaderStreamingOptions: remote.TTHeaderStreamingOptions{
134+
TransportOptions: []interface{}{
135+
WithServerHeaderFrameHandler(&mockHeaderFrameReadHandler{
136+
ripTag: ripTag,
137+
}),
138+
},
139+
},
101140
})
102141
test.Assert(t, err == nil, err)
103142
transHdl := rawTransHdl.(*svrTransHandler)
@@ -124,30 +163,53 @@ func TestOnStream(t *testing.T) {
124163
}()
125164

126165
t.Run("invoking handler successfully", func(t *testing.T) {
166+
tracer.finishFunc = func(ctx context.Context) {
167+
ri := rpcinfo.GetRPCInfo(ctx)
168+
test.Assert(t, ri != nil, ri)
169+
rip, ok := ri.From().Tag(ripTag)
170+
test.Assert(t, ok)
171+
test.Assert(t, rip == "127.0.0.1:8888", rip)
172+
}
127173
transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
128174
return nil
129175
})
130176
err = EncodeFrame(context.Background(), wbuf, &Frame{
131177
streamFrame: streamFrame{
132178
sid: 1,
133179
method: mocks.MockStreamingMethod,
180+
header: map[string]string{
181+
ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888",
182+
},
134183
},
135184
typ: headerFrameType,
136185
})
137186
test.Assert(t, err == nil, err)
138187
err = wbuf.Flush()
139188
test.Assert(t, err == nil, err)
140-
nctx, ss, err := transHdl.provider.OnStream(ctx, mockConn)
189+
nctx, st, err := transHdl.provider.OnStream(ctx, mockConn)
141190
test.Assert(t, err == nil, err)
142-
err = transHdl.OnStream(nctx, mockConn, ss)
191+
err = transHdl.OnStream(nctx, mockConn, st)
143192
test.Assert(t, err == nil, err)
144193
})
145194

146195
t.Run("invoking handler panic", func(t *testing.T) {
196+
tracer.finishFunc = func(ctx context.Context) {
197+
ri := rpcinfo.GetRPCInfo(ctx)
198+
test.Assert(t, ri != nil, ri)
199+
rip, ok := ri.From().Tag(ripTag)
200+
test.Assert(t, ok)
201+
test.Assert(t, rip == "127.0.0.1:8888", rip)
202+
ok, pErr := ri.Stats().Panicked()
203+
test.Assert(t, ok)
204+
test.Assert(t, errors.Is(pErr.(error), kerrors.ErrPanic), pErr)
205+
}
147206
transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
148207
defer func() {
149208
if handlerErr := recover(); handlerErr != nil {
209+
ri := rpcinfo.GetRPCInfo(ctx)
150210
err = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[panic] %s", handlerErr), string(debug.Stack()))
211+
rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
212+
rpcStats.SetPanicked(err)
151213
}
152214
}()
153215
panic("test")
@@ -156,20 +218,33 @@ func TestOnStream(t *testing.T) {
156218
streamFrame: streamFrame{
157219
sid: 1,
158220
method: mocks.MockStreamingMethod,
221+
header: map[string]string{
222+
ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888",
223+
},
159224
},
160225
typ: headerFrameType,
161226
})
162227
test.Assert(t, err == nil, err)
163228
err = wbuf.Flush()
164229
test.Assert(t, err == nil, err)
165-
nctx, ss, err := transHdl.provider.OnStream(ctx, mockConn)
230+
nctx, st, err := transHdl.provider.OnStream(ctx, mockConn)
166231
test.Assert(t, err == nil, err)
167-
err = transHdl.OnStream(nctx, mockConn, ss)
232+
err = transHdl.OnStream(nctx, mockConn, st)
168233
test.Assert(t, errors.Is(err, kerrors.ErrPanic), err)
169234
transHdl.OnError(ctx, err, mockConn)
170235
})
171236

172237
t.Run("invoking handler throws biz error", func(t *testing.T) {
238+
tracer.finishFunc = func(ctx context.Context) {
239+
ri := rpcinfo.GetRPCInfo(ctx)
240+
test.Assert(t, ri != nil, ri)
241+
rip, ok := ri.From().Tag(ripTag)
242+
test.Assert(t, ok)
243+
test.Assert(t, rip == "127.0.0.1:8888", rip)
244+
bizErr := ri.Invocation().BizStatusErr()
245+
test.Assert(t, bizErr.BizStatusCode() == 10000, bizErr)
246+
test.Assert(t, strings.Contains(bizErr.BizMessage(), "biz-error test"), bizErr)
247+
}
173248
transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
174249
ri := rpcinfo.GetRPCInfo(ctx)
175250
defer func() {
@@ -186,15 +261,18 @@ func TestOnStream(t *testing.T) {
186261
streamFrame: streamFrame{
187262
sid: 1,
188263
method: mocks.MockStreamingMethod,
264+
header: map[string]string{
265+
ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888",
266+
},
189267
},
190268
typ: headerFrameType,
191269
})
192270
test.Assert(t, err == nil, err)
193271
err = wbuf.Flush()
194272
test.Assert(t, err == nil, err)
195-
nctx, ss, err := transHdl.provider.OnStream(ctx, mockConn)
273+
nctx, st, err := transHdl.provider.OnStream(ctx, mockConn)
196274
test.Assert(t, err == nil, err)
197-
err = transHdl.OnStream(nctx, mockConn, ss)
275+
err = transHdl.OnStream(nctx, mockConn, st)
198276
test.Assert(t, err == nil, err)
199277
})
200278
}

pkg/remote/trans/ttstream/server_provider.go

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@ import (
2222
"net"
2323
"strconv"
2424

25-
"github.com/bytedance/gopkg/cloud/metainfo"
2625
"github.com/cloudwego/gopkg/protocol/thrift"
2726
"github.com/cloudwego/gopkg/protocol/ttheader"
2827
"github.com/cloudwego/netpoll"
2928

3029
"github.com/cloudwego/kitex/pkg/kerrors"
3130
"github.com/cloudwego/kitex/pkg/remote"
32-
"github.com/cloudwego/kitex/pkg/remote/trans/ttstream/ktx"
3331
"github.com/cloudwego/kitex/pkg/streaming"
3432
"github.com/cloudwego/kitex/pkg/utils"
3533
)
@@ -86,7 +84,7 @@ func (s *serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context
8684
return ctx, nil
8785
}
8886

89-
func (s *serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, streaming.ServerStream, error) {
87+
func (s *serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, *stream, error) {
9088
trans, _ := ctx.Value(serverTransCtxKey{}).(*transport)
9189
if trans == nil {
9290
return nil, nil, fmt.Errorf("server transport is nil")
@@ -98,22 +96,7 @@ func (s *serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.C
9896
return nil, nil, err
9997
}
10098
st.setMetaFrameHandler(s.metaHandler)
101-
102-
// headerHandler return a new stream level ctx
103-
if s.headerHandler != nil {
104-
ctx, err = s.headerHandler.OnReadStream(ctx, st.meta, st.header)
105-
if err != nil {
106-
return nil, nil, err
107-
}
108-
}
109-
// register metainfo into ctx
110-
ctx = metainfo.SetMetaInfoFromMap(ctx, st.header)
111-
ss := newServerStream(st)
112-
113-
// cancel ctx when OnStreamFinish
114-
ctx, cancelFunc := ktx.WithCancel(ctx)
115-
ctx = context.WithValue(ctx, serverStreamCancelCtxKey{}, cancelFunc)
116-
return ctx, ss, nil
99+
return ctx, st, nil
117100
}
118101

119102
func (s *serverProvider) OnStreamFinish(ctx context.Context, ss streaming.ServerStream, err error) (context.Context, error) {

0 commit comments

Comments
 (0)