2222
2323import pytest
2424
25+ import firebase_admin
2526from firebase_admin import exceptions
2627from firebase_admin import ml
2728from tests import testutils
3435except ImportError :
3536 _TF_ENABLED = False
3637
38+ try :
39+ from google .cloud import automl_v1
40+ _AUTOML_ENABLED = True
41+ except ImportError :
42+ _AUTOML_ENABLED = False
3743
3844def _random_identifier (prefix ):
3945 #pylint: disable=unused-variable
@@ -62,7 +68,6 @@ def _random_identifier(prefix):
6268 'file_name' : 'invalid_model.tflite'
6369}
6470
65-
6671@pytest .fixture
6772def firebase_model (request ):
6873 args = request .param
@@ -101,6 +106,7 @@ def _clean_up_model(model):
101106 try :
102107 # Try to delete the model.
103108 # Some tests delete the model as part of the test.
109+ model .wait_for_unlocked ()
104110 ml .delete_model (model .model_id )
105111 except exceptions .NotFoundError :
106112 pass
@@ -132,35 +138,45 @@ def check_model(model, args):
132138 assert model .locked is False
133139 assert model .etag is not None
134140
141+ # Model Format Checks
135142
136- def check_model_format (model , has_model_format = False , validation_error = None ):
137- if has_model_format :
138- assert model .validation_error == validation_error
139- assert model .published is False
140- assert model .model_format .model_source .gcs_tflite_uri .startswith ('gs://' )
141- if validation_error :
142- assert model .model_format .size_bytes is None
143- assert model .model_hash is None
144- else :
145- assert model .model_format .size_bytes is not None
146- assert model .model_hash is not None
147- else :
148- assert model .model_format is None
149- assert model .validation_error == 'No model file has been uploaded.'
150- assert model .published is False
143+ def check_no_model_format (model ):
144+ assert model .model_format is None
145+ assert model .validation_error == 'No model file has been uploaded.'
146+ assert model .published is False
147+ assert model .model_hash is None
148+
149+
150+ def check_tflite_gcs_format (model , validation_error = None ):
151+ assert model .validation_error == validation_error
152+ assert model .published is False
153+ assert model .model_format .model_source .gcs_tflite_uri .startswith ('gs://' )
154+ if validation_error :
155+ assert model .model_format .size_bytes is None
151156 assert model .model_hash is None
157+ else :
158+ assert model .model_format .size_bytes is not None
159+ assert model .model_hash is not None
160+
161+
162+ def check_tflite_automl_format (model ):
163+ assert model .validation_error is None
164+ assert model .published is False
165+ assert model .model_format .model_source .auto_ml_model .startswith ('projects/' )
166+ # Automl models don't have validation errors since they are references
167+ # to valid automl models.
152168
153169
154170@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
155171def test_create_simple_model (firebase_model ):
156172 check_model (firebase_model , NAME_AND_TAGS_ARGS )
157- check_model_format (firebase_model )
173+ check_no_model_format (firebase_model )
158174
159175
160176@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
161177def test_create_full_model (firebase_model ):
162178 check_model (firebase_model , FULL_MODEL_ARGS )
163- check_model_format (firebase_model , True )
179+ check_tflite_gcs_format (firebase_model )
164180
165181
166182@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
@@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model):
175191@pytest .mark .parametrize ('firebase_model' , [INVALID_FULL_MODEL_ARGS ], indirect = True )
176192def test_create_invalid_model (firebase_model ):
177193 check_model (firebase_model , INVALID_FULL_MODEL_ARGS )
178- check_model_format (firebase_model , True , 'Invalid flatbuffer format' )
194+ check_tflite_gcs_format (firebase_model , 'Invalid flatbuffer format' )
179195
180196
181197@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
182198def test_get_model (firebase_model ):
183199 get_model = ml .get_model (firebase_model .model_id )
184200 check_model (get_model , NAME_AND_TAGS_ARGS )
185- check_model_format (get_model )
201+ check_no_model_format (get_model )
186202
187203
188204@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -201,12 +217,12 @@ def test_update_model(firebase_model):
201217 firebase_model .display_name = new_model_name
202218 updated_model = ml .update_model (firebase_model )
203219 check_model (updated_model , NAME_ONLY_ARGS_UPDATED )
204- check_model_format (updated_model )
220+ check_no_model_format (updated_model )
205221
206222 # Second call with same model does not cause error
207223 updated_model2 = ml .update_model (updated_model )
208224 check_model (updated_model2 , NAME_ONLY_ARGS_UPDATED )
209- check_model_format (updated_model2 )
225+ check_no_model_format (updated_model2 )
210226
211227
212228@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -290,7 +306,7 @@ def test_delete_model(firebase_model):
290306
291307# Test tensor flow conversion functions if tensor flow is enabled.
292308#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
293- #'pip install tensorflow==2.0.0b ' for version 2 etc.
309+ #'pip install tensorflow==2.2.0 ' for version 2.2.0 etc.
294310
295311
296312def _clean_up_directory (save_dir ):
@@ -334,6 +350,7 @@ def saved_model_dir(keras_model):
334350 _clean_up_directory (parent )
335351
336352
353+
337354@pytest .mark .skipif (not _TF_ENABLED , reason = 'Tensor flow is required for this test.' )
338355def test_from_keras_model (keras_model ):
339356 source = ml .TFLiteGCSModelSource .from_keras_model (keras_model , 'model2.tflite' )
@@ -348,7 +365,7 @@ def test_from_keras_model(keras_model):
348365
349366 try :
350367 check_model (created_model , {'display_name' : model .display_name })
351- check_model_format (created_model , True )
368+ check_tflite_gcs_format (created_model )
352369 finally :
353370 _clean_up_model (created_model )
354371
@@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir):
371388 assert created_model .validation_error is None
372389 finally :
373390 _clean_up_model (created_model )
391+
392+
393+ # Test AutoML functionality if AutoML is enabled.
394+ #'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
395+ # You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
396+ # successful test. (Test is skipped otherwise)
397+
398+ @pytest .fixture
399+ def automl_model ():
400+ assert _AUTOML_ENABLED
401+
402+ # It takes > 20 minutes to train a model, so we expect a predefined AutoMl
403+ # model named 'admin_sdk_integ_test1' to exist in the project, or we skip
404+ # the test.
405+ automl_client = automl_v1 .AutoMlClient ()
406+ project_id = firebase_admin .get_app ().project_id
407+ parent = automl_client .location_path (project_id , 'us-central1' )
408+ models = automl_client .list_models (parent , filter_ = "display_name=admin_sdk_integ_test1" )
409+ # Expecting exactly one. (Ok to use last one if somehow more than 1)
410+ automl_ref = None
411+ for model in models :
412+ automl_ref = model .name
413+
414+ # Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
415+ if automl_ref is None :
416+ pytest .skip ("No pre-existing AutoML model found. Skipping test" )
417+
418+ source = ml .TFLiteAutoMlSource (automl_ref )
419+ tflite_format = ml .TFLiteFormat (model_source = source )
420+ ml_model = ml .Model (
421+ display_name = _random_identifier ('TestModel_automl_' ),
422+ tags = ['test_automl' ],
423+ model_format = tflite_format )
424+ model = ml .create_model (model = ml_model )
425+ yield model
426+ _clean_up_model (model )
427+
428+ @pytest .mark .skipif (not _AUTOML_ENABLED , reason = 'AutoML is required for this test.' )
429+ def test_automl_model (automl_model ):
430+ # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
431+ automl_model .wait_for_unlocked ()
432+
433+ check_model (automl_model , {
434+ 'display_name' : automl_model .display_name ,
435+ 'tags' : ['test_automl' ],
436+ })
437+ check_tflite_automl_format (automl_model )
0 commit comments