27
27
import six
28
28
29
29
30
+ from six .moves import urllib
30
31
from firebase_admin import _http_client
31
32
from firebase_admin import _utils
32
33
from firebase_admin import exceptions
@@ -200,6 +201,7 @@ def from_dict(cls, data, app=None):
200
201
data_copy = dict (data )
201
202
tflite_format = None
202
203
tflite_format_data = data_copy .pop ('tfliteModel' , None )
204
+ data_copy .pop ('@type' , None ) # Returned by Operations. (Not needed)
203
205
if tflite_format_data :
204
206
tflite_format = TFLiteFormat .from_dict (tflite_format_data )
205
207
model = Model (model_format = tflite_format )
@@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
495
497
return TFLiteGCSModelSource (gcs_tflite_uri = gcs_uri , app = app )
496
498
497
499
@staticmethod
498
- def _assert_tf_version_1_enabled ():
500
+ def _assert_tf_enabled ():
499
501
if not _TF_ENABLED :
500
502
raise ImportError ('Failed to import the tensorflow library for Python. Make sure '
501
503
'to install the tensorflow module.' )
502
- if not tf .VERSION .startswith ('1.' ):
503
- raise ImportError ('Expected tensorflow version 1.x, but found {0}' .format (tf .VERSION ))
504
+ if not tf .version .VERSION .startswith ('1.' ) and not tf .version .VERSION .startswith ('2.' ):
505
+ raise ImportError ('Expected tensorflow version 1.x or 2.x, but found {0}'
506
+ .format (tf .version .VERSION ))
507
+
508
+ @staticmethod
509
+ def _tf_convert_from_saved_model (saved_model_dir ):
510
+ # Same for both v1.x and v2.x
511
+ converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_dir )
512
+ return converter .convert ()
513
+
514
+ @staticmethod
515
+ def _tf_convert_from_keras_model (keras_model ):
516
+ # Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
517
+ if tf .version .VERSION .startswith ('1.' ):
518
+ keras_file = 'firebase_keras_model.h5'
519
+ tf .keras .models .save_model (keras_model , keras_file )
520
+ converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
521
+ return converter .convert ()
522
+ else :
523
+ converter = tf .lite .TFLiteConverter .from_keras_model (keras_model )
524
+ return converter .convert ()
504
525
505
526
@classmethod
506
527
def from_saved_model (cls , saved_model_dir , bucket_name = None , app = None ):
@@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
518
539
Raises:
519
540
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
520
541
"""
521
- TFLiteGCSModelSource ._assert_tf_version_1_enabled ()
522
- converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_dir )
523
- tflite_model = converter .convert ()
542
+ TFLiteGCSModelSource ._assert_tf_enabled ()
543
+ tflite_model = TFLiteGCSModelSource ._tf_convert_from_saved_model (saved_model_dir )
524
544
open ('firebase_mlkit_model.tflite' , 'wb' ).write (tflite_model )
525
545
return TFLiteGCSModelSource .from_tflite_model_file (
526
546
'firebase_mlkit_model.tflite' , bucket_name , app )
@@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
541
561
Raises:
542
562
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
543
563
"""
544
- TFLiteGCSModelSource ._assert_tf_version_1_enabled ()
545
- keras_file = 'keras_model.h5'
546
- tf .keras .models .save_model (keras_model , keras_file )
547
- converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
548
- tflite_model = converter .convert ()
564
+ TFLiteGCSModelSource ._assert_tf_enabled ()
565
+ tflite_model = TFLiteGCSModelSource ._tf_convert_from_keras_model (keras_model )
549
566
open ('firebase_mlkit_model.tflite' , 'wb' ).write (tflite_model )
550
567
return TFLiteGCSModelSource .from_tflite_model_file (
551
568
'firebase_mlkit_model.tflite' , bucket_name , app )
@@ -810,28 +827,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
810
827
"""
811
828
if not isinstance (operation , dict ):
812
829
raise TypeError ('Operation must be a dictionary.' )
813
- op_name = operation .get ('name' )
814
- _ , model_id = _validate_and_parse_operation_name (op_name )
815
-
816
- current_attempt = 0
817
- start_time = datetime .datetime .now ()
818
- stop_time = (None if max_time_seconds is None else
819
- start_time + datetime .timedelta (seconds = max_time_seconds ))
820
- while wait_for_operation and not operation .get ('done' ):
821
- # We just got this operation. Wait before getting another
822
- # so we don't exceed the GetOperation maximum request rate.
823
- self ._exponential_backoff (current_attempt , stop_time )
824
- operation = self .get_operation (op_name )
825
- current_attempt += 1
826
830
827
831
if operation .get ('done' ):
832
+ # Operations which are immediately done don't have an operation name
828
833
if operation .get ('response' ):
829
834
return operation .get ('response' )
830
835
elif operation .get ('error' ):
831
836
raise _utils .handle_operation_error (operation .get ('error' ))
832
-
833
- # If the operation is not complete or timed out, return a (locked) model instead
834
- return get_model (model_id ).as_dict ()
837
+ raise exceptions .UnknownError (message = 'Internal Error: Malformed Operation.' )
838
+ else :
839
+ op_name = operation .get ('name' )
840
+ _ , model_id = _validate_and_parse_operation_name (op_name )
841
+ current_attempt = 0
842
+ start_time = datetime .datetime .now ()
843
+ stop_time = (None if max_time_seconds is None else
844
+ start_time + datetime .timedelta (seconds = max_time_seconds ))
845
+ while wait_for_operation and not operation .get ('done' ):
846
+ # We just got this operation. Wait before getting another
847
+ # so we don't exceed the GetOperation maximum request rate.
848
+ self ._exponential_backoff (current_attempt , stop_time )
849
+ operation = self .get_operation (op_name )
850
+ current_attempt += 1
851
+
852
+ if operation .get ('done' ):
853
+ if operation .get ('response' ):
854
+ return operation .get ('response' )
855
+ elif operation .get ('error' ):
856
+ raise _utils .handle_operation_error (operation .get ('error' ))
857
+
858
+ # If the operation is not complete or timed out, return a (locked) model instead
859
+ return get_model (model_id ).as_dict ()
835
860
836
861
837
862
def create_model (self , model ):
@@ -844,12 +869,12 @@ def create_model(self, model):
844
869
845
870
def update_model (self , model , update_mask = None ):
846
871
_validate_model (model , update_mask )
847
- data = { 'model' : model . as_dict ( for_upload = True )}
872
+ path = 'models/{0}' . format ( model . model_id )
848
873
if update_mask is not None :
849
- data [ 'updateMask' ] = update_mask
874
+ path = path + '?updateMask={0}' . format ( update_mask )
850
875
try :
851
876
return self .handle_operation (
852
- self ._client .body ('patch' , url = 'models/{0}' . format ( model . model_id ) , json = data ))
877
+ self ._client .body ('patch' , url = path , json = model . as_dict ( for_upload = True ) ))
853
878
except requests .exceptions .RequestException as error :
854
879
raise _utils .handle_platform_error_from_requests (error )
855
880
@@ -876,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token):
876
901
_validate_list_filter (list_filter )
877
902
_validate_page_size (page_size )
878
903
_validate_page_token (page_token )
879
- payload = {}
904
+ params = {}
880
905
if list_filter :
881
- payload [ 'list_filter ' ] = list_filter
906
+ params [ 'filter ' ] = list_filter
882
907
if page_size :
883
- payload ['page_size' ] = page_size
908
+ params ['page_size' ] = page_size
884
909
if page_token :
885
- payload ['page_token' ] = page_token
910
+ params ['page_token' ] = page_token
911
+ path = 'models'
912
+ if params :
913
+ # pylint: disable=too-many-function-args
914
+ param_str = urllib .parse .urlencode (sorted (params .items ()), True )
915
+ path = path + '?' + param_str
886
916
try :
887
- return self ._client .body ('get' , url = 'models' , json = payload )
917
+ return self ._client .body ('get' , url = path )
888
918
except requests .exceptions .RequestException as error :
889
919
raise _utils .handle_platform_error_from_requests (error )
890
920
0 commit comments