diff --git a/packages/core/src/transaction.ts b/packages/core/src/transaction.ts index f48a63a11..f3f27e828 100644 --- a/packages/core/src/transaction.ts +++ b/packages/core/src/transaction.ts @@ -210,6 +210,19 @@ class Transaction { return this._state === _states.ACTIVE } + /** + * Closes the transaction + * + * This method will roll back the transaction if it is not already committed or rolled back. + * + * @returns {Promise} An empty promise if closed successfully or error if any error happened during + */ + async close(): Promise { + if (this.isOpen()) { + await this.rollback() + } + } + _onErrorCallback(err: any): Promise { // error will be "acknowledged" by sending a RESET message // database will then forget about this transaction and cleanup all corresponding resources diff --git a/packages/core/test/transaction.test.ts b/packages/core/test/transaction.test.ts new file mode 100644 index 000000000..cf03b52db --- /dev/null +++ b/packages/core/test/transaction.test.ts @@ -0,0 +1,115 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ConnectionProvider, Transaction } from "../src"; +import { Bookmark } from "../src/internal/bookmark"; +import { ConnectionHolder } from "../src/internal/connection-holder"; +import FakeConnection from "./utils/connection.fake"; + +describe('Transaction', () => { + + describe('.close()', () => { + describe('when transaction is open', () => { + it('should roll back the transaction', async () => { + const connection = newFakeConnection() + const tx = newTransaction({ connection }) + + await tx.run('RETURN 1') + await tx.close() + + expect(connection.rollbackInvoked).toEqual(1) + }) + + it('should surface errors during the rollback', async () => { + const expectedError = new Error('rollback error') + const connection = newFakeConnection().withRollbackError(expectedError) + const tx = newTransaction({ connection }) + + await tx.run('RETURN 1') + + try { + await tx.close() + fail('should have thrown') + } catch (error) { + expect(error).toEqual(expectedError) + } + }) + }) + + describe('when transaction is closed', () => { + const commit = async (tx: Transaction) => tx.commit() + const rollback = async (tx: Transaction) => tx.rollback() + const error = async (tx: Transaction, conn: FakeConnection) => { + conn.withRollbackError(new Error('rollback error')) + return tx.rollback().catch(() => { }) + } + + it.each([ + ['commmited', commit], + ['rolled back', rollback], + ['with error', error] + ])('should not roll back the connection', async (_, operation) => { + const connection = newFakeConnection() + const tx = newTransaction({ connection }) + + await operation(tx, connection) + const rollbackInvokedAfterOperation = connection.rollbackInvoked + + await tx.close() + + expect(connection.rollbackInvoked).toEqual(rollbackInvokedAfterOperation) + }) + }) + }) +}) + +function newTransaction({ + connection, + fetchSize = 1000, + highRecordWatermark = 700, + lowRecordWatermark = 300 +}: { + connection: FakeConnection + fetchSize?: number + highRecordWatermark?: number, + lowRecordWatermark?: number +}): Transaction { + const connectionProvider = new ConnectionProvider() + connectionProvider.acquireConnection = () => Promise.resolve(connection) + connectionProvider.close = () => Promise.resolve() + + const connectionHolder = new ConnectionHolder({ connectionProvider }) + connectionHolder.initializeConnection() + + const transaction = new Transaction({ + connectionHolder, + onClose: () => { }, + onBookmark: (_: Bookmark) => { }, + onConnection: () => { }, + reactive: false, + fetchSize, + impersonatedUser: "" + }) + + return transaction +} + +function newFakeConnection(): FakeConnection { + return new FakeConnection() +} diff --git a/packages/core/test/utils/connection.fake.ts b/packages/core/test/utils/connection.fake.ts new file mode 100644 index 000000000..eeb5d1d61 --- /dev/null +++ b/packages/core/test/utils/connection.fake.ts @@ -0,0 +1,183 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Connection, ResultObserver, Record, ResultSummary } from '../../src' +import { ResultStreamObserver } from '../../src/internal/observers' + + +/** + * This class is like a mock of {@link Connection} that tracks invocations count. + * It tries to maintain same "interface" as {@link Connection}. + * It could be replaced with a proper mock by a library like testdouble. + * At the time of writing such libraries require {@link Proxy} support but browser tests execute in + * PhantomJS which does not support proxies. + */ +export default class FakeConnection extends Connection { + private _open: boolean + private _requestRoutingInformationMock: ((params: any) => void) | null + public creationTimestamp: number + public resetInvoked: number + public releaseInvoked: number + public seenQueries: string[] + public seenParameters: any[] + public seenProtocolOptions: any[] + private _server: any + public protocolVersion: number | undefined + public protocolErrorsHandled: number + public seenProtocolErrors: string[] + public seenRequestRoutingInformation: any[] + public rollbackInvoked: number + public _rollbackError: Error | null + + constructor() { + super() + + this._open = true + this._requestRoutingInformationMock = null + this.creationTimestamp = Date.now() + + this.resetInvoked = 0 + this.releaseInvoked = 0 + this.seenQueries = [] + this.seenParameters = [] + this.seenProtocolOptions = [] + this._server = {} + this.protocolVersion = undefined + this.protocolErrorsHandled = 0 + this.seenProtocolErrors = [] + this.seenRequestRoutingInformation = [] + this.rollbackInvoked = 0 + this._rollbackError = null + } + + protocol() { + // return fake protocol object that simply records seen queries and parameters + return { + run: (query: string, parameters: any | undefined, protocolOptions: any | undefined): ResultStreamObserver => { + this.seenQueries.push(query) + this.seenParameters.push(parameters) + this.seenProtocolOptions.push(protocolOptions) + return mockResultStreamObserver(query, parameters) + }, + commitTransaction: () => { + return mockResultStreamObserver('COMMIT', {}) + }, + beginTransaction: () => { + return Promise.resolve() + }, + rollbackTransaction: () => { + this.rollbackInvoked ++ + if (this._rollbackError !== null) { + return mockResultStreamObserverWithError('ROLLBACK', {}, this._rollbackError) + } + return mockResultStreamObserver('ROLLBACK', {}) + }, + requestRoutingInformation: (params: any | undefined) => { + this.seenRequestRoutingInformation.push(params) + if (this._requestRoutingInformationMock) { + this._requestRoutingInformationMock(params) + } + }, + version: this.protocolVersion + } + } + + resetAndFlush() { + this.resetInvoked++ + return Promise.resolve() + } + + _release() { + this.releaseInvoked++ + return Promise.resolve() + } + + isOpen() { + return this._open + } + + isNeverReleased() { + return this.isReleasedTimes(0) + } + + isReleasedOnce() { + return this.isReleasedTimes(1) + } + + isReleasedTimes(times: number) { + return this.resetInvoked === times && this.releaseInvoked === times + } + + _handleProtocolError(message: string) { + this.protocolErrorsHandled++ + this.seenProtocolErrors.push(message) + } + + withProtocolVersion(version: number) { + this.protocolVersion = version + return this + } + + withCreationTimestamp(value: number) { + this.creationTimestamp = value + return this + } + + withRequestRoutingInformationMock(requestRoutingInformationMock: (params: any) => void) { + this._requestRoutingInformationMock = requestRoutingInformationMock + return this + } + + withRollbackError(error: Error) { + this._rollbackError = error + return this + } + + closed() { + this._open = false + return this + } +} + +function mockResultStreamObserverWithError (query: string, parameters: any | undefined, error: Error) { + const observer = mockResultStreamObserver(query, parameters) + observer.subscribe = (observer: ResultObserver) => { + if (observer && observer.onError) { + observer.onError(error) + } + } + return observer +} + +function mockResultStreamObserver(query: string, parameters: any | undefined): ResultStreamObserver { + return { + onError: (error: any) => { }, + onCompleted: () => { }, + onNext: (result: any) => { }, + cancel: () => { }, + prepareToHandleSingleResponse: () => { }, + markCompleted: () => { }, + subscribe: (observer: ResultObserver) => { + if (observer && observer.onCompleted) { + observer.onCompleted(new ResultSummary(query, parameters, {})) + } + + } + } +} diff --git a/packages/neo4j-driver/src/transaction-rx.js b/packages/neo4j-driver/src/transaction-rx.js index 5fa713c5c..11acac062 100644 --- a/packages/neo4j-driver/src/transaction-rx.js +++ b/packages/neo4j-driver/src/transaction-rx.js @@ -90,4 +90,22 @@ export default class RxTransaction { .catch(err => observer.error(err)) }) } + + /** + * Closes the transaction + * + * This method will roll back the transaction if it is not already committed or rolled back. + * + * @returns {Observable} - An empty observable + */ + close () { + return new Observable(observer => { + this._txc + .close() + .then(() => { + observer.complete() + }) + .catch(err => observer.error(err)) + }) + } } diff --git a/packages/neo4j-driver/test/rx/transaction.test.js b/packages/neo4j-driver/test/rx/transaction.test.js index 783a0620b..14ca8fa0b 100644 --- a/packages/neo4j-driver/test/rx/transaction.test.js +++ b/packages/neo4j-driver/test/rx/transaction.test.js @@ -29,6 +29,7 @@ import { } from 'rxjs/operators' import neo4j from '../../src' import RxSession from '../../src/session-rx' +import RxTransaction from '../../src/transaction-rx' import sharedNeo4j from '../internal/shared-neo4j' import { newError } from 'neo4j-driver-core' @@ -148,6 +149,35 @@ describe('#integration-rx transaction', () => { expect(await countNodes(42)).toBe(0) }) + it('should run query and close', async () => { + if (protocolVersion < 4.0) { + return + } + + const result = await session + .beginTransaction() + .pipe( + flatMap(txc => + txc + .run('CREATE (n:Node {id: 42}) RETURN n') + .records() + .pipe( + map(r => r.get('n').properties.id), + concat(txc.close()) + ) + ), + materialize(), + toArray() + ) + .toPromise() + expect(result).toEqual([ + Notification.createNext(neo4j.int(42)), + Notification.createComplete() + ]) + + expect(await countNodes(42)).toBe(0) + }) + it('should run multiple queries and commit', async () => { await verifyCanRunMultipleQueries(true) }) @@ -720,3 +750,37 @@ describe('#integration-rx transaction', () => { .toPromise() } }) + +describe('#unit', () => { + describe('.close()', () => { + it('should delegate to the original Transaction', async () => { + const txc = { + close: jasmine.createSpy('close').and.returnValue(Promise.resolve()) + } + + const transaction = new RxTransaction(txc) + + await transaction.close().toPromise() + + expect(txc.close).toHaveBeenCalled() + }) + + it('should fail if to the original Transaction.close call fails', async () => { + const expectedError = new Error('expected') + const txc = { + close: jasmine + .createSpy('close') + .and.returnValue(Promise.reject(expectedError)) + } + + const transaction = new RxTransaction(txc) + + try { + await transaction.close().toPromise() + fail('should have thrown') + } catch (error) { + expect(error).toBe(expectedError) + } + }) + }) +}) diff --git a/packages/neo4j-driver/test/types/transaction-rx.test.ts b/packages/neo4j-driver/test/types/transaction-rx.test.ts index 6f9275d8b..b10139ceb 100644 --- a/packages/neo4j-driver/test/types/transaction-rx.test.ts +++ b/packages/neo4j-driver/test/types/transaction-rx.test.ts @@ -68,3 +68,7 @@ tx.commit() tx.rollback() .pipe(concat(of('rolled back'))) .subscribe(stringObserver) + +tx.close() + .pipe(concat(of('closed'))) + .subscribe(stringObserver) diff --git a/packages/neo4j-driver/types/transaction-rx.d.ts b/packages/neo4j-driver/types/transaction-rx.d.ts index ddf69708c..64d6494a7 100644 --- a/packages/neo4j-driver/types/transaction-rx.d.ts +++ b/packages/neo4j-driver/types/transaction-rx.d.ts @@ -26,6 +26,8 @@ declare interface RxTransaction { commit(): Observable rollback(): Observable + + close(): Observable } export default RxTransaction diff --git a/packages/testkit-backend/src/request-handlers.js b/packages/testkit-backend/src/request-handlers.js index 1380661c0..d6ac45346 100644 --- a/packages/testkit-backend/src/request-handlers.js +++ b/packages/testkit-backend/src/request-handlers.js @@ -285,6 +285,14 @@ export function TransactionRollback (context, data, wire) { .catch(e => wire.writeError(e)) } +export function TransactionClose (context, data, wire) { + const { txId: id } = data + const { tx } = context.getTx(id) + return tx.close() + .then(() => wire.writeResponse('Transaction', { id })) + .catch(e => wire.writeError(e)) +} + export function SessionLastBookmarks (context, data, wire) { const { sessionId } = data const session = context.getSession(sessionId) @@ -337,6 +345,7 @@ export function GetFeatures (_context, _params, wire) { 'Temporary:DriverMaxConnectionPoolSize', 'Temporary:FastFailingDiscovery', 'Temporary:ResultKeys', + 'Temporary:TransactionClose', ...SUPPORTED_TLS ] })