===================
== Thomas Pinder ==
===================
Bayesian ML, Causal Inference, and JAX


Bayesian Estimation Supersedes the T-Test in NumPyro

A NumPyro implementation of Bayesian estimation supersedes the t-test.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import jax
import jax.random as jr
import matplotlib.pyplot as plt
import numpyro as npy
from numpyro import distributions as npy_dist
from pydantic.dataclasses import dataclass
from pydantic import ConfigDict
from jaxtyping import Float, Array, PRNGKeyArray
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
import numpy as np
import arviz as az
import seaborn as sns
import pandas as pd
import typing as tp
import matplotlib.ticker as mtick

jax.config.update("jax_platform_name", "cpu")
npy.set_host_device_count(8)
key = jr.PRNGKey(123)
sns.set_theme(
    context="notebook",
    font="serif",
    style="whitegrid",
    palette="deep",
    rc={"figure.figsize": (6, 3), "axes.spines.top": False, "axes.spines.right": False},
)

This notebook is an implementation of Bayesian estimation supersedes the t-test (BEST, Kruschke, J. K. (2013)).

Simulate Data

Define Simulation Parameters

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@dataclass(frozen=True, kw_only=True)
class SimulationParameters:
    relative_treatment_effect: float
    control_response: float
    control_response_noise: float = 0.5
    treatment_response_noise: float = 0.5
    num_control_units: int = 200
    num_treatment_units: int = 200

    @property
    def treatment_response(self) -> float:
        return self.control_response * (1 + self.relative_treatment_effect)

    @property
    def level_treatment_effect(self) -> float:
        return self.control_response * self.relative_treatment_effect

    @property
    def treatment_params(self) -> tuple[int, float, float]:
        return (
            self.num_treatment_units,
            self.treatment_response,
            self.treatment_response_noise,
        )

    @property
    def control_params(self) -> tuple[int, float, float]:
        return (
            self.num_control_units,
            self.control_response,
            self.control_response_noise,
        )


sim_params = SimulationParameters(relative_treatment_effect=0.01, control_response=10.0)

Simulate Measurement Data

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@dataclass(frozen=True, kw_only=True, config=ConfigDict(arbitrary_types_allowed=True))
class Dataset:
    control: Float[Array, "Nc D"]
    treated: Float[Array, "Nt D"]
    true_effect: float

    def unpack(self) -> tuple[Float[Array, "Nc D"], Float[Array, "Nt D"]]:
        return self.control, self.treated

    @property
    def pandas(self) -> pd.DataFrame:
        df = pd.DataFrame(jnp.hstack(self.unpack()), columns=["Control", "Treated"])
        return df.melt().rename(columns={"variable": "Group"})


def simulate_data(key: PRNGKeyArray, parameters: SimulationParameters) -> Dataset:
    n_c, mu_c, sigma_c = parameters.control_params
    n_t, mu_t, sigma_t = parameters.treatment_params

    key_c, key_t = jr.split(key, 2)
    y_c = mu_c + sigma_c * jr.normal(key=key_c, shape=(n_c, 1))
    y_t = mu_t + sigma_t * jr.normal(key=key_t, shape=(n_t, 1))
    return Dataset(
        control=y_c, treated=y_t, true_effect=parameters.relative_treatment_effect
    )
1
2
3
4
5
6
data = simulate_data(key, sim_params)

fig, ax = plt.subplots()
sns.kdeplot(data.pandas, x="value", hue="Group", fill=True)
ax.set(xlabel="Measured Value", ylabel="Response")
plt.show()

Inference

Model Specification

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def model(
    control: Float[Array, "Nc D"],
    treated: Float[Array, "Nt D"],
    scale_multiplier: float = 2.0,
    scale_bounds: tuple[float, float] = (0.1, 10),
):
    pooled_mean = jnp.concat([control, treated]).mean()
    pooled_scale = scale_multiplier * jnp.concat([control, treated]).std()

    control_mean = npy.sample(
        "control_mean", npy_dist.Normal(loc=pooled_mean, scale=pooled_scale)
    )
    treated_mean = npy.sample(
        "treated_mean", npy_dist.Normal(loc=pooled_mean, scale=pooled_scale)
    )

    control_scale = npy.sample("control_scale", npy_dist.Uniform(*scale_bounds))
    treated_scale = npy.sample("treated_scale", npy_dist.Uniform(*scale_bounds))

    nu_minus_one = npy.sample("nu_minus_one", npy_dist.Exponential(1.0 / 29.0))
    nu = nu_minus_one + 1.0
    nu_log10 = jnp.log10(nu)

    control_likelihood = npy.sample(
        "control",
        npy_dist.StudentT(df=nu, loc=control_mean, scale=control_scale),
        obs=control,
    )
    treated_likelihood = npy.sample(
        "treated",
        npy_dist.StudentT(df=nu, loc=treated_mean, scale=treated_scale),
        obs=treated,
    )

    difference_of_means = npy.deterministic(
        "diff_of_means", treated_mean - control_mean
    )
    difference_of_scales = npy.deterministic(
        "diff_of_scales", treated_scale - control_scale
    )
    level_effect = npy.deterministic(
        "level_effect",
        difference_of_means / jnp.sqrt((control_scale**2 + treated_scale**2) / 2),
    )
    relative_effect = npy.deterministic(
        "relative_effect", difference_of_means / control_mean
    )

Sampling Routine

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def get_sampler_fn(
    control: Float[Array, "Nc D"], treated: Float[Array, "Nt D"]
) -> tp.Callable[[PRNGKeyArray], dict[str, Float[Array, "..."]]]:
    def _get_sampler_fn(
        key: PRNGKeyArray,
        /,
        n_vectorised: int = 4,
        num_samples: int = 5000,
        num_warmup: int = 2500,
    ) -> dict[str, Float[Array, "..."]]:
        kernel = NUTS(model)
        mcmc = MCMC(
            kernel,
            num_warmup=num_warmup,
            num_samples=num_samples,
            chain_method="parallel",
            progress_bar=False,
        )

        mcmc.run(
            key,
            control=control,
            treated=treated,
            extra_fields=("diverging", "num_steps"),
        )
        return {**mcmc.get_samples(), **mcmc.get_extra_fields()}

    return _get_sampler_fn

Perform Inference

1
2
3
4
5
6
num_devices = jax.local_device_count()
sampler_keys = jr.split(key, num_devices)

run_mcmc = get_sampler_fn(*data.unpack())
traces = jax.pmap(run_mcmc)(sampler_keys)
trace = {k: np.concatenate(v) for k, v in traces.items()}
1
2
3
4
5
6
7
sampler_stats_names = ["diverging", "num_steps"]
params_stats_names = [k for k in trace.keys() if k not in sampler_stats_names]

posterior_samples = {k: v for k, v in traces.items() if k in params_stats_names}
sample_stats = {k: v for k, v in traces.items() if k in sampler_stats_names}

idata = az.from_dict(posterior=posterior_samples, sample_stats=sample_stats)
1
2
3
4
efig, eax = plt.subplots(figsize=(6, 3))
az.plot_posterior(idata, var_names=["relative_effect"], ref_val=0.0, ax=eax)
eax.xaxis.set_major_formatter(mtick.PercentFormatter(decimals=0, xmax=1))
plt.show()