Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

any plans for adding repo using stable vicuna for conversation .. human: assistant #16

Open
andysingal opened this issue Aug 14, 2023 · 0 comments

Comments

@andysingal
Copy link

andysingal commented Aug 14, 2023

Hi,
Thanks for the amazing repo, checking if you have any plans to add a repo using stablevicuna.
Morever, if you can share hardware requirements to download the checkpoints, that will be awesome. I ran out of memory using Kaggle notebook.
Looking forward to hearing from you.
Best,
Andy

i tried to create for human bot within a single column message text but still getting error:

%%writefile sft_dataloader.py
def format_prompt(prompt: str) -> str:
    text = f"""
### Human: {prompt}
### Assistant:
    """
    return text.strip()


class SFTDataLoader(object):
    def __init__(self, data, CUTOFF_LEN, VAL_SET_SIZE, tokenizer) -> None:
        super(SFTDataLoader, self).__init__()

        self.data = data
        self.CUTOFF_LEN = CUTOFF_LEN
        self.VAL_SET_SIZE = VAL_SET_SIZE

        self.tokenizer = tokenizer

    def generate_prompt(self, data_point):
        return format_prompt(data_point["message_tree_text"])

    def tokenize(self, prompt):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        return self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.CUTOFF_LEN + 1,
            padding="max_length",
            return_unused_tokens=True,
        )

    def generate_and_tokenize_prompt(self, data_point):
        # This function masks out the labels for the input,
        # so that our loss is computed only on the response.
        user_prompt = format_prompt(data_point["message_tree_text"])
        len_user_prompt_tokens = len(
            self.tokenizer(
                user_prompt,
                truncation=True,
                max_length=self.CUTOFF_LEN + 1,
                padding="max_length",
                return_unused_tokens=True,
            )["input_ids"]
        )
        full_tokens = self.tokenizer(
            user_prompt,
            truncation=True,
            max_length=self.CUTOFF_LEN + 1,
            padding="max_length",
            return_unused_tokens=True,
        )["input_ids"]
        return {
            "input_ids": full_tokens[: len_user_prompt_tokens],
            "labels": full_tokens[len_user_prompt_tokens:],
            "attention_mask": [1] * len(full_tokens),
        }

    def load_data(self):
        train_val = formatted_dataset.train_test_split(
            test_size=self.VAL_SET_SIZE, shuffle=True, seed=42
        )
        train_data = train_val["train"].shuffle().map(
            self.generate_and_tokenize_prompt
        )
        val_data = train_val["test"].shuffle().map(
            self.generate_and_tokenize_prompt
        )

        return train_data, val_data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant