summaryrefslogtreecommitdiff
path: root/pkgs/development/rocm-modules/clr/test-hiprtc-type-traits.cpp
blob: 39a99f39eae113d471c32b3d0be2697d1e98f6f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include <iostream>
#include <string>

#define CHECK_HIP(expr) do { \
  if ((expr) != hipSuccess) { \
    std::cerr << #expr << " failed" << std::endl; \
    return 1; \
  } \
} while(0)

#define CHECK_HIPRTC(expr) do { \
  hiprtcResult _res = (expr); \
  if (_res != HIPRTC_SUCCESS) { \
    std::cerr << #expr << " failed: " << hiprtcGetErrorString(_res) << std::endl; \
    hiprtcGetProgramLogSize(prog, &log_size); \
    if (log_size > 0) { \
      std::string log(log_size, '\0'); \
      hiprtcGetProgramLog(prog, log.data()); \
      std::cerr << "Compile log:\n" << log << std::endl; \
    } \
    return 1; \
  } \
} while(0)

static const char* kernelSource = R"(
  #include <type_traits>

  extern "C" __global__ void test_kernel(int* out) {
    static_assert(std::is_same<int, std::remove_const<const int>::type>::value,
                  "type_traits not working");
    out[0] = 5;
  }
)";

int main() {
  hiprtcProgram prog;
  size_t log_size = 0;
  CHECK_HIPRTC(hiprtcCreateProgram(&prog, kernelSource, "test.hip", 0, nullptr, nullptr));
  CHECK_HIPRTC(hiprtcCompileProgram(prog, 0, nullptr));

  size_t code_size;
  CHECK_HIPRTC(hiprtcGetCodeSize(prog, &code_size));
  std::string code(code_size, '\0');
  CHECK_HIPRTC(hiprtcGetCode(prog, code.data()));
  hiprtcDestroyProgram(&prog);

  hipModule_t module;
  hipFunction_t kernel;
  CHECK_HIP(hipModuleLoadData(&module, code.data()));
  CHECK_HIP(hipModuleGetFunction(&kernel, module, "test_kernel"));

  int* d_out;
  int h_out = 0;
  CHECK_HIP(hipMalloc(&d_out, sizeof(int)));
  void* args[] = { &d_out };
  CHECK_HIP(hipModuleLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, nullptr, args, nullptr));
  CHECK_HIP(hipMemcpy(&h_out, d_out, sizeof(int), hipMemcpyDeviceToHost));

  if (h_out != 5) {
    std::cerr << "Kernel output mismatch: expected 5, got " << h_out << std::endl;
    return 1;
  }

  std::cout << "HIPRTC type_traits test passed (output=" << h_out << ")" << std::endl;
  return 0;
}