//
// Syd: rock-solid application kernel
// src/sigset.rs: Portable sigset that can handle reserved signals
//
// Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::ops::{
    BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Sub, SubAssign,
};

use libc::c_int;
use nix::sys::signal::Signal;
use serde::{Serialize, Serializer};

/// sigset is 128 bits on MIPS, and 64 bits on other architectures.
#[expect(non_camel_case_types)]
#[cfg(any(
    target_arch = "mips",
    target_arch = "mips32r6",
    target_arch = "mips64",
    target_arch = "mips64r6"
))]
pub type sigset_t = u128;

/// sigset is 128 bits on MIPS, and 64 bits on other architectures.
#[expect(non_camel_case_types)]
#[cfg(not(any(
    target_arch = "mips",
    target_arch = "mips32r6",
    target_arch = "mips64",
    target_arch = "mips64r6"
)))]
pub type sigset_t = u64;

/// A sigset that can handle reserved signals.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct SydSigSet(pub sigset_t);

impl BitOr for SydSigSet {
    type Output = Self;

    fn bitor(self, rhs: Self) -> Self::Output {
        SydSigSet(self.0 | rhs.0)
    }
}

impl BitOrAssign for SydSigSet {
    fn bitor_assign(&mut self, rhs: Self) {
        self.0 |= rhs.0;
    }
}

impl BitAnd for SydSigSet {
    type Output = Self;

    fn bitand(self, rhs: Self) -> Self::Output {
        SydSigSet(self.0 & rhs.0)
    }
}

impl BitAndAssign for SydSigSet {
    fn bitand_assign(&mut self, rhs: Self) {
        self.0 &= rhs.0;
    }
}

impl BitXor for SydSigSet {
    type Output = Self;

    fn bitxor(self, rhs: Self) -> Self::Output {
        SydSigSet(self.0 ^ rhs.0)
    }
}

impl BitXorAssign for SydSigSet {
    fn bitxor_assign(&mut self, rhs: Self) {
        self.0 ^= rhs.0;
    }
}

impl Not for SydSigSet {
    type Output = Self;

    fn not(self) -> Self::Output {
        SydSigSet(!self.0)
    }
}

impl Sub for SydSigSet {
    type Output = Self;

    fn sub(self, rhs: Self) -> Self::Output {
        SydSigSet(self.0 & !rhs.0)
    }
}

impl SubAssign for SydSigSet {
    fn sub_assign(&mut self, rhs: Self) {
        self.0 &= !rhs.0;
    }
}

impl std::fmt::Display for SydSigSet {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let max_signal = libc::SIGRTMAX();
        #[expect(clippy::cast_sign_loss)]
        let mut signals = Vec::with_capacity(max_signal as usize);

        for sig_num in 1..max_signal {
            if self.contains(sig_num) {
                signals.push(sig_num.to_string());
            }
        }

        if !signals.is_empty() {
            write!(f, "{}", signals.join(", "))
        } else {
            write!(f, "?")
        }
    }
}

impl SydSigSet {
    /// Create a new SydSigSet.
    pub fn new(mask: sigset_t) -> Self {
        Self(mask)
    }

    /// Returns `true` if the set is empty.
    pub fn is_empty(&self) -> bool {
        self.0 == 0
    }

    /// Returns `true` if the set contains the given signal `sig`.
    #[expect(clippy::arithmetic_side_effects)]
    #[expect(clippy::cast_sign_loss)]
    pub fn contains(&self, sig: c_int) -> bool {
        if sig < 1 {
            return false;
        }
        let bit = (sig - 1) as sigset_t;
        (self.0 & (1 << bit)) != 0
    }

    /// Returns `true` if this set intersects with the given set (i.e.,
    /// they share any common signals).
    pub fn intersects(&self, other: Self) -> bool {
        (self.0 & other.0) != 0
    }

    /// Add a signal to the set.
    #[expect(clippy::arithmetic_side_effects)]
    #[expect(clippy::cast_sign_loss)]
    pub fn add(&mut self, sig: c_int) {
        if sig < 1 {
            return; // ignore invalid signals
        }
        let bit = (sig - 1) as sigset_t;
        self.0 |= 1 << bit;
    }

    /// Remove a signal from the set.
    #[expect(clippy::arithmetic_side_effects)]
    #[expect(clippy::cast_sign_loss)]
    pub fn del(&mut self, sig: c_int) {
        if sig < 1 {
            return; // ignore invalid signals
        }
        let bit = (sig - 1) as sigset_t;
        self.0 &= !(1 << bit);
    }

    /// Add all signals from another SydSigSet to this one.
    pub fn add_set(&mut self, set: Self) {
        self.0 |= set.0;
    }

    /// Remove all signals present in `set` from `self`.
    pub fn del_set(&mut self, set: Self) {
        self.0 &= !set.0;
    }
}

impl Serialize for SydSigSet {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let max_signal = libc::SIGRTMAX();

        let mut signals = Vec::new();

        for sig_num in 1..max_signal {
            if self.contains(sig_num) {
                let sig_name = Signal::try_from(sig_num)
                    .map(|sig| sig.to_string())
                    .unwrap_or_else(|_| format!("SIG{sig_num}"));
                signals.push(sig_name);
            }
        }

        signals.serialize(serializer)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn max_signal() -> c_int {
        libc::SIGRTMAX()
    }

    #[test]
    fn test_sigset_empty_set() {
        let set = SydSigSet::new(0);
        // No signals should be contained.
        for sig_num in 1..max_signal() {
            assert!(
                !set.contains(sig_num),
                "Empty set should not contain any signal"
            );
        }

        // Intersecting empty set with itself is still empty (no intersection means false).
        let empty2 = SydSigSet::new(0);
        assert!(
            !set.intersects(empty2),
            "Empty set should not intersect with another empty set"
        );
    }

    #[test]
    fn test_sigset_single_signal() {
        let sigalrm = libc::SIGALRM; // commonly 14
        let mask = (1 as sigset_t) << (sigalrm - 1);
        let set = SydSigSet::new(mask);

        // Should contain SIGALRM only.
        assert!(set.contains(sigalrm), "Set should contain SIGALRM");
        for sig_num in 1..max_signal() {
            if sig_num != sigalrm {
                assert!(!set.contains(sig_num), "Only SIGALRM should be set");
            }
        }

        // Intersects with empty set? no
        let empty = SydSigSet::new(0);
        assert!(
            !set.intersects(empty),
            "Single-signal set should not intersect an empty set"
        );

        // Intersect with itself? yes
        assert!(set.intersects(set), "Set should intersect with itself");
    }

    #[test]
    fn test_sigset_multiple_signals() {
        let signals = [libc::SIGINT, libc::SIGALRM, libc::SIGTERM]; // e.g. [2, 14, 15]
        let mut mask: sigset_t = 0;
        for &sig in &signals {
            mask |= (1 as sigset_t) << ((sig - 1) as sigset_t);
        }
        let set = SydSigSet::new(mask);

        // Check contains
        for &sig in &signals {
            assert!(set.contains(sig), "Set should contain signal {}", sig);
        }

        // Check that others are not contained
        for sig_num in 1..max_signal() {
            if !signals.contains(&sig_num) {
                assert!(!set.contains(sig_num), "Only INT, ALRM, TERM should be set");
            }
        }

        // Check intersects
        // Intersect with a set that has one of those signals
        let single = SydSigSet::new((1 as sigset_t) << ((libc::SIGINT - 1) as sigset_t));
        assert!(
            set.intersects(single),
            "Should intersect since both contain SIGINT"
        );

        // Intersect with a set that shares no signals
        let unrelated_mask = (1 as sigset_t) << ((libc::SIGHUP - 1) as sigset_t); // SIGHUP=1 if not in the original set
        let unrelated_set = SydSigSet::new(unrelated_mask);
        if !signals.contains(&libc::SIGHUP) {
            assert!(
                !set.intersects(unrelated_set),
                "Should not intersect if no signals in common"
            );
        }

        // Test add and del by starting empty and building the set
        let mut dynamic_set = SydSigSet::new(0);
        for &sig in &signals {
            dynamic_set.add(sig);
            assert!(
                dynamic_set.contains(sig),
                "Signal {} should now be contained",
                sig
            );
        }

        // Remove one signal and ensure it's gone
        dynamic_set.del(libc::SIGALRM);
        assert!(
            !dynamic_set.contains(libc::SIGALRM),
            "SIGALRM should be removed"
        );
        assert!(dynamic_set.contains(libc::SIGINT), "Other signals remain");
        assert!(dynamic_set.contains(libc::SIGTERM), "Other signals remain");
    }

    #[test]
    fn test_sigset_all_signals() {
        let max_sig = max_signal();
        let mut mask: sigset_t = 0;
        for sig_num in 1..max_sig {
            mask |= (1 as sigset_t) << ((sig_num - 1) as sigset_t);
        }
        let set = SydSigSet::new(mask);

        // Should contain all signals up to SIGRTMAX
        for sig_num in 1..max_sig {
            assert!(set.contains(sig_num), "All signals should be contained");
        }

        // Removing a signal from a full set
        let mut copy_set = set;
        copy_set.del(libc::SIGINT);
        assert!(
            !copy_set.contains(libc::SIGINT),
            "SIGINT should be removed from the full set"
        );
        // Others remain
        for sig_num in 1..max_sig {
            if sig_num != libc::SIGINT {
                assert!(
                    copy_set.contains(sig_num),
                    "All others should still be present"
                );
            }
        }

        // Intersects with partial sets
        let single_set = SydSigSet::new((1 as sigset_t) << ((libc::SIGTERM - 1) as sigset_t));
        assert!(
            set.intersects(single_set),
            "Full set intersects with any non-empty set"
        );
    }

    #[test]
    fn test_sigset_bits_beyond_rtm() {
        let max_sig = max_signal();
        let mut mask: sigset_t = 0;

        // Set every bit up to 64 (or 128 on MIPS) to ensure
        // we cover beyond SIGRTMAX.
        let bits = std::mem::size_of::<sigset_t>() * 8;
        for i in 0..bits {
            mask |= (1 as sigset_t) << i;
        }

        let set = SydSigSet::new(mask);

        // Contains all signals up to SIGRTMAX
        for sig_num in 1..max_sig {
            assert!(
                set.contains(sig_num),
                "All signals up to SIGRTMAX should be contained"
            );
        }
        // No error if we have bits beyond SIGRTMAX; they're simply meaningless beyond that range.
        // Since we don't have signals beyond SIGRTMAX, there's no direct test other than ensuring no panic.
    }

    #[test]
    fn test_sigset_invalid_signals() {
        let mut set = SydSigSet::new(0);

        // Adding an invalid signal (<1) should do nothing
        set.add(0);
        set.add(-1);
        assert!(!set.contains(0), "Invalid signal should not be contained");
        assert!(!set.contains(-1), "Invalid signal should not be contained");
        assert!(!set.contains(1), "We never added a valid signal");

        // Removing an invalid signal does nothing
        set.del(0);
        set.del(-1);
        assert!(
            !set.contains(1),
            "No signals should be added or removed by invalid ops"
        );
    }

    #[test]
    fn test_sigset_intersects() {
        // Create two sets with partial overlap
        // Set A: SIGINT, SIGALRM
        // Set B: SIGALRM, SIGTERM
        let set_a_mask = ((1 as sigset_t) << ((libc::SIGINT - 1) as sigset_t))
            | ((1 as sigset_t) << ((libc::SIGALRM - 1) as sigset_t));
        let set_b_mask = ((1 as sigset_t) << ((libc::SIGALRM - 1) as sigset_t))
            | ((1 as sigset_t) << ((libc::SIGTERM - 1) as sigset_t));

        let set_a = SydSigSet::new(set_a_mask);
        let set_b = SydSigSet::new(set_b_mask);

        // They both share SIGALRM
        assert!(set_a.intersects(set_b), "Sets should intersect on SIGALRM");

        // Create a set that does not share any signals with A or B
        let set_c = SydSigSet::new((1 as sigset_t) << ((libc::SIGHUP - 1) as sigset_t));
        if libc::SIGHUP != libc::SIGINT
            && libc::SIGHUP != libc::SIGALRM
            && libc::SIGHUP != libc::SIGTERM
        {
            assert!(
                !set_a.intersects(set_c),
                "A and C should not intersect if distinct signals"
            );
            assert!(
                !set_b.intersects(set_c),
                "B and C should not intersect if distinct signals"
            );
        }
    }
}
