@@ -25,11 +25,14 @@ import (
25
25
"runtime/debug"
26
26
"sync"
27
27
28
+ "github.com/bytedance/gopkg/cloud/metainfo"
29
+
28
30
"github.com/cloudwego/kitex/pkg/endpoint"
29
31
"github.com/cloudwego/kitex/pkg/gofunc"
30
32
"github.com/cloudwego/kitex/pkg/kerrors"
31
33
"github.com/cloudwego/kitex/pkg/klog"
32
34
"github.com/cloudwego/kitex/pkg/remote"
35
+ "github.com/cloudwego/kitex/pkg/remote/trans/ttstream/ktx"
33
36
"github.com/cloudwego/kitex/pkg/rpcinfo"
34
37
"github.com/cloudwego/kitex/pkg/streaming"
35
38
ktransport "github.com/cloudwego/kitex/transport"
@@ -118,7 +121,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
118
121
}()
119
122
// connection level goroutine
120
123
for {
121
- nctx , ss , nerr := t .provider .OnStream (ctx , conn )
124
+ nctx , st , nerr := t .provider .OnStream (ctx , conn )
122
125
if nerr != nil {
123
126
if errors .Is (nerr , io .EOF ) {
124
127
return nil
@@ -130,7 +133,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
130
133
// stream level goroutine
131
134
gofunc .GoFunc (nctx , func () {
132
135
defer wg .Done ()
133
- err := t .OnStream (nctx , conn , ss )
136
+ err := t .OnStream (nctx , conn , st )
134
137
if err != nil && ! errors .Is (err , io .EOF ) {
135
138
t .OnError (nctx , err , conn )
136
139
}
@@ -142,7 +145,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
142
145
// - create server stream
143
146
// - process server stream
144
147
// - close server stream
145
- func (t * svrTransHandler ) OnStream (ctx context.Context , conn net.Conn , ss streaming. ServerStream ) (err error ) {
148
+ func (t * svrTransHandler ) OnStream (ctx context.Context , conn net.Conn , st * stream ) (err error ) {
146
149
ri := t .opt .InitOrResetRPCInfoFunc (nil , conn .RemoteAddr ())
147
150
ctx = rpcinfo .NewCtxWithRPCInfo (ctx , ri )
148
151
defer func () {
@@ -153,23 +156,35 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream
153
156
}()
154
157
155
158
ink := ri .Invocation ().(rpcinfo.InvocationSetter )
156
- if si , ok := ss .(StreamInfo ); ok {
157
- sinfo := t .opt .SvcSearcher .SearchService (si .Service (), si .Method (), false )
158
- if sinfo == nil {
159
- return remote .NewTransErrorWithMsg (remote .UnknownService , fmt .Sprintf ("unknown service %s" , si .Service ()))
160
- }
161
- minfo := sinfo .MethodInfo (si .Method ())
162
- if minfo == nil {
163
- return remote .NewTransErrorWithMsg (remote .UnknownMethod , fmt .Sprintf ("unknown method %s" , si .Method ()))
164
- }
165
- ink .SetServiceName (sinfo .ServiceName )
166
- ink .SetMethodName (si .Method ())
167
- ink .SetStreamingMode (minfo .StreamingMode ())
168
- if mutableTo := rpcinfo .AsMutableEndpointInfo (ri .To ()); mutableTo != nil {
169
- _ = mutableTo .SetMethod (si .Method ())
159
+ sinfo := t .opt .SvcSearcher .SearchService (st .Service (), st .Method (), false )
160
+ if sinfo == nil {
161
+ return remote .NewTransErrorWithMsg (remote .UnknownService , fmt .Sprintf ("unknown service %s" , st .Service ()))
162
+ }
163
+ minfo := sinfo .MethodInfo (st .Method ())
164
+ if minfo == nil {
165
+ return remote .NewTransErrorWithMsg (remote .UnknownMethod , fmt .Sprintf ("unknown method %s" , st .Method ()))
166
+ }
167
+ ink .SetServiceName (sinfo .ServiceName )
168
+ ink .SetMethodName (st .Method ())
169
+ ink .SetStreamingMode (minfo .StreamingMode ())
170
+ if mutableTo := rpcinfo .AsMutableEndpointInfo (ri .To ()); mutableTo != nil {
171
+ _ = mutableTo .SetMethod (st .Method ())
172
+ }
173
+ rpcinfo .AsMutableRPCConfig (ri .Config ()).SetTransportProtocol (st .TransportProtocol ())
174
+ // headerHandler return a new stream level ctx
175
+ if t .provider .headerHandler != nil {
176
+ ctx , err = t .provider .headerHandler .OnReadStream (ctx , st .meta , st .header )
177
+ if err != nil {
178
+ return err
170
179
}
171
- rpcinfo .AsMutableRPCConfig (ri .Config ()).SetTransportProtocol (si .TransportProtocol ())
172
180
}
181
+ // register metainfo into ctx
182
+ ctx = metainfo .SetMetaInfoFromMap (ctx , st .header )
183
+ ss := newServerStream (st )
184
+
185
+ // cancel ctx when OnStreamFinish
186
+ ctx , cancelFunc := ktx .WithCancel (ctx )
187
+ ctx = context .WithValue (ctx , serverStreamCancelCtxKey {}, cancelFunc )
173
188
174
189
ctx = t .startTracer (ctx , ri )
175
190
defer func () {
0 commit comments