//
// Syd: rock-solid application kernel
// src/rng.rs: OS Random Number Generator (RNG) interface
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

//! Set of functions to manage the OS Random Number Generator (RNG)

use std::{
    ops::RangeInclusive,
    os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd},
};

use libc::{c_int, dup3, GRND_RANDOM};
use nix::{
    errno::Errno,
    fcntl::{OFlag, ResolveFlag},
    sys::resource::{getrlimit, Resource},
    unistd::{close, UnlinkatFlags},
    NixPath,
};

use crate::{
    cookie::safe_unlinkat,
    fs::{is_active_fd, retry_on_eintr, safe_open},
    path::{XPath, XPathBuf, PATH_MAX},
};

/// RAII guard that disables pthread cancellation for the current thread
/// and restores the previous state on drop. Uses pthread_setcancelstate(3).
#[must_use = "hold the guard to keep cancellation disabled"]
pub struct CancelGuard(c_int);

const _PTHREAD_CANCEL_ENABLE: c_int = 0;
const PTHREAD_CANCEL_DISABLE: c_int = 1;

// Libc crate does not define this symbol explicitly yet.
extern "C" {
    fn pthread_setcancelstate(state: c_int, oldstate: *mut c_int) -> c_int;
}

impl CancelGuard {
    /// Acquire the guard by disabling pthread cancellation for this thread.
    ///
    /// Returns a guard that will restore the previous state when dropped.
    pub fn acquire() -> Result<Self, Errno> {
        let mut old: c_int = 0;

        // SAFETY: We call pthread_setcancelstate(3) for the current thread.
        // - PTHREAD_CANCEL_DISABLE is a valid constant.
        // - Second arg is a valid, writable pointer to store the previous state.
        // - This does not move or alias Rust values; it only flips the thread-local flag.
        let err = unsafe { pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &raw mut old) };

        // POSIX returns error code directly (not via errno).
        if err == 0 {
            Ok(Self(old))
        } else {
            Err(Errno::from_raw(err))
        }
    }
}

impl Drop for CancelGuard {
    fn drop(&mut self) {
        // SAFETY: Restore the exact state captured at construction
        // for the current thread. The second parameter can be NULL
        // when we don't care about the previous value.
        unsafe {
            pthread_setcancelstate(self.0, std::ptr::null_mut());
        }
    }
}

/// Public trait for unsigned integers that support uniform sampling without widening.
pub trait RandUint: Copy + Ord {
    /// Additive zero.
    const ZERO: Self;
    /// Additive one.
    const ONE: Self;
    /// Maximum value.
    const MAX: Self;

    /// Draw a uniformly random value of this type using the OS RNG for exactly this width.
    fn rand_from_os() -> Result<Self, Errno>;

    /// Checked add returning None on overflow.
    fn checked_add(self, rhs: Self) -> Option<Self>;
    /// Checked sub returning None on underflow.
    fn checked_sub(self, rhs: Self) -> Option<Self>;
    /// Checked mul returning None on overflow.
    fn checked_mul(self, rhs: Self) -> Option<Self>;

    /// Euclidean division returning None if rhs is zero.
    fn div_euclid_opt(self, rhs: Self) -> Option<Self>;
}

macro_rules! impl_rand_uint {
    ($($t:ty),* $(,)?) => {$(
        impl RandUint for $t {
            const ZERO: Self = 0;
            const ONE: Self = 1;
            const MAX: Self = <$t>::MAX;

            #[inline]
            fn rand_from_os() -> Result<Self, Errno> {
                // Read exactly size_of::<$t>() bytes, and interpret in native endian.
                let mut buf = [0u8; { std::mem::size_of::<$t>() }];
                fillrandom(&mut buf)?;
                Ok(<$t>::from_ne_bytes(buf))
            }

            #[inline] fn checked_add(self, rhs: Self) -> Option<Self> { self.checked_add(rhs) }
            #[inline] fn checked_sub(self, rhs: Self) -> Option<Self> { self.checked_sub(rhs) }
            #[inline] fn checked_mul(self, rhs: Self) -> Option<Self> { self.checked_mul(rhs) }

            #[inline]
            fn div_euclid_opt(self, rhs: Self) -> Option<Self> {
                if rhs == 0 { None } else { Some(self.div_euclid(rhs)) }
            }
        }
    )*};
}
impl_rand_uint!(u8, u16, u32, u64, u128, usize);

/// Return a uniform random unsigned integer in the inclusive range,
/// using OS randomness with rejection sampling.
pub fn randint<T>(range: RangeInclusive<T>) -> Result<T, Errno>
where
    T: RandUint,
{
    let (lo, hi) = range.into_inner();

    // Reject inverted or one-point ranges as invalid input.
    if lo >= hi {
        return Err(Errno::EINVAL);
    }

    // Full-domain path returns raw OS bytes for exact type width.
    if lo == T::ZERO && hi == T::MAX {
        return T::rand_from_os();
    }

    // Compute span = (hi - lo) + 1 with checked ops to avoid overflow.
    let span = hi
        .checked_sub(lo)
        .ok_or(Errno::EOVERFLOW)?
        .checked_add(T::ONE)
        .ok_or(Errno::EOVERFLOW)?;

    // Compute accept_top = floor(MAX / span) * span,
    // using only checked ops and Euclidean division.
    let k = T::MAX.div_euclid_opt(span).ok_or(Errno::EOVERFLOW)?;
    let accept_top = k.checked_mul(span).ok_or(Errno::EOVERFLOW)?;

    // Draw until r < accept_top so the mapping is unbiased.
    loop {
        let r = T::rand_from_os()?;
        if r < accept_top {
            // Compute off = r - floor(r / span) * span without remainder operators.
            let q = r.div_euclid_opt(span).ok_or(Errno::EOVERFLOW)?;
            let qspan = q.checked_mul(span).ok_or(Errno::EOVERFLOW)?;
            let off = r.checked_sub(qspan).ok_or(Errno::EOVERFLOW)?;
            let v = lo.checked_add(off).ok_or(Errno::EOVERFLOW)?;
            return Ok(v);
        }
    }
}

/// Return a random unprivileged port number using the OS RNG.
#[inline]
pub fn randport() -> Result<u16, Errno> {
    randint(1025u16..=0xFFFF)
}

/// Get secure bytes using the OS random number generator.
pub fn getrandom(size: usize) -> Result<Vec<u8>, Errno> {
    if size == 0 {
        // SAFETY:
        // Return EINVAL on zero length which is a common case of error.
        return Err(Errno::EINVAL);
    }

    let mut buf = Vec::new();
    if buf.try_reserve(size).is_err() {
        return Err(Errno::ENOMEM);
    }
    buf.resize(size, 0);

    fillrandom(&mut buf)?;
    Ok(buf)
}

/// Fill the given buffer using the OS random number generator.
pub fn fillrandom(buf: &mut [u8]) -> Result<(), Errno> {
    // SAFETY: Ensure buffer is not empty,
    // which is a common case of error.
    let siz = buf.len();
    if siz == 0 {
        return Err(Errno::EINVAL);
    }

    // Disable pthread cancellation within this critical section.
    // Restored automatically when guard is dropped.
    let guard = CancelGuard::acquire()?;

    let mut n = 0;
    while n < siz {
        let ptr = &mut buf[n..];
        let ptr = ptr.as_mut_ptr().cast();
        let siz = siz.checked_sub(n).ok_or(Errno::EOVERFLOW)?;

        n = n
            .checked_add(
                retry_on_eintr(|| {
                    // SAFETY: In libc we trust.
                    Errno::result(unsafe { libc::getrandom(ptr, siz, GRND_RANDOM) })
                })?
                .try_into()
                .or(Err(Errno::EINVAL))?,
            )
            .ok_or(Errno::EOVERFLOW)?;
    }

    // End of critical section.
    drop(guard);

    Ok(())
}

/// Duplicate the file descriptor to a random fd.
///
/// Valid flags:
/// - O_EXCL: closes oldfd after successful duplication.
/// - All other flags are passed to dup3(2), ie O_CLOEXEC.
pub fn duprand(oldfd: RawFd, mut flags: OFlag) -> Result<RawFd, Errno> {
    let range_start = 7u64;
    let (range_end, _) = getrlimit(Resource::RLIMIT_NOFILE)?;
    #[expect(clippy::unnecessary_cast)]
    let range_end = range_end.saturating_sub(1) as u64;

    // SAFETY: Cap to a sane maximum because sufficiently big values
    // of the hard limit tend to return ENOMEM.
    let range_end = range_end.min(0x10000);
    if range_end <= range_start {
        return Err(Errno::EMFILE);
    }
    let range = range_start..=range_end;

    // Close old fd if O_EXCL is given,
    // pass the rest of the flags to dup3.
    let close_old = flags.contains(OFlag::O_EXCL);
    flags.remove(OFlag::O_EXCL);

    // SAFETY: To make this file descriptor harder to spot by an
    // attacker we duplicate it to a random fd number.
    for _ in range.clone() {
        #[expect(clippy::cast_possible_truncation)]
        let fd_rand = randint(range.clone())? as RawFd;

        // SAFETY: fd only used after validation.
        let fd_rand = unsafe { BorrowedFd::borrow_raw(fd_rand) };

        // Check if the slot is free.
        // This is arguably subject to race but since this is solely
        // used for fds at startup, we don't really care.
        if is_active_fd(fd_rand) {
            continue;
        }

        match retry_on_eintr(|| {
            // SAFETY: In libc we trust.
            Errno::result(unsafe { dup3(oldfd, fd_rand.as_raw_fd(), flags.bits()) })
        }) {
            Ok(_) => {
                if close_old {
                    let _ = close(oldfd);
                }
                return Ok(fd_rand.as_raw_fd());
            }
            Err(Errno::EMFILE) => return Err(Errno::EMFILE),
            Err(_) => {}
        }
    }

    Err(Errno::EBADF)
}

/// Create a unique temporary file in `dirfd` relative to `prefix`
/// unlink the file and return its file descriptor. Unlike libc's
/// mkstemp(3) function the template here does not have to end with any
/// number of `X` characters. The function appends an implementation
/// defined number of random characters after `prefix`. `prefix` must
/// not start with the `/` character and not be longer than `PATH_MAX`
/// characters long. It is OK for prefix to be empty.
/// If `dirfd` supports the `O_TMPFILE` operation, an unnamed temporary
/// file is created instead with `O_TMPFILE|O_EXCL`.
pub fn mkstempat<Fd: AsFd>(dirfd: Fd, prefix: &XPath) -> Result<OwnedFd, Errno> {
    const MAX_TCOUNT: usize = 8;
    const SUFFIX_LEN: usize = 128;
    const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";

    // Step 1: Attempt to use O_TMPFILE|O_EXCL which is safer.
    let mut flags = OFlag::O_TMPFILE | OFlag::O_EXCL | OFlag::O_RDWR;
    match safe_open(&dirfd, XPath::from_bytes(b"."), flags, ResolveFlag::empty()) {
        Ok(fd) => return Ok(fd),
        Err(Errno::EISDIR | Errno::ENOENT | Errno::EOPNOTSUPP) => {}
        Err(errno) => return Err(errno),
    }

    // Step 2: Fallback to random name generation.
    flags.remove(OFlag::O_TMPFILE);
    flags.insert(OFlag::O_CREAT);
    if prefix.is_absolute() {
        return Err(Errno::EINVAL);
    } else if prefix.len().saturating_sub(SUFFIX_LEN) > PATH_MAX {
        return Err(Errno::ENAMETOOLONG);
    }

    let mut attempts = 0;
    let mut rng_data = [0u8; SUFFIX_LEN];
    #[expect(clippy::arithmetic_side_effects)]
    loop {
        attempts += 1;
        if attempts > MAX_TCOUNT {
            // Too many collisions.
            return Err(Errno::EEXIST);
        }

        // Fill with random bytes.
        fillrandom(&mut rng_data)?;

        // Map bytes to characters.
        let mut base = XPathBuf::with_capacity(prefix.len() + SUFFIX_LEN);
        base.append_bytes(prefix.as_bytes());
        for &b in &rng_data {
            base.append_byte(CHARSET[(b as usize) % CHARSET.len()]);
        }

        match safe_open(&dirfd, &base, flags, ResolveFlag::empty()) {
            Ok(fd) => {
                safe_unlinkat(dirfd, &base, UnlinkatFlags::NoRemoveDir)?;
                return Ok(fd);
            }
            Err(Errno::EEXIST) => {
                // Try again with a new random sequence.
                continue;
            }
            Err(errno) => return Err(errno),
        }
    }
}

#[cfg(test)]
mod tests {
    use std::fmt::Debug;

    use super::*;

    // Check basic API functions for sanity

    #[test]
    fn test_fillrandom() {
        assert_eq!(fillrandom(&mut []), Err(Errno::EINVAL));

        assert_eq!(fillrandom(&mut [0u8; 257]), Ok(()));
    }

    #[test]
    fn test_getrandom() {
        assert_eq!(getrandom(0), Err(Errno::EINVAL));

        let result = getrandom(257);
        assert!(result.is_ok(), "result:{result:?}");
    }

    // Test helpers

    fn draw<T: RandUint + Debug>(lo: T, hi: T) -> T {
        match randint::<T>(lo..=hi) {
            Ok(v) => v,
            Err(e) => panic!("randint failed for [{:?},{:?}] -> {:?}", lo, hi, e),
        }
    }

    fn sample<T: RandUint + Debug>(lo: T, hi: T, n: usize) -> Vec<T> {
        (0..n).map(|_| draw::<T>(lo, hi)).collect()
    }

    fn all_in_range<T: RandUint + Debug>(xs: &[T], lo: T, hi: T) -> bool {
        xs.iter().all(|&v| v >= lo && v <= hi)
    }

    fn chi2(counts: &[u64]) -> f64 {
        let k = counts.len() as f64;
        let n: u64 = counts.iter().copied().sum();
        let exp = (n as f64) / k;
        counts.iter().fold(0.0, |acc, &c| {
            let d = (c as f64) - exp;
            acc + (d * d) / exp
        })
    }

    // Conservative chi-square cap: df + z*sqrt(2*df);
    // z=6.0 for near-zero false alarms.
    fn chi2_cap(df: usize) -> f64 {
        let d = df as f64;
        d + 6.0 * (2.0 * d).sqrt()
    }

    fn hist_u8_zero_span(span: u8, xs: &[u8]) -> Vec<u64> {
        let m = span as usize;
        let mut h = vec![0u64; m];
        for &v in xs {
            let i = v as usize;
            if i < m {
                h[i] = h[i].checked_add(1).unwrap();
            }
        }
        h
    }

    fn hist_u16_zero_span(span: u16, xs: &[u16]) -> Vec<u64> {
        let m = span as usize;
        let mut h = vec![0u64; m];
        for &v in xs {
            let i = v as usize;
            if i < m {
                h[i] = h[i].checked_add(1).unwrap();
            }
        }
        h
    }

    fn hist_u32_zero_span(span: u32, xs: &[u32]) -> Vec<u64> {
        let m = span as usize;
        let mut h = vec![0u64; m];
        for &v in xs {
            let i = v as usize;
            if i < m {
                h[i] = h[i].checked_add(1).unwrap();
            }
        }
        h
    }

    fn hist_u64_zero_span(span: u64, xs: &[u64]) -> Vec<u64> {
        let m = span as usize;
        let mut h = vec![0u64; m];
        for &v in xs {
            let i = v as usize;
            if i < m {
                h[i] = h[i].checked_add(1).unwrap();
            }
        }
        h
    }

    fn hist_u128_zero_span(span: u128, xs: &[u128]) -> Vec<u64> {
        let m = span as usize;
        let mut h = vec![0u64; m];
        for &v in xs {
            let i = v as usize;
            if i < m {
                h[i] = h[i].checked_add(1).unwrap();
            }
        }
        h
    }

    fn hist_usize_zero_span(span: usize, xs: &[usize]) -> Vec<u64> {
        let m = span;
        let mut h = vec![0u64; m];
        for &v in xs {
            let i = v;
            if i < m {
                h[i] = h[i].checked_add(1).unwrap();
            }
        }
        h
    }

    fn rem_euclid_opt<T: RandUint + Debug>(r: T, span: T) -> Option<T> {
        let q = r.div_euclid_opt(span)?;
        let qspan = q.checked_mul(span)?;
        r.checked_sub(qspan)
    }

    // API checks

    #[test]
    fn test_randint_invalid_u8() {
        assert!(matches!(randint::<u8>(200..=100), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u16() {
        assert!(matches!(randint::<u16>(5000..=4999), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u32() {
        assert!(matches!(randint::<u32>(42..=41), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u64() {
        assert!(matches!(randint::<u64>(999..=998), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_u128() {
        assert!(matches!(randint::<u128>(500..=499), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_invalid_usize() {
        assert!(matches!(randint::<usize>(100..=99), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_u8() {
        assert!(matches!(randint::<u8>(77..=77), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_u16() {
        assert!(matches!(randint::<u16>(31337..=31337), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_u32() {
        assert!(matches!(
            randint::<u32>(1_000_000..=1_000_000),
            Err(Errno::EINVAL)
        ));
    }

    #[test]
    fn test_randint_onepoint_u64() {
        assert!(matches!(
            randint::<u64>(123456789..=123456789),
            Err(Errno::EINVAL)
        ));
    }

    #[test]
    fn test_randint_onepoint_u128() {
        assert!(matches!(randint::<u128>(999..=999), Err(Errno::EINVAL)));
    }

    #[test]
    fn test_randint_onepoint_usize() {
        assert!(matches!(randint::<usize>(4242..=4242), Err(Errno::EINVAL)));
    }

    // In-bounds (full domain and near max):

    #[test]
    fn test_randint_fulldomain_u8_inbounds() {
        let xs = sample::<u8>(u8::MIN, u8::MAX, 4096);
        assert!(all_in_range(&xs, u8::MIN, u8::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u16_inbounds() {
        let xs = sample::<u16>(u16::MIN, u16::MAX, 2048);
        assert!(all_in_range(&xs, u16::MIN, u16::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u32_inbounds() {
        let xs = sample::<u32>(u32::MIN, u32::MAX, 2048);
        assert!(all_in_range(&xs, u32::MIN, u32::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u64_inbounds() {
        let xs = sample::<u64>(u64::MIN, u64::MAX, 1024);
        assert!(all_in_range(&xs, u64::MIN, u64::MAX));
    }

    #[test]
    fn test_randint_fulldomain_u128_inbounds() {
        let xs = sample::<u128>(u128::MIN, u128::MAX, 256);
        assert!(all_in_range(&xs, u128::MIN, u128::MAX));
    }

    #[test]
    fn test_randint_fulldomain_usize_inbounds() {
        let xs = sample::<usize>(usize::MIN, usize::MAX, 1024);
        assert!(all_in_range(&xs, usize::MIN, usize::MAX));
    }

    #[test]
    fn test_randint_u8_nearmax_inbounds() {
        let lo = u8::MAX.saturating_sub(15);
        let xs = sample::<u8>(lo, u8::MAX, 2000);
        assert!(all_in_range(&xs, lo, u8::MAX));
    }

    #[test]
    fn test_randint_u16_nearmax_inbounds() {
        let lo = u16::MAX.saturating_sub(1023);
        let xs = sample::<u16>(lo, u16::MAX, 4000);
        assert!(all_in_range(&xs, lo, u16::MAX));
    }

    #[test]
    fn test_randint_u32_nearmax_inbounds() {
        let lo = u32::MAX.saturating_sub(1000);
        let xs = sample::<u32>(lo, u32::MAX, 3000);
        assert!(all_in_range(&xs, lo, u32::MAX));
    }

    #[test]
    fn test_randint_u64_nearmax_inbounds() {
        let lo = u64::MAX.saturating_sub(1000);
        let xs = sample::<u64>(lo, u64::MAX, 3000);
        assert!(all_in_range(&xs, lo, u64::MAX));
    }

    #[test]
    fn test_randint_u128_nearmax_inbounds() {
        let lo = u128::MAX.saturating_sub(1000);
        let xs = sample::<u128>(lo, u128::MAX, 2000);
        assert!(all_in_range(&xs, lo, u128::MAX));
    }

    #[test]
    fn test_randint_usize_nearmax_inbounds() {
        let lo = usize::MAX.saturating_sub(1000);
        let xs = sample::<usize>(lo, usize::MAX, 3000);
        assert!(all_in_range(&xs, lo, usize::MAX));
    }

    // Uniformity (rejection path).
    // Caps are computed from df to avoid random failures.

    fn check_uniform_u8(span: u8, n: usize) {
        let hi = span.checked_sub(1).unwrap();
        let xs = sample::<u8>(0, hi, n);
        let h = hist_u8_zero_span(span, &xs);
        let df = (span as usize).checked_sub(1).unwrap();
        let c2 = chi2(&h);
        let cap = chi2_cap(df);
        assert!(
            c2 < cap,
            "u8 chi2 {} cap {} df {} span {} n {}",
            c2,
            cap,
            df,
            span,
            n
        );
    }

    fn check_uniform_u16(span: u16, n: usize) {
        let hi = span.checked_sub(1).unwrap();
        let xs = sample::<u16>(0, hi, n);
        let h = hist_u16_zero_span(span, &xs);
        let df = (span as usize).checked_sub(1).unwrap();
        let c2 = chi2(&h);
        let cap = chi2_cap(df);
        assert!(
            c2 < cap,
            "u16 chi2 {} cap {} df {} span {} n {}",
            c2,
            cap,
            df,
            span,
            n
        );
    }

    fn check_uniform_u32(span: u32, n: usize) {
        let hi = span.checked_sub(1).unwrap();
        let xs = sample::<u32>(0, hi, n);
        let h = hist_u32_zero_span(span, &xs);
        let df = (span as usize).checked_sub(1).unwrap();
        let c2 = chi2(&h);
        let cap = chi2_cap(df);
        assert!(
            c2 < cap,
            "u32 chi2 {} cap {} df {} span {} n {}",
            c2,
            cap,
            df,
            span,
            n
        );
    }

    fn check_uniform_u64(span: u64, n: usize) {
        let hi = span.checked_sub(1).unwrap();
        let xs = sample::<u64>(0, hi, n);
        let h = hist_u64_zero_span(span, &xs);
        let df = (span as usize).checked_sub(1).unwrap();
        let c2 = chi2(&h);
        let cap = chi2_cap(df);
        assert!(
            c2 < cap,
            "u64 chi2 {} cap {} df {} span {} n {}",
            c2,
            cap,
            df,
            span,
            n
        );
    }

    fn check_uniform_u128(span: u128, n: usize) {
        let hi = span.checked_sub(1).unwrap();
        let xs = sample::<u128>(0, hi, n);
        let h = hist_u128_zero_span(span, &xs);
        let df = (span as usize).checked_sub(1).unwrap();
        let c2 = chi2(&h);
        let cap = chi2_cap(df);
        assert!(
            c2 < cap,
            "u128 chi2 {} cap {} df {} span {} n {}",
            c2,
            cap,
            df,
            span,
            n
        );
    }

    fn check_uniform_usize(span: usize, n: usize) {
        let hi = span.checked_sub(1).unwrap();
        let xs = sample::<usize>(0, hi, n);
        let h = hist_usize_zero_span(span, &xs);
        let df = span.checked_sub(1).unwrap();
        let c2 = chi2(&h);
        let cap = chi2_cap(df);
        assert!(
            c2 < cap,
            "usize chi2 {} cap {} df {} span {} n {}",
            c2,
            cap,
            df,
            span,
            n
        );
    }

    // u8 spans:
    // Mix of primes, composites, powers of two.

    #[test]
    fn test_randint_u8_uniform_span_2() {
        check_uniform_u8(2, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_3() {
        check_uniform_u8(3, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_4() {
        check_uniform_u8(4, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_5() {
        check_uniform_u8(5, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_6() {
        check_uniform_u8(6, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_7() {
        check_uniform_u8(7, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_8() {
        check_uniform_u8(8, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_9() {
        check_uniform_u8(9, 20_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_10() {
        check_uniform_u8(10, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_11() {
        check_uniform_u8(11, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_12() {
        check_uniform_u8(12, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_13() {
        check_uniform_u8(13, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_14() {
        check_uniform_u8(14, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_15() {
        check_uniform_u8(15, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_16() {
        check_uniform_u8(16, 40_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_17() {
        check_uniform_u8(17, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_18() {
        check_uniform_u8(18, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_19() {
        check_uniform_u8(19, 30_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_20() {
        check_uniform_u8(20, 40_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_31() {
        check_uniform_u8(31, 40_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_37() {
        check_uniform_u8(37, 40_000);
    }

    #[test]
    fn test_randint_u8_uniform_span_43() {
        check_uniform_u8(43, 40_000);
    }
    #[test]
    fn test_randint_u8_uniform_span_50() {
        check_uniform_u8(50, 50_000);
    }

    // Other integer types:
    // Selected spans that exercise rejection math.

    #[test]
    fn test_randint_u16_uniform_span_7() {
        check_uniform_u16(7, 20_000);
    }

    #[test]
    fn test_randint_u16_uniform_span_31() {
        check_uniform_u16(31, 30_000);
    }

    #[test]
    fn test_randint_u16_uniform_span_61() {
        check_uniform_u16(61, 30_000);
    }

    #[test]
    fn test_randint_u16_uniform_span_127() {
        check_uniform_u16(127, 30_000);
    }
    #[test]
    fn test_randint_u16_uniform_span_257() {
        check_uniform_u16(257, 30_000);
    }

    #[test]
    fn test_randint_u32_uniform_span_7() {
        check_uniform_u32(7, 20_000);
    }

    #[test]
    fn test_randint_u32_uniform_span_31() {
        check_uniform_u32(31, 30_000);
    }

    #[test]
    fn test_randint_u32_uniform_span_61() {
        check_uniform_u32(61, 30_000);
    }

    #[test]
    fn test_randint_u32_uniform_span_127() {
        check_uniform_u32(127, 30_000);
    }

    #[test]
    fn test_randint_u64_uniform_span_7() {
        check_uniform_u64(7, 20_000);
    }

    #[test]
    fn test_randint_u64_uniform_span_31() {
        check_uniform_u64(31, 30_000);
    }

    #[test]
    fn test_randint_u64_uniform_span_61() {
        check_uniform_u64(61, 30_000);
    }

    #[test]
    fn test_randint_u64_uniform_span_100() {
        check_uniform_u64(100, 30_000);
    }

    #[test]
    fn test_randint_u128_uniform_span_7() {
        check_uniform_u128(7, 20_000);
    }

    #[test]
    fn test_randint_u128_uniform_span_31() {
        check_uniform_u128(31, 30_000);
    }

    #[test]
    fn test_randint_u128_uniform_span_61() {
        check_uniform_u128(61, 30_000);
    }

    #[test]
    fn test_randint_usize_uniform_span_7() {
        check_uniform_usize(7, 20_000);
    }

    #[test]
    fn test_randint_usize_uniform_span_31() {
        check_uniform_usize(31, 30_000);
    }

    #[test]
    fn test_randint_usize_uniform_span_61() {
        check_uniform_usize(61, 30_000);
    }

    // Shift invariance (u8)

    fn check_shift_inv_u8(lo: u8, span: u8, n: usize) {
        let hi = lo.saturating_add(span.checked_sub(1).unwrap());
        let a = sample::<u8>(lo, hi, n);
        let b = sample::<u8>(0, span.checked_sub(1).unwrap(), n);

        let mut ha = vec![0u64; span as usize];
        for &x in &a {
            let idx = rem_euclid_opt::<u8>(x, span).unwrap() as usize;
            ha[idx] = ha[idx].checked_add(1).unwrap();
        }
        let hb = hist_u8_zero_span(span, &b);

        let c2a = chi2(&ha);
        let c2b = chi2(&hb);
        let cap = chi2_cap((span as usize) - 1);
        assert!(
            c2a < cap && c2b < cap,
            "shift inv chi2a {} chi2b {} cap {}",
            c2a,
            c2b,
            cap
        );

        let mut tv = 0.0;
        for i in 0..(span as usize) {
            let pa = (ha[i] as f64) / (n as f64);
            let pb = (hb[i] as f64) / (n as f64);
            tv += (pa - pb).abs();
        }
        assert!(tv < 0.25, "total variation {} too large", tv);
    }

    #[test]
    fn test_randint_u8_shift_lo0_span10() {
        check_shift_inv_u8(0, 10, 30_000);
    }

    #[test]
    fn test_randint_u8_shift_lo50_span10() {
        check_shift_inv_u8(50, 10, 30_000);
    }

    #[test]
    fn test_randint_u8_shift_lo120_span10() {
        check_shift_inv_u8(120, 10, 30_000);
    }

    #[test]
    fn test_randint_u8_shift_lo200_span10() {
        check_shift_inv_u8(200, 10, 30_000);
    }

    #[test]
    fn test_randint_u8_shift_lo240_span10() {
        check_shift_inv_u8(240, 10, 30_000);
    }

    // Crypto sanity (byte stream)

    #[test]
    fn test_randint_u8_monobit_ok() {
        let n = 100_000usize;
        let xs = sample::<u8>(0, u8::MAX, n);
        let mut ones = [0u64; 8];
        for &v in &xs {
            for b in 0..8 {
                if ((v >> b) & 1) == 1 {
                    ones[b] = ones[b].checked_add(1).unwrap();
                }
            }
        }
        for &c in &ones {
            let p = (c as f64) / (n as f64);
            assert!(((0.49f64)..=(0.51f64)).contains(&p));
        }
    }

    #[test]
    fn test_randint_u8_runs_reasonable() {
        let n = 100_000usize;
        let xs = sample::<u8>(0, u8::MAX, n);
        let median = 128u8;
        let mut runs: u64 = 0;
        let mut last = None;
        for &b in &xs {
            let bit = b >= median;
            if last.map(|x| x != bit).unwrap_or(true) {
                runs = runs.checked_add(1).unwrap();
                last = Some(bit);
            }
        }
        let exp = 1.0 + (n as f64) / 2.0;
        let r = runs as f64;
        assert!(((exp * 0.95)..=(exp * 1.05)).contains(&r));
    }

    #[test]
    fn test_randint_u8_serial_corr_small() {
        let n = 100_000usize;
        let xs = sample::<u8>(0, u8::MAX, n);
        let mean = (xs.iter().map(|&v| v as f64).sum::<f64>()) / (n as f64);
        let mut num = 0.0;
        let mut denx = 0.0;
        let mut deny = 0.0;
        for w in xs.windows(2) {
            let x = w[0] as f64 - mean;
            let y = w[1] as f64 - mean;
            num += x * y;
            denx += x * x;
            deny += y * y;
        }
        let corr = num / (denx.sqrt() * deny.sqrt());
        assert!(corr.abs() < 0.03);
    }

    // Representative windows (bounds only)

    #[test]
    fn test_randint_u8_window_low() {
        let xs = sample::<u8>(1, 9, 4000);
        assert!(all_in_range(&xs, 1, 9));
    }

    #[test]
    fn test_randint_u8_window_mid() {
        let xs = sample::<u8>(120, 140, 4000);
        assert!(all_in_range(&xs, 120, 140));
    }

    #[test]
    fn test_randint_u8_window_high() {
        let xs = sample::<u8>(240, 255, 4000);
        assert!(all_in_range(&xs, 240, 255));
    }

    #[test]
    fn test_randint_u16_window_small() {
        let xs = sample::<u16>(100, 200, 4000);
        assert!(all_in_range(&xs, 100, 200));
    }

    #[test]
    fn test_randint_u32_window_small() {
        let xs = sample::<u32>(1_000, 1_999, 4000);
        assert!(all_in_range(&xs, 1_000, 1_999));
    }

    #[test]
    fn test_randint_u64_window_small() {
        let xs = sample::<u64>(10_000, 10_999, 4000);
        assert!(all_in_range(&xs, 10_000, 10_999));
    }

    #[test]
    fn test_randint_u128_window_small() {
        let xs = sample::<u128>(1_000, 1_999, 1000);
        assert!(all_in_range(&xs, 1_000, 1_999));
    }

    #[test]
    fn test_randint_usize_window_small() {
        let xs = sample::<usize>(100, 200, 4000);
        assert!(all_in_range(&xs, 100, 200));
    }

    // Span equals MAX (not full domain): still in-bounds

    #[test]
    fn test_randint_u8_span_eq_max_inbounds() {
        let xs = sample::<u8>(1, u8::MAX, 4096);
        assert!(all_in_range(&xs, 1, u8::MAX));
    }

    #[test]
    fn test_randint_u16_span_eq_max_inbounds() {
        let xs = sample::<u16>(1, u16::MAX, 4096);
        assert!(all_in_range(&xs, 1, u16::MAX));
    }

    #[test]
    fn test_randint_u32_span_eq_max_inbounds() {
        let xs = sample::<u32>(1, u32::MAX, 4096);
        assert!(all_in_range(&xs, 1, u32::MAX));
    }

    #[test]
    fn test_randint_u64_span_eq_max_inbounds() {
        let xs = sample::<u64>(1, u64::MAX, 4096);
        assert!(all_in_range(&xs, 1, u64::MAX));
    }

    #[test]
    fn test_randint_u128_span_eq_max_inbounds() {
        let xs = sample::<u128>(1, u128::MAX, 1024);
        assert!(all_in_range(&xs, 1, u128::MAX));
    }

    #[test]
    fn test_randint_usize_span_eq_max_inbounds() {
        let xs = sample::<usize>(1, usize::MAX, 4096);
        assert!(all_in_range(&xs, 1, usize::MAX));
    }
}
