3333import json
3434import shutil
3535import hashlib
36- import inspect
3736import logging
3837import tempfile
3938import warnings
4039import contextlib
4140from pathlib import Path
42- from functools import wraps
4341from urllib .request import urlopen
4442
4543import pytest
@@ -83,6 +81,14 @@ def pathify(path):
8381 return Path (path + ext )
8482
8583
84+ def _pytest_pyfunc_call (obj , pyfuncitem ):
85+ testfunction = pyfuncitem .obj
86+ funcargs = pyfuncitem .funcargs
87+ testargs = {arg : funcargs [arg ] for arg in pyfuncitem ._fixtureinfo .argnames }
88+ obj .result = testfunction (** testargs )
89+ return True
90+
91+
8692def pytest_report_header (config , startdir ):
8793 import matplotlib
8894 import matplotlib .ft2font
@@ -211,13 +217,11 @@ def close_mpl_figure(fig):
211217 plt .close (fig )
212218
213219
214- def get_marker (item , marker_name ):
215- if hasattr (item , 'get_closest_marker' ):
216- return item .get_closest_marker (marker_name )
217- else :
218- # "item.keywords.get" was deprecated in pytest 3.6
219- # See https://docs.pytest.org/en/latest/mark.html#updating-code
220- return item .keywords .get (marker_name )
220+ def get_compare (item ):
221+ """
222+ Return the mpl_image_compare marker for the given item.
223+ """
224+ return item .get_closest_marker ("mpl_image_compare" )
221225
222226
223227def path_is_not_none (apath ):
@@ -278,20 +282,14 @@ def __init__(self,
278282 logging .basicConfig (level = level )
279283 self .logger = logging .getLogger ('pytest-mpl' )
280284
281- def get_compare (self , item ):
282- """
283- Return the mpl_image_compare marker for the given item.
284- """
285- return get_marker (item , 'mpl_image_compare' )
286-
287285 def generate_filename (self , item ):
288286 """
289287 Given a pytest item, generate the figure filename.
290288 """
291289 if self .config .getini ('mpl-use-full-test-name' ):
292290 filename = self .generate_test_name (item ) + '.png'
293291 else :
294- compare = self . get_compare (item )
292+ compare = get_compare (item )
295293 # Find test name to use as plot name
296294 filename = compare .kwargs .get ('filename' , None )
297295 if filename is None :
@@ -304,7 +302,11 @@ def generate_test_name(self, item):
304302 """
305303 Generate a unique name for the hash for this test.
306304 """
307- return f"{ item .module .__name__ } .{ item .name } "
305+ if item .cls is not None :
306+ name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
307+ else :
308+ name = f"{ item .module .__name__ } .{ item .name } "
309+ return name
308310
309311 def make_test_results_dir (self , item ):
310312 """
@@ -319,7 +321,7 @@ def baseline_directory_specified(self, item):
319321 """
320322 Returns `True` if a non-default baseline directory is specified.
321323 """
322- compare = self . get_compare (item )
324+ compare = get_compare (item )
323325 item_baseline_dir = compare .kwargs .get ('baseline_dir' , None )
324326 return item_baseline_dir or self .baseline_dir or self .baseline_relative_dir
325327
@@ -330,7 +332,7 @@ def get_baseline_directory(self, item):
330332 Using the global and per-test configuration return the absolute
331333 baseline dir, if the baseline file is local else return base URL.
332334 """
333- compare = self . get_compare (item )
335+ compare = get_compare (item )
334336 baseline_dir = compare .kwargs .get ('baseline_dir' , None )
335337 if baseline_dir is None :
336338 if self .baseline_dir is None :
@@ -394,7 +396,7 @@ def generate_baseline_image(self, item, fig):
394396 """
395397 Generate reference figures.
396398 """
397- compare = self . get_compare (item )
399+ compare = get_compare (item )
398400 savefig_kwargs = compare .kwargs .get ('savefig_kwargs' , {})
399401
400402 if not os .path .exists (self .generate_dir ):
@@ -413,7 +415,7 @@ def generate_image_hash(self, item, fig):
413415 For a `matplotlib.figure.Figure`, returns the SHA256 hash as a hexadecimal
414416 string.
415417 """
416- compare = self . get_compare (item )
418+ compare = get_compare (item )
417419 savefig_kwargs = compare .kwargs .get ('savefig_kwargs' , {})
418420
419421 imgdata = io .BytesIO ()
@@ -436,7 +438,7 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
436438 if summary is None :
437439 summary = {}
438440
439- compare = self . get_compare (item )
441+ compare = get_compare (item )
440442 tolerance = compare .kwargs .get ('tolerance' , 2 )
441443 savefig_kwargs = compare .kwargs .get ('savefig_kwargs' , {})
442444
@@ -510,7 +512,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
510512 if summary is None :
511513 summary = {}
512514
513- compare = self . get_compare (item )
515+ compare = get_compare (item )
514516 savefig_kwargs = compare .kwargs .get ('savefig_kwargs' , {})
515517
516518 if not self .results_hash_library_name :
@@ -582,11 +584,13 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
582584 return
583585 return summary ['status_msg' ]
584586
585- def pytest_runtest_setup (self , item ): # noqa
587+ @pytest .hookimpl (hookwrapper = True )
588+ def pytest_runtest_call (self , item ): # noqa
586589
587- compare = self . get_compare (item )
590+ compare = get_compare (item )
588591
589592 if compare is None :
593+ yield
590594 return
591595
592596 import matplotlib .pyplot as plt
@@ -600,95 +604,82 @@ def pytest_runtest_setup(self, item): # noqa
600604 remove_text = compare .kwargs .get ('remove_text' , False )
601605 backend = compare .kwargs .get ('backend' , 'agg' )
602606
603- original = item .function
604-
605- @wraps (item .function )
606- def item_function_wrapper (* args , ** kwargs ):
607-
608- with plt .style .context (style , after_reset = True ), switch_backend (backend ):
609-
610- # Run test and get figure object
611- if inspect .ismethod (original ): # method
612- # In some cases, for example if setup_method is used,
613- # original appears to belong to an instance of the test
614- # class that is not the same as args[0], and args[0] is the
615- # one that has the correct attributes set up from setup_method
616- # so we ignore original.__self__ and use args[0] instead.
617- fig = original .__func__ (* args , ** kwargs )
618- else : # function
619- fig = original (* args , ** kwargs )
620-
621- if remove_text :
622- remove_ticks_and_titles (fig )
623-
624- test_name = self .generate_test_name (item )
625- result_dir = self .make_test_results_dir (item )
626-
627- summary = {
628- 'status' : None ,
629- 'image_status' : None ,
630- 'hash_status' : None ,
631- 'status_msg' : None ,
632- 'baseline_image' : None ,
633- 'diff_image' : None ,
634- 'rms' : None ,
635- 'tolerance' : None ,
636- 'result_image' : None ,
637- 'baseline_hash' : None ,
638- 'result_hash' : None ,
639- }
640-
641- # What we do now depends on whether we are generating the
642- # reference images or simply running the test.
643- if self .generate_dir is not None :
644- summary ['status' ] = 'skipped'
645- summary ['image_status' ] = 'generated'
646- summary ['status_msg' ] = 'Skipped test, since generating image.'
647- generate_image = self .generate_baseline_image (item , fig )
648- if self .results_always : # Make baseline image available in HTML
649- result_image = (result_dir / "baseline.png" ).absolute ()
650- shutil .copy (generate_image , result_image )
651- summary ['baseline_image' ] = \
652- result_image .relative_to (self .results_dir ).as_posix ()
653-
654- if self .generate_hash_library is not None :
655- summary ['hash_status' ] = 'generated'
656- image_hash = self .generate_image_hash (item , fig )
657- self ._generated_hash_library [test_name ] = image_hash
658- summary ['baseline_hash' ] = image_hash
659-
660- # Only test figures if not generating images
661- if self .generate_dir is None :
662- # Compare to hash library
663- if self .hash_library or compare .kwargs .get ('hash_library' , None ):
664- msg = self .compare_image_to_hash_library (item , fig , result_dir , summary = summary )
665-
666- # Compare against a baseline if specified
667- else :
668- msg = self .compare_image_to_baseline (item , fig , result_dir , summary = summary )
669-
670- close_mpl_figure (fig )
671-
672- if msg is None :
673- if not self .results_always :
674- shutil .rmtree (result_dir )
675- for image_type in ['baseline_image' , 'diff_image' , 'result_image' ]:
676- summary [image_type ] = None # image no longer exists
677- else :
678- self ._test_results [test_name ] = summary
679- pytest .fail (msg , pytrace = False )
607+ with plt .style .context (style , after_reset = True ), switch_backend (backend ):
608+
609+ # Run test and get figure object
610+ yield
611+ fig = self .result
612+
613+ if remove_text :
614+ remove_ticks_and_titles (fig )
615+
616+ test_name = self .generate_test_name (item )
617+ result_dir = self .make_test_results_dir (item )
618+
619+ summary = {
620+ 'status' : None ,
621+ 'image_status' : None ,
622+ 'hash_status' : None ,
623+ 'status_msg' : None ,
624+ 'baseline_image' : None ,
625+ 'diff_image' : None ,
626+ 'rms' : None ,
627+ 'tolerance' : None ,
628+ 'result_image' : None ,
629+ 'baseline_hash' : None ,
630+ 'result_hash' : None ,
631+ }
632+
633+ # What we do now depends on whether we are generating the
634+ # reference images or simply running the test.
635+ if self .generate_dir is not None :
636+ summary ['status' ] = 'skipped'
637+ summary ['image_status' ] = 'generated'
638+ summary ['status_msg' ] = 'Skipped test, since generating image.'
639+ generate_image = self .generate_baseline_image (item , fig )
640+ if self .results_always : # Make baseline image available in HTML
641+ result_image = (result_dir / "baseline.png" ).absolute ()
642+ shutil .copy (generate_image , result_image )
643+ summary ['baseline_image' ] = \
644+ result_image .relative_to (self .results_dir ).as_posix ()
645+
646+ if self .generate_hash_library is not None :
647+ summary ['hash_status' ] = 'generated'
648+ image_hash = self .generate_image_hash (item , fig )
649+ self ._generated_hash_library [test_name ] = image_hash
650+ summary ['baseline_hash' ] = image_hash
651+
652+ # Only test figures if not generating images
653+ if self .generate_dir is None :
654+ # Compare to hash library
655+ if self .hash_library or compare .kwargs .get ('hash_library' , None ):
656+ msg = self .compare_image_to_hash_library (item , fig , result_dir , summary = summary )
657+
658+ # Compare against a baseline if specified
659+ else :
660+ msg = self .compare_image_to_baseline (item , fig , result_dir , summary = summary )
680661
681662 close_mpl_figure (fig )
682663
683- self ._test_results [test_name ] = summary
664+ if msg is None :
665+ if not self .results_always :
666+ shutil .rmtree (result_dir )
667+ for image_type in ['baseline_image' , 'diff_image' , 'result_image' ]:
668+ summary [image_type ] = None # image no longer exists
669+ else :
670+ self ._test_results [test_name ] = summary
671+ pytest .fail (msg , pytrace = False )
684672
685- if summary ['status' ] == 'skipped' :
686- pytest .skip (summary ['status_msg' ])
673+ close_mpl_figure (fig )
687674
688- if item .cls is not None :
689- setattr (item .cls , item .function .__name__ , item_function_wrapper )
690- else :
691- item .obj = item_function_wrapper
675+ self ._test_results [test_name ] = summary
676+
677+ if summary ['status' ] == 'skipped' :
678+ pytest .skip (summary ['status_msg' ])
679+
680+ @pytest .hookimpl (tryfirst = True )
681+ def pytest_pyfunc_call (self , pyfuncitem ):
682+ return _pytest_pyfunc_call (self , pyfuncitem )
692683
693684 def generate_summary_json (self ):
694685 json_file = self .results_dir / 'results.json'
@@ -742,26 +733,12 @@ class FigureCloser:
742733 def __init__ (self , config ):
743734 self .config = config
744735
745- def pytest_runtest_setup (self , item ):
746-
747- compare = get_marker (item , 'mpl_image_compare' )
748-
749- if compare is None :
750- return
751-
752- original = item .function
753-
754- @wraps (item .function )
755- def item_function_wrapper (* args , ** kwargs ):
756-
757- if inspect .ismethod (original ): # method
758- fig = original .__func__ (* args , ** kwargs )
759- else : # function
760- fig = original (* args , ** kwargs )
761-
762- close_mpl_figure (fig )
736+ @pytest .hookimpl (hookwrapper = True )
737+ def pytest_runtest_call (self , item ):
738+ yield
739+ if get_compare (item ) is not None :
740+ close_mpl_figure (self .result )
763741
764- if item .cls is not None :
765- setattr (item .cls , item .function .__name__ , item_function_wrapper )
766- else :
767- item .obj = item_function_wrapper
742+ @pytest .hookimpl (tryfirst = True )
743+ def pytest_pyfunc_call (self , pyfuncitem ):
744+ return _pytest_pyfunc_call (self , pyfuncitem )
0 commit comments