#!/usr/bin/env python3
# coding: utf-8
#
# ioctlgen: Generate rust code from strace ioctls
# Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
# SPDX-License-Identifier: GPL-3.0

import argparse
import os
import re
import sys

# ScmpArch::<Variant> -> strace src/linux/<archdir>
ARCH_DIR = {
    "X8664": "x86_64",
    "X86": "i386",
    "X32": "x32",
    "Aarch64": "aarch64",
    "Arm": "arm",
    "M68k": "m68k",
    "Mips": "mips",
    "Mipsel": "mips",
    "Mips64": "mips",
    "Mips64N32": "mips",
    "Mipsel64": "mips",
    "Mipsel64N32": "mips",
    "Ppc": "powerpc",
    "Ppc64": "powerpc64",
    "Ppc64Le": "powerpc64le",
    "Riscv64": "riscv64",
    "S390": "s390",
    "S390X": "s390x",
    "Loongarch64": "loongarch64",
}

# Only expand the correct seed headers for each variant/personality.
# Paths are relative to src/linux/<archdir>. X32 seeds live under x86_64/.
def seeds_for_variant(variant, archdir):
    # x86_64 personalities: inc0/arch0 (native), inc2/arch2 (x32)
    if variant == "X8664":
        return [(archdir, "ioctls_inc0.h"), (archdir, "ioctls_arch0.h")]
    if variant == "X32":
        return [("x86_64", "ioctls_inc2.h"), ("x86_64", "ioctls_arch2.h")]
    # aarch64 native (compat ARM generated separately from arm/)
    if variant == "Aarch64":
        return [(archdir, "ioctls_inc0.h"), (archdir, "ioctls_arch0.h")]
    # powerpc64 and s390x also ship compat tables in *_inc1.h, but those are for 32-bit personalities.
    # We generate 32-bit variants from their native 32-bit dirs (powerpc/, s390/).
    # Everything else uses inc0/arch0 in its own dir.
    return [(archdir, "ioctls_inc0.h"), (archdir, "ioctls_arch0.h")]

# Include handling
INCLUDE_RE = re.compile(r'^[ \t]*#\s*include\s*"([^"]+)"\s*$', re.MULTILINE)

# Strip comments
CSTYLE_COMMENT_RE = re.compile(r"/\*.*?\*/", re.DOTALL)
CPPCOMMENT_RE = re.compile(r"//[^\n]*")

# Entry lines:
# { "hdr", "NAME", DIR, TYPE_NR, SIZE },
ENTRY_RE = re.compile(
    r"""
    \{\s*
    "([^"]+)"\s*,\s*                   # header path (group 1)
    "([^"]+)"\s*,\s*                   # NAME (group 2)
    ([^,]*?(?:_IOC_[^,]*|\b0x[0-9a-fA-F]+\b|\b\d+\b))\s*,  # DIR tokens or numeric (group 3)
    \s*(0x[0-9a-fA-F]+|\d+)\s*,\s*     # TYPE_NR (group 4)
    (0x[0-9a-fA-F]+|\d+)               # SIZE (group 5)
    \s*\}
    """,
    re.VERBOSE | re.DOTALL,
)

def fail(msg):
    print("error: " + msg, file=sys.stderr); sys.exit(2)

def warn(msg):
    print("warn: " + msg, file=sys.stderr)

def note(msg):
    print(msg, file=sys.stderr)

def resolve_include(including, inc, linux_dir, archdir):
    # Absolute path
    if os.path.isabs(inc) and os.path.isfile(inc):
        return inc
    # Relative to including file
    base_dir = os.path.dirname(including)
    cand = os.path.join(base_dir, inc)
    if os.path.isfile(cand):
        return cand
    # In per-arch dir
    cand = os.path.join(linux_dir, archdir, inc)
    if os.path.isfile(cand):
        return cand
    # In linux dir (32/, 64/, generic/, etc.)
    cand = os.path.join(linux_dir, inc)
    if os.path.isfile(cand):
        return cand
    raise FileNotFoundError(inc)

def expand_includes(path, archdir, linux_dir, seen):
    real = os.path.realpath(path)
    if real in seen:
        return ""
    seen.add(real)
    try:
        with open(real, "r", encoding="utf-8", errors="ignore") as f:
            src = f.read()
    except Exception as e:
        warn("cannot read %s: %s" % (real, e))
        return ""
    out = []
    pos = 0
    for m in INCLUDE_RE.finditer(src):
        out.append(src[pos:m.start()])
        inc = m.group(1)
        try:
            target = resolve_include(real, inc, linux_dir, archdir)
            out.append(expand_includes(target, archdir, linux_dir, seen))
        except FileNotFoundError:
            out.append('/* include "%s" not found while expanding %s */\n' % (inc, real))
        pos = m.end()
    out.append(src[pos:])
    return "".join(out)

def strip_comments(s):
    s = CSTYLE_COMMENT_RE.sub("", s)
    s = CPPCOMMENT_RE.sub("", s)
    return s

def dir_mapping_variant(variant):
    # PPC/MIPS families use 3 dir bits and different NONE/READ/WRITE values.
    ppc_mips_like = variant in {
        "Ppc", "Ppc64", "Ppc64Le",
        "Mips", "Mipsel", "Mips64", "Mips64N32", "Mipsel64", "Mipsel64N32",
    }
    if ppc_mips_like:
        # sizebits=13, dirbits=3, tokens: NONE=1, READ=2, WRITE=4
        return {"NONE": 1, "READ": 2, "WRITE": 4}, 13, 3
    else:
        # asm-generic: sizebits=14, dirbits=2, tokens: NONE=0, WRITE=1, READ=2
        return {"NONE": 0, "WRITE": 1, "READ": 2}, 14, 2

def parse_dir_numeric(variant, field):
    t = field.strip()
    # Already numeric?
    try:
        return int(t, 0) & 0xFFFFFFFF
    except ValueError:
        pass
    # Tokens joined with '|'
    tokens_map, _, _ = dir_mapping_variant(variant)
    t = t.replace("(", "").replace(")", "")
    parts = [p.strip() for p in t.split("|") if p.strip()]
    val = 0
    for p in parts:
        if p.startswith("_IOC_"):
            key = p[len("_IOC_"):].upper()
            if key in tokens_map:
                val |= tokens_map[key]
                continue
        try:
            val |= int(p, 0)
        except ValueError:
            warn("unknown dir token '%s' in '%s' (arch %s); ignoring" % (p, field, variant))
    return val & 0xFFFFFFFF

def parse_entries(expanded, variant):
    txt = strip_comments(expanded)
    items = []
    for m in ENTRY_RE.finditer(txt):
        hdr = m.group(1)
        name = m.group(2)
        dir_field = m.group(3)
        type_nr_str = m.group(4)
        size_str = m.group(5)
        try:
            dir_num = parse_dir_numeric(variant, dir_field)
            type_nr = int(type_nr_str, 0) & 0xFFFF  # type in [15:8], nr in [7:0]
            size = int(size_str, 0) & 0xFFFFFFFF
        except ValueError:
            continue
        items.append((hdr, name, dir_num, type_nr, size))
    return items

def header_priority(hdr):
    # Prefer arch-specific "asm/" headers over generic "asm-generic/"
    # Normalize separators and avoid requiring leading '/'
    h = hdr.replace("\\", "/")
    if h.startswith("asm-generic/") or "/asm-generic/" in h:
        return 0
    if h.startswith("asm/") or "/asm/" in h:
        return 3
    # driver or linux/* headers sit in the middle
    if h.startswith("linux/") or "/linux/" in h:
        return 2
    return 1

def compute_ioctl_value(variant, dir_num, type_nr, size):
    _, sizebits, dirbits = dir_mapping_variant(variant)
    IOC_SIZESHIFT = 16
    IOC_DIRSHIFT = IOC_SIZESHIFT + sizebits
    size_mask = (1 << sizebits) - 1
    dir_mask = (1 << dirbits) - 1
    full = (type_nr & 0xFFFF) | ((size & size_mask) << IOC_SIZESHIFT) | ((dir_num & dir_mask) << IOC_DIRSHIFT)
    return full & 0xFFFFFFFF

def gather_variant(linux_dir, variant, archdir):
    seeds = seeds_for_variant(variant, archdir)

    seen = set()
    # name -> (full, prio, hdr)
    chosen = {}

    for (seed_archdir, seed_file) in seeds:
        seed_path = os.path.join(linux_dir, seed_archdir, seed_file)
        if not os.path.isfile(seed_path):
            # ok if a seed is missing for some older/newer trees
            continue
        expanded = expand_includes(seed_path, seed_archdir, linux_dir, seen)
        for hdr, name, dir_num, type_nr, size in parse_entries(expanded, variant):
            full = compute_ioctl_value(variant, dir_num, type_nr, size)
            prio = header_priority(hdr)

            prev = chosen.get(name)
            if prev is None:
                chosen[name] = (full, prio, hdr)
                continue

            prev_full, prev_prio, _ = prev
            if full == prev_full:
                # identical encoding; keep existing for determinism
                continue

            # Prefer higher header priority (e.g., asm/ over asm-generic/)
            if prio > prev_prio:
                chosen[name] = (full, prio, hdr)
            # If equal priority but different values, keep existing for determinism

    if not chosen:
        warn("skip %s: no data (seeds not found or empty)" % variant)
        return []

    # Final pairs sorted by name then value
    pairs = sorted(((name, tup[0]) for name, tup in chosen.items()),
                   key=lambda t: (t[0], t[1]))
    return pairs

def emit_rust(out_dir, variant, pairs):
    rs_path = os.path.join(out_dir, "ioctls_%s.rs" % variant.lower())
    with open(rs_path, "w", encoding="utf-8") as w:
        w.write("// This file was automatically generated from strace sources!\n")
        w.write("// vim: set ro :\n\n")
        w.write("static IOCTL_ARCH_%s: IoctlList = &[\n" % ascii_upper(variant))
        for name, full in pairs:
            w.write('    ("%s", 0x%x),\n' % (name, full))
        w.write("];\n")
    note("ok: wrote %s (%d entries)" % (rs_path, len(pairs)))

def ascii_upper(s):
    return s.translate({i: i - 32 for i in range(97, 123)})

def main():
    ap = argparse.ArgumentParser(description="Generate per-arch Rust arrays from strace ioctl tables")
    ap.add_argument("strace_source_dir")
    ap.add_argument("output_dir")
    args = ap.parse_args()

    strace_root = os.path.realpath(args.strace_source_dir)
    linux_dir = os.path.join(strace_root, "src", "linux")
    if not os.path.isdir(linux_dir):
        fail("'%s' does not look like a strace source tree (missing src/linux)" % strace_root)

    out_dir = os.path.realpath(args.output_dir)
    os.makedirs(out_dir, exist_ok=True)

    for variant, archdir in sorted(ARCH_DIR.items(), key=lambda kv: kv[0]):
        pairs = gather_variant(linux_dir, variant, archdir)
        if not pairs:
            note("skip: %s (no data)" % variant)
            continue
        emit_rust(out_dir, variant, pairs)

    note("done: outputs in %s" % out_dir)

if __name__ == "__main__":
    main()
