@@ -81,12 +81,33 @@ def pathify(path):
8181 return Path (path + ext )
8282
8383
84- def _pytest_pyfunc_call (obj , pyfuncitem ):
85- testfunction = pyfuncitem .obj
86- funcargs = pyfuncitem .funcargs
87- testargs = {arg : funcargs [arg ] for arg in pyfuncitem ._fixtureinfo .argnames }
88- obj .result = testfunction (** testargs )
89- return True
84+ def generate_test_name (item ):
85+ """
86+ Generate a unique name for the hash for this test.
87+ """
88+ if item .cls is not None :
89+ name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
90+ else :
91+ name = f"{ item .module .__name__ } .{ item .name } "
92+ return name
93+
94+
95+ def wrap_figure_interceptor (plugin , item ):
96+ """
97+ Intercept and store figures returned by test functions.
98+ """
99+ # Only intercept figures on marked figure tests
100+ if get_compare (item ) is not None :
101+
102+ # Use the full test name as a key to ensure correct figure is being retrieved
103+ test_name = generate_test_name (item )
104+
105+ def figure_interceptor (store , obj ):
106+ def wrapper (* args , ** kwargs ):
107+ store .return_value [test_name ] = obj (* args , ** kwargs )
108+ return wrapper
109+
110+ item .obj = figure_interceptor (plugin , item .obj )
90111
91112
92113def pytest_report_header (config , startdir ):
@@ -275,6 +296,7 @@ def __init__(self,
275296 self ._generated_hash_library = {}
276297 self ._test_results = {}
277298 self ._test_stats = None
299+ self .return_value = {}
278300
279301 # https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
280302 # turn debug prints on only if "-vv" or more passed
@@ -287,7 +309,7 @@ def generate_filename(self, item):
287309 Given a pytest item, generate the figure filename.
288310 """
289311 if self .config .getini ('mpl-use-full-test-name' ):
290- filename = self . generate_test_name (item ) + '.png'
312+ filename = generate_test_name (item ) + '.png'
291313 else :
292314 compare = get_compare (item )
293315 # Find test name to use as plot name
@@ -298,21 +320,11 @@ def generate_filename(self, item):
298320 filename = str (pathify (filename ))
299321 return filename
300322
301- def generate_test_name (self , item ):
302- """
303- Generate a unique name for the hash for this test.
304- """
305- if item .cls is not None :
306- name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
307- else :
308- name = f"{ item .module .__name__ } .{ item .name } "
309- return name
310-
311323 def make_test_results_dir (self , item ):
312324 """
313325 Generate the directory to put the results in.
314326 """
315- test_name = pathify (self . generate_test_name (item ))
327+ test_name = pathify (generate_test_name (item ))
316328 results_dir = self .results_dir / test_name
317329 results_dir .mkdir (exist_ok = True , parents = True )
318330 return results_dir
@@ -526,7 +538,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
526538 pytest .fail (f"Can't find hash library at path { hash_library_filename } " )
527539
528540 hash_library = self .load_hash_library (hash_library_filename )
529- hash_name = self . generate_test_name (item )
541+ hash_name = generate_test_name (item )
530542 baseline_hash = hash_library .get (hash_name , None )
531543 summary ['baseline_hash' ] = baseline_hash
532544
@@ -607,13 +619,17 @@ def pytest_runtest_call(self, item): # noqa
607619 with plt .style .context (style , after_reset = True ), switch_backend (backend ):
608620
609621 # Run test and get figure object
622+ wrap_figure_interceptor (self , item )
610623 yield
611- fig = self .result
624+ test_name = generate_test_name (item )
625+ if test_name not in self .return_value :
626+ # Test function did not complete successfully
627+ return
628+ fig = self .return_value [test_name ]
612629
613630 if remove_text :
614631 remove_ticks_and_titles (fig )
615632
616- test_name = self .generate_test_name (item )
617633 result_dir = self .make_test_results_dir (item )
618634
619635 summary = {
@@ -677,10 +693,6 @@ def pytest_runtest_call(self, item): # noqa
677693 if summary ['status' ] == 'skipped' :
678694 pytest .skip (summary ['status_msg' ])
679695
680- @pytest .hookimpl (tryfirst = True )
681- def pytest_pyfunc_call (self , pyfuncitem ):
682- return _pytest_pyfunc_call (self , pyfuncitem )
683-
684696 def generate_summary_json (self ):
685697 json_file = self .results_dir / 'results.json'
686698 with open (json_file , 'w' ) as f :
@@ -732,13 +744,16 @@ class FigureCloser:
732744
733745 def __init__ (self , config ):
734746 self .config = config
747+ self .return_value = {}
735748
736749 @pytest .hookimpl (hookwrapper = True )
737750 def pytest_runtest_call (self , item ):
751+ wrap_figure_interceptor (self , item )
738752 yield
739753 if get_compare (item ) is not None :
740- close_mpl_figure (self .result )
741-
742- @pytest .hookimpl (tryfirst = True )
743- def pytest_pyfunc_call (self , pyfuncitem ):
744- return _pytest_pyfunc_call (self , pyfuncitem )
754+ test_name = generate_test_name (item )
755+ if test_name not in self .return_value :
756+ # Test function did not complete successfully
757+ return
758+ fig = self .return_value [test_name ]
759+ close_mpl_figure (fig )
0 commit comments