diff --git a/arangodb/README.md b/arangodb/README.md index ae3a6c39..74f5a680 100644 --- a/arangodb/README.md +++ b/arangodb/README.md @@ -21,9 +21,13 @@ A ArangoDB storage driver using `arangodb/go-driver` and [arangodb/go-driver](ht ### Signatures ```go func New(config ...Config) Storage +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) func (s *Storage) Get(key string) ([]byte, error) +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error func (s *Storage) Set(key string, val []byte, exp time.Duration) error +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error func (s *Storage) Delete(key string) error +func (s *Storage) ResetWithContext(ctx context.Context) error func (s *Storage) Reset() error func (s *Storage) Close() error func (s *Storage) Conn() driver.Client diff --git a/arangodb/arangodb.go b/arangodb/arangodb.go index 19a72300..ce3b59fa 100644 --- a/arangodb/arangodb.go +++ b/arangodb/arangodb.go @@ -116,14 +116,12 @@ func New(config ...Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext value by key with given context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - ctx := context.Background() - // Check if the document exists // to avoid errors later exists, err := s.collection.DocumentExists(ctx, key) @@ -151,8 +149,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return utils.UnsafeBytes(model.Val), nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext key with value with given context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -169,7 +172,6 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { Val: valStr, Exp: expireAt, } - ctx := context.Background() // Arango does not support documents with the same key // So we need to check if the document exists @@ -188,20 +190,35 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } -// Delete value by key -func (s *Storage) Delete(key string) error { +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext value by key with given context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } - _, err := s.collection.RemoveDocument(context.Background(), key) + _, err := s.collection.RemoveDocument(ctx, key) return err } +// Delete value by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext all keys with given context +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.collection.Truncate(ctx) +} + // Reset all keys // truncate the collection func (s *Storage) Reset() error { - return s.collection.Truncate(context.Background()) + return s.ResetWithContext(context.Background()) } // Close the database diff --git a/arangodb/arangodb_test.go b/arangodb/arangodb_test.go index 3b4010b3..d661f789 100644 --- a/arangodb/arangodb_test.go +++ b/arangodb/arangodb_test.go @@ -67,6 +67,22 @@ func Test_ArangoDB_Set(t *testing.T) { require.NoError(t, err) } +func Test_ArangoDB_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_ArangoDB_Upsert(t *testing.T) { var ( key = "john" @@ -100,6 +116,26 @@ func Test_ArangoDB_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_ArangoDB_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_ArangoDB_Set_Expiration(t *testing.T) { var ( key = "john" @@ -156,6 +192,29 @@ func Test_ArangoDB_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_ArangoDB_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_ArangoDB_Reset(t *testing.T) { val := []byte("doe") @@ -180,6 +239,33 @@ func Test_ArangoDB_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_ArangoDB_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.Equal(t, err, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_ArangoDB_Non_UTF8(t *testing.T) { val := []byte("0xF5") diff --git a/azureblob/azureblob.go b/azureblob/azureblob.go index cae22a61..b6ba7bcb 100644 --- a/azureblob/azureblob.go +++ b/azureblob/azureblob.go @@ -48,13 +48,12 @@ func New(config ...Config) *Storage { return storage } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - ctx, cancel := s.requestContext() - defer cancel() + resp, err := s.client.DownloadStream(ctx, s.container, key, nil) if err != nil { return []byte{}, err @@ -63,55 +62,81 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, err } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + ctx, cancel := s.requestContext() + defer cancel() + + return s.GetWithContext(ctx, key) +} + +// SetWithContext sets key with value +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 { return nil } - ctx, cancel := s.requestContext() - defer cancel() + _, err := s.client.UploadBuffer(ctx, s.container, key, val, nil) return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { +// Set sets key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + ctx, cancel := s.requestContext() + defer cancel() + + return s.SetWithContext(ctx, key, val, exp) +} + +// DeleteWithContext deletes entry by key +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - ctx, cancel := s.requestContext() - defer cancel() _, err := s.client.DeleteBlob(ctx, s.container, key, nil) return err } -// Reset all entries -func (s *Storage) Reset() error { +// Delete deletes entry by key +func (s *Storage) Delete(key string) error { ctx, cancel := s.requestContext() defer cancel() + + return s.DeleteWithContext(ctx, key) +} + +// ResetWithContext resets all entries +func (s *Storage) ResetWithContext(ctx context.Context) error { //_, err := s.client.DeleteContainer(ctx, s.container, nil) //return err pager := s.client.NewListBlobsFlatPager(s.container, nil) - errCounter := 0 + for pager.More() { resp, err := pager.NextPage(ctx) if err != nil { - errCounter = errCounter + 1 + return err } + for _, v := range resp.Segment.BlobItems { _, err = s.client.DeleteBlob(ctx, s.container, *v.Name, nil) if err != nil { - errCounter = errCounter + 1 + return err } } } - if errCounter > 0 { - return fmt.Errorf("%d errors occured while resetting", errCounter) - } + return nil } +// Reset resets all entries +func (s *Storage) Reset() error { + ctx, cancel := s.requestContext() + defer cancel() + + return s.ResetWithContext(ctx) +} + // Conn returns storage client func (s *Storage) Conn() *azblob.Client { return s.client diff --git a/azureblob/azureblob_test.go b/azureblob/azureblob_test.go index ddbdfb9e..2160189a 100644 --- a/azureblob/azureblob_test.go +++ b/azureblob/azureblob_test.go @@ -63,6 +63,26 @@ func Test_AzureBlob_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_AzureBlob_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_AzureBlob_Set(t *testing.T) { var ( key = "john" @@ -76,6 +96,22 @@ func Test_AzureBlob_Set(t *testing.T) { require.NoError(t, err) } +func Test_AzureBlob_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_AzureBlob_Delete(t *testing.T) { var ( key = "john" @@ -101,6 +137,29 @@ func Test_AzureBlob_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_AzureBlob_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_AzureBlob_Override(t *testing.T) { var ( key = "john" @@ -172,6 +231,33 @@ func Test_AzureBlob_Conn(t *testing.T) { require.True(t, testStore.Conn() != nil) } +func Test_AzureBlob_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_AzureBlob_Close(t *testing.T) { testStore := newTestStore(t) require.NoError(t, testStore.Close()) diff --git a/clickhouse/clickhouse.go b/clickhouse/clickhouse.go index 1056bc7b..3d58f564 100644 --- a/clickhouse/clickhouse.go +++ b/clickhouse/clickhouse.go @@ -12,7 +12,6 @@ import ( type Storage struct { session driver.Conn - context context.Context table string } @@ -47,12 +46,11 @@ func New(configuration Config) (*Storage, error) { return &Storage{ session: conn, - context: ctx, table: configuration.Table, }, nil } -func (s *Storage) Set(key string, value []byte, expiration time.Duration) error { +func (s *Storage) SetWithContext(ctx context.Context, key string, value []byte, expiration time.Duration) error { if len(key) <= 0 || len(value) <= 0 { return nil } @@ -65,7 +63,7 @@ func (s *Storage) Set(key string, value []byte, expiration time.Duration) error return s. session. Exec( - s.context, + ctx, insertDataString, driver.Named("table", s.table), driver.Named("key", key), @@ -74,7 +72,11 @@ func (s *Storage) Set(key string, value []byte, expiration time.Duration) error ) } -func (s *Storage) Get(key string) ([]byte, error) { +func (s *Storage) Set(key string, value []byte, expiration time.Duration) error { + return s.SetWithContext(context.Background(), key, value, expiration) +} + +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) == 0 { return []byte{}, nil } @@ -82,7 +84,7 @@ func (s *Storage) Get(key string) ([]byte, error) { var result schema row := s.session.QueryRow( - s.context, + ctx, selectDataString, driver.Named("table", s.table), driver.Named("key", key), @@ -109,16 +111,28 @@ func (s *Storage) Get(key string) ([]byte, error) { return []byte(result.Value), nil } -func (s *Storage) Delete(key string) error { +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) == 0 { return nil } - return s.session.Exec(s.context, deleteDataString, driver.Named("table", s.table), driver.Named("key", key)) + return s.session.Exec(ctx, deleteDataString, driver.Named("table", s.table), driver.Named("key", key)) +} + +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.session.Exec(ctx, resetDataString, driver.Named("table", s.table)) } func (s *Storage) Reset() error { - return s.session.Exec(s.context, resetDataString, driver.Named("table", s.table)) + return s.ResetWithContext(context.Background()) } func (s *Storage) Close() error { diff --git a/clickhouse/clickhouse_test.go b/clickhouse/clickhouse_test.go index 16bc3eff..5d03769d 100644 --- a/clickhouse/clickhouse_test.go +++ b/clickhouse/clickhouse_test.go @@ -83,6 +83,21 @@ func Test_Connection(t *testing.T) { defer client.Close() } +func Test_SetWithContext(t *testing.T) { + client := newTestStore(t, Config{ + Engine: Memory, + Table: "test_table", + Clean: true, + }) + defer client.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := client.SetWithContext(ctx, "somekey", []byte("somevalue"), 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Set(t *testing.T) { client := newTestStore(t, Config{ Engine: Memory, @@ -107,6 +122,25 @@ func Test_Set_With_Exp(t *testing.T) { require.NoError(t, err) } +func Test_GetWithContext(t *testing.T) { + client := newTestStore(t, Config{ + Engine: Memory, + Table: "test_table", + Clean: true, + }) + defer client.Close() + + err := client.Set("somekey", []byte("somevalue"), 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + value, err := client.GetWithContext(ctx, "somekey") + require.ErrorIs(t, err, context.Canceled) + assert.Equal(t, []byte{}, value) +} + func Test_Get(t *testing.T) { client := newTestStore(t, Config{ Engine: Memory, @@ -150,6 +184,29 @@ func Test_Get_With_Exp(t *testing.T) { assert.Equal(t, []byte{}, value) } +func Test_DeleteWithContext(t *testing.T) { + client := newTestStore(t, Config{ + Engine: Memory, + Table: "test_table", + Clean: true, + }) + + defer client.Close() + + err := client.Set("somekeytodelete", []byte("somevalue"), time.Second*5) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = client.DeleteWithContext(ctx, "somekeytodelete") + require.ErrorIs(t, err, context.Canceled) + + value, err := client.Get("somekeytodelete") + require.NoError(t, err) + require.Equal(t, []byte("somevalue"), value) +} + func Test_Delete(t *testing.T) { client := newTestStore(t, Config{ Engine: Memory, @@ -171,6 +228,29 @@ func Test_Delete(t *testing.T) { assert.Equal(t, []byte{}, value) } +func Test_ResetWithContext(t *testing.T) { + client := newTestStore(t, Config{ + Engine: Memory, + Table: "test_table", + Clean: true, + }) + + defer client.Close() + + err := client.Set("testkey", []byte("somevalue"), 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = client.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + value, err := client.Get("testkey") + require.NoError(t, err) + require.Equal(t, []byte("somevalue"), value) +} + func Test_Reset(t *testing.T) { client := newTestStore(t, Config{ Engine: Memory, diff --git a/cloudflarekv/cloudflarekv.go b/cloudflarekv/cloudflarekv.go index e79c05c5..b9afdf75 100644 --- a/cloudflarekv/cloudflarekv.go +++ b/cloudflarekv/cloudflarekv.go @@ -55,8 +55,8 @@ func New(config ...Config) *Storage { return storage } -func (s *Storage) Get(key string) ([]byte, error) { - resp, err := s.api.GetWorkersKV(context.Background(), cloudflare.AccountIdentifier(s.accountID), cloudflare.GetWorkersKVParams{NamespaceID: s.namespaceID, Key: key}) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + resp, err := s.api.GetWorkersKV(ctx, cloudflare.AccountIdentifier(s.accountID), cloudflare.GetWorkersKVParams{NamespaceID: s.namespaceID, Key: key}) if err != nil { log.Printf("Error occur in GetWorkersKV: %v", err) @@ -66,8 +66,12 @@ func (s *Storage) Get(key string) ([]byte, error) { return resp, nil } -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { - _, err := s.api.WriteWorkersKVEntry(context.Background(), cloudflare.AccountIdentifier(s.accountID), cloudflare.WriteWorkersKVEntryParams{ +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + _, err := s.api.WriteWorkersKVEntry(ctx, cloudflare.AccountIdentifier(s.accountID), cloudflare.WriteWorkersKVEntryParams{ NamespaceID: s.namespaceID, Key: key, Value: val, @@ -81,8 +85,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return nil } -func (s *Storage) Delete(key string) error { - _, err := s.api.DeleteWorkersKVEntry(context.Background(), cloudflare.AccountIdentifier(s.accountID), cloudflare.DeleteWorkersKVEntryParams{ +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + _, err := s.api.DeleteWorkersKVEntry(ctx, cloudflare.AccountIdentifier(s.accountID), cloudflare.DeleteWorkersKVEntryParams{ NamespaceID: s.namespaceID, Key: key, }) @@ -95,14 +103,18 @@ func (s *Storage) Delete(key string) error { return nil } -func (s *Storage) Reset() error { +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *Storage) ResetWithContext(ctx context.Context) error { var ( cursor string keys []string ) for { - resp, err := s.api.ListWorkersKVKeys(context.Background(), cloudflare.AccountIdentifier(s.accountID), cloudflare.ListWorkersKVsParams{ + resp, err := s.api.ListWorkersKVKeys(ctx, cloudflare.AccountIdentifier(s.accountID), cloudflare.ListWorkersKVsParams{ NamespaceID: s.namespaceID, Cursor: cursor, }) @@ -119,7 +131,7 @@ func (s *Storage) Reset() error { keys = append(keys, name) } - _, err = s.api.DeleteWorkersKVEntries(context.Background(), cloudflare.AccountIdentifier(s.accountID), cloudflare.DeleteWorkersKVEntriesParams{ + _, err = s.api.DeleteWorkersKVEntries(ctx, cloudflare.AccountIdentifier(s.accountID), cloudflare.DeleteWorkersKVEntriesParams{ NamespaceID: s.namespaceID, Keys: keys, }) @@ -140,6 +152,10 @@ func (s *Storage) Reset() error { return nil } +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + func (s *Storage) Close() error { return nil } diff --git a/cloudflarekv/cloudflarekv_test.go b/cloudflarekv/cloudflarekv_test.go index e6d23ac7..78c77a9a 100644 --- a/cloudflarekv/cloudflarekv_test.go +++ b/cloudflarekv/cloudflarekv_test.go @@ -2,18 +2,21 @@ package cloudflarekv import ( "bytes" + "context" + "fmt" "os" "testing" "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { +var testStore *Storage - var testStore *Storage +func TestMain(m *testing.M) { testStore = New(Config{ - Key: "test", + Key: "test", + Reset: true, }) code := m.Run() @@ -25,12 +28,6 @@ func TestMain(m *testing.M) { func Test_CloudflareKV_Get(t *testing.T) { t.Parallel() - var testStore *Storage - - testStore = New(Config{ - Key: "test", - }) - var ( key = "john" val = []byte("doe") @@ -55,14 +52,30 @@ func Test_CloudflareKV_Get(t *testing.T) { _ = testStore.Close() } -func Test_CloudflareKV_Set(t *testing.T) { +func Test_CloudflareKV_GetWithContext(t *testing.T) { t.Parallel() - var testStore *Storage + var ( + key = "john" + val = []byte("doe") + ) - testStore = New(Config{ - Key: "test", - }) + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + val, err = testStore.GetWithContext(ctx, key) + fmt.Println(err) + require.ErrorContains(t, err, context.Canceled.Error()) + require.Nil(t, val) + + _ = testStore.Close() +} + +func Test_CloudflareKV_Set(t *testing.T) { + t.Parallel() var ( key = "john" @@ -76,14 +89,21 @@ func Test_CloudflareKV_Set(t *testing.T) { _ = testStore.Close() } -func Test_CloudflareKV_Delete(t *testing.T) { - t.Parallel() +func Test_CloudflareKV_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) - var testStore *Storage + ctx, cancel := context.WithCancel(context.Background()) + cancel() - testStore = New(Config{ - Key: "test", - }) + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + +func Test_CloudflareKV_Delete(t *testing.T) { + t.Parallel() var ( key = "john" @@ -99,14 +119,28 @@ func Test_CloudflareKV_Delete(t *testing.T) { _ = testStore.Close() } -func Test_CloudflareKV_Reset(t *testing.T) { - t.Parallel() +func Test_CloudflareKV_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) - var testStore *Storage + err := testStore.Set(key, val, 0) + require.NoError(t, err) - testStore = New(Config{ - Key: "test", - }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + +func Test_CloudflareKV_Reset(t *testing.T) { + t.Parallel() err := testStore.Reset() @@ -114,14 +148,33 @@ func Test_CloudflareKV_Reset(t *testing.T) { _ = testStore.Close() } -func Test_CloudflareKV_Close(t *testing.T) { - t.Parallel() - var testStore *Storage +func Test_CloudflareKV_ResetWithContext(t *testing.T) { + val := []byte("doe") - testStore = New(Config{ - Key: "test", - }) + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + +func Test_CloudflareKV_Close(t *testing.T) { + t.Parallel() require.Nil(t, testStore.Close()) @@ -131,25 +184,12 @@ func Test_CloudflareKV_Close(t *testing.T) { func Test_CloudflareKV_Conn(t *testing.T) { t.Parallel() - var testStore *Storage - - testStore = New(Config{ - Key: "test", - }) - require.NotNil(t, testStore.Conn()) _ = testStore.Close() } func Benchmark_CloudflareKV_Set(b *testing.B) { - - var testStore *Storage - - testStore = New(Config{ - Key: "test", - }) - b.ReportAllocs() b.ResetTimer() @@ -164,13 +204,6 @@ func Benchmark_CloudflareKV_Set(b *testing.B) { } func Benchmark_CloudflareKV_Get(b *testing.B) { - - var testStore *Storage - - testStore = New(Config{ - Key: "test", - }) - err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) @@ -187,13 +220,6 @@ func Benchmark_CloudflareKV_Get(b *testing.B) { } func Benchmark_CloudflareKV_SetAndDelete(b *testing.B) { - - var testStore *Storage - - testStore = New(Config{ - Key: "test", - }) - b.ReportAllocs() b.ResetTimer() diff --git a/cloudflarekv/test_module.go b/cloudflarekv/test_module.go index cd996076..11d2902c 100644 --- a/cloudflarekv/test_module.go +++ b/cloudflarekv/test_module.go @@ -54,7 +54,7 @@ func (t *TestModule) GetWorkersKV(ctx context.Context, rc *cloudflare.ResourceCo return nil, err } - req, err := http.NewRequest(http.MethodPost, t.baseUrl+"/getworkerskvvaluebykey", bytes.NewReader(marshalledBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.baseUrl+"/getworkerskvvaluebykey", bytes.NewReader(marshalledBody)) if err != nil { log.Println("Error occur in /getworkerskvvaluebykey > making http call") @@ -95,7 +95,7 @@ func (t *TestModule) WriteWorkersKVEntry(ctx context.Context, rc *cloudflare.Res return cloudflare.Response{}, err } - req, err := http.NewRequest(http.MethodPost, t.baseUrl+"/writeworkerskvkeyvaluepair", bytes.NewReader(marshalledBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.baseUrl+"/writeworkerskvkeyvaluepair", bytes.NewReader(marshalledBody)) if err != nil { log.Println("Error occur in /writeworkerskvkeyvaluepair > making http call") @@ -134,7 +134,7 @@ func (t *TestModule) DeleteWorkersKVEntry(ctx context.Context, rc *cloudflare.Re return cloudflare.Response{}, err } - req, err := http.NewRequest(http.MethodDelete, t.baseUrl+"/deleteworkerskvpairbykey", bytes.NewReader(marshalledBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, t.baseUrl+"/deleteworkerskvpairbykey", bytes.NewReader(marshalledBody)) if err != nil { log.Println("Error occur in /deleteworkerskvpairbykey > making http call") @@ -173,7 +173,7 @@ func (t *TestModule) ListWorkersKVKeys(ctx context.Context, rc *cloudflare.Resou return cloudflare.ListStorageKeysResponse{}, err } - req, err := http.NewRequest(http.MethodPost, t.baseUrl+"/listworkerskvkeys", bytes.NewReader(marshalledBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.baseUrl+"/listworkerskvkeys", bytes.NewReader(marshalledBody)) if err != nil { log.Println("Error occur in /listworkerskvkeys > making http call") @@ -225,7 +225,7 @@ func (t *TestModule) DeleteWorkersKVEntries(ctx context.Context, rc *cloudflare. return cloudflare.Response{}, err } - req, err := http.NewRequest(http.MethodDelete, t.baseUrl+"/deleteworkerskventries", bytes.NewReader(marshalledBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, t.baseUrl+"/deleteworkerskventries", bytes.NewReader(marshalledBody)) if err != nil { log.Println("Error occur in /deleteworkerskventries > making new request") diff --git a/coherence/coherence.go b/coherence/coherence.go index fcf7e5c2..4ecd2a03 100644 --- a/coherence/coherence.go +++ b/coherence/coherence.go @@ -22,7 +22,6 @@ const ( type Storage struct { session *coh.Session namedCache coh.NamedCache[string, []byte] - ctx context.Context } // Config defines configuration options for Coherence connection. @@ -141,12 +140,11 @@ func newCoherenceStorage(session *coh.Session, cacheName string, nearCacheTimeou return &Storage{ session: session, namedCache: nc, - ctx: context.Background(), }, nil } -func (s *Storage) Get(key string) ([]byte, error) { - v, err := s.namedCache.Get(s.ctx, key) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + v, err := s.namedCache.Get(ctx, key) if err != nil { return nil, err } @@ -156,18 +154,34 @@ func (s *Storage) Get(key string) ([]byte, error) { return *v, nil } +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + _, err := s.namedCache.PutWithExpiry(ctx, key, val, exp) + return err +} + func (s *Storage) Set(key string, val []byte, exp time.Duration) error { - _, err := s.namedCache.PutWithExpiry(s.ctx, key, val, exp) + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + _, err := s.namedCache.Remove(ctx, key) return err } func (s *Storage) Delete(key string) error { - _, err := s.namedCache.Remove(s.ctx, key) - return err + return s.DeleteWithContext(context.Background(), key) +} + +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.namedCache.Truncate(ctx) } func (s *Storage) Reset() error { - return s.namedCache.Truncate(s.ctx) + return s.ResetWithContext(context.Background()) } func (s *Storage) Close() error { diff --git a/coherence/coherence_test.go b/coherence/coherence_test.go index c9e12188..f5f1b300 100644 --- a/coherence/coherence_test.go +++ b/coherence/coherence_test.go @@ -5,6 +5,7 @@ package coherence */ import ( "context" + "fmt" "os" "testing" "time" @@ -116,6 +117,26 @@ func Test_Coherence_Set_And_Get(t *testing.T) { require.NotNil(t, testStore.Conn()) } +func Test_Coherence_SetContext_And_Get(t *testing.T) { + var val []byte + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + + err := testStore.SetWithContext(ctx, key1, value1, 1*time.Nanosecond) + require.ErrorIs(t, err, context.DeadlineExceeded) + + val, err = testStore.Get(key1) + require.NoError(t, err) + fmt.Println(string(val)) + require.True(t, len(val) == 0) + + require.NotNil(t, testStore.Conn()) +} + func Test_Coherence_Set_Override(t *testing.T) { var val []byte diff --git a/couchbase/couchbase.go b/couchbase/couchbase.go index 77c8a1f7..5e523b58 100644 --- a/couchbase/couchbase.go +++ b/couchbase/couchbase.go @@ -1,6 +1,7 @@ package couchbase import ( + "context" "time" "github.com/couchbase/gocb/v2" @@ -43,8 +44,10 @@ func New(config ...Config) *Storage { return &Storage{cb: cb, bucket: b} } -func (s *Storage) Get(key string) ([]byte, error) { - out, err := s.bucket.DefaultCollection().Get(key, nil) +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + out, err := s.bucket.DefaultCollection().Get(key, &gocb.GetOptions{ + Context: ctx, + }) if err != nil { switch e := err.(type) { case *gocb.KeyValueError: @@ -66,9 +69,14 @@ func (s *Storage) Get(key string) ([]byte, error) { return value, nil } -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if _, err := s.bucket.DefaultCollection().Upsert(key, val, &gocb.UpsertOptions{ - Expiry: exp, + Context: ctx, + Expiry: exp, }); err != nil { return err } @@ -76,15 +84,31 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return nil } -func (s *Storage) Delete(key string) error { - if _, err := s.bucket.DefaultCollection().Remove(key, nil); err != nil { +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + if _, err := s.bucket.DefaultCollection().Remove(key, &gocb.RemoveOptions{ + Context: ctx, + }); err != nil { return err } return nil } +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.cb.Buckets().FlushBucket(s.bucket.Name(), &gocb.FlushBucketOptions{ + Context: ctx, + }) +} + func (s *Storage) Reset() error { - return s.cb.Buckets().FlushBucket(s.bucket.Name(), nil) + return s.ResetWithContext(context.Background()) } func (s *Storage) Close() error { diff --git a/couchbase/couchbase_test.go b/couchbase/couchbase_test.go index 335441ff..b2cc9ace 100644 --- a/couchbase/couchbase_test.go +++ b/couchbase/couchbase_test.go @@ -61,6 +61,17 @@ func TestSetCouchbase_ShouldReturnNoError(t *testing.T) { require.NoError(t, err) } +func TestSetWithContextCouchbase_ContextCancelled_ShouldReturnError(t *testing.T) { + testStore := newTestStore(t) + + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + + err := testStore.SetWithContext(ctx, "test", []byte("test"), 0) + + require.Error(t, err) +} + func TestGetCouchbase_ShouldReturnNil_WhenDocumentNotFound(t *testing.T) { testStore := newTestStore(t) defer testStore.Close() diff --git a/dynamodb/dynamodb.go b/dynamodb/dynamodb.go index 25e9e57a..813466ae 100644 --- a/dynamodb/dynamodb.go +++ b/dynamodb/dynamodb.go @@ -17,9 +17,8 @@ import ( // Storage interface that is implemented by storage providers type Storage struct { - db *awsdynamodb.Client - table string - requestTimeout time.Duration + db *awsdynamodb.Client + table string } // "k" is used as table column name for the key. @@ -77,11 +76,8 @@ func New(config Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { - ctx, cancel := s.requestContext() - defer cancel() - +// GetWithContext value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { k := make(map[string]types.AttributeValue) k[keyAttrName] = &types.AttributeValueMemberS{ Value: key, @@ -108,11 +104,12 @@ func (s *Storage) Get(key string) ([]byte, error) { return item.V, err } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { - ctx, cancel := s.requestContext() - defer cancel() +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} +// Set key with value +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -134,11 +131,12 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { - ctx, cancel := s.requestContext() - defer cancel() +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} +// Delete entry by key +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil @@ -157,11 +155,12 @@ func (s *Storage) Delete(key string) error { return err } -// Reset all entries, including unexpired -func (s *Storage) Reset() error { - ctx, cancel := s.requestContext() - defer cancel() +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} +// Reset all entries, including unexpired +func (s *Storage) ResetWithContext(ctx context.Context) error { deleteTableInput := awsdynamodb.DeleteTableInput{ TableName: &s.table, } @@ -169,14 +168,17 @@ func (s *Storage) Reset() error { return err } +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the database func (s *Storage) Close() error { return nil } func (s *Storage) createTable(cfg Config, describeTableInput awsdynamodb.DescribeTableInput) error { - ctx, cancel := s.requestContext() - defer cancel() + ctx := context.Background() keyAttrType := "S" // For "string" keyType := "HASH" // As opposed to "RANGE" @@ -225,14 +227,6 @@ func (s *Storage) createTable(cfg Config, describeTableInput awsdynamodb.Describ return nil } -// Context for making requests will timeout if a non-zero timeout is configured -func (s *Storage) requestContext() (context.Context, context.CancelFunc) { - if s.requestTimeout > 0 { - return context.WithTimeout(context.Background(), s.requestTimeout) - } - return context.Background(), func() {} -} - func returnAWSConfig(cfg Config) (aws.Config, error) { if cfg.Credentials != (Credentials{}) { credentials := credentials.NewStaticCredentialsProvider(cfg.Credentials.AccessKey, cfg.Credentials.SecretAccessKey, "") diff --git a/dynamodb/dynamodb_test.go b/dynamodb/dynamodb_test.go index f2b95f8f..1a1ff4f8 100644 --- a/dynamodb/dynamodb_test.go +++ b/dynamodb/dynamodb_test.go @@ -60,6 +60,22 @@ func Test_DynamoDB_Set(t *testing.T) { require.NoError(t, err) } +func Test_DynamoDB_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_DynamoDB_Set_Override(t *testing.T) { var ( key = "john" @@ -93,6 +109,22 @@ func Test_DynamoDB_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_DynamoDB_GetWithContext(t *testing.T) { + var ( + key = "john" + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + testStore := newTestStore(t) + defer testStore.Close() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_DynamoDB_Get_NotExist(t *testing.T) { testStore := newTestStore(t) defer testStore.Close() @@ -122,6 +154,29 @@ func Test_DynamoDB_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_DynamoDB_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_DynamoDB_Reset(t *testing.T) { val := []byte("doe") @@ -146,6 +201,33 @@ func Test_DynamoDB_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_DynamoDB_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_DynamoDB_Close(t *testing.T) { testStore := newTestStore(t) require.Nil(t, testStore.Close()) diff --git a/etcd/etcd.go b/etcd/etcd.go index 5ddd7984..8fa81cb2 100644 --- a/etcd/etcd.go +++ b/etcd/etcd.go @@ -4,7 +4,7 @@ import ( "context" "time" - "go.etcd.io/etcd/client/v3" + clientv3 "go.etcd.io/etcd/client/v3" ) type Storage struct { @@ -32,11 +32,11 @@ func New(config ...Config) *Storage { return store } -func (s *Storage) Get(key string) ([]byte, error) { +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - item, err := s.db.Get(context.Background(), key) + item, err := s.db.Get(ctx, key) if err != nil { return nil, err } @@ -48,18 +48,22 @@ func (s *Storage) Get(key string) ([]byte, error) { return item.Kvs[0].Value, nil } -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil } - lease, err := s.db.Grant(context.Background(), int64(exp.Seconds())) + lease, err := s.db.Grant(ctx, int64(exp.Seconds())) if err != nil { return err } - _, err = s.db.Put(context.Background(), key, string(val), clientv3.WithLease(lease.ID)) + _, err = s.db.Put(ctx, key, string(val), clientv3.WithLease(lease.ID)) if err != nil { return err } @@ -67,12 +71,16 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return nil } -func (s *Storage) Delete(key string) error { +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - _, err := s.db.Delete(context.Background(), key) + _, err := s.db.Delete(ctx, key) if err != nil { return err } @@ -80,8 +88,12 @@ func (s *Storage) Delete(key string) error { return nil } -func (s *Storage) Reset() error { - _, err := s.db.Delete(context.Background(), "", clientv3.WithPrefix()) +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *Storage) ResetWithContext(ctx context.Context) error { + _, err := s.db.Delete(ctx, "", clientv3.WithPrefix()) if err != nil { return err } @@ -89,6 +101,10 @@ func (s *Storage) Reset() error { return nil } +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + func (s *Storage) Close() error { return s.db.Close() } diff --git a/etcd/etcd_test.go b/etcd/etcd_test.go index 7232f864..cb8d64c5 100644 --- a/etcd/etcd_test.go +++ b/etcd/etcd_test.go @@ -1,6 +1,7 @@ package etcd import ( + "context" "os" "testing" "time" @@ -49,6 +50,36 @@ func TestSetAndGet_GetShouldReturn_SettedValueWithoutError(t *testing.T) { require.Equal(t, val, []byte("fiber_test_value")) } +func TestSetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + +func TestSetAndGetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, result) +} + func TestSetAndGet_GetShouldReturnNil_WhenTTLExpired(t *testing.T) { err := testStore.Set("test", []byte("fiber_test_value"), 3*time.Second) require.NoError(t, err) @@ -72,6 +103,21 @@ func TestSetAndDelete_DeleteShouldReturn_NoError(t *testing.T) { require.NoError(t, err) } +func TestDeleteWithContext(t *testing.T) { + err := testStore.Set("test", []byte("fiber_test_value"), 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, "test") + require.ErrorIs(t, err, context.Canceled) + + val, err := testStore.Get("test") + require.NoError(t, err) + require.Equal(t, val, []byte("fiber_test_value")) +} + func TestSetAndReset_ResetShouldReturn_NoError(t *testing.T) { err := testStore.Set("test", []byte("fiber_test_value"), 0) require.NoError(t, err) @@ -83,6 +129,21 @@ func TestSetAndReset_ResetShouldReturn_NoError(t *testing.T) { require.NoError(t, err) } +func TestResetWithContext(t *testing.T) { + err := testStore.Set("test", []byte("fiber_test_value"), 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + val, err := testStore.Get("test") + require.NoError(t, err) + require.Equal(t, val, []byte("fiber_test_value")) +} + func TestClose_CloseShouldReturn_NoError(t *testing.T) { err := testStore.Close() require.NoError(t, err) diff --git a/minio/minio.go b/minio/minio.go index 4649020a..7427979a 100644 --- a/minio/minio.go +++ b/minio/minio.go @@ -18,7 +18,6 @@ import ( type Storage struct { minio *minio.Client cfg Config - ctx context.Context mu sync.Mutex } @@ -41,7 +40,7 @@ func New(config ...Config) *Storage { panic(err) } - storage := &Storage{minio: minioClient, cfg: cfg, ctx: context.Background()} + storage := &Storage{minio: minioClient, cfg: cfg} // Reset all entries if set to true if cfg.Reset { @@ -63,15 +62,14 @@ func New(config ...Config) *Storage { return storage } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { - +// GetWithContext value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, errors.New("the key value is required") } // get object - object, err := s.minio.GetObject(s.ctx, s.cfg.Bucket, key, s.cfg.GetObjectOptions) + object, err := s.minio.GetObject(ctx, s.cfg.Bucket, key, s.cfg.GetObjectOptions) if err != nil { return nil, err } @@ -91,9 +89,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return bb.Bytes(), nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} +// SetWithContext key with value with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 { return errors.New("the key value is required") } @@ -106,35 +108,43 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { s.cfg.PutObjectOptions.ContentType = http.DetectContentType(val) // put object - _, err := s.minio.PutObject(s.ctx, s.cfg.Bucket, key, file, file.Size(), s.cfg.PutObjectOptions) + _, err := s.minio.PutObject(ctx, s.cfg.Bucket, key, file, file.Size(), s.cfg.PutObjectOptions) s.mu.Unlock() return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} +// DeleteWithContext key with value with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return errors.New("the key value is required") } // remove - err := s.minio.RemoveObject(s.ctx, s.cfg.Bucket, key, s.cfg.RemoveObjectOptions) + err := s.minio.RemoveObject(ctx, s.cfg.Bucket, key, s.cfg.RemoveObjectOptions) return err } -// Reset all entries, including unexpired -func (s *Storage) Reset() error { +// Delete entry by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} +// ResetWithContext all keys with context +func (s *Storage) ResetWithContext(ctx context.Context) error { objectsCh := make(chan minio.ObjectInfo) // Send object names that are needed to be removed to objectsCh go func() { defer close(objectsCh) // List all objects from a bucket-name with a matching prefix. - for object := range s.minio.ListObjects(s.ctx, s.cfg.Bucket, s.cfg.ListObjectsOptions) { + for object := range s.minio.ListObjects(ctx, s.cfg.Bucket, s.cfg.ListObjectsOptions) { if object.Err != nil { log.Println(object.Err) } @@ -147,13 +157,17 @@ func (s *Storage) Reset() error { } var errs []error - for err := range s.minio.RemoveObjects(s.ctx, s.cfg.Bucket, objectsCh, opts) { + for err := range s.minio.RemoveObjects(ctx, s.cfg.Bucket, objectsCh, opts) { errs = append(errs, err.Err) } return errors.Join(errs...) } +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the storage func (s *Storage) Close() error { return nil @@ -161,7 +175,7 @@ func (s *Storage) Close() error { // CheckBucket Check to see if bucket already exists func (s *Storage) CheckBucket() error { - exists, err := s.minio.BucketExists(s.ctx, s.cfg.Bucket) + exists, err := s.minio.BucketExists(context.Background(), s.cfg.Bucket) if !exists || err != nil { return errors.New("the specified bucket does not exist") } @@ -170,12 +184,12 @@ func (s *Storage) CheckBucket() error { // CreateBucket Bucket not found so Make a new bucket func (s *Storage) CreateBucket() error { - return s.minio.MakeBucket(s.ctx, s.cfg.Bucket, minio.MakeBucketOptions{Region: s.cfg.Region}) + return s.minio.MakeBucket(context.Background(), s.cfg.Bucket, minio.MakeBucketOptions{Region: s.cfg.Region}) } // RemoveBucket Bucket remove if bucket is empty func (s *Storage) RemoveBucket() error { - return s.minio.RemoveBucket(s.ctx, s.cfg.Bucket) + return s.minio.RemoveBucket(context.Background(), s.cfg.Bucket) } // Conn Return minio client diff --git a/minio/minio_test.go b/minio/minio_test.go index 2aeb7d81..35c5cfec 100644 --- a/minio/minio_test.go +++ b/minio/minio_test.go @@ -81,6 +81,25 @@ func Test_Get(t *testing.T) { require.Zero(t, len(result)) } +func Test_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_Get_Empty_Key(t *testing.T) { var ( key = "" @@ -137,6 +156,21 @@ func Test_Set(t *testing.T) { require.NoError(t, err) } +func Test_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Set_Empty_Key(t *testing.T) { var ( key = "" @@ -185,6 +219,28 @@ func Test_Delete(t *testing.T) { require.NoError(t, err) } +func Test_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + valRet, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, valRet) +} + func Test_Delete_Empty_Key(t *testing.T) { var ( key = "" @@ -237,6 +293,31 @@ func Test_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_ResetWithContext(t *testing.T) { + var ( + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Reset_Not_Exists_Bucket(t *testing.T) { testStore := newTestStore(t) defer testStore.Close() diff --git a/mongodb/mongodb.go b/mongodb/mongodb.go index 4164551f..f424ea3a 100644 --- a/mongodb/mongodb.go +++ b/mongodb/mongodb.go @@ -121,12 +121,12 @@ func New(config ...Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - res := s.col.FindOne(context.Background(), bson.M{"key": key}) + res := s.col.FindOne(ctx, bson.M{"key": key}) item := s.acquireItem() if err := res.Err(); err != nil { @@ -149,11 +149,16 @@ func (s *Storage) Get(key string) ([]byte, error) { return item.Value, nil } -// Set key with value, replace if document exits +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext sets key with value, replace if document exits with context // // document will be remove automatically if exp is set, based on MongoDB TTL Indexes // Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -167,25 +172,43 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if exp != 0 { item.Expiration = time.Now().Add(exp).UTC() } - _, err := s.col.ReplaceOne(context.Background(), filter, item, options.Replace().SetUpsert(true)) + _, err := s.col.ReplaceOne(ctx, filter, item, options.Replace().SetUpsert(true)) s.releaseItem(item) return err } -// Delete document by key -func (s *Storage) Delete(key string) error { +// Set sets key with value, replace if document exits +// +// document will be remove automatically if exp is set, based on MongoDB TTL Indexes +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext deletes document by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } - _, err := s.col.DeleteOne(context.Background(), bson.M{"key": key}) + _, err := s.col.DeleteOne(ctx, bson.M{"key": key}) return err } +// Delete deletes document by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// Reset all keys by drop collection with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.col.Drop(ctx) +} + // Reset all keys by drop collection func (s *Storage) Reset() error { - return s.col.Drop(context.Background()) + return s.ResetWithContext(context.Background()) } // Close the database diff --git a/mongodb/mongodb_test.go b/mongodb/mongodb_test.go index 287519d7..836c9e33 100644 --- a/mongodb/mongodb_test.go +++ b/mongodb/mongodb_test.go @@ -66,6 +66,22 @@ func Test_MongoDB_Set(t *testing.T) { require.NoError(t, err) } +func Test_MongoDB_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_MongoDB_Set_Override(t *testing.T) { var ( key = "john" @@ -99,6 +115,26 @@ func Test_MongoDB_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_MongoDB_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_MongoDB_Set_Expiration(t *testing.T) { var ( key = "john" @@ -159,6 +195,29 @@ func Test_MongoDB_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_MongoDB_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_MongoDB_Reset(t *testing.T) { val := []byte("doe") @@ -183,6 +242,33 @@ func Test_MongoDB_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_MongoDB_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_MongoDB_Close(t *testing.T) { testStore := newTestStore(t) require.NoError(t, testStore.Close()) diff --git a/mssql/mssql.go b/mssql/mssql.go index ff7c5af5..3e332973 100644 --- a/mssql/mssql.go +++ b/mssql/mssql.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "fmt" "net/url" @@ -132,13 +133,13 @@ func New(config ...Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - row := s.db.QueryRow(s.sqlSelect, key) + row := s.db.QueryRowContext(ctx, s.sqlSelect, key) var ( data = []byte{} @@ -161,8 +162,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext key with value and expiration time with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } @@ -172,24 +178,39 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { expSeconds = time.Now().Add(exp).Unix() } - _, err := s.db.Exec(s.sqlInsert, key, val, expSeconds) + _, err := s.db.ExecContext(ctx, s.sqlInsert, key, val, expSeconds) return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { +// Set key with value and expiration time +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext entry by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - _, err := s.db.Exec(s.sqlDelete, key) + _, err := s.db.ExecContext(ctx, s.sqlDelete, key) + return err +} + +// Delete entry by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext all entries, including unexpired with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, s.sqlReset) return err } // Reset all entries, including unexpired func (s *Storage) Reset() error { - _, err := s.db.Exec(s.sqlReset) - return err + return s.ResetWithContext(context.Background()) } // Close the database diff --git a/mssql/mssql_test.go b/mssql/mssql_test.go index bb209caf..501eb393 100644 --- a/mssql/mssql_test.go +++ b/mssql/mssql_test.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "os" "testing" @@ -26,6 +27,19 @@ func Test_MSSQL_Set(t *testing.T) { require.NoError(t, err) } +func Test_MSSQL_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_MSSQL_Set_Override(t *testing.T) { var ( key = "john" @@ -53,6 +67,23 @@ func Test_MSSQL_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_MSSQL_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_MSSQL_Set_Expiration(t *testing.T) { var ( key = "john" @@ -97,6 +128,26 @@ func Test_MSSQL_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_MSSQL_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_MSSQL_Reset(t *testing.T) { val := []byte("doe") @@ -118,6 +169,30 @@ func Test_MSSQL_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_MSSQL_ResetWithContext(t *testing.T) { + val := []byte("doe") + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_MSSQL_GC(t *testing.T) { testVal := []byte("doe") diff --git a/mysql/mysql.go b/mysql/mysql.go index 06509e64..ae0bfcd2 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "strings" @@ -105,12 +106,12 @@ func New(config ...Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - row := s.db.QueryRow(s.sqlSelect, key) + row := s.db.QueryRowContext(ctx, s.sqlSelect, key) // Add db response to data @@ -134,8 +135,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext key with value and expiration time with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -144,26 +150,41 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if exp != 0 { expSeconds = time.Now().Add(exp).Unix() } - _, err := s.db.Exec(s.sqlInsert, key, val, expSeconds, val, expSeconds) + _, err := s.db.ExecContext(ctx, s.sqlInsert, key, val, expSeconds, val, expSeconds) return err } -// Delete key by key -func (s *Storage) Delete(key string) error { +// Set key with value and expiration time +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext key by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } - _, err := s.db.Exec(s.sqlDelete, key) + _, err := s.db.ExecContext(ctx, s.sqlDelete, key) return err } -// Reset all keys -func (s *Storage) Reset() error { - _, err := s.db.Exec(s.sqlReset) +// Delete entry by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext resets all keys with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, s.sqlReset) return err } +// Reset resets all keys +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the database func (s *Storage) Close() error { s.done <- struct{}{} diff --git a/mysql/mysql_test.go b/mysql/mysql_test.go index 6160242c..c8391c3f 100644 --- a/mysql/mysql_test.go +++ b/mysql/mysql_test.go @@ -90,6 +90,22 @@ func Test_MYSQL_Set(t *testing.T) { require.NoError(t, err) } +func Test_MYSQL_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_MYSQL_Set_Override(t *testing.T) { var ( key = "john" @@ -123,6 +139,26 @@ func Test_MYSQL_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_MYSQL_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_MYSQL_Set_Expiration(t *testing.T) { var ( key = "john" @@ -183,6 +219,29 @@ func Test_MYSQL_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_MYSQL_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_MYSQL_Reset(t *testing.T) { val := []byte("doe") @@ -207,6 +266,33 @@ func Test_MYSQL_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_MYSQL_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_MYSQL_GC(t *testing.T) { testVal := []byte("doe") diff --git a/nats/config.go b/nats/config.go index af38ec18..4d3049cb 100644 --- a/nats/config.go +++ b/nats/config.go @@ -22,6 +22,8 @@ type Config struct { KeyValueConfig jetstream.KeyValueConfig // Wait for connection to be established, default: 250ms WaitForConnection time.Duration + // Reset clears any existing keys in existing bucket default: false + Reset bool } // ConfigDefault is the default config diff --git a/nats/nats.go b/nats/nats.go index 05f710ad..a3b02584 100644 --- a/nats/nats.go +++ b/nats/nats.go @@ -20,7 +20,6 @@ type Storage struct { nc *nats.Conn kv jetstream.KeyValue err error - ctx context.Context cfg Config mu sync.RWMutex } @@ -42,7 +41,7 @@ func (s *Storage) connectHandler(nc *nats.Conn) { var err error s.kv, err = newNatsKV( nc, - s.ctx, + context.Background(), s.cfg.KeyValueConfig, ) if err != nil { @@ -118,7 +117,6 @@ func New(config ...Config) *Storage { cfg := configDefault(config...) storage := &Storage{ - ctx: cfg.Context, cfg: cfg, } @@ -156,11 +154,19 @@ func New(config ...Config) *Storage { // TODO improve this crude way to wait for the connection to be established time.Sleep(cfg.WaitForConnection) + // Reset bucket + if cfg.Reset { + err = storage.Reset() + if err != nil { + panic(err) + } + } + return storage } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } @@ -172,7 +178,7 @@ func (s *Storage) Get(key string) ([]byte, error) { return nil, fmt.Errorf("kv not initialized: %v", s.err) } - v, err := kv.Get(s.ctx, key) + v, err := kv.Get(ctx, key) if err != nil { if errors.Is(err, jetstream.ErrKeyNotFound) { return nil, nil @@ -185,15 +191,20 @@ func (s *Storage) Get(key string) ([]byte, error) { bytes.NewBuffer(v.Value())). Decode(&e) if err != nil || e.Expiry <= time.Now().Unix() { - _ = kv.Delete(s.ctx, key) + _ = kv.Delete(ctx, key) return nil, nil } return e.Data, nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext key with value and expiry with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } @@ -221,9 +232,9 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { } // set - _, err = kv.Put(s.ctx, key, e.Bytes()) + _, err = kv.Put(ctx, key, e.Bytes()) if errors.Is(err, jetstream.ErrKeyNotFound) { - _, err := kv.Create(s.ctx, key, e.Bytes()) + _, err := kv.Create(ctx, key, e.Bytes()) if err != nil { return fmt.Errorf("create: %w", err) } @@ -232,8 +243,13 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } -// Delete key by key -func (s *Storage) Delete(key string) error { +// Set key with value and expiry +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext key by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } @@ -246,18 +262,23 @@ func (s *Storage) Delete(key string) error { return fmt.Errorf("kv not initialized: %v", s.err) } - return kv.Delete(s.ctx, key) + return kv.Delete(ctx, key) } -// Reset all keys -func (s *Storage) Reset() error { +// Delete key by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext all keys with context +func (s *Storage) ResetWithContext(ctx context.Context) error { js, err := jetstream.New(s.nc) if err != nil { return fmt.Errorf("get jetstream: %w", err) } // Delete the bucket - err = js.DeleteKeyValue(s.ctx, s.cfg.KeyValueConfig.Bucket) + err = js.DeleteKeyValue(ctx, s.cfg.KeyValueConfig.Bucket) if err != nil { return fmt.Errorf("delete kv: %w", err) } @@ -267,7 +288,7 @@ func (s *Storage) Reset() error { defer s.mu.Unlock() s.kv, err = newNatsKV( s.nc, - s.ctx, + ctx, s.cfg.KeyValueConfig, ) if err != nil { @@ -279,6 +300,11 @@ func (s *Storage) Reset() error { return nil } +// Reset all keys +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the nats connection func (s *Storage) Close() error { s.mu.RLock() @@ -303,7 +329,7 @@ func (s *Storage) Keys() ([]string, error) { return nil, fmt.Errorf("kv not initialized: %v", s.err) } - keyLister, err := kv.ListKeys(s.ctx) + keyLister, err := kv.ListKeys(context.Background()) if err != nil { return nil, fmt.Errorf("keys: %w", err) diff --git a/nats/nats_test.go b/nats/nats_test.go index 5ff6d1d1..2227a61b 100644 --- a/nats/nats_test.go +++ b/nats/nats_test.go @@ -137,6 +137,7 @@ func newTestStore(t testing.TB) *Storage { Bucket: "test", Storage: jetstream.MemoryStorage, }, + Reset: true, }) } @@ -157,6 +158,25 @@ func Test_Storage_Nats_Set(t *testing.T) { require.Len(t, keys, 1) } +func Test_Storage_Nats_SetWithContext(t *testing.T) { + var ( + testStore = newTestStore(t) + key = "john" + val = []byte("doe") + ) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) + + keys, err := testStore.Keys() + require.NoError(t, err) + require.Len(t, keys, 0) +} + func Test_Storage_Nats_Set_Overwrite(t *testing.T) { var ( key = "john" @@ -202,6 +222,30 @@ func Test_Storage_Nats_Get(t *testing.T) { require.Len(t, keys, 1) } +func Test_Storage_Nats_GetWithContext(t *testing.T) { + var ( + testStore = newTestStore(t) + key = "john" + val = []byte("doe") + ) + + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.Set(key, val, 30*time.Second) + require.NoError(t, err) + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) + + keys, err := testStore.Keys() + require.NoError(t, err) + require.Len(t, keys, 1) +} + func Test_Storage_Nats_Set_Expiration(t *testing.T) { var ( key = "john" @@ -300,6 +344,35 @@ func Test_Storage_Nats_Delete(t *testing.T) { require.Nil(t, keys) } +func Test_Storage_Nats_DeleteWithContext(t *testing.T) { + var ( + testStore = newTestStore(t) + key = "john" + val = []byte("doe") + ) + + defer testStore.Close() + + err := testStore.Set(key, val, 5*time.Second) + require.NoError(t, err) + + keys, err := testStore.Keys() + require.NoError(t, err) + require.Len(t, keys, 1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) + + require.NoError(t, testStore.Reset()) +} + func Test_Storage_Nats_Reset(t *testing.T) { testStore := newTestStore(t) defer testStore.Close() @@ -332,6 +405,41 @@ func Test_Storage_Nats_Reset(t *testing.T) { require.Nil(t, keys) } +func Test_Storage_Nats_ResetWithContext(t *testing.T) { + testStore := newTestStore(t) + defer testStore.Close() + + val := []byte("doe") + + err := testStore.Set("john1", val, 5*time.Second) + require.NoError(t, err) + + err = testStore.Set("john2", val, 5*time.Second) + require.NoError(t, err) + + keys, err := testStore.Keys() + require.NoError(t, err) + require.Len(t, keys, 2) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) + + keys, err = testStore.Keys() + require.NoError(t, err) + require.Len(t, keys, 2) +} + func Test_Storage_Nats_Close(t *testing.T) { testStore := newTestStore(t) require.NoError(t, testStore.Close()) diff --git a/neo4j/neo4j.go b/neo4j/neo4j.go index 8e438e9b..794ebf39 100644 --- a/neo4j/neo4j.go +++ b/neo4j/neo4j.go @@ -101,13 +101,11 @@ func New(config ...Config) *Storage { return store } -func (s *Storage) Get(key string) ([]byte, error) { +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - ctx := context.Background() - res, err := neo4j.ExecuteQuery( ctx, s.db, s.cypherMatch, map[string]any{"key": key}, @@ -139,8 +137,12 @@ func (s *Storage) Get(key string) ([]byte, error) { return model.Val, nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext key with value and expiration time with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } @@ -156,8 +158,6 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { Exp: expireAt, } - ctx := context.Background() - _, err := neo4j.ExecuteQuery( ctx, s.db, s.cypherMerge, map[string]any{"key": data.Key, "val": data.Val, "exp": data.Exp}, @@ -167,20 +167,30 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } -// Delete value by key -func (s *Storage) Delete(key string) error { +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext value by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - _, err := neo4j.ExecuteQuery(context.Background(), s.db, s.cypherDelete, map[string]any{"key": key}, neo4j.EagerResultTransformer) + _, err := neo4j.ExecuteQuery(ctx, s.db, s.cypherDelete, map[string]any{"key": key}, neo4j.EagerResultTransformer) return err } -// Reset all keys. Remove all nodes -func (s *Storage) Reset() error { +// Delete value by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext all keys with context. Remove all nodes +func (s *Storage) ResetWithContext(ctx context.Context) error { _, err := neo4j.ExecuteQuery( - context.Background(), s.db, s.cypherReset, + ctx, s.db, s.cypherReset, nil, neo4j.EagerResultTransformer, ) @@ -188,6 +198,11 @@ func (s *Storage) Reset() error { return err } +// Reset all keys. Remove all nodes +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the database func (s *Storage) Close() error { s.done <- struct{}{} diff --git a/neo4j/neo4j_test.go b/neo4j/neo4j_test.go index 35724cac..80c1681e 100644 --- a/neo4j/neo4j_test.go +++ b/neo4j/neo4j_test.go @@ -79,6 +79,19 @@ func Test_Neo4jStore_Set(t *testing.T) { require.NoError(t, err) } +func Test_Neo4jStore_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + + err := testStore.SetWithContext(ctx, key, val, 10*time.Millisecond) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + func Test_Neo4jStore_Upsert(t *testing.T) { var ( key = "john" diff --git a/postgres/postgres.go b/postgres/postgres.go index 3250b187..3175906d 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -129,12 +129,12 @@ func New(config ...Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - row := s.db.QueryRow(context.Background(), s.sqlSelect, key) + row := s.db.QueryRow(ctx, s.sqlSelect, key) // Add db response to data var ( data []byte @@ -155,8 +155,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext sets key with value +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -165,26 +170,41 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if exp != 0 { expSeconds = time.Now().Add(exp).Unix() } - _, err := s.db.Exec(context.Background(), s.sqlInsert, key, val, expSeconds, val, expSeconds) + _, err := s.db.Exec(ctx, s.sqlInsert, key, val, expSeconds, val, expSeconds) return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { +// Set sets key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext deletes entry by key +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } - _, err := s.db.Exec(context.Background(), s.sqlDelete, key) + _, err := s.db.Exec(ctx, s.sqlDelete, key) return err } -// Reset all entries, including unexpired -func (s *Storage) Reset() error { - _, err := s.db.Exec(context.Background(), s.sqlReset) +// Delete deletes entry by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext resets all entries with context, including unexpired ones +func (s *Storage) ResetWithContext(ctx context.Context) error { + _, err := s.db.Exec(ctx, s.sqlReset) return err } +// Reset resets all entries, including unexpired ones +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the database func (s *Storage) Close() error { s.done <- struct{}{} diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index 25cb7f6e..f20f8f1f 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -239,6 +239,22 @@ func Test_Postgres_Set(t *testing.T) { require.NoError(t, err) } +func Test_Postgres_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Postgres_Set_Override(t *testing.T) { var ( key = "john" @@ -272,6 +288,26 @@ func Test_Postgres_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_Postgres_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_Postgres_Set_Expiration(t *testing.T) { var ( key = "john" @@ -332,6 +368,25 @@ func Test_Postgres_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_Postgres_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Postgres_Reset(t *testing.T) { val := []byte("doe") @@ -356,6 +411,33 @@ func Test_Postgres_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_Postgres_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.NotZero(t, len(result)) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.NotZero(t, len(result)) +} + func Test_Postgres_GC(t *testing.T) { testVal := []byte("doe") diff --git a/redis/redis.go b/redis/redis.go index 1f8ee288..748692a4 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -82,37 +82,57 @@ func New(config ...Config) *Storage { } } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - val, err := s.db.Get(context.Background(), key).Bytes() + val, err := s.db.Get(ctx, key).Bytes() if err == redis.Nil { return nil, nil } return val, err } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext key with value with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } - return s.db.Set(context.Background(), key, val, exp).Err() + return s.db.Set(ctx, key, val, exp).Err() } -// Delete key by key -func (s *Storage) Delete(key string) error { +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext key by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - return s.db.Del(context.Background(), key).Err() + return s.db.Del(ctx, key).Err() +} + +// Delete key by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext all keys with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.db.FlushDB(ctx).Err() } // Reset all keys func (s *Storage) Reset() error { - return s.db.FlushDB(context.Background()).Err() + return s.ResetWithContext(context.Background()) } // Close the database diff --git a/redis/redis_test.go b/redis/redis_test.go index 001130f8..d97126a7 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -1,6 +1,7 @@ package redis import ( + "context" "os" "testing" "time" @@ -63,6 +64,22 @@ func Test_Redis_Set(t *testing.T) { require.NoError(t, err) } +func Test_Redis_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Redis_Set_Override(t *testing.T) { var ( key = "john" @@ -104,6 +121,26 @@ func Test_Redis_Get(t *testing.T) { require.Len(t, keys, 1) } +func Test_Redis_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_Redis_Expiration(t *testing.T) { var ( key = "john" @@ -161,6 +198,29 @@ func Test_Redis_Delete(t *testing.T) { require.Nil(t, keys) } +func Test_Redis_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Redis_Reset(t *testing.T) { val := []byte("doe") @@ -193,6 +253,37 @@ func Test_Redis_Reset(t *testing.T) { require.Nil(t, keys) } +func Test_Redis_ResetWithContext(t *testing.T) { + testStore := newTestStore(t) + defer testStore.Close() + + val := []byte("doe") + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + keys, err := testStore.Keys() + require.NoError(t, err) + require.Len(t, keys, 2) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Redis_Close(t *testing.T) { testStore := newTestStore(t) require.NoError(t, testStore.Close()) diff --git a/rueidis/rueidis.go b/rueidis/rueidis.go index 856cf5e9..f6c99793 100644 --- a/rueidis/rueidis.go +++ b/rueidis/rueidis.go @@ -85,41 +85,61 @@ func New(config ...Config) *Storage { } } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - val, err := s.db.DoCache(context.Background(), s.db.B().Get().Key(key).Cache(), cacheTTL).AsBytes() + val, err := s.db.DoCache(ctx, s.db.B().Get().Key(key).Cache(), cacheTTL).AsBytes() if err != nil && rueidis.IsRedisNil(err) { return nil, nil } return val, err } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext sets key with value with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } if exp > 0 { - return s.db.Do(context.Background(), s.db.B().Set().Key(key).Value(string(val)).Ex(exp).Build()).Error() + return s.db.Do(ctx, s.db.B().Set().Key(key).Value(string(val)).Ex(exp).Build()).Error() } else { - return s.db.Do(context.Background(), s.db.B().Set().Key(key).Value(string(val)).Build()).Error() + return s.db.Do(ctx, s.db.B().Set().Key(key).Value(string(val)).Build()).Error() } } -// Delete key by key -func (s *Storage) Delete(key string) error { +// Set sets key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext deletes key by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - return s.db.Do(context.Background(), s.db.B().Del().Key(key).Build()).Error() + return s.db.Do(ctx, s.db.B().Del().Key(key).Build()).Error() +} + +// Delete deletes key by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext resets all keys with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.db.Do(ctx, s.db.B().Flushdb().Build()).Error() } -// Reset all keys +// Reset resets all keys func (s *Storage) Reset() error { - return s.db.Do(context.Background(), s.db.B().Flushdb().Build()).Error() + return s.ResetWithContext(context.Background()) } // Close the database diff --git a/rueidis/rueidis_test.go b/rueidis/rueidis_test.go index 86111613..24b59ef6 100644 --- a/rueidis/rueidis_test.go +++ b/rueidis/rueidis_test.go @@ -1,6 +1,7 @@ package rueidis import ( + "context" "os" "testing" "time" @@ -61,6 +62,22 @@ func Test_Rueidis_Set(t *testing.T) { require.NoError(t, err) } +func Test_Rueidis_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Rueidis_Set_Override(t *testing.T) { var ( key = "john" @@ -94,6 +111,26 @@ func Test_Rueidis_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_Rueidis_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_Rueidis_Expiration(t *testing.T) { var ( key = "john" @@ -143,6 +180,29 @@ func Test_Rueidis_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_Rueidis_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Rueidis_Reset(t *testing.T) { val := []byte("doe") @@ -167,6 +227,33 @@ func Test_Rueidis_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_Rueidis_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Rueidis_Close(t *testing.T) { testStore := newTestStore(t) require.Nil(t, testStore.Close()) diff --git a/s3/s3.go b/s3/s3.go index 2e189f4c..6d77ba11 100644 --- a/s3/s3.go +++ b/s3/s3.go @@ -78,17 +78,14 @@ func New(config ...Config) *Storage { return storage } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { var nsk *types.NoSuchKey if len(key) <= 0 { return nil, nil } - ctx, cancel := s.requestContext() - defer cancel() - buf := manager.NewWriteAtBuffer([]byte{}) _, err := s.downloader.Download(ctx, buf, &s3.GetObjectInput{ @@ -102,15 +99,20 @@ func (s *Storage) Get(key string) ([]byte, error) { return buf.Bytes(), err } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + ctx, cancel := s.requestContext() + defer cancel() + + return s.GetWithContext(ctx, key) +} + +// SetWithContext key with value and expiration time with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 { return nil } - ctx, cancel := s.requestContext() - defer cancel() - _, err := s.uploader.Upload(ctx, &s3.PutObjectInput{ Bucket: &s.bucket, Key: aws.String(key), @@ -120,15 +122,20 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { +// Set key with value and expiration time +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + ctx, cancel := s.requestContext() + defer cancel() + + return s.SetWithContext(ctx, key, val, exp) +} + +// DeleteWithContext deletes entry by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - ctx, cancel := s.requestContext() - defer cancel() - _, err := s.svc.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: &s.bucket, Key: aws.String(key), @@ -137,11 +144,16 @@ func (s *Storage) Delete(key string) error { return err } -// Reset all entries, including unexpired -func (s *Storage) Reset() error { +// Delete deletes entry by key +func (s *Storage) Delete(key string) error { ctx, cancel := s.requestContext() defer cancel() + return s.DeleteWithContext(ctx, key) +} + +// ResetWithContext resets all entries, including unexpired ones with context +func (s *Storage) ResetWithContext(ctx context.Context) error { paginator := s3.NewListObjectsV2Paginator(s.svc, &s3.ListObjectsV2Input{ Bucket: &s.bucket, }) @@ -173,6 +185,14 @@ func (s *Storage) Reset() error { return nil } +// Reset resets all entries, including unexpired ones +func (s *Storage) Reset() error { + ctx, cancel := s.requestContext() + defer cancel() + + return s.ResetWithContext(ctx) +} + // Close the database func (s *Storage) Close() error { return nil diff --git a/s3/s3_test.go b/s3/s3_test.go index fe2ad8f2..1cddf094 100644 --- a/s3/s3_test.go +++ b/s3/s3_test.go @@ -1,6 +1,7 @@ package s3 import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -19,6 +20,22 @@ func Test_S3_Set(t *testing.T) { require.NoError(t, err) } +func Test_S3_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_S3_Set_Override(t *testing.T) { var ( key = "john" @@ -52,6 +69,26 @@ func Test_S3_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_S3_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_S3_Get_NotExist(t *testing.T) { testStore := newTestStore(t) defer testStore.Close() @@ -81,6 +118,29 @@ func Test_S3_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_S3_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_S3_Reset(t *testing.T) { val := []byte("doe") @@ -105,6 +165,33 @@ func Test_S3_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_S3_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_S3_Close(t *testing.T) { testStore := newTestStore(t) require.NoError(t, testStore.Close()) diff --git a/scylladb/scylladb.go b/scylladb/scylladb.go index 359d88eb..35c58d56 100644 --- a/scylladb/scylladb.go +++ b/scylladb/scylladb.go @@ -1,6 +1,7 @@ package scylladb import ( + "context" "errors" "fmt" "strings" @@ -134,10 +135,10 @@ func (s *Storage) checkSchema(keyspace string) { } } -// Get retrieves a value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext retrieves a value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { var value []byte - if err := s.session.Query(s.selectQuery, key).Scan(&value); err != nil { + if err := s.session.Query(s.selectQuery, key).WithContext(ctx).Scan(&value); err != nil { if errors.Is(err, gocql.ErrNotFound) { return nil, nil } @@ -146,23 +147,43 @@ func (s *Storage) Get(key string) ([]byte, error) { return value, nil } -// Set sets a value by key -func (s *Storage) Set(key string, value []byte, expire time.Duration) error { +// Get retrieves a value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext sets a value by key with context +func (s *Storage) SetWithContext(ctx context.Context, key string, value []byte, expire time.Duration) error { var expiration int if expire != 0 { expiration = int(expire.Round(time.Second).Seconds()) } - return s.session.Query(s.insertQuery, key, value, expiration).Exec() + return s.session.Query(s.insertQuery, key, value, expiration).WithContext(ctx).Exec() +} + +// Set sets a value by key +func (s *Storage) Set(key string, value []byte, expire time.Duration) error { + return s.SetWithContext(context.Background(), key, value, expire) +} + +// DeleteWithContext removes a value by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { + return s.session.Query(s.deleteQuery, key).WithContext(ctx).Exec() } // Delete removes a value by key func (s *Storage) Delete(key string) error { - return s.session.Query(s.deleteQuery, key).Exec() + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext resets all values with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.session.Query(s.resetQuery).WithContext(ctx).Exec() } // Reset resets all values func (s *Storage) Reset() error { - return s.session.Query(s.resetQuery).Exec() + return s.ResetWithContext(context.Background()) } // Close closes the storage diff --git a/scylladb/scylladb_test.go b/scylladb/scylladb_test.go index 486a1982..97212fad 100644 --- a/scylladb/scylladb_test.go +++ b/scylladb/scylladb_test.go @@ -62,6 +62,22 @@ func Test_Scylla_Set(t *testing.T) { require.NoError(t, err) } +func Test_Scylla_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Scylla_Set_Override_Get(t *testing.T) { var ( key = "john" @@ -104,6 +120,26 @@ func Test_Scylla_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_Scylla_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_Scylla_Set_Expiration_Get(t *testing.T) { var ( key = "john" @@ -153,6 +189,29 @@ func Test_Scylla_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_Scylla_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Scylla_Reset(t *testing.T) { var val = []byte("doe") @@ -177,6 +236,33 @@ func Test_Scylla_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_Scylla_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Scylla_Close(t *testing.T) { testStore := newTestStore(t) require.NoError(t, testStore.Close()) diff --git a/sqlite3/sqlite3.go b/sqlite3/sqlite3.go index 0e5f7733..66285750 100644 --- a/sqlite3/sqlite3.go +++ b/sqlite3/sqlite3.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "context" "database/sql" "fmt" "time" @@ -88,12 +89,12 @@ func New(config ...Config) *Storage { return store } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - row := s.db.QueryRow(s.sqlSelect, key) + row := s.db.QueryRowContext(ctx, s.sqlSelect, key) // Add db response to data var ( data = []byte{} @@ -113,8 +114,13 @@ func (s *Storage) Get(key string) ([]byte, error) { return data, nil } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext sets key with value and expiration time with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -123,26 +129,41 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if exp != 0 { expSeconds = time.Now().Add(exp).Unix() } - _, err := s.db.Exec(s.sqlInsert, key, val, expSeconds) + _, err := s.db.ExecContext(ctx, s.sqlInsert, key, val, expSeconds) return err } -// Delete entry by key -func (s *Storage) Delete(key string) error { +// Set sets key with value and expiration time +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext deletes entry by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil } - _, err := s.db.Exec(s.sqlDelete, key) + _, err := s.db.ExecContext(ctx, s.sqlDelete, key) return err } -// Reset all entries, including unexpired -func (s *Storage) Reset() error { - _, err := s.db.Exec(s.sqlReset) +// Delete deletes entry by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext all entries, including unexpired ones with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, s.sqlReset) return err } +// Reset all entries, including unexpired ones +func (s *Storage) Reset() error { + return s.ResetWithContext(context.Background()) +} + // Close the database func (s *Storage) Close() error { s.done <- struct{}{} diff --git a/sqlite3/sqlite3_test.go b/sqlite3/sqlite3_test.go index c7965ac1..0c7cd4b9 100644 --- a/sqlite3/sqlite3_test.go +++ b/sqlite3/sqlite3_test.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "context" "database/sql" "testing" "time" @@ -23,6 +24,19 @@ func Test_SQLite3_Set(t *testing.T) { require.NoError(t, err) } +func Test_SQLite3_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_SQLite3_Set_Override(t *testing.T) { var ( key = "john" @@ -50,6 +64,23 @@ func Test_SQLite3_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_SQLite3_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_SQLite3_Set_Expiration(t *testing.T) { var ( key = "john" @@ -94,6 +125,26 @@ func Test_SQLite3_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_SQLite3_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_SQLite3_Reset(t *testing.T) { val := []byte("doe") @@ -115,6 +166,30 @@ func Test_SQLite3_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_SQLite3_ResetWithContext(t *testing.T) { + val := []byte("doe") + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_SQLite3_GC(t *testing.T) { testVal := []byte("doe") diff --git a/valkey/valkey.go b/valkey/valkey.go index 2bd8831b..fd345d9a 100644 --- a/valkey/valkey.go +++ b/valkey/valkey.go @@ -85,41 +85,61 @@ func New(config ...Config) *Storage { } } -// Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +// GetWithContext gets value by key with context +func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } - val, err := s.db.DoCache(context.Background(), s.db.B().Get().Key(key).Cache(), cacheTTL).AsBytes() + val, err := s.db.DoCache(ctx, s.db.B().Get().Key(key).Cache(), cacheTTL).AsBytes() if err != nil && valkey.IsValkeyNil(err) { return nil, nil } return val, err } -// Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +// Get gets value by key +func (s *Storage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +// SetWithContext sets key with value and expiration with context +func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if len(key) <= 0 || len(val) <= 0 { return nil } if exp > 0 { - return s.db.Do(context.Background(), s.db.B().Set().Key(key).Value(string(val)).Ex(exp).Build()).Error() + return s.db.Do(ctx, s.db.B().Set().Key(key).Value(string(val)).Ex(exp).Build()).Error() } else { - return s.db.Do(context.Background(), s.db.B().Set().Key(key).Value(string(val)).Build()).Error() + return s.db.Do(ctx, s.db.B().Set().Key(key).Value(string(val)).Build()).Error() } } -// Delete key by key -func (s *Storage) Delete(key string) error { +// Set sets key with value and expiration +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +// DeleteWithContext deletes key by key with context +func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if len(key) <= 0 { return nil } - return s.db.Do(context.Background(), s.db.B().Del().Key(key).Build()).Error() + return s.db.Do(ctx, s.db.B().Del().Key(key).Build()).Error() +} + +// Delete deletes key by key +func (s *Storage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +// ResetWithContext resets all keys with context +func (s *Storage) ResetWithContext(ctx context.Context) error { + return s.db.Do(ctx, s.db.B().Flushdb().Build()).Error() } -// Reset all keys +// Reset resets all keys func (s *Storage) Reset() error { - return s.db.Do(context.Background(), s.db.B().Flushdb().Build()).Error() + return s.ResetWithContext(context.Background()) } // Close the database diff --git a/valkey/valkey_test.go b/valkey/valkey_test.go index de9e9929..f95af6dc 100644 --- a/valkey/valkey_test.go +++ b/valkey/valkey_test.go @@ -1,6 +1,7 @@ package valkey import ( + "context" "os" "sync" "testing" @@ -79,6 +80,22 @@ func Test_Valkey_Set(t *testing.T) { require.NoError(t, err) } +func Test_Valkey_SetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := testStore.SetWithContext(ctx, key, val, 0) + require.ErrorIs(t, err, context.Canceled) +} + func Test_Valkey_Set_Override(t *testing.T) { var ( key = "john" @@ -112,6 +129,26 @@ func Test_Valkey_Get(t *testing.T) { require.Equal(t, val, result) } +func Test_Valkey_GetWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := testStore.GetWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, len(result)) +} + func Test_Valkey_Expiration(t *testing.T) { var ( key = "john" @@ -161,6 +198,28 @@ func Test_Valkey_Delete(t *testing.T) { require.Zero(t, len(result)) } +func Test_Valkey_DeleteWithContext(t *testing.T) { + var ( + key = "john" + val = []byte("doe") + ) + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set(key, val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.DeleteWithContext(ctx, key) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get(key) + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Valkey_Reset(t *testing.T) { val := []byte("doe") @@ -185,6 +244,33 @@ func Test_Valkey_Reset(t *testing.T) { require.Zero(t, len(result)) } +func Test_Valkey_ResetWithContext(t *testing.T) { + val := []byte("doe") + + testStore := newTestStore(t) + defer testStore.Close() + + err := testStore.Set("john1", val, 0) + require.NoError(t, err) + + err = testStore.Set("john2", val, 0) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = testStore.ResetWithContext(ctx) + require.ErrorIs(t, err, context.Canceled) + + result, err := testStore.Get("john1") + require.NoError(t, err) + require.Equal(t, val, result) + + result, err = testStore.Get("john2") + require.NoError(t, err) + require.Equal(t, val, result) +} + func Test_Valkey_Close(t *testing.T) { testStore := newTestStore(t) require.Nil(t, testStore.Close())