1515Functions that generate data sets used in examples
1616"""
1717
18+ from typing import Any
19+
1820import numpy as np
1921import pandas as pd
2022from scipy .stats import dirichlet , gamma , norm , uniform
2123from statsmodels .nonparametric .smoothers_lowess import lowess
2224
23- default_lowess_kwargs = {"frac" : 0.2 , "it" : 0 }
24- RANDOM_SEED = 8927
25- rng = np .random .default_rng (RANDOM_SEED )
25+ default_lowess_kwargs : dict [ str , float ] = {"frac" : 0.2 , "it" : 0 }
26+ RANDOM_SEED : int = 8927
27+ rng : np . random . Generator = np .random .default_rng (RANDOM_SEED )
2628
2729
2830def _smoothed_gaussian_random_walk (
29- gaussian_random_walk_mu , gaussian_random_walk_sigma , N , lowess_kwargs
30- ):
31+ gaussian_random_walk_mu : float ,
32+ gaussian_random_walk_sigma : float ,
33+ N : int ,
34+ lowess_kwargs : dict [str , Any ],
35+ ) -> tuple [np .ndarray , np .ndarray ]:
3136 """
3237 Generates Gaussian random walk data and applies LOWESS
3338
@@ -48,12 +53,12 @@ def _smoothed_gaussian_random_walk(
4853
4954
5055def generate_synthetic_control_data (
51- N = 100 ,
52- treatment_time = 70 ,
53- grw_mu = 0.25 ,
54- grw_sigma = 1 ,
55- lowess_kwargs = default_lowess_kwargs ,
56- ):
56+ N : int = 100 ,
57+ treatment_time : int = 70 ,
58+ grw_mu : float = 0.25 ,
59+ grw_sigma : float = 1 ,
60+ lowess_kwargs : dict [ str , Any ] | None = None ,
61+ ) -> tuple [ pd . DataFrame , np . ndarray ] :
5762 """
5863 Generates data for synthetic control example.
5964
@@ -73,6 +78,8 @@ def generate_synthetic_control_data(
7378 >>> from causalpy.data.simulate_data import generate_synthetic_control_data
7479 >>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
7580 """
81+ if lowess_kwargs is None :
82+ lowess_kwargs = default_lowess_kwargs
7683
7784 # 1. Generate non-treated variables
7885 df = pd .DataFrame (
@@ -108,8 +115,12 @@ def generate_synthetic_control_data(
108115
109116
110117def generate_time_series_data (
111- N = 100 , treatment_time = 70 , beta_temp = - 1 , beta_linear = 0.5 , beta_intercept = 3
112- ):
118+ N : int = 100 ,
119+ treatment_time : int = 70 ,
120+ beta_temp : float = - 1 ,
121+ beta_linear : float = 0.5 ,
122+ beta_intercept : float = 3 ,
123+ ) -> pd .DataFrame :
113124 """
114125 Generates interrupted time series example data
115126
@@ -155,7 +166,7 @@ def generate_time_series_data(
155166 return df
156167
157168
158- def generate_time_series_data_seasonal (treatment_time ) :
169+ def generate_time_series_data_seasonal (treatment_time : pd . Timestamp ) -> pd . DataFrame :
159170 """
160171 Generates 10 years of monthly data with seasonality
161172 """
@@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
183194 return df
184195
185196
186- def generate_time_series_data_simple (treatment_time , slope = 0.0 ):
197+ def generate_time_series_data_simple (
198+ treatment_time : pd .Timestamp , slope : float = 0.0
199+ ) -> pd .DataFrame :
187200 """Generate simple interrupted time series data, with no seasonality or temporal
188201 structure.
189202 """
@@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205218 return df
206219
207220
208- def generate_did ():
221+ def generate_did () -> pd . DataFrame :
209222 """
210223 Generate Difference in Differences data
211224
@@ -223,8 +236,14 @@ def generate_did():
223236
224237 # local functions
225238 def outcome (
226- t , control_intercept , treat_intercept_delta , trend , Δ , group , post_treatment
227- ):
239+ t : np .ndarray ,
240+ control_intercept : float ,
241+ treat_intercept_delta : float ,
242+ trend : float ,
243+ Δ : float ,
244+ group : np .ndarray ,
245+ post_treatment : np .ndarray ,
246+ ) -> np .ndarray :
228247 """Compute the outcome of each unit"""
229248 return (
230249 control_intercept
@@ -257,8 +276,8 @@ def outcome(
257276
258277
259278def generate_regression_discontinuity_data (
260- N = 100 , true_causal_impact = 0.5 , true_treatment_threshold = 0.0
261- ):
279+ N : int = 100 , true_causal_impact : float = 0.5 , true_treatment_threshold : float = 0.0
280+ ) -> pd . DataFrame :
262281 """
263282 Generate regression discontinuity example data
264283
@@ -272,12 +291,12 @@ def generate_regression_discontinuity_data(
272291 ... ) # doctest: +SKIP
273292 """
274293
275- def is_treated (x ) :
294+ def is_treated (x : np . ndarray ) -> np . ndarray :
276295 """Check if x was treated"""
277296 return np .greater_equal (x , true_treatment_threshold )
278297
279- def impact (x ) :
280- """Assign true_causal_impact to all treaated entries"""
298+ def impact (x : np . ndarray ) -> np . ndarray :
299+ """Assign true_causal_impact to all treated entries"""
281300 y = np .zeros (len (x ))
282301 y [is_treated (x )] = true_causal_impact
283302 return y
@@ -289,8 +308,11 @@ def impact(x):
289308
290309
291310def generate_ancova_data (
292- N = 200 , pre_treatment_means = np .array ([10 , 12 ]), treatment_effect = 2 , sigma = 1
293- ):
311+ N : int = 200 ,
312+ pre_treatment_means : np .ndarray = np .array ([10 , 12 ]),
313+ treatment_effect : float = 2 ,
314+ sigma : float = 1 ,
315+ ) -> pd .DataFrame :
294316 """
295317 Generate ANCOVA example data
296318
@@ -310,7 +332,7 @@ def generate_ancova_data(
310332 return df
311333
312334
313- def generate_geolift_data ():
335+ def generate_geolift_data () -> pd . DataFrame :
314336 """Generate synthetic data for a geolift example. This will consists of 6 untreated
315337 countries. The treated unit `Denmark` is a weighted combination of the untreated
316338 units. We additionally specify a treatment effect which takes effect after the
@@ -360,7 +382,7 @@ def generate_geolift_data():
360382 return df
361383
362384
363- def generate_multicell_geolift_data ():
385+ def generate_multicell_geolift_data () -> pd . DataFrame :
364386 """Generate synthetic data for a geolift example. This will consists of 6 untreated
365387 countries. The treated unit `Denmark` is a weighted combination of the untreated
366388 units. We additionally specify a treatment effect which takes effect after the
@@ -422,7 +444,9 @@ def generate_multicell_geolift_data():
422444# -----------------
423445
424446
425- def generate_seasonality (n = 12 , amplitude = 1 , length_scale = 0.5 ):
447+ def generate_seasonality (
448+ n : int = 12 , amplitude : float = 1 , length_scale : float = 0.5
449+ ) -> np .ndarray :
426450 """Generate monthly seasonality by sampling from a Gaussian process with a
427451 Gaussian kernel, using numpy code"""
428452 # Generate the covariance matrix
@@ -436,14 +460,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436460 return seasonality
437461
438462
439- def periodic_kernel (x1 , x2 , period = 1 , length_scale = 1 , amplitude = 1 ):
463+ def periodic_kernel (
464+ x1 : np .ndarray ,
465+ x2 : np .ndarray ,
466+ period : float = 1 ,
467+ length_scale : float = 1 ,
468+ amplitude : float = 1 ,
469+ ) -> np .ndarray :
440470 """Generate a periodic kernel for gaussian process"""
441471 return amplitude ** 2 * np .exp (
442472 - 2 * np .sin (np .pi * np .abs (x1 - x2 ) / period ) ** 2 / length_scale ** 2
443473 )
444474
445475
446- def create_series (n = 52 , amplitude = 1 , length_scale = 2 , n_years = 4 , intercept = 3 ):
476+ def create_series (
477+ n : int = 52 ,
478+ amplitude : float = 1 ,
479+ length_scale : float = 2 ,
480+ n_years : int = 4 ,
481+ intercept : float = 3 ,
482+ ) -> np .ndarray :
447483 """
448484 Returns numpy tile with generated seasonality data repeated over
449485 multiple years
0 commit comments