Skip to content

Commit

Permalink
llmodel: add CUDA to the DLL search path if CUDA_PATH is set (#2357)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
  • Loading branch information
cebtenzzre committed May 16, 2024
1 parent a92d266 commit 2025d2d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
5 changes: 3 additions & 2 deletions gpt4all-backend/dlhandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ class Dlhandle {
#include <string>
#include <exception>
#include <stdexcept>

#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#include <libloaderapi.h>



class Dlhandle {
HMODULE chandle;

Expand Down
26 changes: 25 additions & 1 deletion gpt4all-backend/llmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@
#include <unordered_map>
#include <vector>

#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
#endif

#ifdef _MSC_VER
#include <intrin.h>
# include <intrin.h>
#endif

#ifndef __APPLE__
Expand Down Expand Up @@ -85,6 +93,20 @@ static bool isImplementation(const Dlhandle &dl) {
return dl.get<bool(uint32_t)>("is_g4a_backend_model_implementation");
}

// Add the CUDA Toolkit to the DLL search path on Windows.
// This is necessary for chat.exe to find CUDA when started from Qt Creator.
static void addCudaSearchPath() {
#ifdef _WIN32
if (const auto *cudaPath = _wgetenv(L"CUDA_PATH")) {
auto libDir = std::wstring(cudaPath) + L"\\bin";
if (!AddDllDirectory(libDir.c_str())) {
auto err = GetLastError();
std::wcerr << L"AddDllDirectory(\"" << libDir << L"\") failed with error 0x" << std::hex << err << L"\n";
}
}
#endif
}

const std::vector<LLModel::Implementation> &LLModel::Implementation::implementationList() {
if (cpu_supports_avx() == 0) {
throw std::runtime_error("CPU does not support AVX");
Expand All @@ -95,6 +117,8 @@ const std::vector<LLModel::Implementation> &LLModel::Implementation::implementat
static auto* libs = new std::vector<Implementation>([] () {
std::vector<Implementation> fres;

addCudaSearchPath();

std::string impl_name_re = "(gptj|llamamodel-mainline)-(cpu|metal|kompute|vulkan|cuda)";
if (cpu_supports_avx2() == 0) {
impl_name_re += "-avxonly";
Expand Down

0 comments on commit 2025d2d

Please sign in to comment.