From 9caddb18bcb89a76c153bf2c23d13bfeae77d86e Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 20 Mar 2020 11:55:43 -0400 Subject: [PATCH 1/3] modifying operation handling to support backend changes --- firebase_admin/ml.py | 21 +++++++++++---------- tests/test_ml.py | 32 ++++++++++++++++---------------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index b7d8b818b..e91f45bb5 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -53,10 +53,9 @@ _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( - r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') + r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( - r'^operations/project/(?P[^/]+)/model/(?P[A-Za-z0-9_-]{1,60})' + - r'/operation/[^/]+$') + r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') def _get_ml_service(app): @@ -712,11 +711,10 @@ def _validate_model_id(model_id): raise ValueError('Model ID format is invalid.') -def _validate_and_parse_operation_name(op_name): - matcher = _OPERATION_NAME_PATTERN.match(op_name) - if not matcher: +def _validate_operation_name(op_name): + if not _OPERATION_NAME_PATTERN.match(op_name): raise ValueError('Operation name format is invalid.') - return matcher.group('project_id'), matcher.group('model_id') + return op_name def _validate_display_name(display_name): @@ -793,7 +791,7 @@ def __init__(self, app): base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): - _validate_and_parse_operation_name(op_name) + _validate_operation_name(op_name) try: return self._operation_client.body('get', url=op_name) except requests.exceptions.RequestException as error: @@ -841,8 +839,11 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds raise _utils.handle_operation_error(operation.get('error')) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) + op_name = _validate_operation_name(operation.get('name')) + metadata = operation.get('metadata') + if metadata is None or '@type' not in metadata or 'ModelOperationMetadata' not in metadata.get('@type'): + raise TypeError('Unknown type of operation metadata.') + _, model_id = _validate_and_parse_name(metadata.get('name')) current_attempt = 0 start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else diff --git a/tests/test_ml.py b/tests/test_ml.py index 6accba1cb..439aeaac3 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -25,7 +25,7 @@ BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' -PROJECT_ID = 'myProject1' +PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' @@ -85,11 +85,11 @@ } } -OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1) +OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) OPERATION_NOT_DONE_JSON_1 = { 'name': OPERATION_NAME_1, 'metadata': { - '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', + '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' } @@ -265,10 +265,10 @@ INVALID_OP_NAME_ARGS = [ 'abc', '123', - 'projects/operations/project/1234/model/abc/operation/123', - 'operations/project/model/abc/operation/123', - 'operations/project/123/model/$#@/operation/123', - 'operations/project/1234/model/abc/operation/123/extrathing', + 'operations/project/1234/model/abc/operation/123', + 'projects/operations/123', + 'projects/$#@/operations/123', + 'projects/1234/operations/123/extrathing', ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(ml._MAX_PAGE_SIZE) @@ -348,9 +348,9 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) def test_model_success_err_state_lro(self): model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -534,7 +534,7 @@ def test_wait_for_unlocked(self): assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestModel._op_url(PROJECT_ID) def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -564,9 +564,9 @@ def _url(project_id): return BASE_URL + 'projects/{0}/models'.format(project_id) @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) @staticmethod def _get_url(project_id, model_id): @@ -660,9 +660,9 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) @@ -765,9 +765,9 @@ def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): From 69f9b7e8880dbb5cf793afd10e54c56e299aee63 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 20 Mar 2020 12:21:00 -0400 Subject: [PATCH 2/3] fix lint --- firebase_admin/ml.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index e91f45bb5..3f9f8c23a 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -841,7 +841,8 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds op_name = _validate_operation_name(operation.get('name')) metadata = operation.get('metadata') - if metadata is None or '@type' not in metadata or 'ModelOperationMetadata' not in metadata.get('@type'): + if (metadata is None or '@type' not in metadata or + 'ModelOperationMetadata' not in metadata.get('@type')): raise TypeError('Unknown type of operation metadata.') _, model_id = _validate_and_parse_name(metadata.get('name')) current_attempt = 0 From 21367ab882eb01894e8d56d94ccb401a919e8036 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 20 Mar 2020 14:24:41 -0400 Subject: [PATCH 3/3] review comments --- firebase_admin/ml.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 3f9f8c23a..d6c14c7ac 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -840,9 +840,9 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') op_name = _validate_operation_name(operation.get('name')) - metadata = operation.get('metadata') - if (metadata is None or '@type' not in metadata or - 'ModelOperationMetadata' not in metadata.get('@type')): + metadata = operation.get('metadata', {}) + metadata_type = metadata.get('@type', '') + if not metadata_type.endswith('ModelOperationMetadata'): raise TypeError('Unknown type of operation metadata.') _, model_id = _validate_and_parse_name(metadata.get('name')) current_attempt = 0