#!/usr/bin/env python3
"""
Harvey Browser Sync Server
Listens on localhost:7429, receives batched page visits,
writes to /Users/harvey/.openclaw/workspace/memory/browser-context.md
Keeps last 7 days of context, auto-prunes older entries.
stdlib only — no pip installs.
"""

import json
import os
import sys
import re
import datetime
import http.server
import socketserver
from urllib.parse import urlparse

# ─── Config ──────────────────────────────────────────────────────────────────

PORT = 7429
BIND_HOST = '127.0.0.1'  # localhost only
MEMORY_DIR = os.path.expanduser('~/.openclaw/workspace/memory')
OUTPUT_FILE = os.path.join(MEMORY_DIR, 'browser-context.md')
KEEP_DAYS = 7

# ─── Helpers ──────────────────────────────────────────────────────────────────

def log(msg):
    ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(f'[{ts}] {msg}', flush=True)

def format_duration(ms):
    """Convert milliseconds to a human-readable string."""
    secs = ms / 1000
    if secs < 60:
        return f'{int(secs)}s'
    mins = secs / 60
    if mins < 60:
        return f'{mins:.0f}m'
    hours = mins / 60
    return f'{hours:.1f}h'

def extract_display_url(url):
    """Return a clean display version of a URL."""
    try:
        parsed = urlparse(url)
        host = parsed.netloc.replace('www.', '')
        path = parsed.path.rstrip('/')
        if path and path != '/':
            # Truncate long paths
            if len(path) > 40:
                path = path[:40] + '…'
            return f'{host}{path}'
        return host
    except Exception:
        return url[:60]

def today_str():
    return datetime.date.today().isoformat()

def parse_date_from_heading(heading):
    """Extract YYYY-MM-DD from a markdown heading like '## 2026-03-21'."""
    m = re.search(r'(\d{4}-\d{2}-\d{2})', heading)
    if m:
        try:
            return datetime.date.fromisoformat(m.group(1))
        except ValueError:
            pass
    return None

# ─── Markdown management ──────────────────────────────────────────────────────

def load_context():
    """Load existing context file as a dict: date_str → list of lines."""
    if not os.path.exists(OUTPUT_FILE):
        return {}

    sections = {}
    current_date = None
    current_lines = []

    with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.rstrip('\n')
            if line.startswith('## ') and re.search(r'\d{4}-\d{2}-\d{2}', line):
                # Save previous section
                if current_date is not None:
                    sections[current_date] = current_lines
                current_date = re.search(r'\d{4}-\d{2}-\d{2}', line).group(0)
                current_lines = []
            elif current_date is not None:
                current_lines.append(line)

    if current_date is not None:
        sections[current_date] = current_lines

    return sections

def prune_old_sections(sections):
    """Remove sections older than KEEP_DAYS."""
    cutoff = datetime.date.today() - datetime.timedelta(days=KEEP_DAYS)
    to_delete = []
    for date_str in sections:
        try:
            d = datetime.date.fromisoformat(date_str)
            if d < cutoff:
                to_delete.append(date_str)
        except ValueError:
            pass
    for k in to_delete:
        del sections[k]
        log(f'Pruned old section: {k}')
    return sections

def add_visits_to_sections(sections, visits):
    """
    Merge new visits into the sections dict.
    Groups visits by date, deduplicates by URL (accumulates time, marks revisit).
    """
    # Group visits by date
    by_date = {}
    for v in visits:
        try:
            dt = datetime.datetime.fromisoformat(v['timestamp'].replace('Z', '+00:00'))
            # Convert to local date (simple approach: use UTC date)
            date_str = dt.strftime('%Y-%m-%d')
        except Exception:
            date_str = today_str()
        by_date.setdefault(date_str, []).append(v)

    for date_str, day_visits in by_date.items():
        existing_lines = sections.get(date_str, [])

        # Parse existing entries into a dict keyed by URL for dedup
        existing_entries = {}
        for line in existing_lines:
            url_match = re.search(r'\(https?://[^\)]+\)', line)
            if url_match:
                url = url_match.group(0)[1:-1]  # strip parens
                existing_entries[url] = line

        # Process new visits
        for v in day_visits:
            url = v.get('url', '')
            title = v.get('title', '').strip() or v.get('domain', url)
            time_ms = v.get('timeSpentMs', 0)
            display_url = extract_display_url(url)

            try:
                dt = datetime.datetime.fromisoformat(v['timestamp'].replace('Z', '+00:00'))
                time_label = dt.strftime('%H:%M')
            except Exception:
                time_label = '??:??'

            duration_str = format_duration(time_ms)
            revisit_marker = ''

            if url in existing_entries:
                # It's a revisit — update time and mark it
                old_line = existing_entries[url]
                # Extract old duration and add
                dur_match = re.search(r'— (\d+(?:\.\d+)?[smh])', old_line)
                if dur_match:
                    old_dur_str = dur_match.group(1)
                    old_ms = parse_duration_to_ms(old_dur_str)
                    total_ms = old_ms + time_ms
                    duration_str = format_duration(total_ms)
                revisit_marker = ' ⭐ (revisit)'
                # Rebuild the line (keep original timestamp)
                time_match = re.search(r'\[(\d{2}:\d{2})\]', old_line)
                if time_match:
                    time_label = time_match.group(1)

            # Truncate long titles
            if len(title) > 70:
                title = title[:70] + '…'

            new_line = f'- [{time_label}] **{title}** ({display_url}) — {duration_str}{revisit_marker}'
            existing_entries[url] = new_line

        # Reconstruct lines for this date, sorted by time
        def sort_key(line):
            m = re.search(r'\[(\d{2}:\d{2})\]', line)
            return m.group(1) if m else '00:00'

        new_lines = sorted(existing_entries.values(), key=sort_key)
        sections[date_str] = new_lines

    return sections

def parse_duration_to_ms(dur_str):
    """Parse '4m', '30s', '1.5h' back to milliseconds."""
    if dur_str.endswith('h'):
        return int(float(dur_str[:-1]) * 3600 * 1000)
    if dur_str.endswith('m'):
        return int(float(dur_str[:-1]) * 60 * 1000)
    if dur_str.endswith('s'):
        return int(float(dur_str[:-1]) * 1000)
    return 0

def write_context(sections):
    """Write the sections dict back to the output file."""
    os.makedirs(MEMORY_DIR, exist_ok=True)

    sorted_dates = sorted(sections.keys(), reverse=True)

    lines = ['# Browser Context — Last 7 Days', '']
    for date_str in sorted_dates:
        lines.append(f'## {date_str}')
        day_lines = sections[date_str]
        if day_lines:
            lines.extend(day_lines)
        else:
            lines.append('_(no activity)_')
        lines.append('')

    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        f.write('\n'.join(lines))

    log(f'Wrote {OUTPUT_FILE} ({len(sorted_dates)} days, {sum(len(v) for v in sections.values())} entries)')

# ─── HTTP Handler ─────────────────────────────────────────────────────────────

class SyncHandler(http.server.BaseHTTPRequestHandler):

    def log_message(self, format, *args):
        # Suppress default access logs; we do our own
        pass

    def send_json(self, code, data):
        body = json.dumps(data).encode('utf-8')
        self.send_response(code)
        self.send_header('Content-Type', 'application/json')
        self.send_header('Content-Length', len(body))
        # Allow requests from browser extension (localhost origin)
        self.send_header('Access-Control-Allow-Origin', '*')
        self.end_headers()
        self.wfile.write(body)

    def do_OPTIONS(self):
        """Handle CORS preflight from browser extension."""
        self.send_response(200)
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
        self.send_header('Access-Control-Allow-Headers', 'Content-Type')
        self.end_headers()

    def do_GET(self):
        if self.path == '/health':
            self.send_json(200, {'status': 'ok', 'output': OUTPUT_FILE})
        else:
            self.send_json(404, {'error': 'not found'})

    def do_POST(self):
        if self.path != '/sync':
            self.send_json(404, {'error': 'not found'})
            return

        try:
            length = int(self.headers.get('Content-Length', 0))
            if length == 0:
                self.send_json(400, {'error': 'empty body'})
                return

            body = self.rfile.read(length)
            payload = json.loads(body.decode('utf-8'))
            visits = payload.get('visits', [])

            if not isinstance(visits, list):
                self.send_json(400, {'error': 'visits must be a list'})
                return

            log(f'Received {len(visits)} visits from extension')

            sections = load_context()
            sections = prune_old_sections(sections)
            sections = add_visits_to_sections(sections, visits)
            write_context(sections)

            self.send_json(200, {'ok': True, 'processed': len(visits)})

        except json.JSONDecodeError as e:
            log(f'JSON parse error: {e}')
            self.send_json(400, {'error': f'invalid JSON: {e}'})
        except Exception as e:
            log(f'Error processing request: {e}')
            self.send_json(500, {'error': str(e)})

# ─── Main ─────────────────────────────────────────────────────────────────────

class ReusableTCPServer(socketserver.TCPServer):
    allow_reuse_address = True

def main():
    os.makedirs(MEMORY_DIR, exist_ok=True)

    log(f'Harvey Browser Sync Server starting on {BIND_HOST}:{PORT}')
    log(f'Output: {OUTPUT_FILE}')

    with ReusableTCPServer((BIND_HOST, PORT), SyncHandler) as server:
        try:
            server.serve_forever()
        except KeyboardInterrupt:
            log('Shutting down.')

if __name__ == '__main__':
    main()
