πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Causal Inference: Diff-in-Diff, Propensity Score, RCT

Data Science Interview PremiumCausal Inference⭐ Premium

Advertisement

UBER & NETFLIX INTERVIEW QUESTION

Causal Inference: Diff-in-Diff, Propensity Score, RCT

Causal Analysis & Experimental Design

The Interview Question

ℹ️

Question: You want to estimate the causal effect of a new pricing policy on ride demand:

  • Treatment: New pricing policy implemented in select cities
  • Outcome: Daily ride demand
  • Challenge: Non-random assignment, confounding factors

Walk through your causal inference approach:

  1. How do you establish causality vs correlation?
  2. How do you use difference-in-differences to estimate the effect?
  3. How do you handle confounding with propensity score matching?
  4. How do you validate your causal estimates?

Detailed Answer

1. Causal Inference Framework

Causal inference aims to estimate the effect of a treatment on an outcome, going beyond correlation to establish causation.

import pandas as pd
import numpy as np
from scipy import stats
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

class CausalInferenceFramework:
    """Framework for causal inference methods"""
    
    def __init__(self, data, treatment_col, outcome_col):
        self.data = data.copy()
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.results = {}
    
    def estimate_ATE_naive(self):
        """Naive comparison (potentially biased)"""
        treatment_mean = self.data[self.data[self.treatment_col] == 1][self.outcome_col].mean()
        control_mean = self.data[self.data[self.treatment_col] == 0][self.outcome_col].mean()
        
        naive_ATE = treatment_mean - control_mean
        
        return {
            'method': 'Naive Comparison',
            'ATE': naive_ATE,
            'treatment_mean': treatment_mean,
            'control_mean': control_mean,
            'warning': 'This estimate may be biased due to confounding'
        }
    
    def potential_outcomes_framework(self):
        """Explain potential outcomes framework"""
        
        framework = {
            'notation': {
                'Y_i(1)': 'Potential outcome for individual i under treatment',
                'Y_i(0)': 'Potential outcome for individual i under control',
                'Ο„_i = Y_i(1) - Y_i(0)': 'Individual treatment effect',
                'ATE = E[Y(1) - Y(0)]': 'Average Treatment Effect',
                'ATT = E[Y(1) - Y(0) | T=1]': 'Average Treatment Effect on Treated'
            },
            'fundamental_problem': 'We can only observe one potential outcome for each individual',
            'identification_strategies': [
                'Randomized Controlled Trials (RCT)',
                'Difference-in-Differences (DiD)',
                'Regression Discontinuity (RD)',
                'Instrumental Variables (IV)',
                'Propensity Score Methods'
            ]
        }
        
        return framework
    
    def check_common_support(self, covariates):
        """Check common support assumption"""
        treatment = self.data[self.treatment_col]
        control = 1 - treatment
        
        support_checks = {}
        
        for cov in covariates:
            treatment_vals = self.data.loc[treatment == 1, cov]
            control_vals = self.data.loc[control == 1, cov]
            
            # Check overlap
            min_t, max_t = treatment_vals.min(), treatment_vals.max()
            min_c, max_c = control_vals.min(), control_vals.max()
            
            overlap_min = max(min_t, min_c)
            overlap_max = min(max_t, max_c)
            
            support_checks[cov] = {
                'treatment_range': (min_t, max_t),
                'control_range': (min_c, max_c),
                'overlap_range': (overlap_min, overlap_max),
                'has_support': overlap_min < overlap_max,
                'treatment_in_control_range': ((treatment_vals >= min_c) & (treatment_vals <= max_c)).mean(),
                'control_in_treatment_range': ((control_vals >= min_t) & (control_vals <= max_t)).mean()
            }
        
        return support_checks

# Example usage
# framework = CausalInferenceFramework(data, 'treatment', 'demand')
# naive_ATE = framework.estimate_ATE_naive()
# support = framework.check_common_support(['price', 'population', 'income'])

2. Difference-in-Differences (DiD)

class DifferenceInDifferences:
    """Difference-in-Differences estimation"""
    
    def __init__(self, data, outcome_col, treatment_col, time_col, group_col):
        self.data = data.copy()
        self.outcome_col = outcome_col
        self.treatment_col = treatment_col
        self.time_col = time_col
        self.group_col = group_col
    
    def estimate_did(self):
        """Estimate DiD effect"""
        
        # Create interaction term
        self.data['did_interaction'] = self.data[self.treatment_col] * self.data[self.time_col]
        
        # DiD regression
        import statsmodels.api as sm
        
        X = self.data[[self.treatment_col, self.time_col, 'did_interaction']]
        X = sm.add_constant(X)
        y = self.data[self.outcome_col]
        
        model = sm.OLS(y, X).fit()
        
        # DiD estimate is the coefficient on the interaction term
        did_estimate = model.params['did_interaction']
        did_se = model.bse['did_interaction']
        did_pvalue = model.pvalues['did_interaction']
        did_ci = model.conf_int().loc['did_interaction']
        
        # Calculate group means for visualization
        group_means = self.data.groupby([self.group_col, self.time_col])[self.outcome_col].mean().unstack()
        
        results = {
            'method': 'Difference-in-Differences',
            'ATE': did_estimate,
            'std_error': did_se,
            'p_value': did_pvalue,
            'ci_95': tuple(did_ci),
            'regression_results': model.summary(),
            'group_means': group_means,
            'interpretation': f'The DiD estimate of {did_estimate:.4f} suggests the treatment '
                            f'caused a {"increase" if did_estimate > 0 else "decrease"} '
                            f'in {self.outcome_col} of {abs(did_estimate):.4f} units.'
        }
        
        self.results['did'] = results
        return results
    
    def parallel_trends_test(self, pre_period_col=None):
        """Test parallel trends assumption"""
        
        if pre_period_col is None:
            # Assume pre-period data exists
            pre_data = self.data[self.data[self.time_col] == 0]
        else:
            pre_data = self.data[self.data[pre_period_col] == 1]
        
        # Test if trends are parallel in pre-period
        treatment_trend = pre_data[pre_data[self.treatment_col] == 1].groupby(
            self.time_col
        )[self.outcome_col].mean()
        
        control_trend = pre_data[pre_data[self.treatment_col] == 0].groupby(
            self.time_col
        )[self.outcome_col].mean()
        
        # Calculate difference in trends
        trend_difference = treatment_trend.diff().mean() - control_trend.diff().mean()
        
        # Statistical test
        from scipy.stats import ttest_ind
        t_stat, p_value = ttest_ind(
            treatment_trend.diff().dropna(),
            control_trend.diff().dropna()
        )
        
        results = {
            'treatment_trend': treatment_trend,
            'control_trend': control_trend,
            'trend_difference': trend_difference,
            't_statistic': t_stat,
            'p_value': p_value,
            'parallel_trends': p_value > 0.05,
            'interpretation': 'Parallel trends assumption ' + 
                            ('holds' if p_value > 0.05 else 'may be violated')
        }
        
        return results
    
    def visualize_did(self):
        """Visualize DiD analysis"""
        
        # Calculate group means
        means = self.data.groupby([self.group_col, self.time_col])[self.outcome_col].mean().unstack()
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Plot 1: Trends
        for group in means.index:
            axes[0].plot(means.columns, means.loc[group], marker='o', label=f'Group {group}')
        
        axes[0].axvline(x=0.5, color='gray', linestyle='--', label='Treatment Start')
        axes[0].set_xlabel('Time')
        axes[0].set_ylabel(self.outcome_col)
        axes[0].set_title('Treatment and Control Group Trends')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Plot 2: DiD visualization
        pre_treatment = means[0]
        post_treatment = means[1]
        
        x = [0, 1]
        axes[1].plot(x, [pre_treatment[0], post_treatment[0]], 
                    'b-o', label='Control', linewidth=2)
        axes[1].plot(x, [pre_treatment[1], post_treatment[1]], 
                    'r-o', label='Treatment', linewidth=2)
        
        # Add counterfactual
        counterfactual = pre_treatment[1] + (post_treatment[0] - pre_treatment[0])
        axes[1].plot([0, 1], [pre_treatment[1], counterfactual], 
                    'r--', label='Counterfactual', alpha=0.5)
        
        # Add DiD arrow
        axes[1].annotate('', xy=(1, post_treatment[1]), xytext=(1, counterfactual),
                        arrowprops=dict(arrowstyle='<->', color='green', lw=2))
        axes[1].text(1.05, (post_treatment[1] + counterfactual) / 2, 
                    'DiD Effect', color='green', fontweight='bold')
        
        axes[1].set_xticks([0, 1])
        axes[1].set_xticklabels(['Pre', 'Post'])
        axes[1].set_ylabel(self.outcome_col)
        axes[1].set_title('Difference-in-Differences')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('did_analysis.png', dpi=150, bbox_inches='tight')
        plt.show()

# Example usage
# did = DifferenceInDifferences(data, 'demand', 'treatment', 'period', 'city')
# results = did.estimate_did()
# parallel_test = did.parallel_trends_test()
# did.visualize_did()

3. Propensity Score Methods

class PropensityScoreMethods:
    """Propensity score matching and weighting"""
    
    def __init__(self, data, treatment_col, outcome_col, covariates):
        self.data = data.copy()
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.covariates = covariates
        self.propensity_scores = None
    
    def estimate_propensity_scores(self, method='logistic'):
        """Estimate propensity scores"""
        
        X = self.data[self.covariates]
        treatment = self.data[self.treatment_col]
        
        if method == 'logistic':
            # Logistic regression
            model = LogisticRegression(max_iter=1000, random_state=42)
            model.fit(X, treatment)
            self.propensity_scores = model.predict_proba(X)[:, 1]
        
        elif method == 'random_forest':
            from sklearn.ensemble import RandomForestClassifier
            model = RandomForestClassifier(n_estimators=100, random_state=42)
            model.fit(X, treatment)
            self.propensity_scores = model.predict_proba(X)[:, 1]
        
        self.data['propensity_score'] = self.propensity_scores
        
        return self.propensity_scores
    
    def check_balance(self):
        """Check covariate balance after matching"""
        
        treatment = self.data[self.data[self.treatment_col] == 1]
        control = self.data[self.data[self.treatment_col] == 0]
        
        balance_results = {}
        
        for cov in self.covariates:
            # Standardized mean difference
            pooled_std = np.sqrt((treatment[cov].var() + control[cov].var()) / 2)
            smd = abs(treatment[cov].mean() - control[cov].mean()) / pooled_std
            
            # Variance ratio
            var_ratio = treatment[cov].var() / control[cov].var() if control[cov].var() > 0 else np.inf
            
            balance_results[cov] = {
                'treatment_mean': treatment[cov].mean(),
                'control_mean': control[cov].mean(),
                'smd': smd,
                'variance_ratio': var_ratio,
                'balanced': smd < 0.1 and 0.8 < var_ratio < 1.25
            }
        
        # Overall balance
        all_smd = [r['smd'] for r in balance_results.values()]
        balance_results['overall'] = {
            'mean_smd': np.mean(all_smd),
            'max_smd': max(all_smd),
            'balanced': np.mean(all_smd) < 0.1
        }
        
        return balance_results
    
    def nearest_neighbor_matching(self, n_neighbors=1, caliper=0.05):
        """Nearest neighbor matching on propensity scores"""
        
        treatment_mask = self.data[self.treatment_col] == 1
        control_mask = self.data[self.treatment_col] == 0
        
        treatment_scores = self.data.loc[treatment_mask, 'propensity_score'].values.reshape(-1, 1)
        control_scores = self.data.loc[control_mask, 'propensity_score'].values.reshape(-1, 1)
        
        # Find nearest neighbors
        nn = NearestNeighbors(n_neighbors=n_neighbors, metric='euclidean')
        nn.fit(control_scores)
        
        distances, indices = nn.kneighbors(treatment_scores)
        
        # Apply caliper
        if caliper:
            valid = distances.flatten() <= caliper
            indices = indices[valid]
            distances = distances[valid]
        
        # Calculate matched outcomes
        treatment_outcomes = self.data.loc[treatment_mask, self.outcome_col].values[valid] if caliper else self.data.loc[treatment_mask, self.outcome_col].values
        control_outcomes = self.data.loc[control_mask, self.outcome_col].values[indices.flatten()]
        
        # ATT estimate
        att = np.mean(treatment_outcomes - control_outcomes)
        
        # Bootstrap standard error
        n_bootstrap = 1000
        bootstrap_atts = []
        
        for _ in range(n_bootstrap):
            boot_idx = np.random.choice(len(treatment_outcomes), len(treatment_outcomes), replace=True)
            boot_att = np.mean(treatment_outcomes[boot_idx] - control_outcomes[boot_idx])
            bootstrap_atts.append(boot_att)
        
        se = np.std(bootstrap_atts)
        
        results = {
            'method': 'Nearest Neighbor Matching',
            'ATT': att,
            'std_error': se,
            'ci_95': (att - 1.96 * se, att + 1.96 * se),
            'n_matched': len(treatment_outcomes),
            'caliper': caliper
        }
        
        return results
    
    def inverse_probability_weighting(self):
        """Inverse probability weighting (IPW)"""
        
        ps = self.data['propensity_score']
        treatment = self.data[self.treatment_col]
        outcome = self.data[self.outcome_col]
        
        # Calculate weights
        weights = treatment / ps + (1 - treatment) / (1 - ps)
        
        # Weighted outcomes
        weighted_outcome = outcome * weights
        
        # ATE estimate
        ate = weighted_outcome.mean()
        
        # Variance estimate
        ate_variance = np.sum(weights ** 2 * (outcome - ate) ** 2) / np.sum(weights) ** 2
        se = np.sqrt(ate_variance)
        
        # Stabilized weights (optional)
        stabilized_weights = treatment * (1 / ps.mean()) + (1 - treatment) * (1 / (1 - ps.mean()))
        
        results = {
            'method': 'Inverse Probability Weighting',
            'ATE': ate,
            'std_error': se,
            'ci_95': (ate - 1.96 * se, ate + 1.96 * se),
            'mean_weight': weights.mean(),
            'max_weight': weights.max(),
            'effective_sample_size': np.sum(weights) ** 2 / np.sum(weights ** 2)
        }
        
        return results
    
    def doubly_robust_estimator(self):
        """Doubly robust estimator"""
        
        treatment = self.data[self.treatment_col]
        outcome = self.data[self.outcome_col]
        ps = self.data['propensity_score']
        
        # Outcome model
        X = self.data[self.covariates]
        
        # Fit outcome model for treatment group
        treat_model = LinearRegression()
        treat_model.fit(X[treatment == 1], outcome[treatment == 1])
        mu1_pred = treat_model.predict(X)
        
        # Fit outcome model for control group
        control_model = LinearRegression()
        control_model.fit(X[treatment == 0], outcome[treatment == 0])
        mu0_pred = control_model.predict(X)
        
        # Doubly robust estimator
        dr = (
            mu1_pred - mu0_pred +
            treatment * (outcome - mu1_pred) / ps -
            (1 - treatment) * (outcome - mu0_pred) / (1 - ps)
        )
        
        ate = dr.mean()
        se = dr.std() / np.sqrt(len(dr))
        
        results = {
            'method': 'Doubly Robust Estimator',
            'ATE': ate,
            'std_error': se,
            'ci_95': (ate - 1.96 * se, ate + 1.96 * se),
            'interpretation': 'Combines outcome modeling and propensity score weighting'
        }
        
        return results
    
    def visualize_propensity_scores(self):
        """Visualize propensity score distributions"""
        
        treatment = self.data[self.data[self.treatment_col] == 1]['propensity_score']
        control = self.data[self.data[self.treatment_col] == 0]['propensity_score']
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Distribution comparison
        axes[0].hist(control, alpha=0.5, label='Control', bins=30, density=True)
        axes[0].hist(treatment, alpha=0.5, label='Treatment', bins=30, density=True)
        axes[0].set_xlabel('Propensity Score')
        axes[0].set_ylabel('Density')
        axes[0].set_title('Propensity Score Distribution')
        axes[0].legend()
        
        # Common support region
        min_ps = max(treatment.min(), control.min())
        max_ps = min(treatment.max(), control.max())
        
        axes[0].axvspan(min_ps, max_ps, alpha=0.2, color='green', label='Common Support')
        
        # Covariate balance (before and after)
        balance_before = self.check_balance()
        
        # After matching (simulated)
        balance_after = {cov: {'smd': np.random.uniform(0, 0.1)} for cov in self.covariates}
        
        covs = list(balance_before.keys())[:5]  # Top 5 covariates
        smds_before = [balance_before[cov]['smd'] for cov in covs]
        smds_after = [balance_after[cov]['smd'] for cov in covs]
        
        y_pos = np.arange(len(covs))
        axes[1].barh(y_pos - 0.2, smds_before, 0.4, label='Before Matching', alpha=0.7)
        axes[1].barh(y_pos + 0.2, smds_after, 0.4, label='After Matching', alpha=0.7)
        axes[1].set_xlabel('Standardized Mean Difference')
        axes[1].set_title('Covariate Balance')
        axes[1].set_yticks(y_pos)
        axes[1].set_yticklabels(covs)
        axes[1].axvline(x=0.1, color='red', linestyle='--', label='Threshold (0.1)')
        axes[1].legend()
        
        plt.tight_layout()
        plt.savefig('propensity_scores.png', dpi=150, bbox_inches='tight')
        plt.show()

# Example usage
# ps_methods = PropensityScoreMethods(data, 'treatment', 'demand', ['price', 'population'])
# ps_methods.estimate_propensity_scores()
# balance = ps_methods.check_balance()
# matching_results = ps_methods.nearest_neighbor_matching()
# ipw_results = ps_methods.inverse_probability_weighting()

πŸ’‘

Pro Tip: Propensity score methods require the assumption of "unconfoundedness" - that all confounders are observed. If important confounders are missing, the estimates will be biased.

4. Real-World Application: Pricing Policy Evaluation

class PricingPolicyEvaluator:
    """Evaluate pricing policy changes using causal inference"""
    
    def __init__(self, data):
        self.data = data
        self.results = {}
    
    def evaluate_policy_change(self, treatment_cities, control_cities,
                              pre_period, post_period):
        """Complete policy evaluation"""
        
        # Filter data
        eval_data = self.data[
            self.data['city'].isin(treatment_cities + control_cities)
        ].copy()
        
        # Create treatment indicator
        eval_data['treatment'] = eval_data['city'].isin(treatment_cities).astype(int)
        eval_data['post'] = eval_data['period'].isin(post_period).astype(int)
        
        # Method 1: Difference-in-Differences
        did_eval = DifferenceInDifferences(
            eval_data, 'demand', 'treatment', 'post', 'city'
        )
        did_results = did_eval.estimate_did()
        self.results['did'] = did_results
        
        # Method 2: Propensity Score Matching
        covariates = ['population', 'avg_income', 'competitor_count']
        ps_eval = PropensityScoreMethods(eval_data, 'treatment', 'demand', covariates)
        ps_eval.estimate_propensity_scores()
        ps_results = ps_eval.nearest_neighbor_matching()
        self.results['psm'] = ps_results
        
        # Method 3: Synthetic Control (simplified)
        synthetic_results = self.synthetic_control(
            eval_data[eval_data['post'] == 0],
            eval_data[eval_data['post'] == 1],
            treatment_cities
        )
        self.results['synthetic'] = synthetic_results
        
        # Compare methods
        comparison = self.compare_methods()
        
        return self.results, comparison
    
    def synthetic_control(self, pre_data, post_data, treatment_cities):
        """Simplified synthetic control method"""
        
        # Get donor pool (control cities)
        donor_cities = [c for c in post_data['city'].unique() 
                       if c not in treatment_cities]
        
        # Calculate weights to match pre-treatment characteristics
        # (Simplified - would use optimization in practice)
        
        # For now, use equal weights
        synthetic = post_data[post_data['city'].isin(donor_cities)].groupby(
            'period'
        )['demand'].mean()
        
        actual = post_data[post_data['city'].isin(treatment_cities)].groupby(
            'period'
        )['demand'].mean()
        
        # Treatment effect
        effect = actual - synthetic
        
        results = {
            'method': 'Synthetic Control',
            'effect': effect.mean(),
            'interpretation': 'Synthetic control estimates treatment effect by constructing '
                            'a weighted combination of control units'
        }
        
        return results
    
    def compare_methods(self):
        """Compare results from different methods"""
        
        comparison = pd.DataFrame({
            'Method': ['DiD', 'PSM', 'Synthetic Control'],
            'ATE': [
                self.results['did']['ATE'],
                self.results['psm']['ATT'],
                self.results['synthetic']['effect']
            ]
        })
        
        # Add confidence intervals if available
        if 'ci_95' in self.results['did']:
            comparison['CI_Lower'] = [
                self.results['did']['ci_95'][0],
                self.results['psm']['ci_95'][0],
                np.nan
            ]
            comparison['CI_Upper'] = [
                self.results['did']['ci_95'][1],
                self.results['psm']['ci_95'][1],
                np.nan
            ]
        
        return comparison
    
    def sensitivity_analysis(self):
        """Sensitivity analysis for unobserved confounding"""
        
        # Rosenbaum bounds (simplified)
        gamma_range = np.arange(1, 3, 0.1)
        
        sensitivity_results = []
        
        for gamma in gamma_range:
            # Calculate how strong an unobserved confounder would need to be
            # to explain away the treatment effect
            required_strength = gamma * abs(self.results['did']['ATE'])
            
            sensitivity_results.append({
                'gamma': gamma,
                'required_confounding_strength': required_strength,
                'robust': required_strength > 1  # Arbitrary threshold
            })
        
        return pd.DataFrame(sensitivity_results)

# Example usage
# evaluator = PricingPolicyEvaluator(data)
# results, comparison = evaluator.evaluate_policy_change(
#     treatment_cities=['City_A', 'City_B'],
#     control_cities=['City_C', 'City_D', 'City_E'],
#     pre_period=['2023-Q1', '2023-Q2'],
#     post_period=['2023-Q3', '2023-Q4']
# )

5. Common Follow-Up Questions

Follow-up 1: How do you test the parallel trends assumption?

def test_parallel_trends(data, outcome_col, treatment_col, time_col, group_col):
    """Test parallel trends assumption for DiD"""
    
    # Event study design
    pre_periods = data[data[time_col] < 0][time_col].unique()
    
    results = []
    
    for t in pre_periods:
        # Estimate coefficient for each pre-period
        period_data = data[data[time_col] <= t]
        
        # Interaction terms
        period_data['interaction'] = period_data[treatment_col] * (period_data[time_col] == t)
        
        # Regression
        import statsmodels.api as sm
        X = period_data[[treatment_col, 'interaction']]
        X = sm.add_constant(X)
        y = period_data[outcome_col]
        
        model = sm.OLS(y, X).fit()
        
        results.append({
            'period': t,
            'coefficient': model.params['interaction'],
            'std_error': model.bse['interaction'],
            'ci_lower': model.conf_int().loc['interaction'][0],
            'ci_upper': model.conf_int().loc['interaction'][1]
        })
    
    results_df = pd.DataFrame(results)
    
    # Plot event study
    plt.figure(figsize=(10, 6))
    plt.errorbar(results_df['period'], results_df['coefficient'],
                yerr=1.96 * results_df['std_error'], fmt='o-', capsize=5)
    plt.axhline(y=0, color='gray', linestyle='--')
    plt.axvline(x=-0.5, color='red', linestyle='--', label='Treatment Start')
    plt.xlabel('Period Relative to Treatment')
    plt.ylabel('Coefficient')
    plt.title('Event Study: Testing Parallel Trends')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('parallel_trends_test.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return results_df

Follow-up 2: How do you handle staggered adoption in DiD?

def staggered_did(data, outcome_col, treatment_time_col, unit_col, time_col):
    """Handle staggered adoption in DiD"""
    
    # Callaway and Sant'Anna (2021) estimator (simplified)
    
    # Get unique treatment times
    treatment_times = data[treatment_time_col].dropna().unique()
    
    results = []
    
    for g in treatment_times:
        # Group treated at time g
        group_data = data[
            (data[treatment_time_col] == g) | 
            (data[treatment_time_col].isna())
        ].copy()
        
        # Create post indicator
        group_data['post'] = (group_data[time_col] >= g).astype(int)
        
        # Estimate ATT for this group
        # (Simplified - would use proper weighting in practice)
        att = group_data[group_data['post'] == 1][outcome_col].mean() - \
              group_data[group_data['post'] == 0][outcome_col].mean()
        
        results.append({
            'group': g,
            'att': att,
            'n_units': group_data[unit_col].nunique()
        })
    
    # Aggregate ATT
    results_df = pd.DataFrame(results)
    overall_att = np.average(results_df['att'], weights=results_df['n_units'])
    
    return {
        'overall_att': overall_att,
        'group_specific_att': results_df,
        'interpretation': 'Staggered DiD accounts for different treatment timing'
    }

Company-Specific Tips

ℹ️

Uber Tips:

  • Uber heavily tests on causal inference for marketplace decisions
  • Know how to evaluate pricing and promotion experiments
  • Understand synthetic control for geographic experiments
  • Be comfortable with instrumental variables

Netflix Tips:

  • Netflix uses causal inference for content recommendations
  • Know how to estimate long-term effects of interventions
  • Understand survival analysis for churn causation
  • Be familiar with Bayesian causal inference

Quiz Section


Related Topics

Advertisement