#!/usr/bin/env python3
"""
Test script for Django psycopg connection pool fork bug.

Starts gunicorn with multiple ASGI workers, sends concurrent requests,
and checks for protocol corruption errors. Then repeats with a single
worker as a control test.

Usage:
    1. Ensure PostgreSQL is running with the test database (see README.md)
    2. python manage.py migrate
    3. python test_fork_bug.py
"""

import concurrent.futures
import json
import os
import signal
import subprocess
import sys
import time
import urllib.request

GUNICORN_BIND = "127.0.0.1:8765"
NUM_REQUESTS = 150
CONCURRENT_CLIENTS = 20


def start_gunicorn(workers, extra_args=None):
    """Start gunicorn and wait until it's accepting requests."""
    cmd = [
        sys.executable, "-m", "gunicorn",
        "repro.asgi:application",
        "--worker-class", "asgi",
        f"--workers={workers}",
        f"--bind={GUNICORN_BIND}",
        "--timeout=30",
        "--log-level=error",
         "--preload",
    ]
    if extra_args:
        cmd.extend(extra_args)

    proc = subprocess.Popen(
        cmd,
        cwd=os.path.dirname(os.path.abspath(__file__)),
        stdout=subprocess.PIPE,
        stderr=open("/tmp/gunicorn_test.log", "w"),
    )

    # Poll until ready
    for _ in range(40):
        time.sleep(0.5)
        try:
            urllib.request.urlopen(f"http://{GUNICORN_BIND}/query/", timeout=3)
            return proc
        except Exception:
            if proc.poll() is not None:
                _, stderr = proc.communicate()
                print(f"  gunicorn exited early: {stderr.decode()[:500] if stderr else stderr}")
                raise RuntimeError("gunicorn failed to start")
            continue

    proc.kill()
    _, stderr = proc.communicate()
    raise RuntimeError(f"gunicorn did not become ready: {stderr.decode()[:500] if stderr else stderr}")


def stop_gunicorn(proc):
    """Gracefully stop gunicorn."""
    os.kill(proc.pid, signal.SIGTERM)
    try:
        proc.wait(timeout=15)
    except subprocess.TimeoutExpired:
        proc.kill()
        proc.wait(timeout=5)


def fire_requests(total, concurrency):
    """Send concurrent HTTP requests and collect results/errors."""
    results = []
    errors = []

    def do_request(_):
        resp = urllib.request.urlopen(
            f"http://{GUNICORN_BIND}/query/", timeout=15
        )
        return json.loads(resp.read())

    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as pool:
        futures = [pool.submit(do_request, i) for i in range(total)]
        for f in concurrent.futures.as_completed(futures):
            try:
                results.append(f.result())
            except Exception as e:
                errors.append(str(e))

    return results, errors


def run_test(label, workers, extra_args=None):
    print(f"\n{'─'*60}")
    print(f"  {label}")
    print(f"  workers={workers}, requests={NUM_REQUESTS}, concurrency={CONCURRENT_CLIENTS}")
    print(f"{'─'*60}")

    proc = start_gunicorn(workers, extra_args)
    try:
        results, errors = fire_requests(NUM_REQUESTS, CONCURRENT_CLIENTS)
    finally:
        stop_gunicorn(proc)

    worker_pids = sorted(set(r["worker_pid"] for r in results))
    print(f"  Responses OK : {len(results)}")
    print(f"  Errors       : {len(errors)}")
    print(f"  Worker PIDs  : {worker_pids}")

    if errors:
        print(f"  Sample errors:")
        for e in errors[:8]:
            # Clean up the error string for readability
            msg = e.replace("\n", " ")[:150]
            print(f"    • {msg}")
        if len(errors) > 8:
            print(f"    … and {len(errors) - 8} more")

    return len(errors)


def main():
    print("=" * 60)
    print("  Django psycopg pool + gunicorn fork() bug reproduction")
    print("=" * 60)

    # Test 1: multiple workers — should trigger the bug
    errors_multi = run_test(
        "Test 1: MULTIPLE workers (expect errors)",
        workers=3,
    )

    time.sleep(2)

    # Test 2: single worker — control, should be clean
    errors_single = run_test(
        "Test 2: SINGLE worker (control)",
        workers=1,
    )

    time.sleep(2)

    # Test 3: multiple workers with post_fork workaround
    # Create a temp gunicorn config with the fix
    conf_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_gunicorn_fix.py")
    with open(conf_path, "w") as f:
        f.write(
            'def post_fork(server, worker):\n'
            '    from django.db.backends.postgresql.base import DatabaseWrapper\n'
            '    DatabaseWrapper._connection_pools.clear()\n'
            '    worker.log.info("Cleared connection pools after fork")\n'
        )
    try:
        errors_fixed = run_test(
            "Test 3: MULTIPLE workers + post_fork pool clear (workaround)",
            workers=3,
            extra_args=["--config", conf_path],
        )
    finally:
        os.unlink(conf_path)

    # Summary
    print(f"\n{'='*60}")
    print(f"  SUMMARY")
    print(f"{'='*60}")
    print(f"  Test 1 (3 workers, shared pool) : {errors_multi} errors")
    print(f"  Test 2 (1 worker, control)      : {errors_single} errors")
    print(f"  Test 3 (3 workers, pool cleared) : {errors_fixed} errors")

    if errors_multi > 0 and errors_single == 0 and errors_fixed == 0:
        print(f"\n  ✗ BUG CONFIRMED")
        print(f"    Multiple forked workers sharing the pool causes {errors_multi} errors.")
        print(f"    Single worker and post_fork clearing both work correctly.")
    elif errors_multi == 0:
        print(f"\n  ⚠ Bug not triggered in this run (it's timing-dependent).")
        print(f"    Try increasing NUM_REQUESTS / CONCURRENT_CLIENTS, or run")
        print(f"    under load. In production with real ASGI views, it triggers")
        print(f"    reliably on every page load with multiple workers.")
    print()


if __name__ == "__main__":
    main()
