/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *  Modifications Copyright© 2019-2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include "test_param_fixtures.hpp"
#include "test_utils.hpp"

THRUST_DIAG_PUSH
THRUST_DIAG_SUPPRESS_MSVC(4244 4267) // possible loss of data

template <typename Iterator1, typename Iterator2>
__global__ THRUST_HIP_LAUNCH_BOUNDS_DEFAULT void
simple_copy_on_device(Iterator1 first1, Iterator1 last1, Iterator2 first2)
{
  while (first1 != last1)
  {
    *(first2++) = *(first1++);
  }
}

template <typename Iterator1, typename Iterator2>
void simple_copy(Iterator1 first1, Iterator1 last1, Iterator2 first2)
{
  simple_copy_on_device<<<1, 1>>>(first1, last1, first2);
}

TEST(DereferenceTests, TestDeviceDereferenceDeviceVectorIterator)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  for (auto seed : get_seeds())
  {
    SCOPED_TRACE(testing::Message() << "with seed= " << seed);

    thrust::device_vector<int> input =
      get_random_data<int>(100, get_default_limits<int>::min(), get_default_limits<int>::max(), seed);
    thrust::device_vector<int> output(input.size(), 0);

    simple_copy(input.begin(), input.end(), output.begin());

    ASSERT_EQ(input, output);
  }
}

TEST(DereferenceTests, TestDeviceDereferenceDevicePtr)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  for (auto seed : get_seeds())
  {
    SCOPED_TRACE(testing::Message() << "with seed= " << seed);

    thrust::device_vector<int> input =
      get_random_data<int>(100, get_default_limits<int>::min(), get_default_limits<int>::max(), seed);
    thrust::device_vector<int> output(input.size(), 0);

    thrust::device_ptr<int> _first1 = &input[0];
    thrust::device_ptr<int> _last1  = _first1 + input.size();
    thrust::device_ptr<int> _first2 = &output[0];

    simple_copy(_first1, _last1, _first2);

    ASSERT_EQ(input, output);
  }
}

TEST(DereferenceTests, TestDeviceDereferenceTransformIterator)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  for (auto seed : get_seeds())
  {
    SCOPED_TRACE(testing::Message() << "with seed= " << seed);

    thrust::device_vector<int> input =
      get_random_data<int>(100, get_default_limits<int>::min(), get_default_limits<int>::max(), seed);
    thrust::device_vector<int> output(input.size(), 0);

    simple_copy(thrust::make_transform_iterator(input.begin(), ::internal::identity{}),
                thrust::make_transform_iterator(input.end(), ::internal::identity{}),
                output.begin());

    ASSERT_EQ(input, output);
  }
}

TEST(DereferenceTests, TestDeviceDereferenceTransformIteratorInputConversion)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  thrust::device_vector<int> input = unittest::random_integers<int>(100);
  thrust::device_vector<double> output(input.size(), 0);

  simple_copy(thrust::make_transform_iterator(input.begin(), ::internal::identity{}),
              thrust::make_transform_iterator(input.end(), ::internal::identity{}),
              output.begin());

  ASSERT_EQ(input == output, true);
}

TEST(DereferenceTests, TestDeviceDereferenceTransformIteratorOutputConversion)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  thrust::device_vector<int> input = unittest::random_integers<int>(100);
  thrust::device_vector<double> output(input.size(), 0);

  simple_copy(thrust::make_transform_iterator(input.begin(), ::internal::identity{}),
              thrust::make_transform_iterator(input.end(), ::internal::identity{}),
              output.begin());

  ASSERT_EQ(input == output, true);
}

TEST(DereferenceTests, TestDeviceDereferenceCountingIterator)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  thrust::counting_iterator<int> first(1);
  thrust::counting_iterator<int> last(6);

  thrust::device_vector<int> output(5);

  simple_copy(first, last, output.begin());

  thrust::device_vector<int> ref{1, 2, 3, 4, 5};
  ASSERT_EQ(output, ref);
}

TEST(DereferenceTests, TestDeviceDereferenceTransformedCountingIterator)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

  thrust::counting_iterator<int> first(1);
  thrust::counting_iterator<int> last(6);

  thrust::device_vector<int> output(5);

  simple_copy(thrust::make_transform_iterator(first, thrust::negate<int>()),
              thrust::make_transform_iterator(last, thrust::negate<int>()),
              output.begin());

  thrust::device_vector<int> ref{-1, -2, -3, -4, -5};
  ASSERT_EQ(output, ref);
}

THRUST_DIAG_POP
