Skip to content

Commit 3b22a00

Browse files
committed
Revert "chore(simulate_data.py): Adding type hints and doing some v light refactoring"
This reverts commit dbe8d90.
1 parent dbe8d90 commit 3b22a00

File tree

1 file changed

+30
-66
lines changed

1 file changed

+30
-66
lines changed

causalpy/data/simulate_data.py

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,19 @@
1515
Functions that generate data sets used in examples
1616
"""
1717

18-
from typing import Any
19-
2018
import numpy as np
2119
import pandas as pd
2220
from scipy.stats import dirichlet, gamma, norm, uniform
2321
from statsmodels.nonparametric.smoothers_lowess import lowess
2422

25-
default_lowess_kwargs: dict[str, float] = {"frac": 0.2, "it": 0}
26-
RANDOM_SEED: int = 8927
27-
rng: np.random.Generator = np.random.default_rng(RANDOM_SEED)
23+
default_lowess_kwargs = {"frac": 0.2, "it": 0}
24+
RANDOM_SEED = 8927
25+
rng = np.random.default_rng(RANDOM_SEED)
2826

2927

3028
def _smoothed_gaussian_random_walk(
31-
gaussian_random_walk_mu: float,
32-
gaussian_random_walk_sigma: float,
33-
N: int,
34-
lowess_kwargs: dict[str, Any],
35-
) -> tuple[np.ndarray, np.ndarray]:
29+
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
30+
):
3631
"""
3732
Generates Gaussian random walk data and applies LOWESS
3833
@@ -53,12 +48,12 @@ def _smoothed_gaussian_random_walk(
5348

5449

5550
def generate_synthetic_control_data(
56-
N: int = 100,
57-
treatment_time: int = 70,
58-
grw_mu: float = 0.25,
59-
grw_sigma: float = 1,
60-
lowess_kwargs: dict[str, Any] | None = None,
61-
) -> tuple[pd.DataFrame, np.ndarray]:
51+
N=100,
52+
treatment_time=70,
53+
grw_mu=0.25,
54+
grw_sigma=1,
55+
lowess_kwargs=default_lowess_kwargs,
56+
):
6257
"""
6358
Generates data for synthetic control example.
6459
@@ -78,8 +73,6 @@ def generate_synthetic_control_data(
7873
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
7974
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
8075
"""
81-
if lowess_kwargs is None:
82-
lowess_kwargs = default_lowess_kwargs
8376

8477
# 1. Generate non-treated variables
8578
df = pd.DataFrame(
@@ -115,12 +108,8 @@ def generate_synthetic_control_data(
115108

116109

117110
def generate_time_series_data(
118-
N: int = 100,
119-
treatment_time: int = 70,
120-
beta_temp: float = -1,
121-
beta_linear: float = 0.5,
122-
beta_intercept: float = 3,
123-
) -> pd.DataFrame:
111+
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
112+
):
124113
"""
125114
Generates interrupted time series example data
126115
@@ -166,7 +155,7 @@ def generate_time_series_data(
166155
return df
167156

168157

169-
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
158+
def generate_time_series_data_seasonal(treatment_time):
170159
"""
171160
Generates 10 years of monthly data with seasonality
172161
"""
@@ -194,9 +183,7 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF
194183
return df
195184

196185

197-
def generate_time_series_data_simple(
198-
treatment_time: pd.Timestamp, slope: float = 0.0
199-
) -> pd.DataFrame:
186+
def generate_time_series_data_simple(treatment_time, slope=0.0):
200187
"""Generate simple interrupted time series data, with no seasonality or temporal
201188
structure.
202189
"""
@@ -218,7 +205,7 @@ def generate_time_series_data_simple(
218205
return df
219206

220207

221-
def generate_did() -> pd.DataFrame:
208+
def generate_did():
222209
"""
223210
Generate Difference in Differences data
224211
@@ -236,14 +223,8 @@ def generate_did() -> pd.DataFrame:
236223

237224
# local functions
238225
def outcome(
239-
t: np.ndarray,
240-
control_intercept: float,
241-
treat_intercept_delta: float,
242-
trend: float,
243-
Δ: float,
244-
group: np.ndarray,
245-
post_treatment: np.ndarray,
246-
) -> np.ndarray:
226+
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
227+
):
247228
"""Compute the outcome of each unit"""
248229
return (
249230
control_intercept
@@ -276,8 +257,8 @@ def outcome(
276257

277258

278259
def generate_regression_discontinuity_data(
279-
N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0
280-
) -> pd.DataFrame:
260+
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
261+
):
281262
"""
282263
Generate regression discontinuity example data
283264
@@ -291,12 +272,12 @@ def generate_regression_discontinuity_data(
291272
... ) # doctest: +SKIP
292273
"""
293274

294-
def is_treated(x: np.ndarray) -> np.ndarray:
275+
def is_treated(x):
295276
"""Check if x was treated"""
296277
return np.greater_equal(x, true_treatment_threshold)
297278

298-
def impact(x: np.ndarray) -> np.ndarray:
299-
"""Assign true_causal_impact to all treated entries"""
279+
def impact(x):
280+
"""Assign true_causal_impact to all treaated entries"""
300281
y = np.zeros(len(x))
301282
y[is_treated(x)] = true_causal_impact
302283
return y
@@ -308,11 +289,8 @@ def impact(x: np.ndarray) -> np.ndarray:
308289

309290

310291
def generate_ancova_data(
311-
N: int = 200,
312-
pre_treatment_means: np.ndarray = np.array([10, 12]),
313-
treatment_effect: float = 2,
314-
sigma: float = 1,
315-
) -> pd.DataFrame:
292+
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
293+
):
316294
"""
317295
Generate ANCOVA example data
318296
@@ -332,7 +310,7 @@ def generate_ancova_data(
332310
return df
333311

334312

335-
def generate_geolift_data() -> pd.DataFrame:
313+
def generate_geolift_data():
336314
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
337315
countries. The treated unit `Denmark` is a weighted combination of the untreated
338316
units. We additionally specify a treatment effect which takes effect after the
@@ -382,7 +360,7 @@ def generate_geolift_data() -> pd.DataFrame:
382360
return df
383361

384362

385-
def generate_multicell_geolift_data() -> pd.DataFrame:
363+
def generate_multicell_geolift_data():
386364
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
387365
countries. The treated unit `Denmark` is a weighted combination of the untreated
388366
units. We additionally specify a treatment effect which takes effect after the
@@ -444,9 +422,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:
444422
# -----------------
445423

446424

447-
def generate_seasonality(
448-
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
449-
) -> np.ndarray:
425+
def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
450426
"""Generate monthly seasonality by sampling from a Gaussian process with a
451427
Gaussian kernel, using numpy code"""
452428
# Generate the covariance matrix
@@ -460,26 +436,14 @@ def generate_seasonality(
460436
return seasonality
461437

462438

463-
def periodic_kernel(
464-
x1: np.ndarray,
465-
x2: np.ndarray,
466-
period: float = 1,
467-
length_scale: float = 1,
468-
amplitude: float = 1,
469-
) -> np.ndarray:
439+
def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
470440
"""Generate a periodic kernel for gaussian process"""
471441
return amplitude**2 * np.exp(
472442
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
473443
)
474444

475445

476-
def create_series(
477-
n: int = 52,
478-
amplitude: float = 1,
479-
length_scale: float = 2,
480-
n_years: int = 4,
481-
intercept: float = 3,
482-
) -> np.ndarray:
446+
def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
483447
"""
484448
Returns numpy tile with generated seasonality data repeated over
485449
multiple years

0 commit comments

Comments
 (0)