#!/usr/bin/env python3
"""
Healthcare Deserts Data Collection Script

Collects data on rural hospital closures, physician shortages, and healthcare access
across U.S. counties. Supports before/after analysis of proposed Medicaid cuts.

Data Sources:
- County Health Rankings: Physician density, preventable deaths, uninsured rates
- Census Bureau API: Population, demographics, poverty
- CMS Medicare Hospital Compare: Hospital locations
- HRSA Health Professional Shortage Areas (manual download required)
- Chartis Center rural hospital closures (web scraping or manual)

Usage:
    python fetch_healthcare_data.py

Environment Variables:
    CENSUS_API_KEY: Census Bureau API key
    USE_CACHED_DATA: true/false (default: false)

Output:
    data/healthcare_desert_merged.csv - County-level healthcare access metrics
    data/healthcare_desert_metadata.json - Collection timestamps and sources
"""

import os
import sys
import json
import time
from datetime import datetime
from pathlib import Path

import pandas as pd
import requests
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration
CENSUS_API_KEY = os.getenv('CENSUS_API_KEY', '')
USE_CACHED_DATA = os.getenv('USE_CACHED_DATA', 'false').lower() == 'true'
CENSUS_YEAR = 2022  # Most recent ACS 5-year estimates

# Directories
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / 'data'
CACHE_DIR = BASE_DIR / 'cache'

# Create directories
DATA_DIR.mkdir(exist_ok=True)
CACHE_DIR.mkdir(exist_ok=True)

# Metadata tracking
metadata = {
    'collection_date': datetime.now().isoformat(),
    'sources': {},
    'record_counts': {}
}


def fetch_county_health_rankings():
    """
    Fetch physician density and health outcomes from County Health Rankings.

    Note: CHR doesn't have a public API. This uses Census ACS data as a proxy
    for some metrics. Full CHR data requires manual download from:
    https://www.countyhealthrankings.org/explore-health-rankings/rankings-data-documentation
    """
    print("\n📊 Fetching County Health Rankings proxy data...")

    cache_file = CACHE_DIR / 'county_health_rankings.csv'
    if USE_CACHED_DATA and cache_file.exists():
        print(f"   Using cached data from {cache_file}")
        df = pd.read_csv(cache_file)
        metadata['sources']['county_health_rankings'] = 'cached'
        return df

    # Using Census ACS as proxy for health insurance coverage
    url = f"https://api.census.gov/data/{CENSUS_YEAR}/acs/acs5"
    params = {
        'get': 'NAME,B27001_001E,B27001_005E',  # Total pop, uninsured
        'for': 'county:*',
        'key': CENSUS_API_KEY
    }

    try:
        response = requests.get(url, params=params, timeout=30)
        response.raise_for_status()
        data = response.json()

        # Convert to DataFrame
        df = pd.DataFrame(data[1:], columns=data[0])

        # Create FIPS code
        df['fips'] = df['state'] + df['county']

        # Rename columns
        df = df.rename(columns={
            'NAME': 'county_name',
            'B27001_001E': 'total_pop',
            'B27001_005E': 'uninsured_pop'
        })

        # Calculate uninsured rate
        df['total_pop'] = pd.to_numeric(df['total_pop'], errors='coerce')
        df['uninsured_pop'] = pd.to_numeric(df['uninsured_pop'], errors='coerce')
        df['uninsured_rate'] = (df['uninsured_pop'] / df['total_pop'] * 100).round(2)

        # Select final columns
        df = df[['fips', 'county_name', 'total_pop', 'uninsured_pop', 'uninsured_rate']]

        # Save cache
        df.to_csv(cache_file, index=False)

        metadata['sources']['county_health_rankings'] = f"Census ACS {CENSUS_YEAR} (proxy for uninsured rate)"
        metadata['record_counts']['county_health_rankings'] = len(df)

        print(f"   ✓ Collected data for {len(df)} counties")
        return df

    except requests.exceptions.RequestException as e:
        print(f"   ✗ Error fetching Census data: {e}")
        return pd.DataFrame()


def fetch_poverty_data():
    """Fetch poverty rates from Census SAIPE"""
    print("\n💰 Fetching poverty data...")

    cache_file = CACHE_DIR / 'poverty_data.csv'
    if USE_CACHED_DATA and cache_file.exists():
        print(f"   Using cached data from {cache_file}")
        df = pd.read_csv(cache_file)
        metadata['sources']['poverty'] = 'cached'
        return df

    url = f"https://api.census.gov/data/{CENSUS_YEAR}/acs/acs5"
    params = {
        'get': 'NAME,B17001_001E,B17001_002E',  # Total, poverty
        'for': 'county:*',
        'key': CENSUS_API_KEY
    }

    try:
        response = requests.get(url, params=params, timeout=30)
        response.raise_for_status()
        data = response.json()

        df = pd.DataFrame(data[1:], columns=data[0])
        df['fips'] = df['state'] + df['county']

        df = df.rename(columns={
            'NAME': 'county_name',
            'B17001_001E': 'total_pop_poverty',
            'B17001_002E': 'poverty_pop'
        })

        df['total_pop_poverty'] = pd.to_numeric(df['total_pop_poverty'], errors='coerce')
        df['poverty_pop'] = pd.to_numeric(df['poverty_pop'], errors='coerce')
        df['poverty_rate'] = (df['poverty_pop'] / df['total_pop_poverty'] * 100).round(2)

        df = df[['fips', 'county_name', 'poverty_rate']]
        df.to_csv(cache_file, index=False)

        metadata['sources']['poverty'] = f"Census ACS {CENSUS_YEAR}"
        metadata['record_counts']['poverty'] = len(df)

        print(f"   ✓ Collected poverty data for {len(df)} counties")
        return df

    except requests.exceptions.RequestException as e:
        print(f"   ✗ Error fetching poverty data: {e}")
        return pd.DataFrame()


def fetch_medicaid_enrollment():
    """
    Estimate Medicaid enrollment by state and county.

    Note: Medicaid.gov doesn't provide county-level enrollment via API.
    This uses poverty as a proxy for Medicaid eligibility.
    """
    print("\n🏥 Estimating Medicaid enrollment...")

    cache_file = CACHE_DIR / 'medicaid_enrollment.csv'
    if USE_CACHED_DATA and cache_file.exists():
        print(f"   Using cached data from {cache_file}")
        df = pd.read_csv(cache_file)
        metadata['sources']['medicaid'] = 'cached'
        return df

    # Medicaid expansion states (as of 2024)
    expansion_states = {
        '01': False,  # Alabama
        '02': True,   # Alaska
        '04': True,   # Arizona
        '05': False,  # Arkansas (expanded then rolled back)
        '06': True,   # California
        '08': True,   # Colorado
        '09': True,   # Connecticut
        '10': True,   # Delaware
        '11': True,   # DC
        '12': False,  # Florida
        '13': False,  # Georgia
        '15': True,   # Hawaii
        '16': True,   # Idaho
        '17': True,   # Illinois
        '18': True,   # Indiana
        '19': True,   # Iowa
        '20': False,  # Kansas
        '21': True,   # Kentucky
        '22': True,   # Louisiana
        '23': True,   # Maine
        '24': True,   # Maryland
        '25': True,   # Massachusetts
        '26': True,   # Michigan
        '27': True,   # Minnesota
        '28': False,  # Mississippi
        '29': True,   # Missouri
        '30': True,   # Montana
        '31': True,   # Nebraska
        '32': True,   # Nevada
        '33': True,   # New Hampshire
        '34': True,   # New Jersey
        '35': True,   # New Mexico
        '36': True,   # New York
        '37': True,   # North Carolina
        '38': True,   # North Dakota
        '39': True,   # Ohio
        '40': True,   # Oklahoma
        '41': True,   # Oregon
        '42': True,   # Pennsylvania
        '44': True,   # Rhode Island
        '45': False,  # South Carolina
        '46': True,   # South Dakota
        '47': False,  # Tennessee
        '48': False,  # Texas
        '49': True,   # Utah
        '50': True,   # Vermont
        '51': True,   # Virginia
        '53': True,   # Washington
        '54': True,   # West Virginia
        '55': True,   # Wisconsin
        '56': True,   # Wyoming
    }

    # Get county populations below 138% FPL (Medicaid expansion threshold)
    url = f"https://api.census.gov/data/{CENSUS_YEAR}/acs/acs5"
    params = {
        'get': 'NAME,B17002_001E,B17002_002E,B17002_003E',  # Total, <0.50 ratio, 0.50-0.99 ratio
        'for': 'county:*',
        'key': CENSUS_API_KEY
    }

    try:
        response = requests.get(url, params=params, timeout=30)
        response.raise_for_status()
        data = response.json()

        df = pd.DataFrame(data[1:], columns=data[0])
        df['fips'] = df['state'] + df['county']

        df = df.rename(columns={
            'NAME': 'county_name',
            'B17002_001E': 'total_pop',
            'B17002_002E': 'below_50pct_poverty',
            'B17002_003E': 'below_99pct_poverty'
        })

        # Convert to numeric
        for col in ['total_pop', 'below_50pct_poverty', 'below_99pct_poverty']:
            df[col] = pd.to_numeric(df[col], errors='coerce')

        # Estimate Medicaid eligible (roughly 138% FPL)
        # Using proxy: poverty + near-poverty populations
        df['medicaid_eligible_est'] = (df['below_50pct_poverty'] + df['below_99pct_poverty'] * 1.38).round(0).astype(int)

        # Apply expansion status
        df['expansion_state'] = df['state'].map(expansion_states)

        # Estimate enrollment (expansion: 80% eligible enroll, non-expansion: 45%)
        df['medicaid_enrolled_est'] = df.apply(
            lambda row: int(row['medicaid_eligible_est'] * (0.80 if row['expansion_state'] else 0.45)),
            axis=1
        )

        df = df[['fips', 'county_name', 'state', 'expansion_state', 'medicaid_eligible_est', 'medicaid_enrolled_est']]
        df.to_csv(cache_file, index=False)

        metadata['sources']['medicaid'] = f"Census ACS {CENSUS_YEAR} (proxy estimates)"
        metadata['record_counts']['medicaid'] = len(df)

        print(f"   ✓ Estimated Medicaid enrollment for {len(df)} counties")
        return df

    except requests.exceptions.RequestException as e:
        print(f"   ✗ Error fetching Medicaid proxy data: {e}")
        return pd.DataFrame()


def fetch_rural_urban_classification():
    """
    Fetch rural/urban classification codes.

    Uses Census urban/rural classification to identify rural counties
    at higher risk of hospital closure.
    """
    print("\n🌾 Fetching rural/urban classification...")

    cache_file = CACHE_DIR / 'rural_urban.csv'
    if USE_CACHED_DATA and cache_file.exists():
        print(f"   Using cached data from {cache_file}")
        df = pd.read_csv(cache_file)
        metadata['sources']['rural_urban'] = 'cached'
        return df

    # Using population density as proxy for rural/urban
    url = f"https://api.census.gov/data/{CENSUS_YEAR}/acs/acs5"
    params = {
        'get': 'NAME,B01003_001E',  # Total population
        'for': 'county:*',
        'key': CENSUS_API_KEY
    }

    try:
        response = requests.get(url, params=params, timeout=30)
        response.raise_for_status()
        data = response.json()

        df = pd.DataFrame(data[1:], columns=data[0])
        df['fips'] = df['state'] + df['county']

        df = df.rename(columns={
            'NAME': 'county_name',
            'B01003_001E': 'total_pop'
        })

        df['total_pop'] = pd.to_numeric(df['total_pop'], errors='coerce')

        # Simplified rural classification: <50k = rural
        df['rural'] = df['total_pop'] < 50000
        df['rural_category'] = df['total_pop'].apply(
            lambda x: 'Rural' if x < 50000 else 'Urban/Suburban'
        )

        df = df[['fips', 'county_name', 'rural', 'rural_category']]
        df.to_csv(cache_file, index=False)

        metadata['sources']['rural_urban'] = "Population-based classification"
        metadata['record_counts']['rural_urban'] = len(df)

        print(f"   ✓ Classified {len(df)} counties")
        return df

    except requests.exceptions.RequestException as e:
        print(f"   ✗ Error fetching rural/urban data: {e}")
        return pd.DataFrame()


def fetch_cms_hospitals():
    """
    Fetch hospital locations from CMS Hospital Compare API.
    dataset_id: xubh-q36u (Hospital General Information)
    """
    print("\n🏥 Fetching CMS Hospital locations...")
    
    cache_file = CACHE_DIR / 'cms_hospitals.csv'
    if USE_CACHED_DATA and cache_file.exists():
        print(f"   Using cached data from {cache_file}")
        return pd.read_csv(cache_file)
        
    url = "https://data.cms.gov/provider-data/api/1/datastore/query/xubh-q36u/0"
    # Note: This is a placeholder standard Socrata-style endpoint. 
    # Real CMS API might need paginated requests.
    
    # Fallback to direct download URL for CSV if API complex
    csv_url = "https://data.cms.gov/provider-data/sites/default/files/resources/1ee6a6e80907bf13663fc7595b228ee1_1730855170/Hospital%20General%20Information.csv"
    
    try:
        print(f"   Downloading from CMS: {csv_url}")
        df = pd.read_csv(csv_url, encoding='ISO-8859-1')
        
        # Standardize columns
        if 'County Name' in df.columns:
            df.rename(columns={'County Name': 'county_name'}, inplace=True)
            
        save_file = CACHE_DIR / 'cms_hospitals.csv'
        df.to_csv(save_file, index=False)
        print(f"   ✓ Downloaded {len(df)} hospitals")
        return df
    except Exception as e:
        print(f"   ✗ Error fetching CMS data: {e}")
        return pd.DataFrame()

def merge_all_data():
    """Merge all datasets on FIPS code"""
    print("\n🔗 Merging all datasets...")

    # Fetch all datasets
    health_df = fetch_county_health_rankings()
    poverty_df = fetch_poverty_data()
    medicaid_df = fetch_medicaid_enrollment()
    rural_df = fetch_rural_urban_classification()
    hospitals_df = fetch_cms_hospitals()

    if health_df.empty:
        print("   ✗ Cannot merge: health data missing")
        return None

    # Start with health data
    merged = health_df.copy()

    # Merge poverty
    if not poverty_df.empty:
        merged = merged.merge(poverty_df[['fips', 'poverty_rate']], on='fips', how='left')

    # Merge Medicaid
    if not medicaid_df.empty:
        merged = merged.merge(
            medicaid_df[['fips', 'expansion_state', 'medicaid_eligible_est', 'medicaid_enrolled_est']],
            on='fips',
            how='left'
        )

    # Merge rural classification
    if not rural_df.empty:
        merged = merged.merge(rural_df[['fips', 'rural', 'rural_category']], on='fips', how='left')

    # Calculate risk scores
    # Higher risk = rural + high poverty + high uninsured + non-expansion state
    merged['hospital_closure_risk_score'] = 0

    if 'rural' in merged.columns:
        merged['hospital_closure_risk_score'] += merged['rural'].astype(int) * 25

    if 'poverty_rate' in merged.columns:
        merged['hospital_closure_risk_score'] += (merged['poverty_rate'] > 20).astype(int) * 25

    if 'uninsured_rate' in merged.columns:
        merged['hospital_closure_risk_score'] += (merged['uninsured_rate'] > 15).astype(int) * 25

    if 'expansion_state' in merged.columns:
        merged['hospital_closure_risk_score'] += (~merged['expansion_state']).astype(int) * 25

    # Risk categories
    merged['risk_category'] = pd.cut(
        merged['hospital_closure_risk_score'],
        bins=[0, 25, 50, 75, 100],
        labels=['Low', 'Moderate', 'High', 'Critical'],
        include_lowest=True
    )

    print(f"   ✓ Merged {len(merged)} counties")
    return merged


def main():
    """Main execution"""
    print("=" * 70)
    print("Healthcare Deserts Data Collection")
    print("=" * 70)

    if not CENSUS_API_KEY:
        print("\n⚠️  WARNING: No Census API key found")
        print("   Set CENSUS_API_KEY in .env file")
        print("   Some data collection will fail without an API key")
        return

    print(f"\n📅 Census Year: {CENSUS_YEAR}")
    print(f"🗄️  Use Cached Data: {USE_CACHED_DATA}")

    # Merge all data
    merged_df = merge_all_data()

    if merged_df is None or merged_df.empty:
        print("\n✗ FAILED: No data collected")
        return

    # Save merged data
    output_file = DATA_DIR / 'healthcare_desert_merged.csv'
    merged_df.to_csv(output_file, index=False)

    # Save metadata
    metadata_file = DATA_DIR / 'healthcare_desert_metadata.json'
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)

    # Summary statistics
    print("\n" + "=" * 70)
    print("📊 SUMMARY STATISTICS")
    print("=" * 70)
    print(f"Total counties: {len(merged_df)}")

    if 'rural_category' in merged_df.columns:
        print(f"\nRural/Urban breakdown:")
        print(merged_df['rural_category'].value_counts())

    if 'expansion_state' in merged_df.columns:
        expansion_count = merged_df['expansion_state'].sum()
        print(f"\nMedicaid expansion states: {expansion_count} counties")
        print(f"Non-expansion states: {len(merged_df) - expansion_count} counties")

    if 'risk_category' in merged_df.columns:
        print(f"\nHospital closure risk distribution:")
        print(merged_df['risk_category'].value_counts().sort_index())

    print(f"\n📁 Output files:")
    print(f"   {output_file}")
    print(f"   {metadata_file}")

    print("\n" + "=" * 70)
    print("✓ SUCCESS: Data collection complete!")
    print("=" * 70)

    print("\n📋 NEXT STEPS:")
    print("1. Manual data collection still needed:")
    print("   - HRSA HPSA shapefiles: https://data.hrsa.gov/")
    print("   - Chartis rural hospital closures: https://www.chartis.com/resources/rural-hospital-closures/")
    print("   - CMS Hospital Compare data: https://data.cms.gov/provider-data/")
    print("2. Create HTML story page with D3 visualizations")
    print("3. Build before/after split-screen comparing current vs. projected with cuts")


if __name__ == '__main__':
    main()
