@@ -24,9 +24,11 @@ import (
24
24
"fmt"
25
25
"net"
26
26
"runtime/debug"
27
+ "strings"
27
28
"testing"
28
29
"time"
29
30
31
+ "github.com/cloudwego/gopkg/protocol/ttheader"
30
32
"github.com/cloudwego/netpoll"
31
33
32
34
"github.com/cloudwego/kitex/internal/mocks"
@@ -87,17 +89,54 @@ func (m *mockNetpollConn) SetOnDisconnect(onDisconnect netpoll.OnDisconnect) err
87
89
return nil
88
90
}
89
91
92
+ type mockTracer struct {
93
+ finishFunc func (ctx context.Context )
94
+ }
95
+
96
+ func (m * mockTracer ) Start (ctx context.Context ) context.Context {
97
+ return ctx
98
+ }
99
+
100
+ func (m * mockTracer ) Finish (ctx context.Context ) {
101
+ m .finishFunc (ctx )
102
+ }
103
+
104
+ type mockHeaderFrameReadHandler struct {
105
+ ripTag string
106
+ }
107
+
108
+ func (m * mockHeaderFrameReadHandler ) OnReadStream (ctx context.Context , ihd IntHeader , shd StrHeader ) (context.Context , error ) {
109
+ ri := rpcinfo .GetRPCInfo (ctx )
110
+ fi := rpcinfo .AsMutableEndpointInfo (ri .From ())
111
+ if rip , ok := shd [ttheader .HeaderTransRemoteAddr ]; ok {
112
+ fi .SetTag (m .ripTag , rip )
113
+ }
114
+ return ctx , nil
115
+ }
116
+
90
117
func TestOnStream (t * testing.T ) {
118
+ ripTag := "rip"
119
+ tracer := & mockTracer {}
120
+ traceCtl := & rpcinfo.TraceController {}
121
+ traceCtl .Append (tracer )
91
122
factory := NewSvrTransHandlerFactory ()
92
123
rawTransHdl , err := factory .NewTransHandler (& remote.ServerOption {
93
124
SvcSearcher : mock_remote .NewDefaultSvcSearcher (),
94
125
InitOrResetRPCInfoFunc : func (info rpcinfo.RPCInfo , addr net.Addr ) rpcinfo.RPCInfo {
95
- return rpcinfo .NewRPCInfo (nil , nil ,
126
+ return rpcinfo .NewRPCInfo (rpcinfo .NewEndpointInfo (
127
+ mocks .MockService2Name , mocks .Mock2Method , nil , make (map [string ]string )), nil ,
96
128
rpcinfo .NewInvocation (mocks .MockServiceName , mocks .MockStreamingMethod ),
97
129
rpcinfo .NewRPCConfig (),
98
130
rpcinfo .NewRPCStats ())
99
131
},
100
- TracerCtl : & rpcinfo.TraceController {},
132
+ TracerCtl : traceCtl ,
133
+ TTHeaderStreamingOptions : remote.TTHeaderStreamingOptions {
134
+ TransportOptions : []interface {}{
135
+ WithServerHeaderFrameHandler (& mockHeaderFrameReadHandler {
136
+ ripTag : ripTag ,
137
+ }),
138
+ },
139
+ },
101
140
})
102
141
test .Assert (t , err == nil , err )
103
142
transHdl := rawTransHdl .(* svrTransHandler )
@@ -124,30 +163,53 @@ func TestOnStream(t *testing.T) {
124
163
}()
125
164
126
165
t .Run ("invoking handler successfully" , func (t * testing.T ) {
166
+ tracer .finishFunc = func (ctx context.Context ) {
167
+ ri := rpcinfo .GetRPCInfo (ctx )
168
+ test .Assert (t , ri != nil , ri )
169
+ rip , ok := ri .From ().Tag (ripTag )
170
+ test .Assert (t , ok )
171
+ test .Assert (t , rip == "127.0.0.1:8888" , rip )
172
+ }
127
173
transHdl .SetInvokeHandleFunc (func (ctx context.Context , req , resp interface {}) (err error ) {
128
174
return nil
129
175
})
130
176
err = EncodeFrame (context .Background (), wbuf , & Frame {
131
177
streamFrame : streamFrame {
132
178
sid : 1 ,
133
179
method : mocks .MockStreamingMethod ,
180
+ header : map [string ]string {
181
+ ttheader .HeaderTransRemoteAddr : "127.0.0.1:8888" ,
182
+ },
134
183
},
135
184
typ : headerFrameType ,
136
185
})
137
186
test .Assert (t , err == nil , err )
138
187
err = wbuf .Flush ()
139
188
test .Assert (t , err == nil , err )
140
- nctx , ss , err := transHdl .provider .OnStream (ctx , mockConn )
189
+ nctx , st , err := transHdl .provider .OnStream (ctx , mockConn )
141
190
test .Assert (t , err == nil , err )
142
- err = transHdl .OnStream (nctx , mockConn , ss )
191
+ err = transHdl .OnStream (nctx , mockConn , st )
143
192
test .Assert (t , err == nil , err )
144
193
})
145
194
146
195
t .Run ("invoking handler panic" , func (t * testing.T ) {
196
+ tracer .finishFunc = func (ctx context.Context ) {
197
+ ri := rpcinfo .GetRPCInfo (ctx )
198
+ test .Assert (t , ri != nil , ri )
199
+ rip , ok := ri .From ().Tag (ripTag )
200
+ test .Assert (t , ok )
201
+ test .Assert (t , rip == "127.0.0.1:8888" , rip )
202
+ ok , pErr := ri .Stats ().Panicked ()
203
+ test .Assert (t , ok )
204
+ test .Assert (t , errors .Is (pErr .(error ), kerrors .ErrPanic ), pErr )
205
+ }
147
206
transHdl .SetInvokeHandleFunc (func (ctx context.Context , req , resp interface {}) (err error ) {
148
207
defer func () {
149
208
if handlerErr := recover (); handlerErr != nil {
209
+ ri := rpcinfo .GetRPCInfo (ctx )
150
210
err = kerrors .ErrPanic .WithCauseAndStack (fmt .Errorf ("[panic] %s" , handlerErr ), string (debug .Stack ()))
211
+ rpcStats := rpcinfo .AsMutableRPCStats (ri .Stats ())
212
+ rpcStats .SetPanicked (err )
151
213
}
152
214
}()
153
215
panic ("test" )
@@ -156,20 +218,33 @@ func TestOnStream(t *testing.T) {
156
218
streamFrame : streamFrame {
157
219
sid : 1 ,
158
220
method : mocks .MockStreamingMethod ,
221
+ header : map [string ]string {
222
+ ttheader .HeaderTransRemoteAddr : "127.0.0.1:8888" ,
223
+ },
159
224
},
160
225
typ : headerFrameType ,
161
226
})
162
227
test .Assert (t , err == nil , err )
163
228
err = wbuf .Flush ()
164
229
test .Assert (t , err == nil , err )
165
- nctx , ss , err := transHdl .provider .OnStream (ctx , mockConn )
230
+ nctx , st , err := transHdl .provider .OnStream (ctx , mockConn )
166
231
test .Assert (t , err == nil , err )
167
- err = transHdl .OnStream (nctx , mockConn , ss )
232
+ err = transHdl .OnStream (nctx , mockConn , st )
168
233
test .Assert (t , errors .Is (err , kerrors .ErrPanic ), err )
169
234
transHdl .OnError (ctx , err , mockConn )
170
235
})
171
236
172
237
t .Run ("invoking handler throws biz error" , func (t * testing.T ) {
238
+ tracer .finishFunc = func (ctx context.Context ) {
239
+ ri := rpcinfo .GetRPCInfo (ctx )
240
+ test .Assert (t , ri != nil , ri )
241
+ rip , ok := ri .From ().Tag (ripTag )
242
+ test .Assert (t , ok )
243
+ test .Assert (t , rip == "127.0.0.1:8888" , rip )
244
+ bizErr := ri .Invocation ().BizStatusErr ()
245
+ test .Assert (t , bizErr .BizStatusCode () == 10000 , bizErr )
246
+ test .Assert (t , strings .Contains (bizErr .BizMessage (), "biz-error test" ), bizErr )
247
+ }
173
248
transHdl .SetInvokeHandleFunc (func (ctx context.Context , req , resp interface {}) (err error ) {
174
249
ri := rpcinfo .GetRPCInfo (ctx )
175
250
defer func () {
@@ -186,15 +261,18 @@ func TestOnStream(t *testing.T) {
186
261
streamFrame : streamFrame {
187
262
sid : 1 ,
188
263
method : mocks .MockStreamingMethod ,
264
+ header : map [string ]string {
265
+ ttheader .HeaderTransRemoteAddr : "127.0.0.1:8888" ,
266
+ },
189
267
},
190
268
typ : headerFrameType ,
191
269
})
192
270
test .Assert (t , err == nil , err )
193
271
err = wbuf .Flush ()
194
272
test .Assert (t , err == nil , err )
195
- nctx , ss , err := transHdl .provider .OnStream (ctx , mockConn )
273
+ nctx , st , err := transHdl .provider .OnStream (ctx , mockConn )
196
274
test .Assert (t , err == nil , err )
197
- err = transHdl .OnStream (nctx , mockConn , ss )
275
+ err = transHdl .OnStream (nctx , mockConn , st )
198
276
test .Assert (t , err == nil , err )
199
277
})
200
278
}
0 commit comments