Skip to content

Commit 3c39121

Browse files
author
Mikhail Sakhnov
committed
handle data package to agent in the same goroutine
cherry pick unit test from the irozzo-1A/apiserver-hetwork-proxy fork
1 parent d8635dc commit 3c39121

File tree

2 files changed

+169
-34
lines changed

2 files changed

+169
-34
lines changed

pkg/server/server.go

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -756,18 +756,19 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, stream agent.AgentServic
756756
klog.V(5).InfoS("Received data from agent", "bytes", len(resp.Data), "agentID", agentID, "connectionID", resp.ConnectID)
757757
kasConnCtx, isFromAgent := s.forwardConnectionManager.Get(resp.ConnectID)
758758
if isFromAgent {
759-
kasConnCtx.dataCh <- resp.Data
759+
s.handleDataPacketFromAgent(resp, kasConnCtx)
760760
continue
761-
}
762-
frontend, err := s.getFrontend(agentID, resp.ConnectID)
763-
if err != nil {
764-
klog.ErrorS(err, "could not get frontend client", "connectionID", resp.ConnectID)
765-
break
766-
}
767-
if err := frontend.send(pkt); err != nil {
768-
klog.ErrorS(err, "send to client stream failure", "serverID", s.serverID, "agentID", agentID, "connectionID", resp.ConnectID)
769761
} else {
770-
klog.V(5).InfoS("DATA sent to frontend")
762+
frontend, err := s.getFrontend(agentID, resp.ConnectID)
763+
if err != nil {
764+
klog.ErrorS(err, "could not get frontend client", "connectionID", resp.ConnectID)
765+
break
766+
}
767+
if err := frontend.send(pkt); err != nil {
768+
klog.ErrorS(err, "send to client stream failure", "serverID", s.serverID, "agentID", agentID, "connectionID", resp.ConnectID)
769+
} else {
770+
klog.V(5).InfoS("DATA sent to frontend")
771+
}
771772
}
772773

773774
case client.PacketType_CLOSE_RSP:
@@ -830,10 +831,8 @@ func (s *ProxyServer) handleDialRequest(pkt *client.Packet, backend Backend) {
830831
// Odd identifiers are used for connections from node to master network,
831832
// increment by 2 to maintain the invariant.
832833
connID := atomic.AddInt64(&s.nextConnID, 2)
833-
dataCh := make(chan []byte, 5)
834834
kasConnCtx := &connContext{
835-
conn: conn,
836-
dataCh: dataCh,
835+
conn: conn,
837836
cleanFunc: func() {
838837
klog.V(4).InfoS("close connection", "connectionID", connID)
839838
req := &client.Packet{
@@ -853,7 +852,6 @@ func (s *ProxyServer) handleDialRequest(pkt *client.Packet, backend Backend) {
853852
klog.ErrorS(err, "error occurred while closing connection", "connectionID", connID)
854853
}
855854

856-
close(dataCh)
857855
s.forwardConnectionManager.Delete(connID)
858856
},
859857
backend: backend,
@@ -869,13 +867,11 @@ func (s *ProxyServer) handleDialRequest(pkt *client.Packet, backend Backend) {
869867

870868
// proxy data to and from KAS for the connection
871869
go s.agentToProxy(connID, kasConnCtx)
872-
go s.proxyToAgent(connID, kasConnCtx)
873870
}
874871

875872
type connContext struct {
876873
conn net.Conn
877874
cleanFunc func()
878-
dataCh chan []byte
879875
cleanOnce sync.Once
880876
backend Backend
881877
}
@@ -915,24 +911,22 @@ func (s *ProxyServer) agentToProxy(connID int64, ctx *connContext) {
915911
}
916912
}
917913

918-
func (s *ProxyServer) proxyToAgent(connID int64, ctx *connContext) {
919-
defer ctx.cleanup()
920-
921-
for d := range ctx.dataCh {
922-
pos := 0
923-
for {
924-
n, err := ctx.conn.Write(d[pos:])
925-
if err == nil {
926-
klog.V(4).InfoS("write to remote", "connectionID", connID, "lastData", n)
927-
break
928-
} else if n > 0 {
929-
// https://golang.org/pkg/io/#Writer specifies return non nil error if n < len(d)
930-
klog.ErrorS(err, "write to remote with failure", "connectionID", connID, "lastData", n)
931-
pos += n
932-
} else {
933-
klog.ErrorS(err, "conn write failure", "connectionID", connID)
934-
return
935-
}
914+
func (s *ProxyServer) handleDataPacketFromAgent(resp *client.Data, kasConnCtx *connContext) {
915+
pos := 0
916+
connID := resp.ConnectID
917+
d := resp.Data
918+
for {
919+
n, err := kasConnCtx.conn.Write(d[pos:])
920+
if err == nil {
921+
klog.V(4).InfoS("write to remote", "connectionID", connID, "lastData", n)
922+
break
923+
} else if n > 0 {
924+
// https://golang.org/pkg/io/#Writer specifies return non nil error if n < len(d)
925+
klog.ErrorS(err, "write to remote with failure", "connectionID", connID, "lastData", n)
926+
pos += n
927+
} else {
928+
klog.ErrorS(err, "conn write failure", "connectionID", connID)
929+
return
936930
}
937931
}
938932
}

pkg/server/server_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ import (
2020
"context"
2121
"fmt"
2222
"io"
23+
"math/rand"
24+
"net"
2325
"reflect"
26+
"sync"
2427
"testing"
2528

2629
"github.com/golang/mock/gomock"
@@ -32,6 +35,8 @@ import (
3235
fakeauthenticationv1 "k8s.io/client-go/kubernetes/typed/authentication/v1/fake"
3336
k8stesting "k8s.io/client-go/testing"
3437

38+
client "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client"
39+
"sigs.k8s.io/apiserver-network-proxy/proto/agent"
3540
agentmock "sigs.k8s.io/apiserver-network-proxy/proto/agent/mocks"
3641
"sigs.k8s.io/apiserver-network-proxy/proto/header"
3742
)
@@ -218,3 +223,139 @@ func TestAddRemoveFrontends(t *testing.T) {
218223
t.Errorf("expected %v, got %v", e, a)
219224
}
220225
}
226+
227+
type testStream struct {
228+
agent.AgentService_ConnectServer
229+
ch chan *client.Packet
230+
}
231+
232+
func (f testStream) SendHeader(md metadata.MD) error {
233+
return nil
234+
}
235+
236+
func (f testStream) Context() context.Context {
237+
ctx := context.Background()
238+
ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(header.AgentID, "test-agent-1"))
239+
return ctx
240+
}
241+
242+
func (t testStream) Send(packet *client.Packet) error {
243+
return nil
244+
}
245+
246+
func (t testStream) Recv() (*client.Packet, error) {
247+
v, ok := <-t.ch
248+
if !ok {
249+
return nil, io.EOF
250+
}
251+
return v, nil
252+
}
253+
254+
func TestNodeToControlPlane(t *testing.T) {
255+
testData := []byte("this is test data")
256+
stub := gomock.NewController(t)
257+
defer stub.Finish()
258+
var wg sync.WaitGroup
259+
260+
// create proxy server
261+
p := NewProxyServer("", []ProxyStrategy{ProxyStrategyDefault}, 1,
262+
&AgentTokenAuthenticationOptions{}, false)
263+
264+
// start a controlplane server
265+
lis, err := net.Listen("tcp", ":0")
266+
if err != nil {
267+
panic(err)
268+
}
269+
270+
var receivedData []byte
271+
wg.Add(1)
272+
go func() {
273+
defer wg.Done()
274+
conn, err := lis.Accept()
275+
if err != nil {
276+
panic(err)
277+
}
278+
for {
279+
var data [20]byte
280+
n, err := conn.Read(data[:])
281+
if err == io.EOF {
282+
break
283+
}
284+
285+
if err != nil {
286+
panic(err)
287+
}
288+
receivedData = append(receivedData, data[:n]...)
289+
}
290+
}()
291+
292+
stream := &testStream{ch: make(chan *client.Packet, 15)}
293+
wg.Add(1)
294+
295+
go func() {
296+
connID := int64(3)
297+
stream.ch <- packetDialReq(lis.Addr().String())
298+
stream.ch <- packetData(connID, testData)
299+
stream.ch <- packetDataEOF(connID)
300+
stream.ch <- packetCloseReq(connID)
301+
close(stream.ch)
302+
wg.Done()
303+
}()
304+
// connect the agent to proxy server
305+
err = p.Connect(stream)
306+
if err != nil {
307+
panic(err)
308+
}
309+
wg.Wait()
310+
if !reflect.DeepEqual(testData, receivedData) {
311+
t.Fatalf("Expected: %v Got: %v", testData, receivedData)
312+
}
313+
}
314+
315+
func packetCloseReq(connID int64) *client.Packet {
316+
return &client.Packet{
317+
Type: client.PacketType_CLOSE_REQ,
318+
Payload: &client.Packet_CloseRequest{
319+
CloseRequest: &client.CloseRequest{
320+
ConnectID: connID,
321+
},
322+
},
323+
}
324+
}
325+
326+
func packetData(connID int64, testData []byte) *client.Packet {
327+
return &client.Packet{
328+
Type: client.PacketType_DATA,
329+
Payload: &client.Packet_Data{
330+
Data: &client.Data{
331+
ConnectID: connID,
332+
Data: testData,
333+
},
334+
},
335+
}
336+
}
337+
338+
func packetDialReq(addr string) *client.Packet {
339+
return &client.Packet{
340+
Type: client.PacketType_DIAL_REQ,
341+
Payload: &client.Packet_DialRequest{
342+
DialRequest: &client.DialRequest{
343+
Protocol: "tcp",
344+
Address: addr,
345+
Random: rand.Int63(),
346+
},
347+
},
348+
}
349+
}
350+
351+
func packetDataEOF(connID int64) *client.Packet {
352+
return &client.Packet{
353+
Type: client.PacketType_DATA,
354+
Payload: &client.Packet_Data{
355+
Data: &client.Data{
356+
ConnectID: connID,
357+
Error: io.EOF.Error(),
358+
},
359+
},
360+
}
361+
}

0 commit comments

Comments
 (0)