#!/usr/bin/env python3
"""
Fetch retail store locations using OpenStreetMap Overpass API
Focus on LA County as case study, expandable to other counties

Stores tracked:
- Whole Foods Market
- Starbucks
- Dollar General / Dollar Tree / Family Dollar
- McDonald's
- Walmart
- Costco
"""

import requests
import json
import time
from pathlib import Path
from typing import List, Dict, Any

# LA County bounding box (approximate)
# Format: [south, west, north, east]
LA_COUNTY_BBOX = [33.7, -118.7, 34.8, -117.6]

# Split large bboxes into smaller chunks to avoid Overpass API timeouts
# LA County is split into 4 quadrants (2x2 grid)
def split_bbox_into_chunks(bbox: List[float], grid_size: int = 2) -> List[List[float]]:
    """
    Split a bounding box into smaller chunks to avoid API timeouts

    Args:
        bbox: [south, west, north, east]
        grid_size: Number of divisions per axis (2 = 2x2 grid = 4 chunks)

    Returns:
        List of smaller bounding boxes
    """
    south, west, north, east = bbox
    lat_step = (north - south) / grid_size
    lon_step = (east - west) / grid_size

    chunks = []
    for i in range(grid_size):
        for j in range(grid_size):
            chunk_south = south + (i * lat_step)
            chunk_north = south + ((i + 1) * lat_step)
            chunk_west = west + (j * lon_step)
            chunk_east = west + ((j + 1) * lon_step)
            chunks.append([chunk_south, chunk_west, chunk_north, chunk_east])

    return chunks

# Store brand identifiers in OpenStreetMap
# Each store can have multiple query strategies: wikidata (can be list), name patterns, shop tag
STORE_QUERIES = {
    'whole_foods': {
        'name': 'Whole Foods Market',
        'wikidata': ['Q1809448'],  # Whole Foods Market
        'name_patterns': ['Whole Foods', 'Whole Foods Market', 'Wholefoods'],
        'shop_tag': 'supermarket'
    },
    'starbucks': {
        'name': 'Starbucks',
        'wikidata': ['Q37158'],  # Starbucks
        'name_patterns': ['Starbucks'],
        'shop_tag': 'coffee'
    },
    'dollar_general': {
        'name': 'Dollar General',
        'wikidata': ['Q145168'],  # Dollar General
        'name_patterns': ['Dollar General'],
        'shop_tag': 'variety_store'
    },
    'dollar_tree': {
        'name': 'Dollar Tree',
        'wikidata': ['Q5289230', 'Q5287284'],  # Dollar Tree (multiple possible IDs)
        'name_patterns': ['Dollar Tree'],
        'shop_tag': 'variety_store'
    },
    'family_dollar': {
        'name': 'Family Dollar',
        'wikidata': ['Q5433101'],  # Family Dollar
        'name_patterns': ['Family Dollar'],
        'shop_tag': 'variety_store'
    },
    'mcdonalds': {
        'name': "McDonald's",
        'wikidata': ['Q38076'],  # McDonald's
        'name_patterns': ["McDonald's", 'McDonalds'],
        'shop_tag': 'fast_food'
    },
    'walmart': {
        'name': 'Walmart',
        'wikidata': ['Q483551'],  # Walmart
        'name_patterns': ['Walmart', 'Wal-Mart', 'Walmart Supercenter'],
        'shop_tag': 'supermarket'
    },
    'costco': {
        'name': 'Costco',
        'wikidata': ['Q715583'],  # Costco
        'name_patterns': ['Costco', 'Costco Wholesale'],
        'shop_tag': 'wholesale'
    }
}

OVERPASS_URL = "https://overpass-api.de/api/interpreter"
OUTPUT_DIR = Path("data/retail_locations")

def build_overpass_query(bbox: List[float], store_type: str, config: Dict[str, Any]) -> str:
    """
    Build comprehensive Overpass QL query for specific store type within bounding box
    Uses multiple query strategies: wikidata (can be list), name patterns, and shop tags

    Args:
        bbox: [south, west, north, east] coordinates
        store_type: Store identifier (e.g., 'whole_foods')
        config: Store configuration with wikidata, name_patterns, shop_tag

    Returns:
        Overpass QL query string
    """
    south, west, north, east = bbox
    wikidata_ids = config.get('wikidata', [])
    if isinstance(wikidata_ids, str):
        wikidata_ids = [wikidata_ids]
    name_patterns = config.get('name_patterns', [])
    shop_tag = config.get('shop_tag', '')

    # Build a comprehensive query using union of multiple strategies
    query_parts = []

    # Strategy 1: Wikidata brand tag (most reliable) - try all provided IDs
    for wikidata_id in wikidata_ids:
        if wikidata_id:
            query_parts.append(f'node["brand:wikidata"="{wikidata_id}"]({south},{west},{north},{east});')
            query_parts.append(f'way["brand:wikidata"="{wikidata_id}"]({south},{west},{north},{east});')

    # Strategy 2: Name pattern matching with shop tag filter
    if name_patterns and shop_tag:
        for pattern in name_patterns:
            # Escape special regex characters in pattern
            escaped_pattern = pattern.replace("'", "\\'")
            query_parts.append(f'node["shop"="{shop_tag}"]["name"~"{escaped_pattern}",i]({south},{west},{north},{east});')
            query_parts.append(f'way["shop"="{shop_tag}"]["name"~"{escaped_pattern}",i]({south},{west},{north},{east});')

    # Strategy 3: Name pattern matching without shop tag (broader search)
    elif name_patterns:
        for pattern in name_patterns:
            escaped_pattern = pattern.replace("'", "\\'")
            query_parts.append(f'node["name"~"{escaped_pattern}",i]({south},{west},{north},{east});')
            query_parts.append(f'way["name"~"{escaped_pattern}",i]({south},{west},{north},{east});')

    # Combine all strategies into union query
    query = f"""
    [out:json][timeout:60];
    (
      {chr(10).join('  ' + part for part in query_parts)}
    );
    out center;
    """

    return query

def fetch_store_locations(bbox: List[float], store_type: str, config: Dict[str, Any], max_retries: int = 2) -> List[Dict[str, Any]]:
    """
    Fetch store locations from Overpass API with retry logic

    Args:
        bbox: Bounding box coordinates
        store_type: Store identifier
        config: Store configuration with name, wikidata, name_patterns, shop_tag
        max_retries: Number of retries on timeout/rate limit

    Returns:
        List of store locations with lat, lon, name
    """
    query = build_overpass_query(bbox, store_type, config)

    for attempt in range(max_retries + 1):
        try:
            response = requests.post(OVERPASS_URL, data={'data': query}, timeout=120)
            response.raise_for_status()
            data = response.json()

            # Debug: Log element count with query summary
            element_count = len(data.get('elements', []))

            locations = []
            for element in data.get('elements', []):
                if element['type'] == 'node':
                    lat = element['lat']
                    lon = element['lon']
                elif element['type'] == 'way' and 'center' in element:
                    lat = element['center']['lat']
                    lon = element['center']['lon']
                else:
                    continue

                locations.append({
                    'store_type': store_type,
                    'store_name': config['name'],
                    'lat': lat,
                    'lon': lon,
                    'osm_id': element.get('id'),
                    'name': element.get('tags', {}).get('name', config['name']),
                    'address': element.get('tags', {}).get('addr:street', ''),
                    'city': element.get('tags', {}).get('addr:city', ''),
                    'postcode': element.get('tags', {}).get('addr:postcode', '')
                })

            print(f"✓ Found {len(locations)} {config['name']} locations")
            return locations

        except requests.exceptions.HTTPError as e:
            if e.response.status_code in [429, 504] and attempt < max_retries:
                # Rate limit or timeout - wait and retry
                wait_time = 5 * (attempt + 1)
                print(f"⏸ Rate limit/timeout, retrying in {wait_time}s...", end=" ")
                time.sleep(wait_time)
                continue
            else:
                print(f"✗ Error fetching {config['name']}: {e}")
                return []
        except requests.exceptions.RequestException as e:
            print(f"✗ Error fetching {config['name']}: {e}")
            return []

    return []

def fetch_all_stores_in_region(bbox: List[float], region_name: str, use_chunks: bool = True) -> Dict[str, List[Dict[str, Any]]]:
    """
    Fetch all store types in a region with rate limiting and chunking for large areas

    Args:
        bbox: Bounding box coordinates
        region_name: Name of region (e.g., "LA County")
        use_chunks: If True, split large bbox into smaller chunks to avoid timeouts

    Returns:
        Dictionary mapping store types to location lists
    """
    print(f"\n🔍 Fetching retail locations in {region_name}...")
    print(f"   Bounding box: {bbox}")

    # Split into chunks if needed (3x3 grid = 9 chunks for better reliability)
    if use_chunks:
        chunks = split_bbox_into_chunks(bbox, grid_size=3)
        print(f"   Split into {len(chunks)} chunks to avoid API timeouts")
    else:
        chunks = [bbox]

    all_locations = {}

    for store_type, config in STORE_QUERIES.items():
        print(f"\n   Querying {config['name']}...", end=" ")

        # Fetch from all chunks and combine
        combined_locations = []
        osm_ids_seen = set()  # Deduplicate across chunks

        for chunk_idx, chunk in enumerate(chunks):
            if len(chunks) > 1:
                print(f"\n     Chunk {chunk_idx + 1}/{len(chunks)}...", end=" ")

            chunk_locations = fetch_store_locations(chunk, store_type, config)

            # Deduplicate by OSM ID
            for loc in chunk_locations:
                osm_id = loc.get('osm_id')
                if osm_id not in osm_ids_seen:
                    combined_locations.append(loc)
                    osm_ids_seen.add(osm_id)

            # Rate limiting between chunks
            if chunk_idx < len(chunks) - 1:
                time.sleep(1)

        all_locations[store_type] = combined_locations
        print(f" Total: {len(combined_locations)}")

        # Rate limiting between store types
        time.sleep(2)

    return all_locations

def save_locations_to_json(locations: Dict[str, List[Dict[str, Any]]], region_name: str) -> Path:
    """
    Save location data to JSON file

    Args:
        locations: Dictionary of store locations
        region_name: Name of region

    Returns:
        Path to saved file
    """
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Flatten structure for easier use
    all_stores = []
    for store_type, store_list in locations.items():
        all_stores.extend(store_list)

    # Summary statistics
    summary = {
        'region': region_name,
        'total_stores': len(all_stores),
        'by_type': {k: len(v) for k, v in locations.items()},
        'stores': all_stores
    }

    filename = OUTPUT_DIR / f"{region_name.lower().replace(' ', '_')}_retail_locations.json"
    with open(filename, 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"\n✓ Saved {len(all_stores)} total locations to {filename}")
    return filename

def generate_summary_stats(locations: Dict[str, List[Dict[str, Any]]]) -> None:
    """Print summary statistics"""
    print("\n" + "="*60)
    print("RETAIL LOCATION SUMMARY")
    print("="*60)

    for store_type, store_list in sorted(locations.items(), key=lambda x: len(x[1]), reverse=True):
        count = len(store_list)
        name = STORE_QUERIES[store_type]['name']
        print(f"  {name:25} {count:4} locations")

    total = sum(len(v) for v in locations.values())
    print(f"\n  {'TOTAL':25} {total:4} locations")
    print("="*60)

def main():
    """Main execution"""
    print("Retail Location Data Fetcher")
    print("Using OpenStreetMap Overpass API\n")

    # Fetch LA County
    la_locations = fetch_all_stores_in_region(LA_COUNTY_BBOX, "LA County")
    generate_summary_stats(la_locations)
    save_locations_to_json(la_locations, "LA County")

    # Could expand to other counties:
    # COOK_COUNTY_BBOX = [41.5, -88.3, 42.2, -87.5]  # Chicago
    # cook_locations = fetch_all_stores_in_region(COOK_COUNTY_BBOX, "Cook County")
    # ...

if __name__ == "__main__":
    main()
