diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 809ba9a41..c6720f081 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -524,11 +524,13 @@ def _tf_convert_from_keras_model(keras_model): return converter.convert() @classmethod - def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: saved_model_dir: The saved model directory. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. bucket_name: The name of an existing bucket. None to use the default bucket configured in the app. app: Optional. A Firebase app instance (or None to use the default app) @@ -541,16 +543,18 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open('firebase_ml_model.tflite', 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_ml_model.tflite', bucket_name, app) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, bucket_name=None, app=None): + def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: keras_model: A tf.keras model. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. bucket_name: The name of an existing bucket. None to use the default bucket configured in the app. app: Optional. A Firebase app instance (or None to use the default app) @@ -563,9 +567,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) - open('firebase_ml_model.tflite', 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_ml_model.tflite', bucket_name, app) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property def gcs_tflite_uri(self):