Skip to content

Latest commit

 

History

History
231 lines (178 loc) · 6.41 KB

Quickstart - Synthetic difference in differences.md

File metadata and controls

231 lines (178 loc) · 6.41 KB
title hide_title status
Quickstart - Synthetic difference in differences
true
stable

Scalable Synthetic Difference in Differences

This sample notebook aims to show readers how to use SynapseML's DiffInDiffEstimator, SyntheticControlEstimator and SyntheticDiffInDiffEstimator to estimate the causal effect of a treatment on a particular outcome.

In this sample notebook, we will use the California smoking cessation program example to demonstrate usage of the SyntheticDiffInDiff Estimator. The goal of the analysis is to estimate the effect of increased cigarette taxes on smoking in California.

from pyspark.sql.types import *
from synapse.ml.causal import (
    DiffInDiffEstimator,
    SyntheticControlEstimator,
    SyntheticDiffInDiffEstimator,
)
from matplotlib import pyplot as plt
from matplotlib import style
import pandas as pd
import numpy as np

spark.sparkContext.setLogLevel("INFO")
style.use("ggplot")

We will select 5 columns from the dataset: state, year, cigsale, california, after_treatment.

df = (
    spark.read.option("header", True)
    .option("inferSchema", True)
    .csv("wasbs://[email protected]/smoking.csv")
    .select("state", "year", "cigsale", "california", "after_treatment")
)
display(df)

First, we use the DiffInDiffEstimator to estimate the causal effect with regular difference in differences method. We set the treatment indicator column to "california", set post-treatment indicator column to "after_treatment", and set the outcome column to "cigsale".

estimator1 = DiffInDiffEstimator(
    treatmentCol="california", postTreatmentCol="after_treatment", outcomeCol="cigsale"
)

model1 = estimator1.fit(df)

print("[Diff in Diff] treatment effect: {}".format(model1.treatmentEffect))
print("[Diff in Diff] standard error: {}".format(model1.standardError))

The treatment effect estimated by difference in differences should be -27.349.

Next, we use SyntheticControlEstimator to synthesize a control unit and use the synthetic control to estimate the causal effect. To create the synthetic control unit, we need to set the column which indicates the time when each outcome is measured, and the column which indicates the unit for which the outcome is measured.

estimator2 = SyntheticControlEstimator(
    timeCol="year",
    unitCol="state",
    treatmentCol="california",
    postTreatmentCol="after_treatment",
    outcomeCol="cigsale",
    maxIter=5000,
    numIterNoChange=50,
    tol=1e-4,
    stepSize=1.0,
)

model2 = estimator2.fit(df)

print("[Synthetic Control] treatment effect: {}".format(model2.treatmentEffect))
print("[Synthetic Control] standard error: {}".format(model2.standardError))

The treatment effect estimated by synthetic control should be about -19.354.

Internally, a constrained least square regression is used to solve the unit weights for the synthetic control, and we can plot the loss history.

lossHistory = pd.Series(np.array(model2.lossHistoryUnitWeights))

plt.plot(lossHistory[2000:])
plt.title("loss history - unit weights")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

print("Mimimal loss: {}".format(lossHistory.min()))

We can also visualize the synthetic control and compare it with the treated unit.

sc_weights = model2.unitWeights.toPandas().set_index("state")
pdf = df.toPandas()
sc = (
    pdf.query("~california")
    .pivot(index="year", columns="state", values="cigsale")
    .dot(sc_weights)
)
plt.plot(sc, label="Synthetic Control")
plt.plot(sc.index, pdf.query("california")["cigsale"], label="California", color="C1")

plt.title("Synthetic Control Estimation")
plt.ylabel("Cigarette Sales")
plt.vlines(
    x=1988,
    ymin=40,
    ymax=140,
    line,
    lw=2,
    label="Proposition 99",
    color="black",
)
plt.legend()

Lastly, we use SyntheticDiffInDiffEstimator to estimate the causal effect.

estimator3 = SyntheticDiffInDiffEstimator(
    timeCol="year",
    unitCol="state",
    treatmentCol="california",
    postTreatmentCol="after_treatment",
    outcomeCol="cigsale",
    maxIter=5000,
    numIterNoChange=50,
    tol=1e-4,
    stepSize=1.0,
)

model3 = estimator3.fit(df)

print("[Synthetic Diff in Diff] treatment effect: {}".format(model3.treatmentEffect))
print("[Synthetic Diff in Diff] standard error: {}".format(model3.standardError))

The treatment effect estimated by synthetic control should be about -15.554.

Again, we can plot the loss history from the optimizer used to solve the unit weights and the time weights.

lossHistory = pd.Series(np.array(model3.lossHistoryUnitWeights))

plt.plot(lossHistory[1000:])
plt.title("loss history - unit weights")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

print("Mimimal loss: {}".format(lossHistory.min()))
lossHistory = pd.Series(np.array(model3.lossHistoryTimeWeights))

plt.plot(lossHistory[1000:])
plt.title("loss history - time weights")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

print("Mimimal loss: {}".format(lossHistory.min()))

Here we plot the synthetic diff in diff estimate together with the time weights.

unit_weights = model3.unitWeights.toPandas().set_index("state")
unit_intercept = model3.unitIntercept

time_weights = model3.timeWeights.toPandas().set_index("year")
time_intercept = model3.timeIntercept

pdf = df.toPandas()
pivot_df_control = pdf.query("~california").pivot(
    index="year", columns="state", values="cigsale"
)
pivot_df_treat = pdf.query("california").pivot(
    index="year", columns="state", values="cigsale"
)
sc_did = pivot_df_control.values @ unit_weights.values
treated_mean = pivot_df_treat.mean(axis=1)
fig, (ax1, ax2) = plt.subplots(
    2, 1, figsize=(15, 8), sharex=True, gridspec_kw={"height_ratios": [3, 1]}
)
fig.suptitle("Synthetic Diff in Diff Estimation")

ax1.plot(
    pivot_df_control.mean(axis=1), lw=3, color="C1", ls="dashed", label="Control Avg."
)
ax1.plot(treated_mean, lw=3, color="C0", label="California")
ax1.plot(
    pivot_df_control.index,
    sc_did,
    label="Synthetic Control (SDID)",
    color="C1",
    alpha=0.8,
)
ax1.set_ylabel("Cigarette Sales")
ax1.vlines(
    1989,
    treated_mean.min(),
    treated_mean.max(),
    color="black",
    ls="dotted",
    label="Prop. 99",
)
ax1.legend()

ax2.bar(time_weights.index, time_weights["value"], color="skyblue")
ax2.set_ylabel("Time Weights")
ax2.set_xlabel("Time")
ax2.vlines(1989, 0, 1, color="black", ls="dotted")