Skip to content

Commit 2d98663

Browse files
committed
refactor: merge readers and writers
1 parent 03de907 commit 2d98663

File tree

1 file changed

+62
-51
lines changed

1 file changed

+62
-51
lines changed

modules/zstd/zstd.go

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,73 +10,49 @@ import (
1010
"github.com/klauspost/compress/zstd"
1111
)
1212

13-
type Writer zstd.Encoder
13+
type Writer struct {
14+
enc *zstd.Encoder
1415

15-
var _ io.WriteCloser = (*Writer)(nil)
16-
17-
func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) {
18-
zstdW, err := zstd.NewWriter(w, opts...)
19-
if err != nil {
20-
return nil, err
21-
}
22-
return (*Writer)(zstdW), nil
23-
}
24-
25-
func (w *Writer) Write(p []byte) (int, error) {
26-
return (*zstd.Encoder)(w).Write(p)
27-
}
28-
29-
func (w *Writer) Close() error {
30-
return (*zstd.Encoder)(w).Close()
16+
skw seekable.Writer
17+
buf []byte
18+
n int
3119
}
3220

33-
type Reader zstd.Decoder
34-
35-
var _ io.ReadCloser = (*Reader)(nil)
21+
var _ io.WriteCloser = (*Writer)(nil)
3622

37-
func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) {
38-
zstdR, err := zstd.NewReader(r, opts...)
23+
func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) {
24+
enc, err := zstd.NewWriter(w, opts...)
3925
if err != nil {
4026
return nil, err
4127
}
42-
return (*Reader)(zstdR), nil
43-
}
44-
45-
func (r *Reader) Read(p []byte) (int, error) {
46-
return (*zstd.Decoder)(r).Read(p)
47-
}
48-
49-
func (r *Reader) Close() error {
50-
(*zstd.Decoder)(r).Close() // no error returned
51-
return nil
52-
}
53-
54-
type SeekableWriter struct {
55-
buf []byte
56-
n int
57-
w seekable.Writer
28+
return &Writer{
29+
enc: enc,
30+
}, nil
5831
}
5932

60-
var _ io.WriteCloser = (*SeekableWriter)(nil)
61-
62-
func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*SeekableWriter, error) {
63-
zstdW, err := zstd.NewWriter(nil, opts...)
33+
func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*Writer, error) {
34+
enc, err := zstd.NewWriter(nil, opts...)
6435
if err != nil {
6536
return nil, err
6637
}
6738

68-
seekableW, err := seekable.NewWriter(w, zstdW)
39+
skw, err := seekable.NewWriter(w, enc)
6940
if err != nil {
7041
return nil, err
7142
}
7243

73-
return &SeekableWriter{
44+
return &Writer{
45+
enc: enc,
46+
skw: skw,
7447
buf: make([]byte, blockSize),
75-
w: seekableW,
7648
}, nil
7749
}
7850

79-
func (w *SeekableWriter) Write(p []byte) (int, error) {
51+
func (w *Writer) Write(p []byte) (int, error) {
52+
if w.skw != nil {
53+
return w.enc.Write(p)
54+
}
55+
8056
written := 0
8157
for len(p) > 0 {
8258
n := copy(w.buf[w.n:], p)
@@ -85,7 +61,7 @@ func (w *SeekableWriter) Write(p []byte) (int, error) {
8561
p = p[n:]
8662

8763
if w.n == len(w.buf) {
88-
if _, err := w.w.Write(w.buf); err != nil {
64+
if _, err := w.skw.Write(w.buf); err != nil {
8965
return written, err
9066
}
9167
w.n = 0
@@ -94,13 +70,48 @@ func (w *SeekableWriter) Write(p []byte) (int, error) {
9470
return written, nil
9571
}
9672

97-
func (w *SeekableWriter) Close() error {
98-
if w.n > 0 {
99-
if _, err := w.w.Write(w.buf[:w.n]); err != nil {
73+
func (w *Writer) Close() error {
74+
if w.skw != nil {
75+
if w.n > 0 {
76+
if _, err := w.skw.Write(w.buf[:w.n]); err != nil {
77+
return err
78+
}
79+
}
80+
if err := w.skw.Close(); err != nil {
10081
return err
10182
}
10283
}
103-
return w.w.Close()
84+
return w.enc.Close()
85+
}
86+
87+
type Reader struct {
88+
dec *zstd.Decoder
89+
skr seekable.Reader
90+
}
91+
92+
var _ io.ReadCloser = (*Reader)(nil)
93+
94+
func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) {
95+
dec, err := zstd.NewReader(r, opts...)
96+
if err != nil {
97+
return nil, err
98+
}
99+
return &Reader{
100+
dec: dec,
101+
}, nil
102+
}
103+
104+
func (r *Reader) Read(p []byte) (int, error) {
105+
return r.dec.Read(p)
106+
}
107+
108+
func (r *Reader) Close() error {
109+
r.dec.Close() // no error returned
110+
return nil
111+
}
112+
113+
func (r *Reader) SeekReader() (seekable.Reader, error) {
114+
return r.skr
104115
}
105116

106117
type SeekableReader seekable.Reader

0 commit comments

Comments
 (0)