diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py index 061c8d73f..65ec3ffd7 100644 --- a/tinygrad/runtime/autogen/cuda.py +++ b/tinygrad/runtime/autogen/cuda.py @@ -1,7 +1,7 @@ # mypy: ignore-errors import ctypes from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR -dll = DLL('cuda', 'cuda') +dll = DLL('cuda', '@driverLink@/lib/libcuda.so') cuuint32_t = ctypes.c_uint32 cuuint64_t = ctypes.c_uint64 CUdeviceptr_v2 = ctypes.c_uint64 diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py index 88085c45b..90518d403 100644 --- a/tinygrad/runtime/autogen/nvrtc.py +++ b/tinygrad/runtime/autogen/nvrtc.py @@ -2,7 +2,7 @@ import ctypes from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR import sysconfig -dll = DLL('nvrtc', 'nvrtc', f'/usr/local/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib') +dll = DLL('nvrtc','@cuda_nvrtc@/lib/libnvrtc.so') nvrtcResult = CEnum(ctypes.c_uint32) NVRTC_SUCCESS = nvrtcResult.define('NVRTC_SUCCESS', 0) NVRTC_ERROR_OUT_OF_MEMORY = nvrtcResult.define('NVRTC_ERROR_OUT_OF_MEMORY', 1) diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py index 8f71a9255..fdbf01bad 100644 --- a/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad/runtime/support/compiler_cuda.py @@ -43,7 +43,7 @@ def cuda_disassemble(lib:bytes, arch:str): class CUDACompiler(Compiler): def __init__(self, arch:str, cache_key:str="cuda"): self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}'] - self.compile_options += [f"-I{CUDA_PATH}/include"] if CUDA_PATH else ["-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"] + self.compile_options += ["-I@cuda_cudart@/include"] nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int()))) if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal") super().__init__(f"compile_{cache_key}_{self.arch}")