#!/usr/bin/env python3
"""
Baby Names Data Fetcher
Fetches Social Security Administration baby names data (1880-2024).

Datasets:
- SSA National Baby Names (1880-2024)
- 98,400+ unique names with historical trends
- State-level data available

Usage:
    python fetch_baby_names.py
    USE_CACHED_DATA=true python fetch_baby_names.py

Note: No API key required. Downloads public CSV files from SSA.
"""

import os
import io
import zipfile
import json
import requests
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration
BASE_DIR = Path(__file__).parent.parent.parent
DATA_DIR = BASE_DIR / 'data' / 'cultural'
CACHE_DIR = BASE_DIR / 'tools' / 'fetchers' / 'cache' / 'baby_names'

# Ensure directories exist
DATA_DIR.mkdir(exist_ok=True, parents=True)
CACHE_DIR.mkdir(exist_ok=True, parents=True)

# Use cached data flag
USE_CACHED = os.getenv('USE_CACHED_DATA', 'false').lower() == 'true'

# SSA URLs
SSA_NATIONAL_URL = "https://www.ssa.gov/oact/babynames/names.zip"


class BabyNamesFetcher:
    """Fetches SSA Baby Names data."""

    def __init__(self):
        self.cache_dir = CACHE_DIR
        self.data_dir = DATA_DIR

    def _get_cache_path(self, data_type: str) -> Path:
        """Generate cache path with date stamp."""
        date_str = datetime.now().strftime('%Y%m%d')
        return self.cache_dir / f"{data_type}_{date_str}.zip"

    def _is_cache_valid(self, cache_path: Path, max_age_hours: int = 168) -> bool:
        """Check if cache exists and is recent enough (default: 7 days)."""
        if not cache_path.exists():
            return False
        file_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime)
        return file_age < timedelta(hours=max_age_hours)

    def _fetch_with_cache(self, data_type: str, fetch_func, max_age_hours: int = 168):
        """Fetch data with caching logic."""
        cache_path = self._get_cache_path(data_type)

        if USE_CACHED and self._is_cache_valid(cache_path, max_age_hours):
            print(f"   Using cached data from {cache_path}")
            return cache_path
        else:
            print(f"   Fetching fresh data...")
            path = fetch_func()
            print(f"   Cached to {cache_path}")
            return path

    def download_national_names(self):
        """
        Download SSA national baby names dataset (1880-2024).

        Returns ZIP file containing yearly CSVs (yobXXXX.txt).
        """
        print("\n👶 Downloading SSA National Baby Names...")

        def fetch():
            cache_path = self.cache_dir / f"ssa_national_names_{datetime.now().strftime('%Y%m%d')}.zip"

            try:
                response = requests.get(SSA_NATIONAL_URL, timeout=60)
                response.raise_for_status()

                with open(cache_path, 'wb') as f:
                    f.write(response.content)

                print(f"   ✓ Downloaded {len(response.content) / 1024 / 1024:.2f} MB")
                return cache_path

            except Exception as e:
                print(f"   ✗ Error downloading national names: {e}")
                return None

        return self._fetch_with_cache('ssa_national_names', fetch)

    def process_national_names(self, zip_path):
        """
        Process national names ZIP file into consolidated dataset.

        Args:
            zip_path: Path to SSA names.zip file
        """
        print("\n📊 Processing National Names...")

        if not zip_path or not Path(zip_path).exists():
            print("   ✗ ZIP file not found")
            return None

        all_names = []

        try:
            with zipfile.ZipFile(zip_path, 'r') as z:
                # Process each yearly file (yobXXXX.txt)
                year_files = [f for f in z.namelist() if f.startswith('yob') and f.endswith('.txt')]

                print(f"   Processing {len(year_files)} year files...")

                for filename in sorted(year_files):
                    # Extract year from filename (yob1880.txt -> 1880)
                    year = int(filename.replace('yob', '').replace('.txt', ''))

                    # Read CSV (Name,Sex,Count)
                    with z.open(filename) as f:
                        df = pd.read_csv(f, names=['name', 'sex', 'count'])
                        df['year'] = year
                        all_names.append(df)

                # Concatenate all years
                df_all = pd.concat(all_names, ignore_index=True)

                # Calculate total by name and sex across all years
                name_totals = df_all.groupby(['name', 'sex'])['count'].sum().reset_index()
                name_totals.columns = ['name', 'sex', 'total_count']

                # Merge totals back
                df_all = df_all.merge(name_totals, on=['name', 'sex'])

                # Sort by year and count
                df_all = df_all.sort_values(['year', 'count'], ascending=[True, False])

                print(f"   ✓ Processed {len(df_all):,} name-year records")
                print(f"   ✓ Total unique names: {df_all['name'].nunique():,}")

                # Save full dataset
                output_path = self.data_dir / 'baby_names_century.csv'
                df_all.to_csv(output_path, index=False)
                print(f"   ✓ Saved to {output_path}")

                return df_all

        except Exception as e:
            print(f"   ✗ Error processing national names: {e}")
            return None

    def analyze_name_trends(self, df):
        """
        Analyze name popularity trends over time.

        Args:
            df: DataFrame with baby names data
        """
        print("\n📈 Analyzing Name Trends...")

        if df is None or df.empty:
            print("   ✗ No data to analyze")
            return None

        trends = []

        # Top 20 names of all time
        top_names = (
            df.groupby(['name', 'sex'])['count']
            .sum()
            .reset_index()
            .sort_values('count', ascending=False)
            .head(20)
        )

        for _, row in top_names.iterrows():
            name = row['name']
            sex = row['sex']

            # Get yearly data for this name
            name_data = df[(df['name'] == name) & (df['sex'] == sex)].sort_values('year')

            timeline = name_data[['year', 'count']].to_dict('records')

            # Calculate peak year
            peak_row = name_data.loc[name_data['count'].idxmax()]
            peak_year = int(peak_row['year'])
            peak_count = int(peak_row['count'])

            # Calculate decade-over-decade change
            decades = {}
            for year in range(1880, 2030, 10):
                decade_data = name_data[(name_data['year'] >= year) & (name_data['year'] < year + 10)]
                if not decade_data.empty:
                    decades[f"{year}s"] = int(decade_data['count'].sum())

            trends.append({
                'name': name,
                'sex': sex,
                'total_count': int(row['count']),
                'peak_year': peak_year,
                'peak_count': peak_count,
                'decades': decades,
                'timeline': timeline
            })

        output_path = self.data_dir / 'name_popularity_timeline.json'
        with open(output_path, 'w') as f:
            json.dump(trends, f, indent=2)

        print(f"   ✓ Saved top 20 name trends to {output_path}")
        return trends

    def analyze_modern_trends(self, df, start_year=2010):
        """
        Analyze modern naming trends (2010-present).

        Args:
            df: DataFrame with baby names data
            start_year: Start year for modern analysis (default: 2010)
        """
        print(f"\n🔥 Analyzing Modern Trends ({start_year}+)...")

        if df is None or df.empty:
            print("   ✗ No data to analyze")
            return None

        modern = df[df['year'] >= start_year].copy()

        # Rising stars (names with biggest percentage increase)
        name_first_year = modern.groupby(['name', 'sex'])['year'].min().reset_index()
        name_first_year.columns = ['name', 'sex', 'first_year']

        # Get first and last year counts
        name_trends = modern.groupby(['name', 'sex', 'year'])['count'].sum().reset_index()
        name_trends = name_trends.merge(name_first_year, on=['name', 'sex'])

        rising = []
        for (name, sex), group in name_trends.groupby(['name', 'sex']):
            if len(group) >= 3:  # Need at least 3 years of data
                first_count = group[group['year'] == group['first_year'].iloc[0]]['count'].values[0]
                last_count = group[group['year'] == group['year'].max()]['count'].values[0]

                if first_count > 0:
                    pct_change = ((last_count - first_count) / first_count) * 100

                    if pct_change > 100:  # At least doubled
                        rising.append({
                            'name': name,
                            'sex': sex,
                            'first_year': int(group['first_year'].iloc[0]),
                            'first_count': int(first_count),
                            'last_year': int(group['year'].max()),
                            'last_count': int(last_count),
                            'percent_change': round(pct_change, 1)
                        })

        rising_sorted = sorted(rising, key=lambda x: x['percent_change'], reverse=True)[:50]

        output_path = self.data_dir / 'rising_names_modern.json'
        with open(output_path, 'w') as f:
            json.dump(rising_sorted, f, indent=2)

        print(f"   ✓ Saved top 50 rising names to {output_path}")
        return rising_sorted

    def save_metadata(self, total_records, unique_names, year_range):
        """Save dataset metadata."""
        metadata = {
            'dataset_name': 'Social Security Baby Names',
            'last_updated': datetime.now().strftime('%Y-%m-%d'),
            'source': 'Social Security Administration',
            'source_url': 'https://www.ssa.gov/oact/babynames/limits.html',
            'record_count': {
                'total_name_year_records': total_records,
                'unique_names': unique_names
            },
            'year_range': year_range,
            'fields': {
                'name': 'First name',
                'sex': 'Sex (M/F)',
                'year': 'Birth year',
                'count': 'Number of babies with this name in this year',
                'total_count': 'Total count across all years',
                'peak_year': 'Year with most occurrences',
                'peak_count': 'Count in peak year'
            },
            'notes': (
                'SSA baby names data from birth certificates. '
                'Names are included if they appear at least 5 times in a given year. '
                'Data covers 1880-2024 (updated annually in May). '
                'Includes 98,400+ unique names across 145 years.'
            ),
            'methodology': (
                'SSA collects baby names from Social Security card applications. '
                'Names with fewer than 5 occurrences in a year are excluded for privacy. '
                'Data is from U.S. citizens and permanent residents only.'
            )
        }

        meta_path = self.data_dir / 'baby_names_metadata.json'
        with open(meta_path, 'w') as f:
            json.dump(metadata, f, indent=2)

        print(f"\n✓ Metadata saved to {meta_path}")


def main():
    print("=" * 60)
    print("BABY NAMES DATA FETCHER")
    print("=" * 60)

    fetcher = BabyNamesFetcher()

    # Download national names dataset
    zip_path = fetcher.download_national_names()

    if zip_path:
        # Process into consolidated dataset
        df = fetcher.process_national_names(zip_path)

        if df is not None and not df.empty:
            # Analyze historical trends
            trends = fetcher.analyze_name_trends(df)

            # Analyze modern trends
            rising = fetcher.analyze_modern_trends(df, start_year=2010)

            # Save metadata
            fetcher.save_metadata(
                total_records=len(df),
                unique_names=df['name'].nunique(),
                year_range='1880-2024'
            )

    print("\n" + "=" * 60)
    print("✓ Baby names data fetching complete")
    print("=" * 60)


if __name__ == "__main__":
    main()
