Skip to content

Commit fe99d36

Browse files
committed
Add waitstream strategy
1 parent e3d398e commit fe99d36

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

x/configurl/module.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ func RegisterDefaultProviders(c *ProviderContainer) *ProviderContainer {
6565
registerWebsocketStreamDialer(&c.StreamDialers, "ws", c.StreamDialers.NewInstance)
6666
registerWebsocketPacketDialer(&c.PacketDialers, "ws", c.StreamDialers.NewInstance)
6767

68+
registerWaitStreamDialer(&c.StreamDialers, "waitstream", c.StreamDialers.NewInstance)
69+
6870
return c
6971
}
7072

x/configurl/wait_stream.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright 2024 The Outline Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package configurl
16+
17+
import (
18+
"context"
19+
20+
"github.com/Jigsaw-Code/outline-sdk/transport"
21+
"github.com/Jigsaw-Code/outline-sdk/x/wait_stream"
22+
)
23+
24+
func registerWaitStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) {
25+
r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) {
26+
sd, err := newSD(ctx, config.BaseConfig)
27+
if err != nil {
28+
return nil, err
29+
}
30+
return wait_stream.NewStreamDialer(sd)
31+
})
32+
}

x/wait_stream/stream_dialer.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2024 The Outline Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package wait_stream
16+
17+
import (
18+
"context"
19+
"errors"
20+
"net"
21+
22+
"github.com/Jigsaw-Code/outline-sdk/transport"
23+
"github.com/Jigsaw-Code/outline-sdk/x/sockopt"
24+
)
25+
26+
type waitStreamDialer struct {
27+
dialer transport.StreamDialer
28+
}
29+
30+
var _ transport.StreamDialer = (*waitStreamDialer)(nil)
31+
32+
func NewStreamDialer(dialer transport.StreamDialer) (transport.StreamDialer, error) {
33+
if dialer == nil {
34+
return nil, errors.New("argument dialer must not be nil")
35+
}
36+
return &waitStreamDialer{dialer: dialer}, nil
37+
}
38+
39+
func (d *waitStreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) {
40+
innerConn, err := d.dialer.DialStream(ctx, remoteAddr)
41+
if err != nil {
42+
return nil, err
43+
}
44+
45+
tcpInnerConn, ok := innerConn.(*net.TCPConn)
46+
if !ok {
47+
return nil, errors.New("wait_stream strategy: expected base dialer to return TCPConn")
48+
}
49+
50+
tcpOptions, err := sockopt.NewTCPOptions(tcpInnerConn)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
dw := NewWriter(innerConn, tcpOptions)
56+
57+
return transport.WrapConn(innerConn, innerConn, dw), nil
58+
}

x/wait_stream/writer.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2024 The Outline Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package wait_stream
16+
17+
import (
18+
"io"
19+
20+
"github.com/Jigsaw-Code/outline-sdk/x/sockopt"
21+
)
22+
23+
type waitStreamWriter struct {
24+
conn io.Writer
25+
tcpOptions sockopt.TCPOptions
26+
}
27+
28+
var _ io.Writer = (*waitStreamWriter)(nil)
29+
30+
func NewWriter(conn io.Writer, tcpOptions sockopt.TCPOptions) io.Writer {
31+
return &waitStreamWriter{
32+
conn: conn,
33+
tcpOptions: tcpOptions,
34+
}
35+
}
36+
37+
func (w *waitStreamWriter) Write(data []byte) (written int, err error) {
38+
written, err = w.conn.Write(data)
39+
40+
// This may not be implemented, so it's best effort really.
41+
/*
42+
waitUntilBytesAreSentErr := w.tcpOptions.WaitUntilBytesAreSent()
43+
if waitUntilBytesAreSentErr != nil && !errors.Is(waitUntilBytesAreSentErr, errors.ErrUnsupported) {
44+
return written, fmt.Errorf("error when waiting for stream to send all bytes: %w", waitUntilBytesAreSentErr)
45+
}
46+
*/
47+
48+
return
49+
}

0 commit comments

Comments
 (0)