1515Functions that generate data sets used in examples
1616"""
1717
18- from typing import Any
19-
2018import numpy as np
2119import pandas as pd
2220from scipy .stats import dirichlet , gamma , norm , uniform
2321from statsmodels .nonparametric .smoothers_lowess import lowess
2422
25- default_lowess_kwargs : dict [ str , float ] = {"frac" : 0.2 , "it" : 0 }
26- RANDOM_SEED : int = 8927
27- rng : np . random . Generator = np .random .default_rng (RANDOM_SEED )
23+ default_lowess_kwargs = {"frac" : 0.2 , "it" : 0 }
24+ RANDOM_SEED = 8927
25+ rng = np .random .default_rng (RANDOM_SEED )
2826
2927
3028def _smoothed_gaussian_random_walk (
31- gaussian_random_walk_mu : float ,
32- gaussian_random_walk_sigma : float ,
33- N : int ,
34- lowess_kwargs : dict [str , Any ],
35- ) -> tuple [np .ndarray , np .ndarray ]:
29+ gaussian_random_walk_mu , gaussian_random_walk_sigma , N , lowess_kwargs
30+ ):
3631 """
3732 Generates Gaussian random walk data and applies LOWESS
3833
@@ -53,12 +48,12 @@ def _smoothed_gaussian_random_walk(
5348
5449
5550def generate_synthetic_control_data (
56- N : int = 100 ,
57- treatment_time : int = 70 ,
58- grw_mu : float = 0.25 ,
59- grw_sigma : float = 1 ,
60- lowess_kwargs : dict [ str , Any ] | None = None ,
61- ) -> tuple [ pd . DataFrame , np . ndarray ] :
51+ N = 100 ,
52+ treatment_time = 70 ,
53+ grw_mu = 0.25 ,
54+ grw_sigma = 1 ,
55+ lowess_kwargs = default_lowess_kwargs ,
56+ ):
6257 """
6358 Generates data for synthetic control example.
6459
@@ -78,8 +73,6 @@ def generate_synthetic_control_data(
7873 >>> from causalpy.data.simulate_data import generate_synthetic_control_data
7974 >>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
8075 """
81- if lowess_kwargs is None :
82- lowess_kwargs = default_lowess_kwargs
8376
8477 # 1. Generate non-treated variables
8578 df = pd .DataFrame (
@@ -115,12 +108,8 @@ def generate_synthetic_control_data(
115108
116109
117110def generate_time_series_data (
118- N : int = 100 ,
119- treatment_time : int = 70 ,
120- beta_temp : float = - 1 ,
121- beta_linear : float = 0.5 ,
122- beta_intercept : float = 3 ,
123- ) -> pd .DataFrame :
111+ N = 100 , treatment_time = 70 , beta_temp = - 1 , beta_linear = 0.5 , beta_intercept = 3
112+ ):
124113 """
125114 Generates interrupted time series example data
126115
@@ -166,7 +155,7 @@ def generate_time_series_data(
166155 return df
167156
168157
169- def generate_time_series_data_seasonal (treatment_time : pd . Timestamp ) -> pd . DataFrame :
158+ def generate_time_series_data_seasonal (treatment_time ) :
170159 """
171160 Generates 10 years of monthly data with seasonality
172161 """
@@ -194,9 +183,7 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF
194183 return df
195184
196185
197- def generate_time_series_data_simple (
198- treatment_time : pd .Timestamp , slope : float = 0.0
199- ) -> pd .DataFrame :
186+ def generate_time_series_data_simple (treatment_time , slope = 0.0 ):
200187 """Generate simple interrupted time series data, with no seasonality or temporal
201188 structure.
202189 """
@@ -218,7 +205,7 @@ def generate_time_series_data_simple(
218205 return df
219206
220207
221- def generate_did () -> pd . DataFrame :
208+ def generate_did ():
222209 """
223210 Generate Difference in Differences data
224211
@@ -236,14 +223,8 @@ def generate_did() -> pd.DataFrame:
236223
237224 # local functions
238225 def outcome (
239- t : np .ndarray ,
240- control_intercept : float ,
241- treat_intercept_delta : float ,
242- trend : float ,
243- Δ : float ,
244- group : np .ndarray ,
245- post_treatment : np .ndarray ,
246- ) -> np .ndarray :
226+ t , control_intercept , treat_intercept_delta , trend , Δ , group , post_treatment
227+ ):
247228 """Compute the outcome of each unit"""
248229 return (
249230 control_intercept
@@ -276,8 +257,8 @@ def outcome(
276257
277258
278259def generate_regression_discontinuity_data (
279- N : int = 100 , true_causal_impact : float = 0.5 , true_treatment_threshold : float = 0.0
280- ) -> pd . DataFrame :
260+ N = 100 , true_causal_impact = 0.5 , true_treatment_threshold = 0.0
261+ ):
281262 """
282263 Generate regression discontinuity example data
283264
@@ -291,12 +272,12 @@ def generate_regression_discontinuity_data(
291272 ... ) # doctest: +SKIP
292273 """
293274
294- def is_treated (x : np . ndarray ) -> np . ndarray :
275+ def is_treated (x ) :
295276 """Check if x was treated"""
296277 return np .greater_equal (x , true_treatment_threshold )
297278
298- def impact (x : np . ndarray ) -> np . ndarray :
299- """Assign true_causal_impact to all treated entries"""
279+ def impact (x ) :
280+ """Assign true_causal_impact to all treaated entries"""
300281 y = np .zeros (len (x ))
301282 y [is_treated (x )] = true_causal_impact
302283 return y
@@ -308,11 +289,8 @@ def impact(x: np.ndarray) -> np.ndarray:
308289
309290
310291def generate_ancova_data (
311- N : int = 200 ,
312- pre_treatment_means : np .ndarray = np .array ([10 , 12 ]),
313- treatment_effect : float = 2 ,
314- sigma : float = 1 ,
315- ) -> pd .DataFrame :
292+ N = 200 , pre_treatment_means = np .array ([10 , 12 ]), treatment_effect = 2 , sigma = 1
293+ ):
316294 """
317295 Generate ANCOVA example data
318296
@@ -332,7 +310,7 @@ def generate_ancova_data(
332310 return df
333311
334312
335- def generate_geolift_data () -> pd . DataFrame :
313+ def generate_geolift_data ():
336314 """Generate synthetic data for a geolift example. This will consists of 6 untreated
337315 countries. The treated unit `Denmark` is a weighted combination of the untreated
338316 units. We additionally specify a treatment effect which takes effect after the
@@ -382,7 +360,7 @@ def generate_geolift_data() -> pd.DataFrame:
382360 return df
383361
384362
385- def generate_multicell_geolift_data () -> pd . DataFrame :
363+ def generate_multicell_geolift_data ():
386364 """Generate synthetic data for a geolift example. This will consists of 6 untreated
387365 countries. The treated unit `Denmark` is a weighted combination of the untreated
388366 units. We additionally specify a treatment effect which takes effect after the
@@ -444,9 +422,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:
444422# -----------------
445423
446424
447- def generate_seasonality (
448- n : int = 12 , amplitude : float = 1 , length_scale : float = 0.5
449- ) -> np .ndarray :
425+ def generate_seasonality (n = 12 , amplitude = 1 , length_scale = 0.5 ):
450426 """Generate monthly seasonality by sampling from a Gaussian process with a
451427 Gaussian kernel, using numpy code"""
452428 # Generate the covariance matrix
@@ -460,26 +436,14 @@ def generate_seasonality(
460436 return seasonality
461437
462438
463- def periodic_kernel (
464- x1 : np .ndarray ,
465- x2 : np .ndarray ,
466- period : float = 1 ,
467- length_scale : float = 1 ,
468- amplitude : float = 1 ,
469- ) -> np .ndarray :
439+ def periodic_kernel (x1 , x2 , period = 1 , length_scale = 1 , amplitude = 1 ):
470440 """Generate a periodic kernel for gaussian process"""
471441 return amplitude ** 2 * np .exp (
472442 - 2 * np .sin (np .pi * np .abs (x1 - x2 ) / period ) ** 2 / length_scale ** 2
473443 )
474444
475445
476- def create_series (
477- n : int = 52 ,
478- amplitude : float = 1 ,
479- length_scale : float = 2 ,
480- n_years : int = 4 ,
481- intercept : float = 3 ,
482- ) -> np .ndarray :
446+ def create_series (n = 52 , amplitude = 1 , length_scale = 2 , n_years = 4 , intercept = 3 ):
483447 """
484448 Returns numpy tile with generated seasonality data repeated over
485449 multiple years
0 commit comments