πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Survival Analysis: Kaplan-Meier, Cox Regression, Hazards

Data Science Interview PremiumSurvival Analysis⭐ Premium

Advertisement

GOOGLE & AMAZON INTERVIEW QUESTION

Survival Analysis: Kaplan-Meier, Cox Regression, Hazards

Time-to-Event Analysis & Duration Modeling

The Interview Question

ℹ️

Question: You're analyzing customer churn for a subscription service:

  • Dataset: 100K customers with signup date, churn date, and features
  • Requirements: Predict time to churn, identify churn factors, estimate survival curves
  • Challenge: Censored data (some customers haven't churned yet)

Walk through your survival analysis approach:

  1. How do you handle censored data properly?
  2. How do you estimate survival curves using Kaplan-Meier?
  3. How do you build a Cox proportional hazards model?
  4. How do you validate and interpret your model?

Detailed Answer

1. Survival Analysis Fundamentals

Survival analysis handles time-to-event data with censoring, which occurs when we don't observe the event for all subjects.

import pandas as pd
import numpy as np
from lifelines import (
    KaplanMeierFitter, 
    CoxPHFitter, 
    WeibullFitter,
    ExponentialFitter,
    LogNormalFitter
)
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines.plotting import plot_lifetimes
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

class SurvivalAnalysisFramework:
    """Framework for survival analysis"""
    
    def __init__(self, data, duration_col, event_col):
        self.data = data.copy()
        self.duration_col = duration_col
        self.event_col = event_col
        self.results = {}
    
    def prepare_survival_data(self, start_date_col, end_date_col, event_indicator_col):
        """Prepare data for survival analysis"""
        
        # Calculate duration
        self.data['duration'] = (
            pd.to_datetime(self.data[end_date_col]) - 
            pd.to_datetime(self.data[start_date_col])
        ).dt.days
        
        # Ensure non-negative duration
        self.data['duration'] = self.data['duration'].clip(lower=0)
        
        # Event indicator (1 = event occurred, 0 = censored)
        self.data['event'] = self.data[event_indicator_col].astype(int)
        
        # Summary
        n_events = self.data['event'].sum()
        n_censored = len(self.data) - n_events
        
        print(f"Survival Data Summary:")
        print(f"  Total subjects: {len(self.data)}")
        print(f"  Events (churned): {n_events} ({n_events/len(self.data)*100:.1f}%)")
        print(f"  Censored (active): {n_censored} ({n_censored/len(self.data)*100:.1f}%)")
        print(f"  Median duration: {self.data['duration'].median():.0f} days")
        
        return self.data
    
    def describe_censoring(self):
        """Describe censoring patterns"""
        
        censoring_info = {
            'right_censored': (self.data[self.event_col] == 0).sum(),
            'events': (self.data[self.event_col] == 1).sum(),
            'censoring_rate': (self.data[self.event_col] == 0).mean()
        }
        
        # Check for informative censoring
        # (If censoring is related to outcomes)
        
        print(f"\nCensoring Information:")
        print(f"  Right-censored: {censoring_info['right_censored']}")
        print(f"  Events observed: {censoring_info['events']}")
        print(f"  Censoring rate: {censoring_info['censoring_rate']:.2%}")
        
        return censoring_info

# Example usage
# framework = SurvivalAnalysisFramework(data, 'duration', 'churned')
# framework.prepare_survival_data('signup_date', 'churn_date', 'has_churned')

2. Kaplan-Meier Estimation

class KaplanMeierAnalyzer:
    """Kaplan-Meier survival estimation"""
    
    def __init__(self, data, duration_col, event_col):
        self.data = data
        self.duration_col = duration_col
        self.event_col = event_col
        self.fitter = KaplanMeierFitter()
        self.fitted = False
    
    def fit(self, group_col=None):
        """Fit Kaplan-Meier estimator"""
        
        if group_col is None:
            # Overall survival curve
            self.fitter.fit(
                self.data[self.duration_col],
                self.data[self.event_col],
                label='Overall'
            )
            self.fitted = True
            
            results = {
                'median_survival': self.fitter.median_survival_time_,
                'survival_function': self.fitter.survival_function_,
                'confidence_intervals': self.fitter.confidence_interval_survival_function_
            }
        
        else:
            # Group-specific survival curves
            groups = self.data[group_col].unique()
            self.group_fitters = {}
            
            for group in groups:
                mask = self.data[group_col] == group
                fitter = KaplanMeierFitter()
                fitter.fit(
                    self.data.loc[mask, self.duration_col],
                    self.data.loc[mask, self.event_col],
                    label=str(group)
                )
                self.group_fitters[group] = fitter
            
            self.fitted = True
            results = {group: {
                'median_survival': fitter.median_survival_time_,
                'survival_function': fitter.survival_function_
            } for group, fitter in self.group_fitters.items()}
        
        self.results = results
        return results
    
    def compare_groups(self, group_col):
        """Compare survival curves between groups"""
        
        groups = self.data[group_col].unique()
        
        if len(groups) == 2:
            # Log-rank test for two groups
            group1 = self.data[self.data[group_col] == groups[0]]
            group2 = self.data[self.data[group_col] == groups[1]]
            
            result = logrank_test(
                group1[self.duration_col], group2[self.duration_col],
                event_observed_A=group1[self.event_col],
                event_observed_B=group2[self.event_col]
            )
            
            comparison = {
                'test': 'Log-rank',
                'statistic': result.test_statistic,
                'p_value': result.p_value,
                'significant': result.p_value < 0.05
            }
        
        else:
            # Multivariate log-rank test
            result = multivariate_logrank_test(
                self.data[group_col],
                self.data[self.duration_col],
                self.data[self.event_col]
            )
            
            comparison = {
                'test': 'Multivariate log-rank',
                'statistic': result.test_statistic,
                'p_value': result.p_value,
                'significant': result.p_value < 0.05
            }
        
        return comparison
    
    def calculate_survival_probabilities(self, time_points):
        """Calculate survival probabilities at specific time points"""
        
        survival_probs = {}
        
        for t in time_points:
            if self.fitted:
                surv_prob = self.fitter.predict(t)
                survival_probs[t] = {
                    'survival_probability': surv_prob,
                    'risk': 1 - surv_prob
                }
        
        return survival_probs
    
    def median_survival_time(self):
        """Calculate median survival time"""
        return self.fitter.median_survival_time_
    
    def visualize(self, group_col=None, figsize=(10, 6)):
        """Visualize Kaplan-Meier curves"""
        
        fig, ax = plt.subplots(figsize=figsize)
        
        if group_col is None:
            self.fitter.plot_survival_function(ax=ax, ci_show=True)
        else:
            for group, fitter in self.group_fitters.items():
                fitter.plot_survival_function(ax=ax, ci_show=True)
        
        ax.set_xlabel('Time (days)')
        ax.set_ylabel('Survival Probability')
        ax.set_title('Kaplan-Meier Survival Curves')
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        plt.tight_layout()
        plt.savefig('kaplan_meier_curves.png', dpi=150, bbox_inches='tight')
        plt.show()
    
    def survival_table(self):
        """Generate survival table"""
        return self.fitter.survival_function_at_times(
            np.arange(0, self.data[self.duration_col].max(), 30)
        )

# Example usage
# km_analyzer = KaplanMeierAnalyzer(data, 'duration', 'event')
# km_analyzer.fit()
# km_analyzer.fit(group_col='plan_type')
# comparison = km_analyzer.compare_groups('plan_type')
# survival_probs = km_analyzer.calculate_survival_probabilities([30, 90, 180, 365])
# km_analyzer.visualize(group_col='plan_type')

3. Cox Proportional Hazards Model

class CoxPHAnalyzer:
    """Cox Proportional Hazards regression"""
    
    def __init__(self, data, duration_col, event_col):
        self.data = data
        self.duration_col = duration_col
        self.event_col = event_col
        self.model = CoxPHFitter()
        self.fitted = False
    
    def fit(self, covariates, penalizer=0.01):
        """Fit Cox PH model"""
        
        # Prepare data
        fit_data = self.data[[self.duration_col, self.event_col] + covariates].dropna()
        
        # Fit model
        self.model.fit(
            fit_data,
            duration_col=self.duration_col,
            event_col=self.event_col,
            penalizer=penalizer  # L2 regularization
        )
        
        self.fitted = True
        self.covariates = covariates
        
        # Summary
        results = {
            'coefficients': self.model.params_,
            'hazard_ratios': self.model.hazard_ratios_,
            'p_values': self.model.summary['p'],
            'confidence_intervals': self.model.confidence_intervals_,
            'concordance_index': self.model.concordance_index_,
            'log_likelihood': self.model.log_likelihood_,
            'AIC': self.model.AIC_partial_
        }
        
        return results
    
    def summarize(self):
        """Detailed model summary"""
        
        if not self.fitted:
            print("Model not fitted yet")
            return
        
        summary = self.model.summary
        
        print("Cox Proportional Hazards Model Summary")
        print("=" * 70)
        print(f"Concordance index: {self.model.concordance_index_:.4f}")
        print(f"Log-likelihood ratio test p-value: {self.model.log_likelihood_ratio_test().p_value:.6f}")
        print(f"AIC: {self.model.AIC_partial_:.2f}")
        print("\nCoefficients:")
        print(summary)
        
        return summary
    
    def hazard_ratios(self):
        """Calculate and interpret hazard ratios"""
        
        if not self.fitted:
            print("Model not fitted yet")
            return
        
        hr = self.model.hazard_ratios_
        
        interpretation = {}
        for covariate, ratio in hr.items():
            if ratio > 1:
                interpretation[covariate] = {
                    'hazard_ratio': ratio,
                    'interpretation': f'One unit increase in {covariate} increases '
                                    f'hazard by {(ratio - 1) * 100:.1f}%',
                    'direction': 'increases risk'
                }
            elif ratio < 1:
                interpretation[covariate] = {
                    'hazard_ratio': ratio,
                    'interpretation': f'One unit increase in {covariate} decreases '
                                    f'hazard by {(1 - ratio) * 100:.1f}%',
                    'direction': 'decreases risk'
                }
            else:
                interpretation[covariate] = {
                    'hazard_ratio': ratio,
                    'interpretation': f'{covariate} has no effect on hazard',
                    'direction': 'no effect'
                }
        
        return interpretation
    
    def check_proportional_hazards(self):
        """Test proportional hazards assumption"""
        
        if not self.fitted:
            print("Model not fitted yet")
            return
        
        # Schoenfeld residuals test
        results = self.model.check_assumptions(
            self.data[[self.duration_col, self.event_col] + self.covariates],
            p_value_threshold=0.05,
            show_plots=False
        )
        
        return results
    
    def predict_survival_function(self, new_data, times):
        """Predict survival function for new observations"""
        
        if not self.fitted:
            print("Model not fitted yet")
            return
        
        survival_probs = self.model.predict_survival_function(new_data, times=times)
        
        return survival_probs
    
    def predict_partial_hazard(self, new_data):
        """Predict partial hazard (risk score)"""
        
        if not self.fitted:
            print("Model not fitted yet")
            return
        
        risk_scores = self.model.predict_partial_hazard(new_data)
        
        return risk_scores
    
    def visualize_coefficients(self, figsize=(10, 6)):
        """Visualize model coefficients"""
        
        if not self.fitted:
            print("Model not fitted yet")
            return
        
        fig, axes = plt.subplots(1, 2, figsize=figsize)
        
        # Coefficients with confidence intervals
        summary = self.model.summary
        coefficients = summary['coef']
        conf_int = summary[['coef lower 95%', 'coef upper 95%']]
        
        y_pos = range(len(coefficients))
        
        axes[0].barh(y_pos, coefficients.values, xerr=[
            coefficients.values - conf_int['coef lower 95%'].values,
            conf_int['coef upper 95%'].values - coefficients.values
        ], capsize=5)
        axes[0].axvline(x=0, color='gray', linestyle='--')
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(coefficients.index)
        axes[0].set_xlabel('Coefficient')
        axes[0].set_title('Cox Model Coefficients')
        axes[0].grid(True, alpha=0.3)
        
        # Hazard ratios
        hr = self.model.hazard_ratios_
        axes[1].barh(y_pos, hr.values)
        axes[1].axvline(x=1, color='gray', linestyle='--')
        axes[1].set_yticks(y_pos)
        axes[1].set_yticklabels(hr.index)
        axes[1].set_xlabel('Hazard Ratio')
        axes[1].set_title('Hazard Ratios (exp(coef))')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('cox_model_coefficients.png', dpi=150, bbox_inches='tight')
        plt.show()

# Example usage
# cox_analyzer = CoxPHAnalyzer(data, 'duration', 'event')
# covariates = ['age', 'plan_type', 'usage_minutes', 'support_tickets']
# results = cox_analyzer.fit(covariates)
# cox_analyzer.summarize()
# hr_interpretation = cox_analyzer.hazard_ratios()
# cox_analyzer.check_proportional_hazards()
# cox_analyzer.visualize_coefficients()

4. Parametric Survival Models

class ParametricSurvivalModels:
    """Parametric survival models"""
    
    def __init__(self, data, duration_col, event_col):
        self.data = data
        self.duration_col = duration_col
        self.event_col = event_col
        self.models = {}
    
    def fit_exponential(self):
        """Fit exponential model"""
        model = ExponentialFitter()
        model.fit(self.data[self.duration_col], self.data[self.event_col])
        self.models['exponential'] = model
        return model
    
    def fit_weibull(self):
        """Fit Weibull model"""
        model = WeibullFitter()
        model.fit(self.data[self.duration_col], self.data[self.event_col])
        self.models['weibull'] = model
        return model
    
    def fit_log_normal(self):
        """Fit log-normal model"""
        model = LogNormalFitter()
        model.fit(self.data[self.duration_col], self.data[self.event_col])
        self.models['log_normal'] = model
        return model
    
    def compare_models(self):
        """Compare parametric models using AIC/BIC"""
        
        comparison = []
        
        for name, model in self.models.items():
            comparison.append({
                'model': name,
                'AIC': model.AIC_,
                'BIC': model.BIC_,
                'log_likelihood': model.log_likelihood_
            })
        
        comparison_df = pd.DataFrame(comparison).sort_values('AIC')
        
        print("Model Comparison (lower AIC is better):")
        print("=" * 50)
        print(comparison_df)
        
        return comparison_df
    
    def predict_survival(self, model_name, times):
        """Predict survival function using specified model"""
        
        if model_name not in self.models:
            print(f"Model {model_name} not fitted")
            return None
        
        model = self.models[model_name]
        survival = model.predict_survival_function(times)
        
        return survival
    
    def predict_hazard(self, model_name, times):
        """Predict hazard function using specified model"""
        
        if model_name not in self.models:
            print(f"Model {model_name} not fitted")
            return None
        
        model = self.models[model_name]
        hazard = model.predict_hazard(times)
        
        return hazard
    
    def visualize_model_comparison(self, figsize=(10, 6)):
        """Visualize comparison of parametric models"""
        
        fig, axes = plt.subplots(1, 2, figsize=figsize)
        
        # Survival functions
        times = np.arange(0, self.data[self.duration_col].max(), 1)
        
        for name, model in self.models.items():
            survival = model.predict_survival_function(times)
            axes[0].plot(times, survival, label=name)
        
        axes[0].set_xlabel('Time (days)')
        axes[0].set_ylabel('Survival Probability')
        axes[0].set_title('Survival Functions Comparison')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Hazard functions
        for name, model in self.models.items():
            hazard = model.predict_hazard(times)
            axes[1].plot(times, hazard, label=name)
        
        axes[1].set_xlabel('Time (days)')
        axes[1].set_ylabel('Hazard Rate')
        axes[1].set_title('Hazard Functions Comparison')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('parametric_models_comparison.png', dpi=150, bbox_inches='tight')
        plt.show()

# Example usage
# param_models = ParametricSurvivalModels(data, 'duration', 'event')
# param_models.fit_exponential()
# param_models.fit_weibull()
# param_models.fit_log_normal()
# comparison = param_models.compare_models()
# param_models.visualize_model_comparison()

πŸ’‘

Pro Tip: Cox PH is semi-parametric and doesn't require specifying the baseline hazard distribution. Use parametric models when you need to extrapolate survival curves beyond observed data.

5. Real-World Application: Churn Prediction

class ChurnPredictor:
    """Complete churn prediction using survival analysis"""
    
    def __init__(self, data):
        self.data = data
        self.km_analyzer = None
        self.cox_analyzer = None
    
    def analyze_churn(self, customer_features, duration_col='tenure_days', event_col='churned'):
        """Complete churn analysis pipeline"""
        
        print("Step 1: Kaplan-Meier Analysis")
        self.km_analyzer = KaplanMeierAnalyzer(self.data, duration_col, event_col)
        
        # Overall survival curve
        self.km_analyzer.fit()
        
        # By customer segment
        if 'plan_type' in self.data.columns:
            self.km_analyzer.fit(group_col='plan_type')
            comparison = self.km_analyzer.compare_groups('plan_type')
            print(f"  Plan type comparison p-value: {comparison['p_value']:.4f}")
        
        print("\nStep 2: Cox Proportional Hazards Model")
        self.cox_analyzer = CoxPHAnalyzer(self.data, duration_col, event_col)
        cox_results = self.cox_analyzer.fit(customer_features)
        
        print(f"  Concordance index: {cox_results['concordance_index']:.4f}")
        
        print("\nStep 3: Hazard Ratio Interpretation")
        hr_interpretation = self.cox_analyzer.hazard_ratios()
        for feature, info in hr_interpretation.items():
            print(f"  {feature}: {info['interpretation']}")
        
        print("\nStep 4: Risk Segmentation")
        risk_scores = self.cox_analyzer.predict_partial_hazard(self.data[customer_features])
        self.data['risk_score'] = risk_scores
        
        # Segment customers by risk
        self.data['risk_segment'] = pd.qcut(
            self.data['risk_score'], 
            q=4, 
            labels=['Low Risk', 'Medium Risk', 'High Risk', 'Very High Risk']
        )
        
        print("\nRisk Segment Distribution:")
        print(self.data['risk_segment'].value_counts())
        
        return self.data
    
    def generate_insights(self):
        """Generate business insights from analysis"""
        
        insights = {
            'key_findings': [],
            'recommendations': []
        }
        
        # Analyze risk segments
        risk_analysis = self.data.groupby('risk_segment').agg({
            'risk_score': 'mean',
            'churned': 'mean',
            'tenure_days': 'mean',
            'monthly_revenue': 'sum'
        })
        
        insights['key_findings'].append(
            f"Very High Risk customers have {risk_analysis.loc['Very High Risk', 'churned']:.1%} churn rate"
        )
        
        # Cox model insights
        hr = self.cox_analyzer.hazard_ratios()
        for feature, info in hr.items():
            if info['direction'] == 'increases risk':
                insights['key_findings'].append(
                    f"{feature} increases churn risk: {info['interpretation']}"
                )
        
        # Recommendations
        insights['recommendations'].extend([
            "Focus retention efforts on Very High Risk segment",
            "Investigate factors driving churn in Cox model",
            "Implement early warning system using risk scores",
            "A/B test retention interventions by risk segment"
        ])
        
        return insights
    
    def predict_individual_churn(self, customer_data, time_horizon=30):
        """Predict churn probability for individual customers"""
        
        # Get survival function
        survival_probs = self.cox_analyzer.predict_survival_function(
            customer_data, 
            times=[time_horizon]
        )
        
        # Calculate churn probability
        churn_probability = 1 - survival_probs.iloc[0].values[0]
        
        # Risk score
        risk_score = self.cox_analyzer.predict_partial_hazard(customer_data)
        
        return {
            'churn_probability_30d': churn_probability,
            'risk_score': risk_score.values[0],
            'risk_segment': self._get_risk_segment(risk_score.values[0])
        }
    
    def _get_risk_segment(self, risk_score):
        """Determine risk segment from score"""
        percentiles = self.data['risk_score'].quantile([0.25, 0.5, 0.75])
        
        if risk_score <= percentiles[0.25]:
            return 'Low Risk'
        elif risk_score <= percentiles[0.5]:
            return 'Medium Risk'
        elif risk_score <= percentiles[0.75]:
            return 'High Risk'
        else:
            return 'Very High Risk'

# Example usage
# churn_predictor = ChurnPredictor(data)
# features = ['age', 'tenure_days', 'monthly_revenue', 'support_tickets', 'plan_type']
# analyzed_data = churn_predictor.analyze_churn(features)
# insights = churn_predictor.generate_insights()
# individual_pred = churn_predictor.predict_individual_churn(new_customer_data)

6. Common Follow-Up Questions

Follow-up 1: How do you handle time-varying covariates?

def time_varying_cox_model(data, id_col, start_col, end_col, event_col, 
                          time_varying_covariates):
    """Cox model with time-varying covariates"""
    
    from lifelines import CoxTimeVaryingFitter
    
    # Prepare data in long format
    # Each row represents a time interval for a subject
    
    ctv = CoxTimeVaryingFitter()
    
    # Fit model
    ctv.fit(
        data,
        id_col=id_col,
        start_col=start_col,
        stop_col=end_col,
        event_col=event_col
    )
    
    # Summary
    print(ctv.summary)
    
    return ctv

# Example: Model where usage changes over time
# time_varying_data = create_long_format(data, covariates=['usage_minutes', 'support_tickets'])
# tv_model = time_varying_cox_model(time_varying_data, 'customer_id', 'start', 'end', 'churned')

Follow-up 2: How do you validate survival models?

def validate_survival_model(model, data, duration_col, event_col, covariates):
    """Validate survival model using multiple metrics"""
    
    from lifelines.utils import concordance_index
    
    # Split data
    from sklearn.model_selection import train_test_split
    
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
    
    # Fit on training data
    model.fit(train_data, duration_col=duration_col, event_col=event_col)
    
    # Predict on test data
    risk_scores = model.predict_partial_hazard(test_data[covariates])
    
    # Calculate concordance index
    c_index = concordance_index(
        test_data[duration_col],
        -risk_scores,  # Negate because higher risk = lower survival
        test_data[event_col]
    )
    
    # Calibration (simplified)
    # Check if predicted risks match observed outcomes
    
    results = {
        'concordance_index': c_index,
        'interpretation': f'C-index of {c_index:.3f} indicates ' + 
                         ('good' if c_index > 0.7 else 'fair' if c_index > 0.6 else 'poor') +
                         ' discrimination'
    }
    
    return results

# Example
# validation_results = validate_survival_model(
#     CoxPHFitter(), data, 'duration', 'event', ['age', 'usage', 'tickets']
# )

Company-Specific Tips

ℹ️

Google Tips:

  • Google values survival analysis for user behavior modeling
  • Know how to handle large-scale survival data
  • Understand how to extend survival analysis to multiple events
  • Be comfortable with competing risks models

Amazon Tips:

  • Amazon uses survival analysis for subscription churn
  • Know how to build real-time churn prediction systems
  • Understand how to design retention interventions
  • Be familiar with customer lifetime value estimation

Quiz Section


Related Topics

Advertisement