Skip to content

Modifying operation handling to support backend changes #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^operations/project/(?P<project_id>[^/]+)/model/(?P<model_id>[A-Za-z0-9_-]{1,60})' +
r'/operation/[^/]+$')
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/operations/[^/]+$')


def _get_ml_service(app):
Expand Down Expand Up @@ -712,11 +711,10 @@ def _validate_model_id(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_and_parse_operation_name(op_name):
matcher = _OPERATION_NAME_PATTERN.match(op_name)
if not matcher:
def _validate_operation_name(op_name):
if not _OPERATION_NAME_PATTERN.match(op_name):
raise ValueError('Operation name format is invalid.')
return matcher.group('project_id'), matcher.group('model_id')
return op_name


def _validate_display_name(display_name):
Expand Down Expand Up @@ -793,7 +791,7 @@ def __init__(self, app):
base_url=_MLService.OPERATION_URL)

def get_operation(self, op_name):
_validate_and_parse_operation_name(op_name)
_validate_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
Expand Down Expand Up @@ -841,8 +839,12 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
raise _utils.handle_operation_error(operation.get('error'))
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')

op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)
op_name = _validate_operation_name(operation.get('name'))
metadata = operation.get('metadata', {})
metadata_type = metadata.get('@type', '')
if not metadata_type.endswith('ModelOperationMetadata'):
raise TypeError('Unknown type of operation metadata.')
_, model_id = _validate_and_parse_name(metadata.get('name'))
current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
Expand Down
32 changes: 16 additions & 16 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'
PROJECT_ID = 'myProject1'
PROJECT_ID = 'my-project-1'
PAGE_TOKEN = 'pageToken'
NEXT_PAGE_TOKEN = 'nextPageToken'

Expand Down Expand Up @@ -85,11 +85,11 @@
}
}

OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1)
OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID)
OPERATION_NOT_DONE_JSON_1 = {
'name': OPERATION_NAME_1,
'metadata': {
'@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata',
'@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata',
'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1),
'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING'
}
Expand Down Expand Up @@ -265,10 +265,10 @@
INVALID_OP_NAME_ARGS = [
'abc',
'123',
'projects/operations/project/1234/model/abc/operation/123',
'operations/project/model/abc/operation/123',
'operations/project/123/model/$#@/operation/123',
'operations/project/1234/model/abc/operation/123/extrathing',
'operations/project/1234/model/abc/operation/123',
'projects/operations/123',
'projects/$#@/operations/123',
'projects/1234/operations/123/extrathing',
]
PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \
'1 and {0}'.format(ml._MAX_PAGE_SIZE)
Expand Down Expand Up @@ -348,9 +348,9 @@ def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _op_url(project_id, model_id):
def _op_url(project_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
'projects/{0}/operations/123'.format(project_id)

def test_model_success_err_state_lro(self):
model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON)
Expand Down Expand Up @@ -534,7 +534,7 @@ def test_wait_for_unlocked(self):
assert model == FULL_MODEL_PUBLISHED
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].url == TestModel._op_url(PROJECT_ID)

def test_wait_for_unlocked_timeout(self):
recorder = instrument_ml_service(
Expand Down Expand Up @@ -564,9 +564,9 @@ def _url(project_id):
return BASE_URL + 'projects/{0}/models'.format(project_id)

@staticmethod
def _op_url(project_id, model_id):
def _op_url(project_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
'projects/{0}/operations/123'.format(project_id)

@staticmethod
def _get_url(project_id, model_id):
Expand Down Expand Up @@ -660,9 +660,9 @@ def _url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
def _op_url(project_id, model_id):
def _op_url(project_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
'projects/{0}/operations/123'.format(project_id)

def test_immediate_done(self):
instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE)
Expand Down Expand Up @@ -765,9 +765,9 @@ def _get_url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
def _op_url(project_id, model_id):
def _op_url(project_id):
return BASE_URL + \
'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id)
'projects/{0}/operations/123'.format(project_id)

@pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS)
def test_immediate_done(self, publish_function, published):
Expand Down