#!/usr/bin/env python3
"""
Air Quality Data Fetcher
Fetches EPA Air Quality System (AQS) historical data.

Datasets:
- EPA AQS API for historical AQI trends
- Air quality measurements from thousands of monitoring stations

Usage:
    python fetch_air_quality_data.py
    USE_CACHED_DATA=true python fetch_air_quality_data.py

Note: Requires EPA AQS API key. Get at: https://aqs.epa.gov/data/api/signup
      Store in .env file as AQS_API_KEY=your_key_here
"""

import os
import json
import requests
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration
BASE_DIR = Path(__file__).parent.parent.parent
DATA_DIR = BASE_DIR / 'data' / 'environmental'
CACHE_DIR = BASE_DIR / 'tools' / 'fetchers' / 'cache' / 'air_quality'

# Ensure directories exist
DATA_DIR.mkdir(exist_ok=True, parents=True)
CACHE_DIR.mkdir(exist_ok=True, parents=True)

# API configuration
AQS_API_KEY = os.getenv('AQS_API_KEY', '')
AQS_API_EMAIL = os.getenv('AQS_API_EMAIL', 'test@aqs.api')
USE_CACHED = os.getenv('USE_CACHED_DATA', 'false').lower() == 'true'

# EPA AQS API base URL
AQS_BASE_URL = "https://aqs.epa.gov/data/api"


class AirQualityDataFetcher:
    """Fetches EPA Air Quality System (AQS) data."""

    def __init__(self):
        self.cache_dir = CACHE_DIR
        self.data_dir = DATA_DIR
        self.api_key = AQS_API_KEY
        self.email = AQS_API_EMAIL

    def _get_cache_path(self, data_type: str) -> Path:
        """Generate cache path with date stamp."""
        date_str = datetime.now().strftime('%Y%m%d')
        return self.cache_dir / f"{data_type}_{date_str}.json"

    def _is_cache_valid(self, cache_path: Path, max_age_hours: int = 168) -> bool:
        """Check if cache exists and is recent enough (default: 7 days)."""
        if not cache_path.exists():
            return False
        file_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime)
        return file_age < timedelta(hours=max_age_hours)

    def _fetch_with_cache(self, data_type: str, fetch_func, max_age_hours: int = 168):
        """Fetch data with caching logic."""
        cache_path = self._get_cache_path(data_type)

        if USE_CACHED and self._is_cache_valid(cache_path, max_age_hours):
            print(f"   Using cached data from {cache_path}")
            with open(cache_path) as f:
                return json.load(f)
        else:
            print(f"   Fetching fresh data...")
            data = fetch_func()
            with open(cache_path, 'w') as f:
                json.dump(data, f, indent=2)
            print(f"   Cached to {cache_path}")
            return data

    def _make_aqs_request(self, endpoint: str, params: dict):
        """Make a request to EPA AQS API."""
        if not self.api_key:
            raise ValueError(
                "EPA AQS API key not found. "
                "Get key at https://aqs.epa.gov/data/api/signup "
                "and set AQS_API_KEY in .env file"
            )

        url = f"{AQS_BASE_URL}/{endpoint}"
        params['email'] = self.email
        params['key'] = self.api_key

        try:
            response = requests.get(url, params=params, timeout=60)
            response.raise_for_status()
            data = response.json()

            if data.get('Header', [{}])[0].get('status') == 'Success':
                return data.get('Data', [])
            else:
                error_msg = data.get('Header', [{}])[0].get('error', 'Unknown error')
                print(f"   API Error: {error_msg}")
                return []

        except Exception as e:
            print(f"   ✗ Error making AQS API request: {e}")
            return []

    def fetch_annual_aqi_summary(self, start_year=2020, end_year=2024):
        """
        Fetch annual AQI summary by CBSA (Core-Based Statistical Area).

        This provides summary statistics for major metropolitan areas.
        """
        print(f"\n🌡️ Fetching Annual AQI Summary ({start_year}-{end_year})...")

        def fetch():
            all_data = []

            for year in range(start_year, end_year + 1):
                print(f"   Fetching {year}...", end='', flush=True)

                params = {
                    'bdate': f"{year}0101",
                    'edate': f"{year}1231",
                }

                # Fetch by CBSA (metro areas)
                data = self._make_aqs_request('annualData/byCBSA', params)

                if data:
                    all_data.extend(data)
                    print(f" ✓ {len(data)} records")
                else:
                    print(" ✗ No data")

            return all_data

        return self._fetch_with_cache('annual_aqi_summary', fetch)

    def fetch_state_aqi_summary(self, states=None, start_year=2020, end_year=2024):
        """
        Fetch annual AQI summary by state.

        Args:
            states: List of state FIPS codes (default: major states)
            start_year: Start year
            end_year: End year
        """
        # Default to major states if not specified
        if states is None:
            states = [
                '06',  # California
                '36',  # New York
                '48',  # Texas
                '12',  # Florida
                '17',  # Illinois
                '42',  # Pennsylvania
                '53',  # Washington
                '04',  # Arizona
            ]

        print(f"\n🗺️ Fetching State AQI Summary ({len(states)} states, {start_year}-{end_year})...")

        def fetch():
            all_data = []

            for state_code in states:
                for year in range(start_year, end_year + 1):
                    print(f"   State {state_code}, {year}...", end='', flush=True)

                    params = {
                        'bdate': f"{year}0101",
                        'edate': f"{year}1231",
                        'state': state_code
                    }

                    data = self._make_aqs_request('annualData/byState', params)

                    if data:
                        all_data.extend(data)
                        print(f" ✓ {len(data)} records")
                    else:
                        print(" ✗ No data")

            return all_data

        return self._fetch_with_cache('state_aqi_summary', fetch)

    def process_aqi_trends(self, annual_data):
        """Process annual AQI data into trends."""
        print("\n📊 Processing AQI Trends...")

        if not annual_data:
            print("   ✗ No data to process")
            return None

        df = pd.DataFrame(annual_data)

        # Extract key columns
        trends = df[[
            'cbsa', 'cbsa_name', 'state', 'year',
            'good_days', 'moderate_days', 'usg_days',  # USG = Unhealthy for Sensitive Groups
            'unhealthy_days', 'very_unhealthy_days', 'hazardous_days',
            'max_aqi', 'median_aqi'
        ]].copy() if 'cbsa' in df.columns else df

        # Calculate good air percentage
        if 'good_days' in trends.columns:
            total_days = trends.get('days_with_aqi', 365)
            trends['good_air_percentage'] = (trends['good_days'] / total_days * 100).round(2)

        output_path = self.data_dir / 'aqi_trends_by_metro.csv'
        trends.to_csv(output_path, index=False)

        print(f"   ✓ Saved AQI trends to {output_path}")
        return trends

    def generate_aqi_timeline(self, state_data):
        """Generate AQI timeline from state data."""
        print("\n📈 Generating AQI Timeline...")

        if not state_data:
            print("   ✗ No data to process")
            return None

        df = pd.DataFrame(state_data)

        # Group by state and year
        timeline = df.groupby(['state', 'year']).agg({
            'good_days': 'sum',
            'moderate_days': 'sum',
            'unhealthy_days': 'sum',
            'max_aqi': 'max',
            'median_aqi': 'mean'
        }).reset_index()

        output_path = self.data_dir / 'aqi_timeline_by_state.csv'
        timeline.to_csv(output_path, index=False)

        # Also save as JSON for easier web use
        timeline_json = timeline.to_dict('records')
        json_path = self.data_dir / 'aqi_timeline_by_state.json'
        with open(json_path, 'w') as f:
            json.dump(timeline_json, f, indent=2)

        print(f"   ✓ Saved AQI timeline to {output_path}")
        print(f"   ✓ Saved AQI timeline to {json_path}")
        return timeline

    def save_metadata(self, record_count, years_covered):
        """Save dataset metadata."""
        metadata = {
            'dataset_name': 'EPA Air Quality Historical Data',
            'last_updated': datetime.now().strftime('%Y-%m-%d'),
            'source': 'EPA Air Quality System (AQS) API',
            'api_url': 'https://aqs.epa.gov/data/api',
            'record_count': record_count,
            'years_covered': years_covered,
            'fields': {
                'cbsa': 'Core-Based Statistical Area code',
                'cbsa_name': 'Metropolitan area name',
                'state': 'State FIPS code',
                'year': 'Calendar year',
                'good_days': 'Days with AQI 0-50 (Good)',
                'moderate_days': 'Days with AQI 51-100 (Moderate)',
                'usg_days': 'Days with AQI 101-150 (Unhealthy for Sensitive Groups)',
                'unhealthy_days': 'Days with AQI 151-200 (Unhealthy)',
                'very_unhealthy_days': 'Days with AQI 201-300 (Very Unhealthy)',
                'hazardous_days': 'Days with AQI 301+ (Hazardous)',
                'max_aqi': 'Maximum AQI recorded',
                'median_aqi': 'Median AQI value'
            },
            'notes': (
                'Air Quality Index (AQI) data from EPA monitoring stations. '
                'AQI is calculated for five major pollutants: ground-level ozone, '
                'particulate matter, carbon monoxide, sulfur dioxide, and nitrogen dioxide. '
                'Data aggregated by metropolitan area (CBSA) and state.'
            ),
            'api_key_required': True,
            'api_signup_url': 'https://aqs.epa.gov/data/api/signup'
        }

        meta_path = self.data_dir / 'air_quality_metadata.json'
        with open(meta_path, 'w') as f:
            json.dump(metadata, f, indent=2)

        print(f"\n✓ Metadata saved to {meta_path}")


def main():
    print("=" * 60)
    print("EPA AIR QUALITY DATA FETCHER")
    print("=" * 60)

    fetcher = AirQualityDataFetcher()

    # Check for API key
    if not fetcher.api_key:
        print("\n⚠️  No EPA AQS API key found!")
        print("    Get a free API key at: https://aqs.epa.gov/data/api/signup")
        print("    Then add to .env file: AQS_API_KEY=your_key_here")
        print("\n    Skipping data fetch...")
        return

    # Fetch annual AQI summary (last 5 years)
    current_year = datetime.now().year
    start_year = current_year - 5
    end_year = current_year - 1  # Complete years only

    annual_data = fetcher.fetch_annual_aqi_summary(start_year, end_year)

    if annual_data:
        # Process trends
        trends = fetcher.process_aqi_trends(annual_data)

    # Fetch state-level data
    state_data = fetcher.fetch_state_aqi_summary(start_year=start_year, end_year=end_year)

    if state_data:
        # Generate timeline
        timeline = fetcher.generate_aqi_timeline(state_data)

        # Save metadata
        fetcher.save_metadata(
            record_count=len(annual_data) + len(state_data),
            years_covered=f"{start_year}-{end_year}"
        )

    print("\n" + "=" * 60)
    print("✓ Air quality data fetching complete")
    print("=" * 60)


if __name__ == "__main__":
    main()
