#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AP52.2 / Single-Image URL Cypher

Goal:
  - URL-specialized compact packet
  - encrypted fixed 128-bit packet
  - packet inhabits Gray-Scott / reaction-diffusion generation conditions
  - no metadata / LSB / alpha / overlay / tile / QR / grid carrier
  - PASS only when exact URL recovery + verifier OK

Default target:
  https://masato-lab.pages.dev/portfolio/multiverse

Dependencies:
  pip install numpy pillow
"""

from __future__ import annotations

import argparse
import hashlib
import hmac
import itertools
import math
import os
from dataclasses import dataclass
from typing import Iterable, List, Optional, Sequence, Tuple

import numpy as np
from PIL import Image

VERSION = 2
BIT_COUNT = 128
MAX_SLUG = 16
VERIFIER_BITS = 24

# URL-specialized dictionaries. This is intentionally not a general UTF-8 carrier.
PREFIXES = [
    "https://masato-lab.pages.dev/",
]
PATH_CLASSES = [
    "",                 # root + slug
    "tools",            # /tools or /tools/<slug>
    "portfolio",        # /portfolio/<slug>
    "peripheral-memory", # /peripheral-memory/<slug>
    "about",            # /about or /about/<slug>
]

# 5-bit URL slug charset for AP52.2. Compact by design.
# URLs outside this alphabet should fall back to the multipage/long carrier path.
SLUG32 = "abcdefghijklmnopqrstuvwxyz-01234"

DEFAULT_URL = "https://masato-lab.pages.dev/portfolio/multiverse"
DEFAULT_KEY = "MASATO TURING LABYRINTH TRUE"


@dataclass
class PacketInfo:
    url: str
    prefix_id: int
    path_id: int
    slug: str
    plain_bits: np.ndarray
    encrypted_bits: np.ndarray
    verifier_bits: np.ndarray


@dataclass
class DecodeResult:
    status: str
    url: Optional[str]
    verifier_ok: bool
    exact_match: bool
    flips: Tuple[int, ...]
    orientation: int
    confidence_margin: float
    encrypted_bits: Optional[np.ndarray]
    plain_bits: Optional[np.ndarray]


def sha256(data: bytes) -> bytes:
    return hashlib.sha256(data).digest()


def int_to_bits(value: int, width: int) -> List[int]:
    return [(value >> shift) & 1 for shift in range(width - 1, -1, -1)]


def bits_to_int(bits: Sequence[int]) -> int:
    v = 0
    for b in bits:
        v = (v << 1) | int(b)
    return v


def bytes_to_bits(data: bytes) -> np.ndarray:
    return np.unpackbits(np.frombuffer(data, dtype=np.uint8)).astype(np.uint8)


def stream_bits(key: str, n_bits: int) -> np.ndarray:
    out = bytearray()
    counter = 0
    while len(out) * 8 < n_bits:
        block = sha256((f"AP52.2|stream|{key}|{counter}").encode("utf-8"))
        out.extend(block)
        counter += 1
    return bytes_to_bits(bytes(out))[:n_bits]


def verifier_for_url(url: str, key: str) -> np.ndarray:
    mac = hmac.new(
        key.encode("utf-8"),
        ("AP52.2|verifier|" + url).encode("utf-8"),
        hashlib.sha256,
    ).digest()
    return bytes_to_bits(mac)[:VERIFIER_BITS]


def normalize_url(url: str) -> str:
    return url.strip()


def split_url(url: str) -> Tuple[int, int, str]:
    url = normalize_url(url)
    for pi, prefix in enumerate(PREFIXES):
        if not url.startswith(prefix):
            continue
        rest = url[len(prefix) :].strip("/")
        if rest == "":
            return pi, 0, ""
        parts = rest.split("/", 1)
        first = parts[0]
        tail = parts[1] if len(parts) > 1 else ""
        for path_id, path_name in enumerate(PATH_CLASSES):
            if path_id == 0:
                continue
            if first == path_name:
                slug = tail
                return pi, path_id, slug
        # unknown first path segment: root slug fallback, if compact enough
        return pi, 0, rest
    raise ValueError("URL is outside AP52.2 dictionary prefix. Use multipage/long fallback.")


def url_from_parts(prefix_id: int, path_id: int, slug: str) -> str:
    if prefix_id >= len(PREFIXES) or path_id >= len(PATH_CLASSES):
        raise ValueError("invalid dictionary id")
    prefix = PREFIXES[prefix_id]
    path = PATH_CLASSES[path_id]
    if path == "":
        return prefix + slug
    if slug:
        return prefix + path + "/" + slug
    return prefix + path


def slug_to_values(slug: str) -> List[int]:
    if len(slug) > MAX_SLUG:
        raise ValueError(f"slug too long for AP52.2 single-image packet: {len(slug)} > {MAX_SLUG}")
    vals = []
    for ch in slug:
        idx = SLUG32.find(ch)
        if idx < 0:
            raise ValueError(f"slug character {ch!r} is outside AP52.2 compact charset: {SLUG32!r}")
        vals.append(idx)
    while len(vals) < MAX_SLUG:
        vals.append(0)
    return vals


def values_to_slug(vals: Sequence[int], length: int) -> str:
    if length > MAX_SLUG:
        raise ValueError("invalid slug length")
    chars = []
    for v in vals[:length]:
        if v < 0 or v >= len(SLUG32):
            raise ValueError("invalid slug code")
        chars.append(SLUG32[v])
    return "".join(chars)


def make_plain_bits(url: str, key: str) -> Tuple[np.ndarray, int, int, str]:
    url = normalize_url(url)
    prefix_id, path_id, slug = split_url(url)
    vals = slug_to_values(slug)

    bits: List[int] = []
    bits += int_to_bits(VERSION, 4)
    bits += int_to_bits(prefix_id, 3)
    bits += int_to_bits(path_id, 4)
    bits += int_to_bits(len(slug), 5)
    for v in vals:
        bits += int_to_bits(v, 5)
    bits += int_to_bits(0, 4)  # flags / reserved
    bits += int_to_bits(0, 4)  # reserved
    bits += verifier_for_url(url, key).astype(int).tolist()
    if len(bits) != BIT_COUNT:
        raise RuntimeError(f"internal packet size mismatch: {len(bits)} bits")
    return np.array(bits, dtype=np.uint8), prefix_id, path_id, slug


def encrypt_bits(plain_bits: np.ndarray, key: str) -> np.ndarray:
    return np.bitwise_xor(plain_bits.astype(np.uint8), stream_bits(key, len(plain_bits)))


def decrypt_bits(encrypted_bits: np.ndarray, key: str) -> np.ndarray:
    return encrypt_bits(encrypted_bits, key)


def parse_plain_bits(plain_bits: np.ndarray, key: str) -> Tuple[str, bool]:
    b = [int(x) for x in plain_bits.tolist()]
    p = 0
    version = bits_to_int(b[p : p + 4]); p += 4
    if version != VERSION:
        raise ValueError("version mismatch")
    prefix_id = bits_to_int(b[p : p + 3]); p += 3
    path_id = bits_to_int(b[p : p + 4]); p += 4
    slug_len = bits_to_int(b[p : p + 5]); p += 5
    vals = []
    for _ in range(MAX_SLUG):
        vals.append(bits_to_int(b[p : p + 5])); p += 5
    flags = bits_to_int(b[p : p + 4]); p += 4
    reserved = bits_to_int(b[p : p + 4]); p += 4
    if flags != 0 or reserved != 0:
        raise ValueError("reserved bits are not zero")
    verifier = np.array(b[p : p + VERIFIER_BITS], dtype=np.uint8); p += VERIFIER_BITS
    if p != BIT_COUNT:
        raise ValueError("packet length mismatch")
    slug = values_to_slug(vals, slug_len)
    url = url_from_parts(prefix_id, path_id, slug)
    verifier_ok = bool(np.array_equal(verifier, verifier_for_url(url, key)))
    return url, verifier_ok


def build_packet(url: str, key: str) -> PacketInfo:
    plain, prefix_id, path_id, slug = make_plain_bits(url, key)
    enc = encrypt_bits(plain, key)
    return PacketInfo(
        url=normalize_url(url),
        prefix_id=prefix_id,
        path_id=path_id,
        slug=slug,
        plain_bits=plain,
        encrypted_bits=enc,
        verifier_bits=verifier_for_url(normalize_url(url), key),
    )


def seeded_rng(seed_text: str) -> np.random.Generator:
    seed = int.from_bytes(sha256(seed_text.encode("utf-8"))[:8], "big", signed=False)
    return np.random.default_rng(seed)


def make_basis(size: int, bit_count: int, key: str) -> np.ndarray:
    """Full-field smooth basis: every bit affects the whole reaction field.

    This is the part that avoids cells/tiles. There is no local region assigned to a bit.
    """
    x = np.linspace(-1.0, 1.0, size, endpoint=False, dtype=np.float32)
    y = np.linspace(-1.0, 1.0, size, endpoint=False, dtype=np.float32)
    X, Y = np.meshgrid(x, y)
    B = np.empty((bit_count, size, size), dtype=np.float32)
    for i in range(bit_count):
        rng = seeded_rng(f"AP52.2|basis|{key}|{i}")
        m = np.zeros((size, size), dtype=np.float32)
        # Several oblique waves; not grid aligned, not cell aligned.
        for _ in range(6):
            angle = rng.uniform(0.0, 2.0 * math.pi)
            freq = rng.uniform(2.0, 18.0)
            fx = math.cos(angle) * freq
            fy = math.sin(angle) * freq
            phase = rng.uniform(0.0, 2.0 * math.pi)
            m += np.sin(2.0 * math.pi * (fx * X + fy * Y) + phase).astype(np.float32)
        m -= float(m.mean())
        m /= float(np.sqrt(np.mean(m * m)) + 1e-9)
        B[i] = m
    return B


def lowpass_noise(size: int, key: str) -> np.ndarray:
    rng = seeded_rng(f"AP52.2|base-seed|{key}|{size}")
    noise = rng.normal(0.0, 1.0, (size, size)).astype(np.float32)
    freq = np.fft.fftfreq(size)
    FX, FY = np.meshgrid(freq, freq)
    # broad organic seed, independent from packet bits
    filt = np.exp(-((FX * FX + FY * FY) / (0.045 ** 2)))
    smooth = np.fft.ifft2(np.fft.fft2(noise) * filt).real.astype(np.float32)
    smooth -= float(smooth.mean())
    smooth /= float(smooth.std() + 1e-9)
    return smooth


def encrypted_field(encrypted_bits: np.ndarray, basis: np.ndarray) -> np.ndarray:
    signs = encrypted_bits.astype(np.float32) * 2.0 - 1.0
    E = np.tensordot(signs, basis, axes=(0, 0)) / math.sqrt(len(encrypted_bits))
    E = E.astype(np.float32)
    E -= float(E.mean())
    E /= float(E.std() + 1e-9)
    return E


def gray_scott_generate(
    encrypted_bits: np.ndarray,
    key: str,
    size: int = 384,
    steps: int = 1000,
    base_f: float = 0.0290,
    base_k: float = 0.0565,
    delta_k: float = 0.0020,
    v_bias: float = 0.030,
) -> Tuple[Image.Image, np.ndarray]:
    basis = make_basis(size, len(encrypted_bits), key)
    E = encrypted_field(encrypted_bits, basis)
    seed = lowpass_noise(size, key)

    U = np.ones((size, size), dtype=np.float32)
    V = np.zeros((size, size), dtype=np.float32)

    # Organic seed cloud: no bit positions, no tiles.
    mask = seed > 1.0
    U[mask] = 0.50
    V[mask] = 0.25

    # The encrypted packet enters as generation conditions:
    # - global V initial bias
    # - global k-map perturbation
    # Both are smooth full-field carriers, not post-render pixels.
    V += v_bias * np.tanh(E / 1.5).astype(np.float32)
    V = np.clip(V, 0.0, 1.0)
    kmap = (base_k + delta_k * np.tanh(E / 2.0)).astype(np.float32)

    Du = 0.16
    Dv = 0.08
    for t in range(int(steps)):
        Lu = np.roll(U, 1, 0) + np.roll(U, -1, 0) + np.roll(U, 1, 1) + np.roll(U, -1, 1) - 4.0 * U
        Lv = np.roll(V, 1, 0) + np.roll(V, -1, 0) + np.roll(V, 1, 1) + np.roll(V, -1, 1) - 4.0 * V
        uvv = U * V * V
        U += Du * Lu - uvv + base_f * (1.0 - U)
        V += Dv * Lv + uvv - (base_f + kmap) * V
        if (t & 127) == 0:
            U = np.clip(U, 0.0, 1.0)
            V = np.clip(V, 0.0, 1.0)

    # Export visible reaction field only. No metadata payload is written.
    lo, hi = np.percentile(V, [2.0, 98.0])
    img = np.clip((V - lo) / (hi - lo + 1e-9), 0.0, 1.0)
    arr = (img * 255.0 + 0.5).astype(np.uint8)
    return Image.fromarray(arr, mode="L"), basis


def image_to_field(path: str, size: int) -> np.ndarray:
    im = Image.open(path).convert("L")
    if im.size != (size, size):
        im = im.resize((size, size), Image.Resampling.BICUBIC)
    arr = np.asarray(im).astype(np.float32) / 255.0
    arr -= float(arr.mean())
    arr /= float(arr.std() + 1e-9)
    return arr


def raw_read_bits_from_image(path: str, key: str, size: int, bit_count: int = BIT_COUNT) -> Tuple[np.ndarray, np.ndarray]:
    Y = image_to_field(path, size)
    B = make_basis(size, bit_count, key)
    scores = np.tensordot(B, Y, axes=([1, 2], [0, 1])) / float(size * size)
    # Empirically the Gray-Scott readout is often polarity-inverted, so decoder tries both.
    raw = (scores > 0.0).astype(np.uint8)
    return raw, scores.astype(np.float32)


def candidate_bits_with_beam(
    raw: np.ndarray,
    scores: np.ndarray,
    max_flip: int = 4,
    weak_count: int = 18,
) -> Iterable[Tuple[int, Tuple[int, ...], np.ndarray, float]]:
    weak = np.argsort(np.abs(scores))[:weak_count]
    for orientation in (+1, -1):
        if orientation == +1:
            base = raw.copy()
            oriented_scores = scores
        else:
            base = (1 - raw).astype(np.uint8)
            oriented_scores = -scores
        # Try exact first, then weak-score flips. This is not a success criterion;
        # it only proposes packet candidates. Verifier decides.
        yield orientation, tuple(), base.copy(), float(np.min(np.abs(oriented_scores)))
        for r in range(1, max_flip + 1):
            for combo in itertools.combinations([int(x) for x in weak], r):
                cand = base.copy()
                for idx in combo:
                    cand[idx] ^= 1
                margin = float(min(abs(float(scores[i])) for i in combo)) if combo else float(np.min(np.abs(scores)))
                yield orientation, combo, cand, margin


def decode_image(
    path: str,
    key: str,
    expected_url: Optional[str] = None,
    size: int = 384,
    max_flip: int = 4,
    weak_count: int = 18,
) -> DecodeResult:
    raw, scores = raw_read_bits_from_image(path, key, size, BIT_COUNT)
    best_invalid: Optional[DecodeResult] = None
    tested = 0
    for orientation, flips, enc_bits, margin in candidate_bits_with_beam(raw, scores, max_flip=max_flip, weak_count=weak_count):
        tested += 1
        try:
            plain = decrypt_bits(enc_bits, key)
            url, verifier_ok = parse_plain_bits(plain, key)
            exact = (expected_url is None) or (normalize_url(url) == normalize_url(expected_url))
            if verifier_ok and exact:
                return DecodeResult(
                    status="PASS",
                    url=url,
                    verifier_ok=True,
                    exact_match=exact,
                    flips=tuple(int(x) for x in flips),
                    orientation=orientation,
                    confidence_margin=margin,
                    encrypted_bits=enc_bits,
                    plain_bits=plain,
                )
            if verifier_ok and best_invalid is None:
                best_invalid = DecodeResult(
                    status="FAIL_URL_MISMATCH",
                    url=url,
                    verifier_ok=True,
                    exact_match=False,
                    flips=tuple(int(x) for x in flips),
                    orientation=orientation,
                    confidence_margin=margin,
                    encrypted_bits=enc_bits,
                    plain_bits=plain,
                )
        except Exception:
            continue
    if best_invalid:
        return best_invalid
    return DecodeResult(
        status="FAIL",
        url=None,
        verifier_ok=False,
        exact_match=False,
        flips=tuple(),
        orientation=0,
        confidence_margin=0.0,
        encrypted_bits=None,
        plain_bits=None,
    )


def save_png_without_payload(img: Image.Image, path: str) -> None:
    # Do not pass pnginfo. Pillow will write a normal pixel image only.
    img.save(path, format="PNG", optimize=True)


def print_packet(packet: PacketInfo) -> None:
    print("AP52.2 packet")
    print(f"  url        : {packet.url}")
    print(f"  prefix id  : {packet.prefix_id}")
    print(f"  path id    : {packet.path_id} ({PATH_CLASSES[packet.path_id]})")
    print(f"  slug       : {packet.slug!r}")
    print(f"  bits       : {len(packet.encrypted_bits)} encrypted bits")
    print(f"  verifier   : {''.join(str(int(x)) for x in packet.verifier_bits)}")


def main() -> None:
    ap = argparse.ArgumentParser(description="AP52.2 single-image URL cypher")
    ap.add_argument("--url", default=DEFAULT_URL, help="target URL")
    ap.add_argument("--key", default=DEFAULT_KEY, help="secret key / seed")
    ap.add_argument("--out", default="ap52_2_single_image_url.png", help="output PNG")
    ap.add_argument("--size", type=int, default=384, help="square image size")
    ap.add_argument("--steps", type=int, default=1000, help="Gray-Scott steps")
    ap.add_argument("--delta-k", type=float, default=0.0020, help="encoded k-map amplitude")
    ap.add_argument("--v-bias", type=float, default=0.030, help="encoded initial V bias")
    ap.add_argument("--decode", default=None, help="decode an existing PNG instead of generating")
    ap.add_argument("--expected", default=None, help="expected URL for strict PASS")
    ap.add_argument("--max-flip", type=int, default=4, help="beam search max bit flips")
    ap.add_argument("--weak-count", type=int, default=18, help="weak score candidate pool")
    args = ap.parse_args()

    if args.decode:
        expected = args.expected if args.expected is not None else args.url
        res = decode_image(
            args.decode,
            key=args.key,
            expected_url=expected,
            size=args.size,
            max_flip=args.max_flip,
            weak_count=args.weak_count,
        )
        print("AP52.2 decode")
        print(f"  image      : {args.decode}")
        print(f"  status     : {res.status}")
        print(f"  url        : {res.url}")
        print(f"  verifier OK: {res.verifier_ok}")
        print(f"  exact match: {res.exact_match}")
        print(f"  orientation: {res.orientation}")
        print(f"  flips      : {res.flips}")
        return

    try:
        packet = build_packet(args.url, args.key)
    except ValueError as e:
        print("AP52.2 single-image packet cannot hold this URL.")
        print(f"reason: {e}")
        print("fallback: use AP52 multipage/long-carrier path for this URL.")
        raise SystemExit(2)

    print_packet(packet)
    img, _basis = gray_scott_generate(
        packet.encrypted_bits,
        key=args.key,
        size=args.size,
        steps=args.steps,
        delta_k=args.delta_k,
        v_bias=args.v_bias,
    )
    save_png_without_payload(img, args.out)
    print(f"saved      : {args.out}")

    # Self-decode strictly from saved PNG pixels.
    res = decode_image(
        args.out,
        key=args.key,
        expected_url=packet.url,
        size=args.size,
        max_flip=args.max_flip,
        weak_count=args.weak_count,
    )
    print("self decode")
    print(f"  status     : {res.status}")
    print(f"  url        : {res.url}")
    print(f"  verifier OK: {res.verifier_ok}")
    print(f"  exact match: {res.exact_match}")
    print(f"  orientation: {res.orientation}")
    print(f"  flips      : {res.flips}")
    if res.status != "PASS":
        raise SystemExit(1)


if __name__ == "__main__":
    main()
