/*
 * SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
 * Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * OpenIB.org BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 */

#pragma once

// std
#include <cstdint>
#include <unordered_set>
// project
#include "am_common.h"
#include "dump_file.h"

#include "fabric.h"
#include "port_data.h"
#include "reservation_manager.h"
#include "sub_tree_score.h"

struct FabricProviderCallbackContext;
struct AggNodeQuotaAllocInfo;
struct JobData;
class ReservationInfo;
class TreeEdge;
class AggNode;
class ReservationManager;
class ParallelTreeFinder;

using JobSubTreeInfoVec = std::vector<JobSubTreeInfo>;
using JobSubTreeInfoVecVec = std::vector<JobSubTreeInfoVec>;
using MapAnToQuotaAllocInfo = std::map<const AggNodeFabricInfo*, AggNodeQuotaAllocInfo>;
using MapAnToQuotaAllocInfoInsertRes = std::pair<MapAnToQuotaAllocInfo::iterator, bool>;
using VectorANFabricInfoAndANQuotaInfoPair = std::vector<std::pair<AggNodeFabricInfo const*, AggNodeQuotaAllocInfo const*>>;
using MapPortKeyToQuota = std::map<string, sharp_quota>;
using USetPortKeys = std::unordered_set<port_key_t>;

// ports set indexed by rail id
using VecSetPortData = std::vector<SetPortDataConstPtr>;
using VecVecPortData = std::vector<VecPortDataConstPtr>;
using VecPlaneVecSetPortData = std::vector<VecSetPortData>;
using VecPlaneVeRailVecPortData = std::vector<VecVecPortData>;

enum class Confirmation
{
    NO,
    YES,
    ERROR_STATE,
};

struct AggNodeQuotaAllocInfo
{
    Quota m_quota;
    ListTreeIds m_tree_ids;                           // list of tree_ids on this node
    ListTreeIds m_root_tree_ids;                      // list of tree_ids that are root on this node
    bool m_prevent_lock{false};                       // prevent lock on node (unless used by SAT)
    Confirmation m_is_configured{Confirmation::NO};   // whether the node is configured on fabric (yes/no/error)
};

class SharpJob
{
    ReservationManager& m_reservation_manager_;
    ParallelTreeFinder& m_tree_finder_;

    SharpExtJobId m_external_job_id_;   // external job Id
    sharp_job_id_t m_sharp_job_id_;

    uint8_t m_priority_;

    bool m_multicast_enabled_;   // multicast of llt operation result

    // Tells whether we should update the relevant reservation when this job is deleted.
    // Reasons to avoid an update can be:
    // 1. The job is deleted because of reservation deletion, and the reservation is already deleted.
    // 2. The job is invalid, it is cleaned right after creation.
    bool m_should_update_reservation_at_job_delete;

    string m_job_file_path;
    string m_job_mc_trees_file_path_;

    // Job may contain hosts list or ports list
    SetHostInfoPtr m_hosts_;
    VecPlaneVecSetPortData m_ports_by_rail_;

    JobSubTreeInfoVec m_sub_tree_info_vec_;
    JobSubTreeInfoVecVec m_mc_sub_tree_info_vec_;   // multicast trees structure
    MapAnToQuotaAllocInfo m_an_to_quota_alloc_info_;

    uint8_t m_trees_number_;   // not include multicast trees
    SetTreeIds m_tree_ids_;    // not include multicast trees
    SetAggNodeFabricPtr m_agg_nodes_;
    SetAggNodeFabricPtr m_mc_agg_nodes_;

    // DOTO remove m_num_channels_per_conn_ from SharpJob
    uint8_t m_num_channels_per_conn_;

    const smx_ep m_sharpd_ep_;   // sharpd_0 address

    // sharpd_0 connection id, used for keep-alive
    // when keep-alive is enabled we open connection to sharpd_0 until the job ends
    // During the job keep-alive checks if the connection is still alive
    int m_sharpd_conn_id_;
    SharpTimestamp m_job_info_send_time;
    bool m_job_info_reply_received;

    bool m_reproducible;

    uint64_t m_job_key_;

    sharp_job_state m_job_state_;
    bool m_configure_fp19_;
    bool m_configure_bfloat19_;

    // True when the client request that rmc will be supported by every tree in this Job
    bool m_rmc_supported_on_all_trees_ = false;

    uint8_t m_req_num_trees_;

    file_utils::DumpFile* m_resource_alloc_dump_file_ptr_;

    SharpTimestamp m_start_time_{SharpTimestamp::clock::now()};

    // Rounding mode used by this job - needed for conflict detection
    sharp_round_mode m_round_mode_;

    int FindCoalescingRails(const sharp_begin_job& begin_job_data,
                            vector<uint8_t>& coalescing_rails,
                            VecVecPortData& vec_rail_to_vec_port_data,
                            uint32_t& num_guids_per_rail);

    sharp_am_status CoalesceRails(VecVecPortData& vec_rail_to_vec_port_data, uint32_t first_index, uint32_t second_index);

    sharp_am_status CoalesceValidRails(VecVecPortData& vec_rail_to_vec_port_data, vector<uint8_t>& coalescing_rails);

    void SetJobPriority(uint8_t priority);

   public:
    SharpJob(ReservationManager& reservation_manager,
             ParallelTreeFinder& tree_finder,
             const SharpExtJobId& external_job_id,
             sharp_job_id_t sharp_job_id,
             bool enable_mcast,
             const smx_ep* ep,
             uint64_t job_key,
             file_utils::DumpFile* resource_alloc_dump_file);

    void SetJobFilePath(const string& job_file_path) { m_job_file_path = job_file_path; }
    void SetJobMcTreesDumpFilePath(string& job_mc_trees_file_path) { m_job_mc_trees_file_path_ = job_mc_trees_file_path; }
    int AddHostInfo(const char* hostname);

    sharp_am_status AnalizeRailsInfo(const sharp_begin_job& begin_job_data);

    int SetTreesSpanningEndpoints(VecPortDataConstPtr& vec_port_data, uint8_t rail, uint8_t plane);

    static int ParseHostlistCallback(const char* hostname, void* arg);

    sharp_am_status Init(const sharp_begin_job& begin_job_data);

    string GetName() const;
    const SharpExtJobId& GetExternalJobId() const { return m_external_job_id_; }
    sharp_job_id_t GetSharpJobId() const { return m_sharp_job_id_; }
    void SetSharpJobId(sharp_job_id_t sharp_job_id) { m_sharp_job_id_ = sharp_job_id; }

    const string& GetReservationKey() const { return m_external_job_id_.reservation_key; }
    const char* GetReservationKeyCharPtr() const { return m_external_job_id_.reservation_key.c_str(); }

    uint8_t GetJobPriority() { return m_priority_; }
    const smx_ep* GetSharpdZeroAddress() const { return &m_sharpd_ep_; }
    int GetSharpdConnId() const { return m_sharpd_conn_id_; }
    void SetSharpdConnId(int sharpd_conn_id) { m_sharpd_conn_id_ = sharpd_conn_id; }

    bool IsMulticastEnabled() { return m_multicast_enabled_; }

    bool ShouldUpdateReservationAtDelete() { return m_should_update_reservation_at_job_delete; }

    // Sat Job is a job that has both LLT and SAT trees, it's requested by the client with `SHARP_JOB_REQ_FEATURE_MASK_SAT` parameter
    inline bool IsSatJob() const
    {
        return std::any_of(m_tree_ids_.begin(),
                           m_tree_ids_.end(),
                           [](const sharp_trees_t& tree_id) { return g_fabric.IsSatTreeId(tree_id); });
    }

    // If a job is deleted during reservation removal, we want to make sure we wont try to update the reservation
    void MarkDeletedAtReservationRemoval() { m_should_update_reservation_at_job_delete = false; }
    void UnmarkDeletedAtReservationRemoval() { m_should_update_reservation_at_job_delete = true; }

    void DisableMulticast() { m_multicast_enabled_ = false; }

    JobSubTreeInfo* GetJobSubTreeInfoForTreeId(uint16_t tree_id);

    void ResizeSubTreeInfoVec(uint8_t num_trees) { m_sub_tree_info_vec_.resize(num_trees); }
    void ResizeMulticastSubTreeInfoVecVec(uint8_t num_vec) { m_mc_sub_tree_info_vec_.resize(num_vec); }
    void ResizeMulticastSubTreeInfoVec(uint8_t num_vec, uint8_t num_mc_trees) { m_mc_sub_tree_info_vec_[num_vec].resize(num_mc_trees); }
    uint8_t GetMulticastSubTreeInfoSize(uint8_t num_vec)
    {
        return (GetMulticastSubTreeInfoVecVecSize() <= num_vec ? 0 : (uint8_t)m_mc_sub_tree_info_vec_[num_vec].size());
    }
    uint8_t GetMulticastSubTreeInfoVecVecSize() { return (uint8_t)m_mc_sub_tree_info_vec_.size(); }

    JobSubTreeInfo& GetSubTreeInfo(uint8_t tree_number) { return m_sub_tree_info_vec_[tree_number]; }
    const JobSubTreeInfo& GetSubTreeInfo(uint8_t tree_number) const { return m_sub_tree_info_vec_[tree_number]; }
    // get the next unused SubTreeInfo
    JobSubTreeInfo& GetSubTreeInfo() { return m_sub_tree_info_vec_[m_trees_number_]; }

    JobSubTreeInfo& GetMulticastSubTreeInfo(uint8_t num_vec, uint8_t mc_tree_number)
    {
        return m_mc_sub_tree_info_vec_[num_vec][mc_tree_number];
    }

    int ClearJobSubTreeInfo(sharp_trees_t tree_id);

    bool IsReproducible() const { return m_reproducible; }

    void SetJobState(sharp_job_state state);

    sharp_job_state GetJobState() const { return m_job_state_; }

    void CommitSubTreeInfo(const sharp_trees_t tree_id, const SetAggNodeFabricPtr& agg_nodes = {});

    uint8_t GetTreesNumber() const { return m_trees_number_; }
    const SetTreeIds& GetTreeIds() const { return m_tree_ids_; }
    const SetAggNodeFabricPtr& GetAggNodes() const { return m_agg_nodes_; }
    const SetAggNodeFabricPtr& GetMulticastAggNodes() const { return m_mc_agg_nodes_; }

    u_int32_t GetHostNumber() const { return (u_int32_t)m_hosts_.size(); }
    const SetHostInfoPtr& GetHosts() { return m_hosts_; }
    const SetPortDataConstPtr& GetPorts(uint8_t rail, uint8_t plane_index) const { return m_ports_by_rail_[plane_index][rail]; }
    uint8_t GetNumberOfRails(uint8_t plane = 0) const { return m_ports_by_rail_[plane].size(); }
    uint8_t GetNumberOfPlanes() const { return m_ports_by_rail_.size(); }
    uint8_t GetRailIndex(uint8_t tree_index) const { return tree_index % GetNumberOfRails(); }
    uint8_t GetPlaneIndex(uint8_t tree_index) const { return tree_index / GetNumberOfRails(); }
    uint8_t GetTreeIndex(uint8_t rail_index, uint8_t plane_index) const { return plane_index * GetNumberOfRails() + rail_index; }

    uint32_t GetPortsNumberForRail(uint8_t rail, uint8_t plane);
    uint32_t GetPortsNumberForAllRails();
    // uint32_t GetPortsNumberForAllRailsPlanes();

    // Similar to std all_of and for_each, the following methods provide a way to execute a callback on all HCA ports
    // AllOf will execute as long as the callback returns true
    // ForEach will execute on all ports
    bool AllOfPorts(const std::function<bool(const PortData*)>& callback) const;
    void ForEachPort(const std::function<void(const PortData*)>& callback) const;

    u_int32_t GetConnectionsNumber() const;
    bool IsSAT() const;
    bool IsJobInfoSent() const { return m_job_info_send_time != SharpTimestamp::min(); }
    std::chrono::seconds GetTotalRunTime() const;
    const SharpTimestamp& GetJobInfoTimeStamp() const { return m_job_info_send_time; }
    bool IsJobInfoReplyReceived() const { return m_job_info_reply_received; }
    void SetJobInfoSendTimeStamp(bool job_info_send);
    void SetJobInfoReplyReceived(bool job_info_reply_received) { m_job_info_reply_received = job_info_reply_received; }

    sharp_round_mode GetRoundMode() const { return m_round_mode_; }

    sharp_am_status FabricTreesConfig();
    sharp_am_status FabricQuotaConfig();
    void FabricJobResourceCleanup();
    void FabricJobResourceCleanupV2();
    void DisconnectSatJobFromTrees(bool cleanup_v2 = false);
    void CleanJobSatTrees(bool cleanup_v2 = false);
    void CleanSharpJob(bool cleanup_v2 = false);
    void CleanJobTreeConfigOnFabric();
    void DisableJobTreeNodes();

    sharp_am_status FabricMulticastConfig(bool clear_config);
    int CopyMulticastPersistentJobInfo(persistent_job_info* job_info);

    void IncreaseJobLoadOnAggNodes();

    void GetSubTreesBFS(ListTreeNodePtr* sub_tree_nodes_list);
    void GetSubTreeNodesBFS(ListTreeNodePtr* sub_tree_nodes_list, struct JobSubTreeInfo& sub_tree_info);
    void CreateJobInfoFiles(const string& persistent_path);
    void DeleteJobInfoFiles(bool should_print = true);

    uint8_t GetNumChannelsPerConn() const { return m_num_channels_per_conn_; }
    void AddTreeId(sharp_trees_t tree_id);
    void AddAggNode(AggNodeFabricInfo* p_agg_node);
    void AddMulticastAggNode(AggNodeFabricInfo* p_agg_node);
    int SetTreeRootForJob();
    int SetQuotaAllocForJob();
    bool IsExclusiveLockUsedByTree(sharp_trees_t tree_id);
    int CopyInfoFromJobData(const JobData* job_data);

    int AllocateQuota(JobSubTreeInfo& job_sub_tree_info, uint8_t child_index_per_port);

    void FreeQuota(const bool is_job_end);
    void FreeJobResource(const bool is_job_end);
    int AllocateSat(JobSubTreeScore& result, JobSubTreeInfo& job_sub_tree_info_llt);

    void ModifyAvailableTreeIDsInAllSubTrees(bool is_available);

    bool IsAnyTreeNodeOnJob();

    sharp_am_status AllocateJobTreeResource(const JobResource& job_resource, const sharp_begin_job& begin_Job_data, uint8_t tree_index);

    void UpdateJobDataQpcOpts(sharp_job_data& job_data) const;
    void UpdateJobData(sharp_job_data& job_data);
    static void FreeJobData(sharp_job_data& job_data);
    void PrepareJobTreesInfoMessage(sharp_job_trees_info& job_info);
    int ReconstructTrees();
    int ReconstructMulticastTrees();
    int UpdateSubTreesInfo();
    void RestorePeerTreeIds(bool is_sat);

    void PrintResourceAllocationSummary();
    void CleanJobsDataRestoredDuringSeamlessRestart();
    void GetAnPortKeys(USetPortKeys& an_port_keys) const;
    void GetAnPortKeysForMulticast(USetPortKeys& an_port_keys) const;

    void DumpJobTrees(file_utils::DumpFile*) const;

   private:
    sharp_am_status AllocateJobMulticastResource(const JobResource& job_resource,
                                                 const sharp_begin_job& begin_Job_data,
                                                 uint8_t max_radix,
                                                 uint8_t tree_index);
    uint8_t GetNumOfRequiredMulticastTrees(uint8_t max_radix, uint16_t compute_ports_number, uint8_t max_mc_trees);

    sharp_am_status AddComputesToMulticastTrees(const SetPortDataConstPtr& compute_ports, uint8_t tree_index);
    sharp_am_status AddComputesToMulticastTrees(MultisetPortDataByPeerRange& range, uint8_t tree_index);

    void AddJobToMulticastTree(JobSubTreeInfo& job_mc_sub_tree_info);
    void RemoveJobFromMulticastTrees();
    void RemoveJobFromMulticastTrees(JobSubTreeInfoVec& mc_sub_tree_info_vec);
    void RemoveJobFromMulticastTree(JobSubTreeInfo& job_mc_sub_tree_info);

    VectorANFabricInfoAndANQuotaInfoPair GetANFabricInfoAndANQuotaInfoPairsSortedByPortRankAndGuid(
        const MapPortKeyToQuota& port_key_to_quota = {});
    void WriteResourcesAllocDetailsToFile(char const* const resource_msg_format,
                                          AggNodeFabricInfo const* const an_fabric_info,
                                          const uint32_t qps1,
                                          const uint32_t qps2,
                                          const uint32_t qps3,
                                          const uint32_t buffers1,
                                          const uint32_t buffers2,
                                          const uint32_t buffers3,
                                          const uint32_t osts1,
                                          const uint32_t osts2,
                                          const uint32_t osts3,
                                          const uint32_t groups1,
                                          const uint32_t groups2,
                                          const uint32_t groups3);
    int DumpJobResourceFree();
    int DumpJobResourceAllocation(const uint16_t tree_id, const MapPortKeyToQuota& port_key_to_quota);

    void SetComputePorts();

    void DeleteJobInfoFile(string& file_path, bool should_print = true);
    void DumpJobMulticastTreesState();
    void UpdateJobDataTreeMlids(sharp_tree& sharp_tree, uint32_t mc_tree);

    void PrepareJobTreeInfoMessage(sharp_job_tree_info& job_tree_info, JobSubTreeInfo& job_tree, bool xdr_job);

    int UpdateMulticastTables(SetAggNodeFabricPtr& mc_agg_nodes, bool clear_config);
    void GetMulticastSubTreeNodesBFS(ListTreeNodePtr* sub_tree_nodes_list, JobSubTreeInfo& sub_tree_info, ListTreeNodePtr& nodes_queue);

    ////////                 Callback          ///////////
    // m_data1 = int operation_status;
    // m_data2 = *AggNode
    // m_data3 = *AggNodeQuotaAllocInfo
    void JobResourceCleanupCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);
    // m_data1 = int operation_status;
    // m_data2 = *AggNode
    void DisconnectJobTreeQpCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);
    // m_data1 = int operation_status;
    // m_data2 = *AggNode
    void CleanJobTreeCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    // m_data1 = int operation_status;
    // m_data2 = *AggNode
    void SatCleanupCallback(FabricProviderCallbackContext* p_context,
                            int rec_status,
                            void* p_data,
                            const char* caller_func_name,
                            int next_state);

    // m_data1 = int operation_status;
    // m_data2 = *AggNode
    // m_data3 = *AggNodeQuotaAllocInfo
    void FabricQuotaConfigCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);

    // m_data1 = int operation_status;
    // m_data2 = *AggNode
    // m_data3 = *AggNodeQuotaAllocInfo
    void FabricTreeToJobBindCallback(FabricProviderCallbackContext* p_context, int rec_status, void* p_data);
    //////////////////////////////////////////////////////
};
