Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Configuring ZeroShotClassificationModel with DeBERTaV2 - Documentation #433

Open
Philipp-Sc opened this issue Nov 14, 2023 · 1 comment

Comments

@Philipp-Sc
Copy link

Hi @jondot @guillaume-be

I work on a llm-fraud-detection library using rust-bert and llama.cpp.

In the process of updating the repo, I like to test DeBERTaV2 for my zero shot classification task.

Currently I am using BERT which is very simple to setup due to the provided Default implementation.

ZeroShotClassificationModel::new(Default::default();

See

Do I need to download the model on my own and convert it to rust and provide the paths like this? See #406

fn generation_config(base_path: &str) -> ZeroShotClassificationConfig {
    let model_path = PathBuf::from(base_path.to_owned() + "rust_model.ot");
    let config_path = PathBuf::from(base_path.to_owned() + "config.json");
    let vocab_path = PathBuf::from(base_path.to_owned() + "vocab.json");
    let merges_path = PathBuf::from(base_path.to_owned() + "merges.txt");

    ZeroShotClassificationConfig {
        model_type: ModelType::DeBERTaV2,
        model_resource: Box::new(LocalResource::from(model_path)),
        config_resource: Box::new(LocalResource::from(config_path)),
        vocab_resource: Box::new(LocalResource::from(vocab_path)),
        merges_resource: Some(Box::new(LocalResource::from(merges_path))),
        lower_case: false,
        strip_accents: None,
        add_prefix_space: None,
        device: Device::cuda_if_available(),
    }
}

Can you point me in the right direction on how to setup the ZeroShotClassificationModel for DeBERTaV2?

This might be a good time to add an example for this, this would likely help #425 as well. (I would be happy to propose a PR once I got the hang of it)

Thanks in advance.

Best regards,
Philipp-Sc

@mikkel1156
Copy link

mikkel1156 commented Jan 21, 2024

In my case it was because of wrong vocab file, found out that the spm.model file is the vocab file for the model that I'm using.

However this was with ONNX backend (needs the feature enabled). The model I used is https://huggingface.co/Xenova/nli-deberta-v3-large

use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationConfig;
use rust_bert::resources::LocalResource;
use std::path::PathBuf;

fn model(base_path: &str) {
    let model_path = PathBuf::from(base_path.to_owned() + "onnx/model.onnx");
    let config_path = PathBuf::from(base_path.to_owned() + "config.json");
    let vocab_path = PathBuf::from(base_path.to_owned() + "spm.model");

    let classification_model = ZeroShotClassificationModel::new(ZeroShotClassificationConfig::new(
        ModelType::DebertaV2,
        ModelResource::ONNX(ONNXModelResources {
            encoder_resource: Some(Box::new(LocalResource::from(model_path))),
            ..Default::default()
        }),
        LocalResource::from(config_path),
        LocalResource::from(vocab_path),
        None,
        false,
        None,
        None,
    )).expect("could not create zero_shot_classification model");
    let input = ["Who are you voting for in 2020?", "The prime minister has announced a stimulus package which was widely criticized by the opposition."];
    let labels = &["politics", "public health", "economics", "sports"];
    let output = classification_model.predict_multilabel(
        &input,
        labels,
        None,
        128
    );
    println!("{:?}", output);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants