Skip to content

Commit c6dafbb

Browse files
authored
Modify Operation Handling to not require a name for Done Operations (#371)
* Firebase ML Kit Modify Operation Handling to not require a name for Done Operations * Adding support for TensorFlow 2.x
1 parent 7b4731f commit c6dafbb

File tree

2 files changed

+93
-83
lines changed

2 files changed

+93
-83
lines changed

firebase_admin/mlkit.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import six
2828

2929

30+
from six.moves import urllib
3031
from firebase_admin import _http_client
3132
from firebase_admin import _utils
3233
from firebase_admin import exceptions
@@ -200,6 +201,7 @@ def from_dict(cls, data, app=None):
200201
data_copy = dict(data)
201202
tflite_format = None
202203
tflite_format_data = data_copy.pop('tfliteModel', None)
204+
data_copy.pop('@type', None) # Returned by Operations. (Not needed)
203205
if tflite_format_data:
204206
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
205207
model = Model(model_format=tflite_format)
@@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
495497
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)
496498

497499
@staticmethod
498-
def _assert_tf_version_1_enabled():
500+
def _assert_tf_enabled():
499501
if not _TF_ENABLED:
500502
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
501503
'to install the tensorflow module.')
502-
if not tf.VERSION.startswith('1.'):
503-
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))
504+
if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'):
505+
raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}'
506+
.format(tf.version.VERSION))
507+
508+
@staticmethod
509+
def _tf_convert_from_saved_model(saved_model_dir):
510+
# Same for both v1.x and v2.x
511+
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
512+
return converter.convert()
513+
514+
@staticmethod
515+
def _tf_convert_from_keras_model(keras_model):
516+
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
517+
if tf.version.VERSION.startswith('1.'):
518+
keras_file = 'firebase_keras_model.h5'
519+
tf.keras.models.save_model(keras_model, keras_file)
520+
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
521+
return converter.convert()
522+
else:
523+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
524+
return converter.convert()
504525

505526
@classmethod
506527
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
@@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
518539
Raises:
519540
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
520541
"""
521-
TFLiteGCSModelSource._assert_tf_version_1_enabled()
522-
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
523-
tflite_model = converter.convert()
542+
TFLiteGCSModelSource._assert_tf_enabled()
543+
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
524544
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
525545
return TFLiteGCSModelSource.from_tflite_model_file(
526546
'firebase_mlkit_model.tflite', bucket_name, app)
@@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
541561
Raises:
542562
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
543563
"""
544-
TFLiteGCSModelSource._assert_tf_version_1_enabled()
545-
keras_file = 'keras_model.h5'
546-
tf.keras.models.save_model(keras_model, keras_file)
547-
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
548-
tflite_model = converter.convert()
564+
TFLiteGCSModelSource._assert_tf_enabled()
565+
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
549566
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
550567
return TFLiteGCSModelSource.from_tflite_model_file(
551568
'firebase_mlkit_model.tflite', bucket_name, app)
@@ -810,28 +827,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
810827
"""
811828
if not isinstance(operation, dict):
812829
raise TypeError('Operation must be a dictionary.')
813-
op_name = operation.get('name')
814-
_, model_id = _validate_and_parse_operation_name(op_name)
815-
816-
current_attempt = 0
817-
start_time = datetime.datetime.now()
818-
stop_time = (None if max_time_seconds is None else
819-
start_time + datetime.timedelta(seconds=max_time_seconds))
820-
while wait_for_operation and not operation.get('done'):
821-
# We just got this operation. Wait before getting another
822-
# so we don't exceed the GetOperation maximum request rate.
823-
self._exponential_backoff(current_attempt, stop_time)
824-
operation = self.get_operation(op_name)
825-
current_attempt += 1
826830

827831
if operation.get('done'):
832+
# Operations which are immediately done don't have an operation name
828833
if operation.get('response'):
829834
return operation.get('response')
830835
elif operation.get('error'):
831836
raise _utils.handle_operation_error(operation.get('error'))
832-
833-
# If the operation is not complete or timed out, return a (locked) model instead
834-
return get_model(model_id).as_dict()
837+
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
838+
else:
839+
op_name = operation.get('name')
840+
_, model_id = _validate_and_parse_operation_name(op_name)
841+
current_attempt = 0
842+
start_time = datetime.datetime.now()
843+
stop_time = (None if max_time_seconds is None else
844+
start_time + datetime.timedelta(seconds=max_time_seconds))
845+
while wait_for_operation and not operation.get('done'):
846+
# We just got this operation. Wait before getting another
847+
# so we don't exceed the GetOperation maximum request rate.
848+
self._exponential_backoff(current_attempt, stop_time)
849+
operation = self.get_operation(op_name)
850+
current_attempt += 1
851+
852+
if operation.get('done'):
853+
if operation.get('response'):
854+
return operation.get('response')
855+
elif operation.get('error'):
856+
raise _utils.handle_operation_error(operation.get('error'))
857+
858+
# If the operation is not complete or timed out, return a (locked) model instead
859+
return get_model(model_id).as_dict()
835860

836861

837862
def create_model(self, model):
@@ -844,12 +869,12 @@ def create_model(self, model):
844869

845870
def update_model(self, model, update_mask=None):
846871
_validate_model(model, update_mask)
847-
data = {'model': model.as_dict(for_upload=True)}
872+
path = 'models/{0}'.format(model.model_id)
848873
if update_mask is not None:
849-
data['updateMask'] = update_mask
874+
path = path + '?updateMask={0}'.format(update_mask)
850875
try:
851876
return self.handle_operation(
852-
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
877+
self._client.body('patch', url=path, json=model.as_dict(for_upload=True)))
853878
except requests.exceptions.RequestException as error:
854879
raise _utils.handle_platform_error_from_requests(error)
855880

@@ -876,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token):
876901
_validate_list_filter(list_filter)
877902
_validate_page_size(page_size)
878903
_validate_page_token(page_token)
879-
payload = {}
904+
params = {}
880905
if list_filter:
881-
payload['list_filter'] = list_filter
906+
params['filter'] = list_filter
882907
if page_size:
883-
payload['page_size'] = page_size
908+
params['page_size'] = page_size
884909
if page_token:
885-
payload['page_token'] = page_token
910+
params['page_token'] = page_token
911+
path = 'models'
912+
if params:
913+
# pylint: disable=too-many-function-args
914+
param_str = urllib.parse.urlencode(sorted(params.items()), True)
915+
path = path + '?' + param_str
886916
try:
887-
return self._client.body('get', url='models', json=payload)
917+
return self._client.body('get', url=path)
888918
except requests.exceptions.RequestException as error:
889919
raise _utils.handle_platform_error_from_requests(error)
890920

tests/test_mlkit.py

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -158,23 +158,21 @@
158158
}
159159

160160
OPERATION_DONE_MODEL_JSON_1 = {
161-
'name': OPERATION_NAME_1,
162161
'done': True,
163162
'response': CREATED_UPDATED_MODEL_JSON_1
164163
}
165164
OPERATION_MALFORMED_JSON_1 = {
166-
'name': OPERATION_NAME_1,
167165
'done': True,
168166
# if done is true then either response or error should be populated
169167
}
170168
OPERATION_MISSING_NAME = {
169+
# Name is required if the operation is not done.
171170
'done': False
172171
}
173172
OPERATION_ERROR_CODE = 400
174173
OPERATION_ERROR_MSG = "Invalid argument"
175174
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
176175
OPERATION_ERROR_JSON_1 = {
177-
'name': OPERATION_NAME_1,
178176
'done': True,
179177
'error': {
180178
'code': OPERATION_ERROR_CODE,
@@ -609,17 +607,10 @@ def test_operation_error(self):
609607
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
610608

611609
def test_malformed_operation(self):
612-
recorder = instrument_mlkit_service(
613-
status=[200, 200],
614-
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
615-
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
616-
model = mlkit.create_model(MODEL_1)
617-
assert model == expected_model
618-
assert len(recorder) == 2
619-
assert recorder[0].method == 'POST'
620-
assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
621-
assert recorder[1].method == 'GET'
622-
assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)
610+
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
611+
with pytest.raises(Exception) as excinfo:
612+
mlkit.create_model(MODEL_1)
613+
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
623614

624615
def test_rpc_error_create(self):
625616
create_recorder = instrument_mlkit_service(
@@ -708,17 +699,10 @@ def test_operation_error(self):
708699
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
709700

710701
def test_malformed_operation(self):
711-
recorder = instrument_mlkit_service(
712-
status=[200, 200],
713-
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
714-
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
715-
model = mlkit.update_model(MODEL_1)
716-
assert model == expected_model
717-
assert len(recorder) == 2
718-
assert recorder[0].method == 'PATCH'
719-
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
720-
assert recorder[1].method == 'GET'
721-
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
702+
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
703+
with pytest.raises(Exception) as excinfo:
704+
mlkit.update_model(MODEL_1)
705+
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
722706

723707
def test_rpc_error(self):
724708
create_recorder = instrument_mlkit_service(
@@ -779,7 +763,13 @@ def teardown_class(cls):
779763
testutils.cleanup_apps()
780764

781765
@staticmethod
782-
def _url(project_id, model_id):
766+
def _update_url(project_id, model_id):
767+
update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format(
768+
project_id, model_id)
769+
return BASE_URL + update_url
770+
771+
@staticmethod
772+
def _get_url(project_id, model_id):
783773
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
784774

785775
@staticmethod
@@ -794,10 +784,9 @@ def test_immediate_done(self, publish_function, published):
794784
assert model == CREATED_UPDATED_MODEL_1
795785
assert len(recorder) == 1
796786
assert recorder[0].method == 'PATCH'
797-
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
787+
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
798788
body = json.loads(recorder[0].body.decode())
799-
assert body.get('model', {}).get('state', {}).get('published', None) is published
800-
assert body.get('updateMask', {}) == 'state.published'
789+
assert body.get('state', {}).get('published', None) is published
801790

802791
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
803792
def test_returns_locked(self, publish_function):
@@ -810,9 +799,9 @@ def test_returns_locked(self, publish_function):
810799
assert model == expected_model
811800
assert len(recorder) == 2
812801
assert recorder[0].method == 'PATCH'
813-
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
802+
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
814803
assert recorder[1].method == 'GET'
815-
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
804+
assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)
816805

817806
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
818807
def test_operation_error(self, publish_function):
@@ -824,17 +813,10 @@ def test_operation_error(self, publish_function):
824813

825814
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
826815
def test_malformed_operation(self, publish_function):
827-
recorder = instrument_mlkit_service(
828-
status=[200, 200],
829-
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
830-
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
831-
model = publish_function(MODEL_ID_1)
832-
assert model == expected_model
833-
assert len(recorder) == 2
834-
assert recorder[0].method == 'PATCH'
835-
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
836-
assert recorder[1].method == 'GET'
837-
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
816+
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
817+
with pytest.raises(Exception) as excinfo:
818+
publish_function(MODEL_ID_1)
819+
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
838820

839821
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
840822
def test_rpc_error(self, publish_function):
@@ -996,12 +978,10 @@ def test_list_models_with_all_args(self):
996978
page_token=PAGE_TOKEN)
997979
assert len(recorder) == 1
998980
assert recorder[0].method == 'GET'
999-
assert recorder[0].url == TestListModels._url(PROJECT_ID)
1000-
assert json.loads(recorder[0].body.decode()) == {
1001-
'list_filter': 'display_name=displayName3',
1002-
'page_size': 10,
1003-
'page_token': PAGE_TOKEN
1004-
}
981+
assert recorder[0].url == (
982+
TestListModels._url(PROJECT_ID) +
983+
'?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'
984+
.format(PAGE_TOKEN))
1005985
assert isinstance(models_page, mlkit.ListModelsPage)
1006986
assert len(models_page.models) == 1
1007987
assert models_page.models[0] == MODEL_3

0 commit comments

Comments
 (0)