#!/usr/bin/env python3
"""
Contact-Sheet-Generator fuer DJI-Pocket-LRF-Proxies.
Liest LRF (1280x720 H.264), zieht 9 Frames evenly distributed,
baut 3x3 Grid mit Timecode-Overlay und Header-Zeile.
"""
import cv2
import os
import sys
import json
from pathlib import Path

SD_DIR = Path("/Volumes/UNTITLED/DCIM/DJI_001")
OUT_DIR = Path("/Users/marvinkuehlmann/source/agentic-ventures/assets/broll-pool/contact-sheets")
META_PATH = OUT_DIR.parent / "metadata.json"

GRID_COLS = 3
GRID_ROWS = 3
TILE_W = 640
TILE_H = 360
HEADER_H = 60
PADDING = 4

FONT = cv2.FONT_HERSHEY_SIMPLEX

def fmt_tc(seconds: float) -> str:
    s = int(seconds)
    return f"{s // 60:02d}:{s % 60:02d}"

def fmt_size(bytes_: int) -> str:
    for unit in ("B", "KB", "MB", "GB"):
        if bytes_ < 1024:
            return f"{bytes_:.1f}{unit}"
        bytes_ /= 1024
    return f"{bytes_:.1f}TB"

def take_from_name(name: str) -> str:
    # "DJI_20000601074734_0001_D.LRF" -> "0001"
    parts = name.split("_")
    if len(parts) >= 4:
        return parts[2]
    return name

def make_contact_sheet(lrf_path: Path, mp4_path: Path, take: str, out_path: Path):
    cap = cv2.VideoCapture(str(lrf_path))
    if not cap.isOpened():
        return None
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total_frames / fps if fps else 0
    n = GRID_COLS * GRID_ROWS

    # Frames evenly spaced, beginning slightly off 0 to skip black frames
    if total_frames < n:
        positions = list(range(total_frames))
    else:
        positions = [int(total_frames * (i + 0.5) / n) for i in range(n)]

    tiles = []
    timecodes = []
    for pos in positions:
        cap.set(cv2.CAP_PROP_POS_FRAMES, pos)
        ret, frame = cap.read()
        if not ret:
            # try a slightly earlier frame
            cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, pos - 5))
            ret, frame = cap.read()
        if not ret:
            tile = (255 * (1 + 0 * 1)).to_bytes(1, "little")
            frame = None
        if frame is not None:
            tile = cv2.resize(frame, (TILE_W, TILE_H))
        else:
            import numpy as np
            tile = (np.zeros((TILE_H, TILE_W, 3), dtype="uint8"))
        tc = fmt_tc(pos / fps)
        timecodes.append(tc)
        # Overlay timecode bottom-right with shadow box
        text = tc
        (tw, th), _ = cv2.getTextSize(text, FONT, 0.7, 2)
        x = TILE_W - tw - 12
        y = TILE_H - 12
        cv2.rectangle(tile, (x - 6, y - th - 6), (x + tw + 6, y + 6), (0, 0, 0), -1)
        cv2.putText(tile, text, (x, y), FONT, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
        tiles.append(tile)
    cap.release()

    # Compose grid
    import numpy as np
    grid_w = GRID_COLS * TILE_W + (GRID_COLS + 1) * PADDING
    grid_h = HEADER_H + GRID_ROWS * TILE_H + (GRID_ROWS + 1) * PADDING
    sheet = np.full((grid_h, grid_w, 3), 18, dtype="uint8")  # almost-black bg

    # Header
    mp4_size = mp4_path.stat().st_size if mp4_path.exists() else 0
    header_text = f"Take {take}  |  {duration:.1f}s  |  MP4 {fmt_size(mp4_size)}  |  {mp4_path.name}"
    cv2.putText(sheet, header_text, (12, 38), FONT, 0.8, (255, 255, 255), 2, cv2.LINE_AA)

    # Tiles
    for idx, tile in enumerate(tiles):
        r = idx // GRID_COLS
        c = idx % GRID_COLS
        y0 = HEADER_H + PADDING + r * (TILE_H + PADDING)
        x0 = PADDING + c * (TILE_W + PADDING)
        sheet[y0:y0 + TILE_H, x0:x0 + TILE_W] = tile

    cv2.imwrite(str(out_path), sheet, [cv2.IMWRITE_JPEG_QUALITY, 88])
    return {
        "take": take,
        "lrf_name": lrf_path.name,
        "mp4_name": mp4_path.name,
        "mp4_size_bytes": mp4_size,
        "duration_s": round(duration, 2),
        "fps": round(fps, 2),
        "frame_count": total_frames,
        "timecodes": timecodes,
        "contact_sheet": out_path.name,
    }

def main():
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    # Find LRF files for takes 0001-0031
    lrf_files = sorted(SD_DIR.glob("DJI_*_*_D.LRF"))
    older = [p for p in lrf_files if 1 <= int(take_from_name(p.name)) <= 31]
    print(f"Found {len(older)} older LRF files (0001-0031)")

    metadata = []
    for lrf in older:
        take = take_from_name(lrf.name)
        mp4 = lrf.with_suffix(".MP4")
        out = OUT_DIR / f"take-{take}.jpg"
        if out.exists():
            print(f"  [skip] take-{take}.jpg already exists")
            # still load metadata if cached
            continue
        print(f"  [..] take {take}: {lrf.name}", flush=True)
        meta = make_contact_sheet(lrf, mp4, take, out)
        if meta is None:
            print(f"  [FAIL] take {take} could not be opened")
            continue
        metadata.append(meta)
        print(f"  [OK]   take {take}: {meta['duration_s']}s -> {out.name}")

    # Merge with existing metadata if present
    if META_PATH.exists():
        existing = json.loads(META_PATH.read_text())
        existing_takes = {m["take"] for m in existing}
        merged = existing + [m for m in metadata if m["take"] not in existing_takes]
    else:
        merged = metadata
    merged.sort(key=lambda m: m["take"])
    META_PATH.write_text(json.dumps(merged, indent=2, ensure_ascii=False))
    print(f"\nMetadata saved: {META_PATH}")
    print(f"Sheets in:     {OUT_DIR}")

if __name__ == "__main__":
    main()
