// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/rocm/tensor/gelu.h"
#include "core/providers/rocm/tensor/gelu_impl.h"

namespace onnxruntime {
namespace rocm {

#define REGISTER_KERNEL_TYPED(T)                                 \
  ONNX_OPERATOR_TYPED_KERNEL_EX(                                 \
      Gelu,                                                      \
      kOnnxDomain,                                               \
      20,                                                        \
      T,                                                         \
      kRocmExecutionProvider,                                    \
      (*KernelDefBuilder::Create())                              \
          .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
          .MayInplace(0, 0),                                     \
      Gelu<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

template <typename T>
Status Gelu<T>::ComputeInternal(OpKernelContext* context) const {
  const Tensor* input = context->Input<Tensor>(0);
  const auto& input_dims = input->Shape().GetDims();
  if (input_dims.size() < 1) {
    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                           "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
  }

  Tensor* output = context->Output(0, input->Shape());

  int64_t input_length = input->Shape().Size();
  if (input_length == 0) {
    return Status::OK();
  }

  typedef typename ToHipType<T>::MappedType HipT;

  if (approximation_algorithm_ == "tanh") {
    return LaunchFastGeluKernel<HipT>(GetDeviceProp(),
                                       Stream(context),
                                       static_cast<int>(input_length),
                                       0 /* no bias */,
                                       reinterpret_cast<const HipT*>(input->Data<T>()),
                                       nullptr /* no bias */,
                                       reinterpret_cast<HipT*>(output->MutableData<T>()),
                                       use_half2_);
  } else if (approximation_algorithm_ == "none") {
    return LaunchGeluKernel<HipT>(Stream(context),
                                   reinterpret_cast<const HipT*>(input->Data<T>()),
                                   reinterpret_cast<HipT*>(output->MutableData<T>()),
                                   static_cast<size_t>(input_length));
  }

  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_);
}

}  // namespace rocm

#ifndef DISABLE_CONTRIB_OPS
namespace contrib::rocm {
#define REGISTER_CONTRIB_KERNEL_TYPED(T)                         \
  ONNX_OPERATOR_TYPED_KERNEL_EX(                                 \
      Gelu,                                                      \
      kMSDomain,                                                 \
      1,                                                         \
      T,                                                         \
      kRocmExecutionProvider,                                    \
      (*KernelDefBuilder::Create())                              \
          .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
          .MayInplace(0, 0),                                     \
      onnxruntime::rocm::Gelu<T>);

REGISTER_CONTRIB_KERNEL_TYPED(float)
REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(BFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(double)

#undef REGISTER_CONTRIB_KERNEL_TYPED
}  // namespace contrib::rocm
#endif

}  // namespace onnxruntime
