===================
== Thomas Pinder ==
===================
Bayesian ML, Causal Inference, and JAX

Bayesian Synthetic Control

A NumPyro implementation of Bayesian Synthetic Control Methods.

Synthetic Control Methods (SCM) were first introduced in Abadie and Gardeazabal (2003) and formalised in Abadie, Diamon, and Hainmueller (2010). In the latter of these publications, the authors investigated the impact of California’s Proposition 99 legislation on cigarette sales. In this notebook I implement a Bayesian variant of SCMs and apply it to the same dataset. We shall see how the original SCM can easily be mapped into a Bayesian framework, before conducting posterior inference and exploring how treatment effects and their corresponding uncertainty may be estimated.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import jax
import jax.random as jr
import matplotlib.pyplot as plt
import numpyro as npy
from numpyro import distributions as npy_dist
from jaxtyping import Float, Array
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
import seaborn as sns
import pandas as pd
import matplotlib as mpl
from numpyro.infer import Predictive


jax.config.update("jax_platform_name", "cpu")
npy.set_host_device_count(8)
key = jr.PRNGKey(123)
sns.set_theme(
    context="notebook",
    font="serif",
    style="whitegrid",
    palette="deep",
    rc={"figure.figsize": (8, 3), "axes.spines.top": False, "axes.spines.right": False},
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
Show plotting code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import jax.numpy as jnp
import matplotlib as mpl


def plot_effect(
    trace: dict[str, Float[Array, "..."]],
    control_units: pd.DataFrame,
    treated_unit: pd.DataFrame,
    treatment_date: int,
    /,
    cols: list[str] | None = None,
):
    if cols is None:
        cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
    full_index = treated_unit.index
    treated_post = treated_unit.loc[treated_unit.index >= treatment_date]
    num_post = len(treated_post)

    counterfactual = trace["intercept"] + control_units.values @ trace["weights"].T
    counterfactual_lower, counterfactual_mean, counterfactual_upper = jnp.percentile(
        counterfactual, q=jnp.array([2.5, 50, 97.5]), axis=1
    )
    counterfactual_post = counterfactual_mean[-num_post:]

    fig, axes = plt.subplots(figsize=(8, 6), nrows=2, sharex=True, sharey=False)
    ax0, ax1 = axes.ravel()
    ax0.axvline(x=treatment_date, label="Treatment Date", color=cols[2], linestyle="--")
    ax0.plot(full_index, treated_unit, "x", color="gray", label="Observed")
    ax0.plot(
        full_index,
        counterfactual_mean,
        color=cols[1],
        linewidth=2,
        label="Counterfactual",
    )
    ax0.fill_between(
        full_index, counterfactual_lower, counterfactual_upper, color=cols[1], alpha=0.4
    )
    ax0.fill_between(
        treated_post.index,
        treated_post.values.squeeze(),
        counterfactual_post,
        color=cols[0],
        alpha=0.3,
    )
    ax0.plot(
        treated_post.index,
        treated_post,
        color=cols[0],
        linewidth=2,
        label="Post-Intervention",
    )
    ax0.legend(loc="best")
    ax0.set(ylabel="Cigarette Sales")

    ax1.axhline(y=0.0, color="grey")
    ax1.axvline(
        x=treatment_date,
        color=cols[2],
        label="Treatment Date",
        linestyle="--",
    )
    ax1.plot(
        treated_unit.index,
        treated_unit.values.squeeze() - counterfactual_mean,
        color=cols[1],
        label="Effect",
        linewidth=2,
    )
    ax1.fill_between(
        full_index,
        treated_unit.values.squeeze() - counterfactual_upper,
        treated_unit.values.squeeze() - counterfactual_lower,
        color=cols[1],
        alpha=0.4,
    )
    ax1.fill_between(
        treated_post.index,
        treated_post.values.squeeze() - counterfactual_post,
        jnp.zeros_like(counterfactual_post),
        color=cols[0],
        alpha=0.3,
    )
    ax1.set(ylabel="Cigarette Sales", title="Causal Impact")

    fig.tight_layout()


def plot_prior_predictive(
    prior_samples: dict[str, Float[Array, "..."]],
    treated_pre: Float[Array, "N 1"],
    pre_index: pd.Index,
    /,
    cols: list[str] | None = None,
):
    """Plot prior predictive check for pre-treatment period."""
    if cols is None:
        cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

    plt.plot(pre_index, prior_samples["obs"].T, alpha=0.1, color=cols[0], linewidth=0.5)
    plt.plot(pre_index, treated_pre, color=cols[1], linewidth=3, label="Observed")
    plt.title("Prior Predictive: Pre-treatment")
    plt.xlabel("Time")
    plt.ylabel("Cigarette Sales")
    plt.legend(loc="best")
    plt.tight_layout()
    plt.show()


def plot_posterior_predictive(
    ppc_pre: dict[str, Float[Array, "..."]],
    ppc_post: dict[str, Float[Array, "..."]],
    treated_pre: Float[Array, "N 1"],
    treated_post: Float[Array, "M 1"],
    pre_index: pd.Index,
    post_index: pd.Index,
    treatment_date: int,
    /,
    cols: list[str] | None = None,
):
    """Plot posterior predictive check for both pre and post treatment periods."""
    if cols is None:
        cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
    ppc_pre_mean = ppc_pre["obs"].mean(axis=0)
    ppc_pre_lower = jnp.percentile(ppc_pre["obs"], 2.5, axis=0)
    ppc_pre_upper = jnp.percentile(ppc_pre["obs"], 97.5, axis=0)

    ppc_post_mean = ppc_post["obs"].mean(axis=0)
    ppc_post_lower = jnp.percentile(ppc_post["obs"], 2.5, axis=0)
    ppc_post_upper = jnp.percentile(ppc_post["obs"], 97.5, axis=0)

    _fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Pre-treatment panel
    axes[0].fill_between(
        pre_index,
        ppc_pre_lower,
        ppc_pre_upper,
        alpha=0.3,
        color=cols[0],
        label="95% CI",
    )
    axes[0].plot(
        pre_index, ppc_pre_mean, color=cols[0], linewidth=2, label="Posterior mean"
    )
    axes[0].plot(
        pre_index,
        treated_pre,
        "o-",
        color=cols[1],
        linewidth=2,
        label="Observed",
        markersize=3,
    )
    axes[0].set_title("Posterior Predictive Check: Pre-treatment")
    axes[0].legend()
    axes[0].set_xlabel("Year")
    axes[0].set_ylabel("Cigarette Sales")

    # Post-treatment panel
    axes[1].fill_between(
        post_index,
        ppc_post_lower,
        ppc_post_upper,
        alpha=0.3,
        color=cols[0],
        label="95% CI (counterfactual)",
    )
    axes[1].plot(
        post_index,
        ppc_post_mean,
        color=cols[0],
        linewidth=2,
        label="Synthetic California",
    )
    axes[1].plot(
        post_index,
        treated_post,
        "o-",
        color=cols[1],
        linewidth=2,
        label="Actual California",
        markersize=3,
    )
    axes[1].axvline(
        treatment_date, color=cols[2], linestyle="--", alpha=0.5, label="Prop 99"
    )
    axes[1].set_title("Treatment Effect: Actual vs Counterfactual")
    axes[1].legend()
    axes[1].set_xlabel("Year")
    axes[1].set_ylabel("Cigarette Sales")

    plt.tight_layout()
    plt.show()


def plot_comparison(
    treated_unit: pd.DataFrame,
    control_units: pd.DataFrame,
    counterfactual_lower: Float[Array, "N"],
    counterfactual_mean: Float[Array, "N"],
    counterfactual_upper: Float[Array, "N"],
    treatment_date: int,
    /,
    cols: list[str] | None = None,
):
    """Plot treated unit vs counterfactual with control units in background."""
    if cols is None:
        cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

    fig, ax = plt.subplots()
    ax.plot(control_units, color="grey", alpha=0.2, label="Controls")
    ax.plot(treated_unit, color=cols[0], linewidth=2, label="Treated")
    ax.plot(
        treated_unit.index,
        counterfactual_mean,
        color=cols[1],
        linewidth=2,
        label="Counterfactual",
    )
    ax.fill_between(
        treated_unit.index,
        counterfactual_lower,
        counterfactual_upper,
        color=cols[1],
        alpha=0.5,
    )
    ax.axvline(treatment_date, label="Treatment Date", color=cols[2], linestyle="--")

    # Clean legend
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles, strict=False))
    ax.legend(by_label.values(), by_label.keys(), loc="best")


def clean_legend(ax):
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles, strict=False))
    ax.legend(by_label.values(), by_label.keys(), loc="best")
    return ax

Model Specification

The idea of SCMs is to create a synthetic equivalent of the treated unit, a counterfactual, by learning a weighted combination of untreated control units. This counterfactual should approximate the behaviour of the treated unit in the pre-treatment period. The trajectory of the counterfactual in the post-treatment period can then be interpreted as “what would have happened in the treated unit, had a treatment not been applied”. The difference between the treated unit’s value and the counterfactual then serves as the treatment effect.

Let $Y_{1t}$ be the outcome for the treated unit (California) at time $t$, and let $\mathbf{Y}_{0t}$ be the vector of outcomes for the $D$ control units at time $t$. $\hat{Y}_{1t}$ is then the counterfactual unit. An ordinary SCM can then be written as

$$ \hat{Y}_{1t} = \mathbf{Y}_{0t}\mathbf{w} + \epsilon, \tag{1} $$

where $\mathbf{w}\in\Delta^{D}$ where $\Delta^{D}$ is the d-dimension probability simplex. The practicaly implication of this is that the weights of our counterfactual are constrained to be strictly positive and sum to 1; a form of regularisation that prevents overfitting.

A Bayesian variant of such a model can then be easily inferred by mapping the model in (1) to a Bayesian linear regression model where the weights are drawn from a Dirichlet distribution. Use of the Dirichlet distribution satisfies the need for our weights to belong to the simplex, and by setting an appropriate hyperprior on the distribution’s concentration parameter, we may even control how sparse our weights are. As the concentration parameter $c$ tends to 0 from above, the samples drawn from the Dirichlet distribution become increasing sparse, leading to fewer units contributing to the counterfactual unit’s response. This aids interpretability, often improves performance by functioning as a form of regularisation, and can be helpful in a practical setting where there is an increasing cost for each additional unit that is included in the experiment.

We may write the model as

$$ \hat{Y}_{1t} \sim \mathcal{N}(\mu_t, \sigma^2) $$

where the mean, $\mu_t$, represents the synthetic control:

$$ \mu_t = \alpha + \mathbf{Y}_{0t}\mathbf{w}\,. $$

The parameters of this model are assigned the priors:

$$ \begin{align} \mathbf{c} &\sim \text{Gamma}(0.5, 0.5) \\ \mathbf{w} &\sim \text{Dirichlet}(\mathbf{c}) \\ \alpha &\sim \mathcal{N}(0, 5) \\ \sigma &\sim \text{HalfNormal}(1) \end{align} $$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def model(
    control_units: Float[Array, "N D"], treated_unit: Float[Array, "N 1"] | None = None
):
    num_units = jnp.shape(control_units)[1]
    concentration = npy.sample("concentration", npy_dist.Gamma(0.5, 0.5)) * jnp.ones(
        num_units
    )
    weights = npy.sample("weights", npy_dist.Dirichlet(concentration=concentration))

    intercept = npy.sample("intercept", npy_dist.Normal(0, 10))
    counterfactual = intercept + jnp.matmul(control_units, weights)

    noise = npy.sample("noise", npy_dist.HalfNormal(scale=1.0))
    with npy.handlers.condition(data={"obs": treated_unit}):
        npy.sample("obs", npy_dist.Normal(counterfactual, noise))
1
data = pd.read_csv("datasets/california_smoking.csv")
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def split(
    df: pd.DataFrame, treatment_time: int
) -> tuple[Float[Array, "Nx D"], Float[Array, "Ny D"]]:
    pre = jnp.array(df.loc[df.index < treatment_time].values)
    post = jnp.array(df.loc[df.index >= treatment_time].values)
    return pre, post


treatment_date = 1988

controls = data.loc[data.state != "California", :].pivot(
    index="year", columns="state", values="cigsale"
)
treated = data.loc[data.state == "California", ["year", "cigsale"]].set_index("year")
1
2
3
4
5
fig, ax = plt.subplots()
ax.plot(controls, color=cols[1], label="Controls", alpha=0.4)
ax.plot(treated, color=cols[0], label="Treated")
ax.axvline(treatment_date, label="Treatment Date", color=cols[2], linestyle="--")
clean_legend(ax)

Data Preparation

We use the california_smoking.csv dataset, which contains per-capita cigarette sales for 39 US states from 1970 to 2000. The treated unit is California, where Proposition 99 was passed in 1988 to increase cigarette taxes.

The data is partitioned into a treated unit (California) and a set of control units (the other 38 states). The data for all units is then split into a pre-treatment period (1970-1987) and a post-treatment period (1988-2000).

1
2
controls_pre, controls_post = split(controls, treatment_date)
treated_pre, _ = split(treated, treatment_date)

Model Fitting and Evaluation

With the model specified and the data prepared, we can perform inference to find the posterior distribution of the parameters. We use a No-U-Turn Sampler (NUTS), which is a self-tuning variant of Hamiltonian Monte Carlo (HMC), to draw samples from the posterior.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1500,
    num_chains=8,
    chain_method="parallel",
    progress_bar=False,
)

mcmc.run(
    key,
    control_units=controls_pre,
    treated_unit=treated_pre.squeeze(),
)
trace = mcmc.get_samples()

Prior Predictive Check

Before fitting the model, it is good practice to perform a prior predictive check. This involves simulating data from the model using only the prior distributions. By comparing the generated data to the actual observed data, we can assess whether the priors are reasonable. The plot below shows that our priors are quite weak, allowing for a wide range of possible outcomes, which is a safe starting point.

1
2
3
4
5
6
7
prior_predictive = Predictive(model, num_samples=1000, return_sites=["obs"])
prior_samples = prior_predictive(key, control_units=controls_pre)
pre_index = pd.to_datetime(
    treated.loc[treated.index < treatment_date].index, format="%Y"
)

plot_prior_predictive(prior_samples, treated_pre, pre_index, cols=cols)

Posterior Predictive Check

After fitting the model on the pre-treatment data, we perform a posterior predictive check. This involves generating data from the model using parameter values from the posterior distribution. The left-hand panel below shows that the fitted model provides a good description of the pre-treatment data. The right-hand panel shows the model’s counterfactual prediction for the post-treatment period, which is what we will use to estimate the treatment effect.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
posterior_predictive = Predictive(model, trace)

ppc_pre = posterior_predictive(
    key,
    control_units=controls_pre,
    treated_unit=None,
)

ppc_post = posterior_predictive(
    key,
    control_units=controls_post,
    treated_unit=None,
)

post_index = treated.loc[treated.index >= treatment_date].index
treated_post = treated.loc[treated.index >= treatment_date].values

plot_posterior_predictive(
    ppc_pre,
    ppc_post,
    treated_pre,
    treated_post,
    pre_index,
    post_index,
    treatment_date,
    cols=cols,
)

Effect Estimation

Now that we have the posterior distribution of the model parameters, we can estimate the causal effect of Proposition 99. We do this by generating a counterfactual outcome for California in the post-treatment period. This is the outcome that would have been observed had the intervention not occurred. The counterfactual is generated by taking the posterior predictive distribution for the post-treatment period.

The treatment effect is the difference between the observed outcome and the counterfactual outcome. A negative effect suggests that Proposition 99 led to a decrease in cigarette sales.

1
2
3
4
counterfactual_post = trace["intercept"] + controls.values @ trace["weights"].T
counterfactual_lower, counterfactual_mean, counterfactual_upper = jnp.percentile(
    counterfactual_post, q=jnp.array([2.5, 50, 97.5]), axis=1
)
1
2
3
4
5
6
7
8
9
plot_comparison(
    treated,
    controls,
    counterfactual_lower,
    counterfactual_mean,
    counterfactual_upper,
    treatment_date,
    cols=cols,
)
1
2
plot_effect(trace, controls, treated, treatment_date)
plt.show()