Skip to content

Commit

Permalink
an attempt for cross-platform (Windows) support
Browse files Browse the repository at this point in the history
  • Loading branch information
didzis committed Feb 10, 2024
1 parent f10a7b4 commit 8be9f23
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions cuda-loader.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
#include <stdio.h>

#include <dlfcn.h>

#include <cuda.h>

#ifdef _WIN32
#include <windows.h>
#define CUDA_DRIVER_LIBRARY "nvcuda.dll"
typedef HMODULE lib_handle;
#define load_library(name) LoadLibrary(name)
#define load_symbol(lib, symbol) GetProcAddress(lib, symbol)
#define close_library(lib) FreeLibrary(lib)
#else
#include <dlfcn.h>
#define CUDA_DRIVER_LIBRARY "libcuda.so"
#define CUDA_DRIVER_LIBRARY_ALT "libcuda.so.1"
typedef void* lib_handle;
#define load_library(name) dlopen(name, RTLD_NOW)
#define load_symbol(lib, symbol) dlsym(lib, symbol)
#define close_library(lib) dlclose(lib)
#endif


typedef CUresult (*cuDeviceGet_pt)(CUdevice *device, int ordinal);
typedef CUresult (*cuDeviceGetAttribute_pt)(int *pi, CUdevice_attribute attrib, CUdevice dev);
Expand All @@ -29,40 +44,42 @@ cuMemSetAccess_pt _cuMemSetAccess = NULL;

int load_libcuda(void) {

static void * libcuda = NULL;
static lib_handle libcuda = NULL;

if (libcuda == (void*)1)
if (libcuda == (lib_handle)1)
return 0;

if (libcuda != NULL)
return 1;

libcuda = dlopen("libcuda.so", RTLD_NOW);
libcuda = load_library(CUDA_DRIVER_LIBRARY);

#ifdef CUDA_DRIVER_LIBRARY_ALT
if (libcuda == NULL) {
libcuda = dlopen("libcuda.so.1", RTLD_NOW);
libcuda = load_library(CUDA_DRIVER_LIBRARY_ALT);
}
#endif

if (libcuda != NULL) {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpedantic"
_cuDeviceGet = (cuDeviceGet_pt)dlsym(libcuda, "cuDeviceGet");
_cuDeviceGetAttribute = (cuDeviceGetAttribute_pt)dlsym(libcuda, "cuDeviceGetAttribute");
_cuGetErrorString = (cuGetErrorString_pt)dlsym(libcuda, "cuGetErrorString");
_cuMemGetAllocationGranularity = (cuMemGetAllocationGranularity_pt)dlsym(libcuda, "cuMemGetAllocationGranularity");
_cuMemCreate = (cuMemCreate_pt)dlsym(libcuda, "cuMemCreate");
_cuMemAddressReserve = (cuMemAddressReserve_pt)dlsym(libcuda, "cuMemAddressReserve");
_cuMemMap = (cuMemMap_pt)dlsym(libcuda, "cuMemMap");
_cuMemRelease = (cuMemRelease_pt)dlsym(libcuda, "cuMemRelease");
_cuMemSetAccess = (cuMemSetAccess_pt)dlsym(libcuda, "cuMemSetAccess");
_cuDeviceGet = (cuDeviceGet_pt)load_symbol(libcuda, "cuDeviceGet");
_cuDeviceGetAttribute = (cuDeviceGetAttribute_pt)load_symbol(libcuda, "cuDeviceGetAttribute");
_cuGetErrorString = (cuGetErrorString_pt)load_symbol(libcuda, "cuGetErrorString");
_cuMemGetAllocationGranularity = (cuMemGetAllocationGranularity_pt)load_symbol(libcuda, "cuMemGetAllocationGranularity");
_cuMemCreate = (cuMemCreate_pt)load_symbol(libcuda, "cuMemCreate");
_cuMemAddressReserve = (cuMemAddressReserve_pt)load_symbol(libcuda, "cuMemAddressReserve");
_cuMemMap = (cuMemMap_pt)load_symbol(libcuda, "cuMemMap");
_cuMemRelease = (cuMemRelease_pt)load_symbol(libcuda, "cuMemRelease");
_cuMemSetAccess = (cuMemSetAccess_pt)load_symbol(libcuda, "cuMemSetAccess");
#pragma GCC diagnostic pop

return 1;
}

fprintf(stderr, "error: failed to load libcuda.so: %s\n", dlerror());
fprintf(stderr, "error: failed to load the CUDA driver: %s\n", dlerror());

libcuda = (void*)1; // tried and failed
libcuda = (lib_handle)1; // tried and failed
return 0;
}

Expand Down

0 comments on commit 8be9f23

Please sign in to comment.