// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

template <typename FlatmmConfig,
          typename ADataType,
          typename BDataType,
          typename DsDatatype,
          typename AccDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename DsLayout,
          typename ELayout,
          ck_tile::MoeFlatmmKind kind,
          typename CDEElementWise = ck_tile::element_wise::PassThrough,
          typename MoeHostArgs>
float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
{
    float ave_time = moe_gemm<FlatmmConfig,
                              ADataType,
                              BDataType,
                              DsDatatype,
                              AccDataType,
                              CDataType,
                              ALayout,
                              BLayout,
                              DsLayout,
                              ELayout,
                              kind,
                              CDEElementWise>(
        args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});

    std::string op_name{"Moe Gemm"};

    std::size_t flop     = std::size_t(2) * args.M * args.N * args.K;
    std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
                           sizeof(BDataType) * args.N * args.K +
                           sizeof(CDataType) * args.M * args.N;
    float tflops     = static_cast<float>(flop) / 1.E9 / ave_time;
    float gb_per_sec = num_byte / 1.E6 / ave_time;

    std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
              << gb_per_sec << " GB/s, " << op_name << std::endl;

    return ave_time;
}

template <typename PrecType,
          typename FlatmmConfig,
          ck_tile::MoeFlatmmKind kind,
          typename ALayout,
          typename BLayout,
          typename CLayout>
int run_moe_gemm_example_with_layouts(int argc,
                                      char* argv[],
                                      const ALayout a_layout                  = ALayout{},
                                      const BLayout b_layout                  = BLayout{},
                                      [[maybe_unused]] const CLayout c_layout = CLayout{})
{
    auto [result, arg_parser] = create_args(argc, argv);

    if(!result)
    {
        return -1;
    };

    using ADataType   = typename GemmBasicTypeConfig<PrecType>::ADataType;
    using BDataType   = typename GemmBasicTypeConfig<PrecType>::BDataType;
    using CDataType   = typename GemmBasicTypeConfig<PrecType>::CDataType;
    using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;

    constexpr int ScaleGranularityM = 1;
    constexpr int ScaleGranularityN = 1;

    const ck_tile::index_t N          = arg_parser.get_int("N");
    const ck_tile::index_t K          = arg_parser.get_int("K");
    ck_tile::index_t stride_A         = arg_parser.get_int("stride_A");
    ck_tile::index_t stride_B         = arg_parser.get_int("stride_B");
    ck_tile::index_t stride_C         = arg_parser.get_int("stride_C");
    const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
    const ck_tile::index_t topk       = arg_parser.get_int("TopK");
    const ck_tile::index_t warmup     = arg_parser.get_int("warmup");
    const ck_tile::index_t repeat     = arg_parser.get_int("repeat");
    const ck_tile::index_t experts    = arg_parser.get_int("experts");

    // TODO: replace the magic declaration
    const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;

    ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
    ck_tile::index_t valid_tile_num  = sorted_tile_num;
    ck_tile::index_t sorted_size     = sorted_tile_num * MPerBlock;

    const ck_tile::index_t M       = sorted_tile_num * MPerBlock;
    const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;

    static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
    constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;

    stride_A = ck_tile::get_default_stride(
        IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
    stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
    stride_C = ck_tile::get_default_stride(
        IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));

    auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
        IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
    auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
        is_row_major(b_layout)
            ? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
            : ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
    auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
        IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));

    ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
    ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);

    auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);

    std::cout << "moe_flatmm:" //
              << "\n  num_experts: " << experts << "\n  num_tokens: " << num_tokens
              << "\n  topk: " << topk << "\n  sorted_tile_num: " << sorted_tile_num
              << "\n  a_m_k: " << a_m_k_tensor.mDesc << "\n  b_k_n: " << b_k_n_tensor.mDesc
              << "\n  b_shuffle: " << b_shuffle_host.mDesc << "\n  c_m_n: " << c_m_n_tensor.mDesc
              << std::endl;

    ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
    ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
    ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
    ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};

    a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
    b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
    b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
    c_m_n_dev_buf.SetZero();
    c_m_n_tensor.SetZero();

    const void* p_a         = a_m_k_dev_buf.GetDeviceBuffer();
    const void* p_b_origin  = b_origin_dev_buf.GetDeviceBuffer();
    const void* p_b_shuffle = b_shuffle_dev_buf.GetDeviceBuffer();
    void* p_c               = c_m_n_dev_buf.GetDeviceBuffer();

    // TODO: malloc and init sorted tokens and max tokens buffer

    ck_tile::HostTensor<ck_tile::index_t> expert_ids(
        ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
    ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
        ck_tile::HostTensorDescriptor({sorted_size}, {1}));
    ck_tile::HostTensor<AccDataType> expert_weight(
        ck_tile::HostTensorDescriptor({sorted_size}, {1}));
    ck_tile::HostTensor<ck_tile::index_t> max_token_id(
        ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));

    ck_tile::HostTensor<AccDataType> per_token_scale(
        ck_tile::HostTensorDescriptor({IsInputGemm ? num_tokens : M}, {1}));
    ck_tile::HostTensor<AccDataType> per_channel_scale(
        ck_tile::HostTensorDescriptor({N * experts}, {1}));

    ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
    ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);

    // for verification only, no need to satify weight normalization
    ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);

    ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
    ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
    ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
    ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};

    ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
    ck_tile::DeviceMem per_channel_scale_dev_buf(
        per_channel_scale.get_element_space_size_in_bytes());

    max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
    // int eids[]         = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}

    for(int i = 0; i < sorted_tile_num; i++)
    {
        expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
    }

    int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
    // int token_per_tile = num_tokens * topk / valid_tile_num;
    int tokenid = 0;
    // sorted_token_ids.mData[0] = 0;
    for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
    {
        int tile_off = i % MPerBlock;
        if(tile_off < token_per_tile && tokenid < num_tokens * topk)
        {
            sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
            tokenid++;
        }
        else
        {
            sorted_token_ids.mData[i] = num_tokens;
        }
    }

    sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
    expert_ids_dev.ToDevice(expert_ids.data());
    max_token_id_dev.ToDevice(max_token_id.data());
    expert_weight_dev.ToDevice(expert_weight.data());
    per_token_scale_dev_buf.ToDevice(per_token_scale.data());
    per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());

    const ck_tile::index_t* p_sorted_token_ids_dev =
        static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
    const ck_tile::index_t* p_expert_ids_dev =
        static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
    const ck_tile::index_t* p_max_token_id_dev =
        static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
    const AccDataType* p_sorted_expert_weight_dev =
        static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());

    using MoeFlatmmArgs =
        ck_tile::MoeFlatmmHostArgs<ck_tile::FlatmmScalePointer<ScaleGranularityM>,
                                   ck_tile::FlatmmScalePointer<ScaleGranularityN>>;

    auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
        static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
    auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
        static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};

    MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
                            p_sorted_expert_weight_dev,
                            p_expert_ids_dev,
                            p_max_token_id_dev,
                            p_a,
                            p_b_shuffle,
                            p_c,
                            num_tokens,
                            experts,
                            topk,
                            1, // k_batch
                            M,
                            N,
                            K,
                            stride_A,
                            stride_B,
                            stride_C,
                            per_token_scale_dev_ptr,
                            per_channel_scale_dev_ptr};

    invoke_moe_gemm<FlatmmConfig,
                    ADataType,
                    BDataType,
                    ck_tile::tuple<>,
                    AccDataType,
                    CDataType,
                    ALayout,
                    BLayout,
                    ck_tile::tuple<>,
                    CLayout,
                    kind>(warmup, repeat, gemm_desc);

    c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());

    bool pass{true};
    if(arg_parser.get_int("validate"))
    {
        ck_tile::HostTensor<CDataType> c_m_n_host_ref(
            ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
                                            outputN,
                                            stride_C,
                                            is_row_major(CLayout{})));

        c_m_n_host_ref.SetZero();

        std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
            std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());

        c_m_n_ref_buf->SetZero();

        ck_tile::reference_moe_gemm_gpu<ADataType,
                                        BDataType,
                                        AccDataType,
                                        CDataType,
                                        ALayout,
                                        BLayout,
                                        CLayout,
                                        static_cast<int>(kind),
                                        ck_tile::moe::MoeSilu>(
            p_sorted_token_ids_dev,
            p_expert_ids_dev,
            p_max_token_id_dev,
            static_cast<const ADataType*>(p_a),
            static_cast<const BDataType*>(p_b_origin),
            static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
            p_sorted_expert_weight_dev,
            num_tokens,
            MPerBlock,
            topk,
            M,
            N,
            K,
            stride_A,
            stride_B,
            stride_C,
            1,
            1,
            K,
            static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
            static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));

        const float max_accumulated_value =
            *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
        const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
            K, 1 /*kbatch*/, max_accumulated_value);
        c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());

        const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
        const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;

        pass = ck_tile::check_err(
            c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);

        std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
                  << std::endl;
        std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
    }

    return pass;
}
