1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
| def model(
control: Float[Array, "Nc D"],
treated: Float[Array, "Nt D"],
scale_multiplier: float = 2.0,
scale_bounds: tuple[float, float] = (0.1, 10),
):
pooled_mean = jnp.concat([control, treated]).mean()
pooled_scale = scale_multiplier * jnp.concat([control, treated]).std()
control_mean = npy.sample(
"control_mean", npy_dist.Normal(loc=pooled_mean, scale=pooled_scale)
)
treated_mean = npy.sample(
"treated_mean", npy_dist.Normal(loc=pooled_mean, scale=pooled_scale)
)
control_scale = npy.sample("control_scale", npy_dist.Uniform(*scale_bounds))
treated_scale = npy.sample("treated_scale", npy_dist.Uniform(*scale_bounds))
nu_minus_one = npy.sample("nu_minus_one", npy_dist.Exponential(1.0 / 29.0))
nu = nu_minus_one + 1.0
nu_log10 = jnp.log10(nu)
control_likelihood = npy.sample(
"control",
npy_dist.StudentT(df=nu, loc=control_mean, scale=control_scale),
obs=control,
)
treated_likelihood = npy.sample(
"treated",
npy_dist.StudentT(df=nu, loc=treated_mean, scale=treated_scale),
obs=treated,
)
difference_of_means = npy.deterministic(
"diff_of_means", treated_mean - control_mean
)
difference_of_scales = npy.deterministic(
"diff_of_scales", treated_scale - control_scale
)
level_effect = npy.deterministic(
"level_effect",
difference_of_means / jnp.sqrt((control_scale**2 + treated_scale**2) / 2),
)
relative_effect = npy.deterministic(
"relative_effect", difference_of_means / control_mean
)
|