Skip to content

Commit f4062be

Browse files
committed
chore(simulate_data.py): Adding type hints and doing some v light refactoring
1 parent 828ba2e commit f4062be

File tree

1 file changed

+66
-30
lines changed

1 file changed

+66
-30
lines changed

causalpy/data/simulate_data.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,24 @@
1515
Functions that generate data sets used in examples
1616
"""
1717

18+
from typing import Any
19+
1820
import numpy as np
1921
import pandas as pd
2022
from scipy.stats import dirichlet, gamma, norm, uniform
2123
from statsmodels.nonparametric.smoothers_lowess import lowess
2224

23-
default_lowess_kwargs = {"frac": 0.2, "it": 0}
24-
RANDOM_SEED = 8927
25-
rng = np.random.default_rng(RANDOM_SEED)
25+
default_lowess_kwargs: dict[str, float] = {"frac": 0.2, "it": 0}
26+
RANDOM_SEED: int = 8927
27+
rng: np.random.Generator = np.random.default_rng(RANDOM_SEED)
2628

2729

2830
def _smoothed_gaussian_random_walk(
29-
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
30-
):
31+
gaussian_random_walk_mu: float,
32+
gaussian_random_walk_sigma: float,
33+
N: int,
34+
lowess_kwargs: dict[str, Any],
35+
) -> tuple[np.ndarray, np.ndarray]:
3136
"""
3237
Generates Gaussian random walk data and applies LOWESS
3338
@@ -48,12 +53,12 @@ def _smoothed_gaussian_random_walk(
4853

4954

5055
def generate_synthetic_control_data(
51-
N=100,
52-
treatment_time=70,
53-
grw_mu=0.25,
54-
grw_sigma=1,
55-
lowess_kwargs=default_lowess_kwargs,
56-
):
56+
N: int = 100,
57+
treatment_time: int = 70,
58+
grw_mu: float = 0.25,
59+
grw_sigma: float = 1,
60+
lowess_kwargs: dict[str, Any] | None = None,
61+
) -> tuple[pd.DataFrame, np.ndarray]:
5762
"""
5863
Generates data for synthetic control example.
5964
@@ -73,6 +78,8 @@ def generate_synthetic_control_data(
7378
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
7479
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
7580
"""
81+
if lowess_kwargs is None:
82+
lowess_kwargs = default_lowess_kwargs
7683

7784
# 1. Generate non-treated variables
7885
df = pd.DataFrame(
@@ -108,8 +115,12 @@ def generate_synthetic_control_data(
108115

109116

110117
def generate_time_series_data(
111-
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
112-
):
118+
N: int = 100,
119+
treatment_time: int = 70,
120+
beta_temp: float = -1,
121+
beta_linear: float = 0.5,
122+
beta_intercept: float = 3,
123+
) -> pd.DataFrame:
113124
"""
114125
Generates interrupted time series example data
115126
@@ -155,7 +166,7 @@ def generate_time_series_data(
155166
return df
156167

157168

158-
def generate_time_series_data_seasonal(treatment_time):
169+
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
159170
"""
160171
Generates 10 years of monthly data with seasonality
161172
"""
@@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
183194
return df
184195

185196

186-
def generate_time_series_data_simple(treatment_time, slope=0.0):
197+
def generate_time_series_data_simple(
198+
treatment_time: pd.Timestamp, slope: float = 0.0
199+
) -> pd.DataFrame:
187200
"""Generate simple interrupted time series data, with no seasonality or temporal
188201
structure.
189202
"""
@@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205218
return df
206219

207220

208-
def generate_did():
221+
def generate_did() -> pd.DataFrame:
209222
"""
210223
Generate Difference in Differences data
211224
@@ -223,8 +236,14 @@ def generate_did():
223236

224237
# local functions
225238
def outcome(
226-
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
227-
):
239+
t: np.ndarray,
240+
control_intercept: float,
241+
treat_intercept_delta: float,
242+
trend: float,
243+
Δ: float,
244+
group: np.ndarray,
245+
post_treatment: np.ndarray,
246+
) -> np.ndarray:
228247
"""Compute the outcome of each unit"""
229248
return (
230249
control_intercept
@@ -257,8 +276,8 @@ def outcome(
257276

258277

259278
def generate_regression_discontinuity_data(
260-
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
261-
):
279+
N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0
280+
) -> pd.DataFrame:
262281
"""
263282
Generate regression discontinuity example data
264283
@@ -272,12 +291,12 @@ def generate_regression_discontinuity_data(
272291
... ) # doctest: +SKIP
273292
"""
274293

275-
def is_treated(x):
294+
def is_treated(x: np.ndarray) -> np.ndarray:
276295
"""Check if x was treated"""
277296
return np.greater_equal(x, true_treatment_threshold)
278297

279-
def impact(x):
280-
"""Assign true_causal_impact to all treaated entries"""
298+
def impact(x: np.ndarray) -> np.ndarray:
299+
"""Assign true_causal_impact to all treated entries"""
281300
y = np.zeros(len(x))
282301
y[is_treated(x)] = true_causal_impact
283302
return y
@@ -289,8 +308,11 @@ def impact(x):
289308

290309

291310
def generate_ancova_data(
292-
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
293-
):
311+
N: int = 200,
312+
pre_treatment_means: np.ndarray = np.array([10, 12]),
313+
treatment_effect: float = 2,
314+
sigma: float = 1,
315+
) -> pd.DataFrame:
294316
"""
295317
Generate ANCOVA example data
296318
@@ -310,7 +332,7 @@ def generate_ancova_data(
310332
return df
311333

312334

313-
def generate_geolift_data():
335+
def generate_geolift_data() -> pd.DataFrame:
314336
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
315337
countries. The treated unit `Denmark` is a weighted combination of the untreated
316338
units. We additionally specify a treatment effect which takes effect after the
@@ -360,7 +382,7 @@ def generate_geolift_data():
360382
return df
361383

362384

363-
def generate_multicell_geolift_data():
385+
def generate_multicell_geolift_data() -> pd.DataFrame:
364386
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
365387
countries. The treated unit `Denmark` is a weighted combination of the untreated
366388
units. We additionally specify a treatment effect which takes effect after the
@@ -422,7 +444,9 @@ def generate_multicell_geolift_data():
422444
# -----------------
423445

424446

425-
def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
447+
def generate_seasonality(
448+
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
449+
) -> np.ndarray:
426450
"""Generate monthly seasonality by sampling from a Gaussian process with a
427451
Gaussian kernel, using numpy code"""
428452
# Generate the covariance matrix
@@ -436,14 +460,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436460
return seasonality
437461

438462

439-
def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
463+
def periodic_kernel(
464+
x1: np.ndarray,
465+
x2: np.ndarray,
466+
period: float = 1,
467+
length_scale: float = 1,
468+
amplitude: float = 1,
469+
) -> np.ndarray:
440470
"""Generate a periodic kernel for gaussian process"""
441471
return amplitude**2 * np.exp(
442472
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
443473
)
444474

445475

446-
def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
476+
def create_series(
477+
n: int = 52,
478+
amplitude: float = 1,
479+
length_scale: float = 2,
480+
n_years: int = 4,
481+
intercept: float = 3,
482+
) -> np.ndarray:
447483
"""
448484
Returns numpy tile with generated seasonality data repeated over
449485
multiple years

0 commit comments

Comments
 (0)