Skip to content

Commit 930b66d

Browse files
authored
TST: Improve runtime of some unit tests (#62968)
1 parent 093586f commit 930b66d

File tree

8 files changed

+50
-67
lines changed

8 files changed

+50
-67
lines changed

pandas/tests/computation/test_eval.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,21 @@ def _eval_single_bin(lhs, cmp1, rhs, engine):
104104
ids=["DataFrame", "Series", "SeriesNaN", "DataFrameNaN", "float"],
105105
)
106106
def lhs(request):
107-
nan_df1 = DataFrame(np.random.default_rng(2).standard_normal((10, 5)))
108-
nan_df1[nan_df1 > 0.5] = np.nan
109-
110-
opts = (
111-
DataFrame(np.random.default_rng(2).standard_normal((10, 5))),
112-
Series(np.random.default_rng(2).standard_normal(5)),
113-
Series([1, 2, np.nan, np.nan, 5]),
114-
nan_df1,
115-
np.random.default_rng(2).standard_normal(),
116-
)
117-
return opts[request.param]
107+
rng = np.random.default_rng(2)
108+
if request.param == 0:
109+
return DataFrame(rng.standard_normal((10, 5)))
110+
elif request.param == 1:
111+
return Series(rng.standard_normal(5))
112+
elif request.param == 2:
113+
return Series([1, 2, np.nan, np.nan, 5])
114+
elif request.param == 3:
115+
nan_df1 = DataFrame(rng.standard_normal((10, 5)))
116+
nan_df1[nan_df1 > 0.5] = np.nan
117+
return nan_df1
118+
elif request.param == 4:
119+
return rng.standard_normal()
120+
else:
121+
raise ValueError(f"{request.param}")
118122

119123

120124
rhs = lhs

pandas/tests/indexing/multiindex/test_indexing_slow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def b(df, cols):
7171
return df.drop_duplicates(subset=cols[:-1])
7272

7373

74+
@pytest.mark.slow
7475
@pytest.mark.filterwarnings("ignore::pandas.errors.PerformanceWarning")
7576
@pytest.mark.parametrize("lexsort_depth", list(range(5)))
7677
@pytest.mark.parametrize("frame_fixture", ["a", "b"])

pandas/tests/io/parser/common/test_chunksize.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,21 @@ def test_chunks_have_consistent_numerical_type(all_parsers, monkeypatch):
229229
assert result.a.dtype == float
230230

231231

232-
def test_warn_if_chunks_have_mismatched_type(all_parsers, using_infer_string):
232+
def test_warn_if_chunks_have_mismatched_type(
233+
all_parsers, using_infer_string, monkeypatch
234+
):
233235
warning_type = None
234236
parser = all_parsers
235-
size = 10000
237+
heuristic = 2**3
238+
size = 10
236239

237240
# see gh-3866: if chunks are different types and can't
238241
# be coerced using numerical types, then issue warning.
239242
if parser.engine == "c" and parser.low_memory:
240243
warning_type = DtypeWarning
241-
# Use larger size to hit warning path
242-
size = 499999
244+
# Use a size to hit warning path dictated by DEFAULT_BUFFER_HEURISTIC
245+
# monkeypatched below
246+
size = heuristic - 1
243247

244248
integers = [str(i) for i in range(size)]
245249
data = "a\n" + "\n".join(integers + ["a", "b"] + integers)
@@ -251,12 +255,14 @@ def test_warn_if_chunks_have_mismatched_type(all_parsers, using_infer_string):
251255
buf,
252256
)
253257
else:
254-
df = parser.read_csv_check_warnings(
255-
warning_type,
256-
r"Columns \(0: a\) have mixed types. "
257-
"Specify dtype option on import or set low_memory=False.",
258-
buf,
259-
)
258+
with monkeypatch.context() as m:
259+
m.setattr(libparsers, "DEFAULT_BUFFER_HEURISTIC", heuristic)
260+
df = parser.read_csv_check_warnings(
261+
warning_type,
262+
r"Columns \(0: a\) have mixed types. "
263+
"Specify dtype option on import or set low_memory=False.",
264+
buf,
265+
)
260266
if parser.engine == "c" and parser.low_memory:
261267
assert df.a.dtype == object
262268
elif using_infer_string:
@@ -295,30 +301,6 @@ def test_empty_with_nrows_chunksize(all_parsers, iterator):
295301
tm.assert_frame_equal(result, expected)
296302

297303

298-
def test_read_csv_memory_growth_chunksize(temp_file, all_parsers):
299-
# see gh-24805
300-
#
301-
# Let's just make sure that we don't crash
302-
# as we iteratively process all chunks.
303-
parser = all_parsers
304-
305-
with open(temp_file, "w", encoding="utf-8") as f:
306-
for i in range(1000):
307-
f.write(str(i) + "\n")
308-
309-
if parser.engine == "pyarrow":
310-
msg = "The 'chunksize' option is not supported with the 'pyarrow' engine"
311-
with pytest.raises(ValueError, match=msg):
312-
with parser.read_csv(temp_file, chunksize=20) as result:
313-
for _ in result:
314-
pass
315-
return
316-
317-
with parser.read_csv(temp_file, chunksize=20) as result:
318-
for _ in result:
319-
pass
320-
321-
322304
def test_chunksize_with_usecols_second_block_shorter(all_parsers):
323305
# GH#21211
324306
parser = all_parsers

pandas/tests/io/parser/test_parse_dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,11 @@ def test_bad_date_parse(all_parsers, cache, value):
265265
)
266266

267267

268-
@pytest.mark.parametrize("value", ["0"])
269-
def test_bad_date_parse_with_warning(all_parsers, cache, value):
268+
def test_bad_date_parse_with_warning(all_parsers, cache):
270269
# if we have an invalid date make sure that we handle this with
271270
# and w/o the cache properly.
272271
parser = all_parsers
273-
s = StringIO((f"{value},\n") * 50000)
272+
s = StringIO(("0,\n") * (start_caching_at + 1))
274273

275274
if parser.engine == "pyarrow":
276275
# pyarrow reads "0" as 0 (of type int64), and so

pandas/tests/libs/test_hashtable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_get_state(self, table_type, dtype):
247247
assert "n_buckets" in state
248248
assert "upper_bound" in state
249249

250-
@pytest.mark.parametrize("N", range(1, 110))
250+
@pytest.mark.parametrize("N", range(1, 110, 4))
251251
def test_no_reallocation(self, table_type, dtype, N):
252252
keys = np.arange(N).astype(dtype)
253253
preallocated_table = table_type(N)
@@ -517,7 +517,7 @@ def test_tracemalloc_for_empty_StringHashTable():
517517
assert get_allocated_khash_memory() == 0
518518

519519

520-
@pytest.mark.parametrize("N", range(1, 110))
520+
@pytest.mark.parametrize("N", range(1, 110, 4))
521521
def test_no_reallocation_StringHashTable(N):
522522
keys = np.arange(N).astype(np.str_).astype(np.object_)
523523
preallocated_table = ht.StringHashTable(N)

pandas/tests/plotting/frame/test_frame.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
date,
55
datetime,
66
)
7-
import gc
87
import itertools
98
import re
109
import string
11-
import weakref
1210

1311
import numpy as np
1412
import pytest
@@ -2173,15 +2171,15 @@ def test_memory_leak(self, kind):
21732171
index=date_range("2000-01-01", periods=10, freq="B"),
21742172
)
21752173

2176-
# Use a weakref so we can see if the object gets collected without
2177-
# also preventing it from being collected
2178-
ref = weakref.ref(df.plot(kind=kind, **args))
2179-
2180-
# have matplotlib delete all the figures
2181-
plt.close("all")
2182-
# force a garbage collection
2183-
gc.collect()
2184-
assert ref() is None
2174+
ax = df.plot(kind=kind, **args)
2175+
# https://github.com/pandas-dev/pandas/issues/9003#issuecomment-70544889
2176+
if kind in ["line", "area"]:
2177+
for i, (cached_data, _, _) in enumerate(ax._plot_data):
2178+
ser = df.iloc[:, i]
2179+
assert not tm.shares_memory(ser, cached_data)
2180+
tm.assert_numpy_array_equal(ser._values, cached_data._values)
2181+
else:
2182+
assert not hasattr(ax, "_plot_data")
21852183

21862184
def test_df_gridspec_patterns_vert_horiz(self):
21872185
# GH 10819

pandas/tests/resample/test_datetime_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def test_nearest_upsample_with_limit(tz_aware_fixture, freq, rule, unit):
526526

527527

528528
def test_resample_ohlc(unit):
529-
index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="Min")
529+
index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 2), freq="Min")
530530
s = Series(range(len(index)), index=index)
531531
s.index.name = "index"
532532
s.index = s.index.as_unit(unit)
@@ -1842,7 +1842,7 @@ def test_resample_equivalent_offsets(n1, freq1, n2, freq2, k, unit):
18421842
# GH 24127
18431843
n1_ = n1 * k
18441844
n2_ = n2 * k
1845-
dti = date_range("1991-09-05", "1991-09-12", freq=freq1).as_unit(unit)
1845+
dti = date_range("1991-09-05", "1991-09-06", freq=freq1).as_unit(unit)
18461846
ser = Series(range(len(dti)), index=dti)
18471847

18481848
result1 = ser.resample(str(n1_) + freq1).mean()

pandas/tests/resample/test_period_index.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_selection(self, freq, kwargs):
130130
def test_annual_upsample_cases(
131131
self, offset, period, conv, meth, month, simple_period_range_series
132132
):
133-
ts = simple_period_range_series("1/1/1990", "12/31/1991", freq=f"Y-{month}")
133+
ts = simple_period_range_series("1/1/1990", "12/31/1990", freq=f"Y-{month}")
134134
warn = FutureWarning if period == "B" else None
135135
msg = r"PeriodDtype\[B\] is deprecated"
136136
with tm.assert_produces_warning(warn, match=msg):
@@ -214,7 +214,7 @@ def test_quarterly_upsample(
214214
self, month, offset, period, convention, simple_period_range_series
215215
):
216216
freq = f"Q-{month}"
217-
ts = simple_period_range_series("1/1/1990", "12/31/1995", freq=freq)
217+
ts = simple_period_range_series("1/1/1990", "12/31/1991", freq=freq)
218218
warn = FutureWarning if period == "B" else None
219219
msg = r"PeriodDtype\[B\] is deprecated"
220220
with tm.assert_produces_warning(warn, match=msg):
@@ -396,8 +396,7 @@ def test_fill_method_and_how_upsample(self):
396396
@pytest.mark.parametrize("convention", ["start", "end"])
397397
def test_weekly_upsample(self, day, target, convention, simple_period_range_series):
398398
freq = f"W-{day}"
399-
ts = simple_period_range_series("1/1/1990", "12/31/1995", freq=freq)
400-
399+
ts = simple_period_range_series("1/1/1990", "07/31/1990", freq=freq)
401400
warn = None if target == "D" else FutureWarning
402401
msg = r"PeriodDtype\[B\] is deprecated"
403402
with tm.assert_produces_warning(warn, match=msg):

0 commit comments

Comments
 (0)