Skip to content

Commit bcefca8

Browse files
authored
Modifying operation handling to support backend changes (#423)
* modifying operation handling to support backend changes
1 parent 7295ea4 commit bcefca8

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

firebase_admin/ml.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@
5353
_GCS_TFLITE_URI_PATTERN = re.compile(
5454
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
5555
_RESOURCE_NAME_PATTERN = re.compile(
56-
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
56+
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
5757
_OPERATION_NAME_PATTERN = re.compile(
58-
r'^operations/project/(?P<project_id>[^/]+)/model/(?P<model_id>[A-Za-z0-9_-]{1,60})' +
59-
r'/operation/[^/]+$')
58+
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/operations/[^/]+$')
6059

6160

6261
def _get_ml_service(app):
@@ -712,11 +711,10 @@ def _validate_model_id(model_id):
712711
raise ValueError('Model ID format is invalid.')
713712

714713

715-
def _validate_and_parse_operation_name(op_name):
716-
matcher = _OPERATION_NAME_PATTERN.match(op_name)
717-
if not matcher:
714+
def _validate_operation_name(op_name):
715+
if not _OPERATION_NAME_PATTERN.match(op_name):
718716
raise ValueError('Operation name format is invalid.')
719-
return matcher.group('project_id'), matcher.group('model_id')
717+
return op_name
720718

721719

722720
def _validate_display_name(display_name):
@@ -793,7 +791,7 @@ def __init__(self, app):
793791
base_url=_MLService.OPERATION_URL)
794792

795793
def get_operation(self, op_name):
796-
_validate_and_parse_operation_name(op_name)
794+
_validate_operation_name(op_name)
797795
try:
798796
return self._operation_client.body('get', url=op_name)
799797
except requests.exceptions.RequestException as error:
@@ -841,8 +839,12 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
841839
raise _utils.handle_operation_error(operation.get('error'))
842840
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
843841

844-
op_name = operation.get('name')
845-
_, model_id = _validate_and_parse_operation_name(op_name)
842+
op_name = _validate_operation_name(operation.get('name'))
843+
metadata = operation.get('metadata', {})
844+
metadata_type = metadata.get('@type', '')
845+
if not metadata_type.endswith('ModelOperationMetadata'):
846+
raise TypeError('Unknown type of operation metadata.')
847+
_, model_id = _validate_and_parse_name(metadata.get('name'))
846848
current_attempt = 0
847849
start_time = datetime.datetime.now()
848850
stop_time = (None if max_time_seconds is None else

tests/test_ml.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'
28-
PROJECT_ID = 'myProject1'
28+
PROJECT_ID = 'my-project-1'
2929
PAGE_TOKEN = 'pageToken'
3030
NEXT_PAGE_TOKEN = 'nextPageToken'
3131

@@ -85,11 +85,11 @@
8585
}
8686
}
8787

88-
OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1)
88+
OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID)
8989
OPERATION_NOT_DONE_JSON_1 = {
9090
'name': OPERATION_NAME_1,
9191
'metadata': {
92-
'@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata',
92+
'@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata',
9393
'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1),
9494
'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING'
9595
}
@@ -265,10 +265,10 @@
265265
INVALID_OP_NAME_ARGS = [
266266
'abc',
267267
'123',
268-
'projects/operations/project/1234/model/abc/operation/123',
269-
'operations/project/model/abc/operation/123',
270-
'operations/project/123/model/$#@/operation/123',
271-
'operations/project/1234/model/abc/operation/123/extrathing',
268+
'operations/project/1234/model/abc/operation/123',
269+
'projects/operations/123',
270+
'projects/$#@/operations/123',
271+
'projects/1234/operations/123/extrathing',
272272
]
273273
PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \
274274
'1 and {0}'.format(ml._MAX_PAGE_SIZE)
@@ -348,9 +348,9 @@ def teardown_class(cls):
348348
testutils.cleanup_apps()
349349

350350
@staticmethod
351-
def _op_url(project_id, model_id):
351+
def _op_url(project_id):
352352
return BASE_URL + \
353-
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
353+
'projects/{0}/operations/123'.format(project_id)
354354

355355
def test_model_success_err_state_lro(self):
356356
model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON)
@@ -534,7 +534,7 @@ def test_wait_for_unlocked(self):
534534
assert model == FULL_MODEL_PUBLISHED
535535
assert len(recorder) == 1
536536
assert recorder[0].method == 'GET'
537-
assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1)
537+
assert recorder[0].url == TestModel._op_url(PROJECT_ID)
538538

539539
def test_wait_for_unlocked_timeout(self):
540540
recorder = instrument_ml_service(
@@ -564,9 +564,9 @@ def _url(project_id):
564564
return BASE_URL + 'projects/{0}/models'.format(project_id)
565565

566566
@staticmethod
567-
def _op_url(project_id, model_id):
567+
def _op_url(project_id):
568568
return BASE_URL + \
569-
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
569+
'projects/{0}/operations/123'.format(project_id)
570570

571571
@staticmethod
572572
def _get_url(project_id, model_id):
@@ -660,9 +660,9 @@ def _url(project_id, model_id):
660660
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
661661

662662
@staticmethod
663-
def _op_url(project_id, model_id):
663+
def _op_url(project_id):
664664
return BASE_URL + \
665-
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
665+
'projects/{0}/operations/123'.format(project_id)
666666

667667
def test_immediate_done(self):
668668
instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE)
@@ -765,9 +765,9 @@ def _get_url(project_id, model_id):
765765
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
766766

767767
@staticmethod
768-
def _op_url(project_id, model_id):
768+
def _op_url(project_id):
769769
return BASE_URL + \
770-
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
770+
'projects/{0}/operations/123'.format(project_id)
771771

772772
@pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS)
773773
def test_immediate_done(self, publish_function, published):

0 commit comments

Comments
 (0)