Skip to content

Commit

Permalink
backend: do not crash if GGUF lacks general.architecture (#2346)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
  • Loading branch information
cebtenzzre committed May 15, 2024
1 parent 6d8888b commit 9f9d8e6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 24 deletions.
16 changes: 11 additions & 5 deletions gpt4all-backend/gptj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,12 +786,14 @@ const std::vector<LLModel::Token> &GPTJ::endTokens() const
}

const char *get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
if (kid == -1)
throw std::runtime_error("key not found in model: general.architecture");

enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != GGUF_TYPE_STRING) {
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
}
if (ktype != GGUF_TYPE_STRING)
throw std::runtime_error("key general.architecture has wrong type");

return gguf_get_val_str(ctx_gguf, kid);
}

Expand Down Expand Up @@ -824,7 +826,11 @@ DLL_EXPORT char *get_file_arch(const char *fname) {

char *arch = nullptr;
if (ctx_gguf && gguf_get_version(ctx_gguf) <= 3) {
arch = strdup(get_arch_name(ctx_gguf));
try {
arch = strdup(get_arch_name(ctx_gguf));
} catch (const std::runtime_error &) {
// cannot read key -> return null
}
}

gguf_free(ctx_gguf);
Expand Down
69 changes: 50 additions & 19 deletions gpt4all-backend/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ static int llama_sample_top_p_top_k(
}

const char *get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
if (kid == -1)
throw std::runtime_error("key not found in model: general.architecture");

enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != (GGUF_TYPE_STRING)) {
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
}
if (ktype != GGUF_TYPE_STRING)
throw std::runtime_error("key general.architecture has wrong type");

return gguf_get_val_str(ctx_gguf, kid);
}

Expand All @@ -136,13 +138,20 @@ static gguf_context *load_gguf(const char *fname) {
}

static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
int32_t value = -1;
std::string arch;

auto * ctx = load_gguf(modelPath.c_str());
if (!ctx)
return -1;
std::string arch = get_arch_name(ctx);
goto cleanup;

int32_t value = -1;
if (ctx) {
try {
arch = get_arch_name(ctx);
} catch (const std::runtime_error &) {
goto cleanup; // cannot read key
}

{
auto key = arch + "." + archKey;
int keyidx = gguf_find_key(ctx, key.c_str());
if (keyidx != -1) {
Expand All @@ -152,6 +161,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
}
}

cleanup:
gguf_free(ctx);
return value;
}
Expand Down Expand Up @@ -244,15 +254,26 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const {
}

bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const {
bool result = false;
std::string arch;

auto *ctx_gguf = load_gguf(modelPath.c_str());
if (!ctx_gguf) {
std::cerr << __func__ << ": failed to load GGUF from " << modelPath << "\n";
return false;
goto cleanup;
}

try {
arch = get_arch_name(ctx_gguf);
} catch (const std::runtime_error &) {
goto cleanup; // cannot read key
}

std::string arch = get_arch_name(ctx_gguf);
result = is_embedding_arch(arch);

cleanup:
gguf_free(ctx_gguf);
return is_embedding_arch(arch);
return result;
}

bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
Expand Down Expand Up @@ -964,16 +985,26 @@ DLL_EXPORT const char *get_build_variant() {
}

DLL_EXPORT char *get_file_arch(const char *fname) {
auto *ctx = load_gguf(fname);
char *arch = nullptr;
if (ctx) {
std::string archStr = get_arch_name(ctx);
if (is_embedding_arch(archStr) && gguf_find_key(ctx, (archStr + ".pooling_type").c_str()) < 0) {
// old bert.cpp embedding model
} else {
arch = strdup(archStr.c_str());
}
std::string archStr;

auto *ctx = load_gguf(fname);
if (!ctx)
goto cleanup;

try {
archStr = get_arch_name(ctx);
} catch (const std::runtime_error &) {
goto cleanup; // cannot read key
}

if (is_embedding_arch(archStr) && gguf_find_key(ctx, (archStr + ".pooling_type").c_str()) < 0) {
// old bert.cpp embedding model
} else {
arch = strdup(archStr.c_str());
}

cleanup:
gguf_free(ctx);
return arch;
}
Expand Down

0 comments on commit 9f9d8e6

Please sign in to comment.