@@ -221,6 +221,37 @@ def pytest_configure(config):
221221 reference_dir = reference_dir ,
222222 generate_dir = generate_dir ,
223223 default_format = default_format ))
224+ else :
225+ config .pluginmanager .register (ArrayInterceptor (config ))
226+
227+
228+ def generate_test_name (item ):
229+ """
230+ Generate a unique name for this test.
231+ """
232+ if item .cls is not None :
233+ name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
234+ else :
235+ name = f"{ item .module .__name__ } .{ item .name } "
236+ return name
237+
238+
239+ def wrap_array_interceptor (plugin , item ):
240+ """
241+ Intercept and store arrays returned by test functions.
242+ """
243+ # Only intercept array on marked array tests
244+ if item .get_closest_marker ('array_compare' ) is not None :
245+
246+ # Use the full test name as a key to ensure correct array is being retrieved
247+ test_name = generate_test_name (item )
248+
249+ def array_interceptor (store , obj ):
250+ def wrapper (* args , ** kwargs ):
251+ store .return_value [test_name ] = obj (* args , ** kwargs )
252+ return wrapper
253+
254+ item .obj = array_interceptor (plugin , item .obj )
224255
225256
226257class ArrayComparison (object ):
@@ -230,12 +261,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format
230261 self .reference_dir = reference_dir
231262 self .generate_dir = generate_dir
232263 self .default_format = default_format
264+ self .return_value = {}
233265
234- def pytest_runtest_setup (self , item ):
266+ @pytest .hookimpl (hookwrapper = True )
267+ def pytest_runtest_call (self , item ):
235268
236269 compare = item .get_closest_marker ('array_compare' )
237270
238271 if compare is None :
272+ yield
239273 return
240274
241275 file_format = compare .kwargs .get ('file_format' , self .default_format )
@@ -255,85 +289,95 @@ def pytest_runtest_setup(self, item):
255289
256290 write_kwargs = compare .kwargs .get ('write_kwargs' , {})
257291
258- original = item .function
292+ reference_dir = compare .kwargs .get ('reference_dir' , None )
293+ if reference_dir is None :
294+ if self .reference_dir is None :
295+ reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), 'reference' )
296+ else :
297+ reference_dir = self .reference_dir
298+ else :
299+ if not reference_dir .startswith (('http://' , 'https://' )):
300+ reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir )
259301
260- @wraps (item .function )
261- def item_function_wrapper (* args , ** kwargs ):
302+ baseline_remote = reference_dir .startswith ('http' )
262303
263- reference_dir = compare .kwargs .get ('reference_dir' , None )
264- if reference_dir is None :
265- if self .reference_dir is None :
266- reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), 'reference' )
267- else :
268- reference_dir = self .reference_dir
269- else :
270- if not reference_dir .startswith (('http://' , 'https://' )):
271- reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir )
272-
273- baseline_remote = reference_dir .startswith ('http' )
274-
275- # Run test and get figure object
276- import inspect
277- if inspect .ismethod (original ): # method
278- array = original (* args [1 :], ** kwargs )
279- else : # function
280- array = original (* args , ** kwargs )
281-
282- # Find test name to use as plot name
283- filename = compare .kwargs .get ('filename' , None )
284- if filename is None :
285- if single_reference :
286- filename = original .__name__ + '.' + extension
287- else :
288- filename = item .name + '.' + extension
289- filename = filename .replace ('[' , '_' ).replace (']' , '_' )
290- filename = filename .replace ('_.' + extension , '.' + extension )
291-
292- # What we do now depends on whether we are generating the reference
293- # files or simply running the test.
294- if self .generate_dir is None :
295-
296- # Save the figure
297- result_dir = tempfile .mkdtemp ()
298- test_array = os .path .abspath (os .path .join (result_dir , filename ))
299-
300- FORMATS [file_format ].write (test_array , array , ** write_kwargs )
301-
302- # Find path to baseline array
303- if baseline_remote :
304- baseline_file_ref = _download_file (reference_dir + filename )
305- else :
306- baseline_file_ref = os .path .abspath (os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir , filename ))
307-
308- if not os .path .exists (baseline_file_ref ):
309- raise Exception ("""File not found for comparison test
310- Generated file:
311- \t {test}
312- This is expected for new tests.""" .format (
313- test = test_array ))
314-
315- # setuptools may put the baseline arrays in non-accessible places,
316- # copy to our tmpdir to be sure to keep them in case of failure
317- baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
318- shutil .copyfile (baseline_file_ref , baseline_file )
319-
320- identical , msg = FORMATS [file_format ].compare (baseline_file , test_array , atol = atol , rtol = rtol )
321-
322- if identical :
323- shutil .rmtree (result_dir )
324- else :
325- raise Exception (msg )
304+ # Run test and get array object
305+ wrap_array_interceptor (self , item )
306+ yield
307+ test_name = generate_test_name (item )
308+ if test_name not in self .return_value :
309+ # Test function did not complete successfully
310+ return
311+ array = self .return_value [test_name ]
312+
313+ # Find test name to use as plot name
314+ filename = compare .kwargs .get ('filename' , None )
315+ if filename is None :
316+ filename = item .name + '.' + extension
317+ if not single_reference :
318+ filename = filename .replace ('[' , '_' ).replace (']' , '_' )
319+ filename = filename .replace ('_.' + extension , '.' + extension )
320+
321+ # What we do now depends on whether we are generating the reference
322+ # files or simply running the test.
323+ if self .generate_dir is None :
324+
325+ # Save the figure
326+ result_dir = tempfile .mkdtemp ()
327+ test_array = os .path .abspath (os .path .join (result_dir , filename ))
326328
329+ FORMATS [file_format ].write (test_array , array , ** write_kwargs )
330+
331+ # Find path to baseline array
332+ if baseline_remote :
333+ baseline_file_ref = _download_file (reference_dir + filename )
327334 else :
335+ baseline_file_ref = os .path .abspath (os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir , filename ))
336+
337+ if not os .path .exists (baseline_file_ref ):
338+ raise Exception ("""File not found for comparison test
339+ Generated file:
340+ \t {test}
341+ This is expected for new tests.""" .format (
342+ test = test_array ))
328343
329- if not os .path .exists (self .generate_dir ):
330- os .makedirs (self .generate_dir )
344+ # setuptools may put the baseline arrays in non-accessible places,
345+ # copy to our tmpdir to be sure to keep them in case of failure
346+ baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
347+ shutil .copyfile (baseline_file_ref , baseline_file )
331348
332- FORMATS [file_format ].write ( os . path . abspath ( os . path . join ( self . generate_dir , filename )), array , ** write_kwargs )
349+ identical , msg = FORMATS [file_format ].compare ( baseline_file , test_array , atol = atol , rtol = rtol )
333350
334- pytest .skip ("Skipping test, since generating data" )
351+ if identical :
352+ shutil .rmtree (result_dir )
353+ else :
354+ raise Exception (msg )
335355
336- if item .cls is not None :
337- setattr (item .cls , item .function .__name__ , item_function_wrapper )
338356 else :
339- item .obj = item_function_wrapper
357+
358+ if not os .path .exists (self .generate_dir ):
359+ os .makedirs (self .generate_dir )
360+
361+ FORMATS [file_format ].write (os .path .abspath (os .path .join (self .generate_dir , filename )), array , ** write_kwargs )
362+
363+ pytest .skip ("Skipping test, since generating data" )
364+
365+
366+ class ArrayInterceptor :
367+ """
368+ This is used in place of ArrayComparison when the array comparison option is not used,
369+ to make sure that we still intercept arrays returned by tests.
370+ """
371+
372+ def __init__ (self , config ):
373+ self .config = config
374+ self .return_value = {}
375+
376+ @pytest .hookimpl (hookwrapper = True )
377+ def pytest_runtest_call (self , item ):
378+
379+ if item .get_closest_marker ('array_compare' ) is not None :
380+ wrap_array_interceptor (self , item )
381+
382+ yield
383+ return
0 commit comments