Skip to content

Migrated the db module to the new exception types #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
403: exceptions.PERMISSION_DENIED,
404: exceptions.NOT_FOUND,
409: exceptions.CONFLICT,
412: exceptions.FAILED_PRECONDITION,
429: exceptions.RESOURCE_EXHAUSTED,
500: exceptions.INTERNAL,
503: exceptions.UNAVAILABLE,
Expand Down
95 changes: 47 additions & 48 deletions firebase_admin/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from six.moves import urllib

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import _http_client
from firebase_admin import _sseclient
from firebase_admin import _utils
Expand Down Expand Up @@ -209,7 +210,7 @@ def get(self, etag=False, shallow=False):

Raises:
ValueError: If both ``etag`` and ``shallow`` are set to True.
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
if etag:
if shallow:
Expand All @@ -236,7 +237,7 @@ def get_if_changed(self, etag):

Raises:
ValueError: If the ETag is not a string.
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
if not isinstance(etag, six.string_types):
raise ValueError('ETag must be a string.')
Expand All @@ -258,7 +259,7 @@ def set(self, value):
Raises:
ValueError: If the provided value is None.
TypeError: If the value is not JSON-serializable.
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
if value is None:
raise ValueError('Value must not be None.')
Expand All @@ -281,7 +282,7 @@ def set_if_unchanged(self, expected_etag, value):

Raises:
ValueError: If the value is None, or if expected_etag is not a string.
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
# pylint: disable=missing-raises-doc
if not isinstance(expected_etag, six.string_types):
Expand All @@ -293,11 +294,11 @@ def set_if_unchanged(self, expected_etag, value):
headers = self._client.headers(
'put', self._add_suffix(), json=value, headers={'if-match': expected_etag})
return True, value, headers.get('ETag')
except ApiCallError as error:
detail = error.detail
if detail.response is not None and 'ETag' in detail.response.headers:
etag = detail.response.headers['ETag']
snapshot = detail.response.json()
except exceptions.FailedPreconditionError as error:
http_response = error.http_response
if http_response is not None and 'ETag' in http_response.headers:
etag = http_response.headers['ETag']
snapshot = http_response.json()
return False, snapshot, etag
else:
raise error
Expand All @@ -317,7 +318,7 @@ def push(self, value=''):
Raises:
ValueError: If the value is None.
TypeError: If the value is not JSON-serializable.
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
if value is None:
raise ValueError('Value must not be None.')
Expand All @@ -333,7 +334,7 @@ def update(self, value):

Raises:
ValueError: If value is empty or not a dictionary.
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
if not value or not isinstance(value, dict):
raise ValueError('Value argument must be a non-empty dictionary.')
Expand All @@ -345,7 +346,7 @@ def delete(self):
"""Deletes this node from the database.

Raises:
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
self._client.request('delete', self._add_suffix())

Expand All @@ -371,7 +372,7 @@ def listen(self, callback):
ListenerRegistration: An object that can be used to stop the event listener.

Raises:
ApiCallError: If an error occurs while starting the initial HTTP connection.
FirebaseError: If an error occurs while starting the initial HTTP connection.
"""
session = _sseclient.KeepAuthSession(self._client.credential)
return self._listen_with_session(callback, session)
Expand All @@ -387,9 +388,9 @@ def transaction(self, transaction_update):
value of this reference into a new value. If another client writes to this location before
the new value is successfully saved, the update function is called again with the new
current value, and the write will be retried. In case of repeated failures, this method
will retry the transaction up to 25 times before giving up and raising a TransactionError.
The update function may also force an early abort by raising an exception instead of
returning a value.
will retry the transaction up to 25 times before giving up and raising a
TransactionAbortedError. The update function may also force an early abort by raising an
exception instead of returning a value.

Args:
transaction_update: A function which will be passed the current data stored at this
Expand All @@ -402,7 +403,7 @@ def transaction(self, transaction_update):
object: New value of the current database Reference (only if the transaction commits).

Raises:
TransactionError: If the transaction aborts after exhausting all retry attempts.
TransactionAbortedError: If the transaction aborts after exhausting all retry attempts.
ValueError: If transaction_update is not a function.
"""
if not callable(transaction_update):
Expand All @@ -416,7 +417,8 @@ def transaction(self, transaction_update):
if success:
return new_data
tries += 1
raise TransactionError('Transaction aborted after failed retries.')

raise TransactionAbortedError('Transaction aborted after failed retries.')

def order_by_child(self, path):
"""Returns a Query that orders data by child values.
Expand Down Expand Up @@ -468,7 +470,7 @@ def _listen_with_session(self, callback, session):
sse = _sseclient.SSEClient(url, session)
return ListenerRegistration(callback, sse)
except requests.exceptions.RequestException as error:
raise ApiCallError(_Client.extract_error_message(error), error)
raise _Client.handle_rtdb_error(error)


class Query(object):
Expand Down Expand Up @@ -614,28 +616,19 @@ def get(self):
object: Decoded JSON result of the Query.

Raises:
ApiCallError: If an error occurs while communicating with the remote database server.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
result = self._client.body('get', self._pathurl, params=self._querystr)
if isinstance(result, (dict, list)) and self._order_by != '$priority':
return _Sorter(result, self._order_by).get()
return result


class ApiCallError(Exception):
"""Represents an Exception encountered while invoking the Firebase database server API."""

def __init__(self, message, error):
Exception.__init__(self, message)
self.detail = error


class TransactionError(Exception):
"""Represents an Exception encountered while performing a transaction."""
class TransactionAbortedError(exceptions.AbortedError):
"""A transaction was aborted aftr exceeding the maximum number of retries."""

def __init__(self, message):
Exception.__init__(self, message)

exceptions.AbortedError.__init__(self, message)


class _Sorter(object):
Expand Down Expand Up @@ -934,7 +927,7 @@ def request(self, method, url, **kwargs):
Response: An HTTP response object.

Raises:
ApiCallError: If an error occurs while making the HTTP call.
FirebaseError: If an error occurs while making the HTTP call.
"""
query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params)
extra_params = kwargs.get('params')
Expand All @@ -950,33 +943,39 @@ def request(self, method, url, **kwargs):
try:
return super(_Client, self).request(method, url, **kwargs)
except requests.exceptions.RequestException as error:
raise ApiCallError(_Client.extract_error_message(error), error)
raise _Client.handle_rtdb_error(error)

@classmethod
def handle_rtdb_error(cls, error):
"""Converts an error encountered while calling RTDB into a FirebaseError."""
if error.response is None:
return _utils.handle_requests_error(error)

message = cls._extract_error_message(error.response)
return _utils.handle_requests_error(error, message=message)

@classmethod
def extract_error_message(cls, error):
"""Extracts an error message from an exception.
def _extract_error_message(cls, response):
"""Extracts an error message from an error response.

If the server has not sent any response, simply converts the exception into a string.
If the server has sent a JSON response with an 'error' field, which is the typical
behavior of the Realtime Database REST API, parses the response to retrieve the error
message. If the server has sent a non-JSON response, returns the full response
as the error message.

Args:
error: An exception raised by the requests library.

Returns:
str: A string error message extracted from the exception.
"""
if error.response is None:
return str(error)
message = None
try:
data = error.response.json()
# RTDB error format: {"error": "text message"}
data = response.json()
if isinstance(data, dict):
return '{0}\nReason: {1}'.format(error, data.get('error', 'unknown'))
message = data.get('error')
except ValueError:
pass
return '{0}\nReason: {1}'.format(error, error.response.content.decode())

if not message:
message = 'Unexpected response from database: {0}'.format(response.content.decode())

return message


class _EmulatorAdminCredentials(google.auth.credentials.Credentials):
Expand Down
33 changes: 15 additions & 18 deletions integration/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import firebase_admin
from firebase_admin import db
from firebase_admin import exceptions
from integration import conftest
from tests import testutils

Expand Down Expand Up @@ -359,30 +360,26 @@ def init_ref(self, path, app):
admin_ref.set('test')
assert admin_ref.get() == 'test'

def check_permission_error(self, excinfo):
assert isinstance(excinfo.value, db.ApiCallError)
assert 'Reason: Permission denied' in str(excinfo.value)

def test_no_access(self, app, override_app):
path = '_adminsdk/python/admin'
self.init_ref(path, app)
user_ref = db.reference(path, override_app)
with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
assert user_ref.get()
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'

with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
user_ref.set('test2')
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'

def test_read(self, app, override_app):
path = '_adminsdk/python/protected/user2'
self.init_ref(path, app)
user_ref = db.reference(path, override_app)
assert user_ref.get() == 'test'
with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
user_ref.set('test2')
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'

def test_read_write(self, app, override_app):
path = '_adminsdk/python/protected/user1'
Expand All @@ -394,9 +391,9 @@ def test_read_write(self, app, override_app):

def test_query(self, override_app):
user_ref = db.reference('_adminsdk/python/protected', override_app)
with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
user_ref.order_by_key().limit_to_first(2).get()
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'

def test_none_auth_override(self, app, none_override_app):
path = '_adminsdk/python/public'
Expand All @@ -405,14 +402,14 @@ def test_none_auth_override(self, app, none_override_app):
assert public_ref.get() == 'test'

ref = db.reference('_adminsdk/python', none_override_app)
with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
assert ref.child('protected/user1').get()
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'

with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
assert ref.child('protected/user2').get()
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'

with pytest.raises(db.ApiCallError) as excinfo:
with pytest.raises(exceptions.UnauthenticatedError) as excinfo:
assert ref.child('admin').get()
self.check_permission_error(excinfo)
assert str(excinfo.value) == 'Permission denied'
2 changes: 1 addition & 1 deletion snippets/database/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def increment_votes(current_value):
try:
new_vote_count = upvotes_ref.transaction(increment_votes)
print('Transaction completed')
except db.TransactionError:
except db.TransactionAbortedError:
print('Transaction failed to commit')
# [END transaction]

Expand Down
Loading