Skip to content

FareedKhan-dev/Improve-Weak-LLM-Using-SPIN-Technique

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 

Repository files navigation

Convert Weak LLM to Strong LLM Using SPIN Technique

Can we help a weak LLM get better without getting more data?

Much of the development in Large Language Models (LLMs) has already taken place as we enter the year 2024. Among these, an important one is alignment methods, which involves Supervised Fine-Tuning (SFT) using human examples and Reinforcement Learning from Human Feedback (RLHF) relying on human preferences. These methods have played a crucial role in recent efforts to make LLMs better. However, the challenge with alignment methods, especially those mentioned earlier, lies in the significant requirement for human-annotated data. This challenge makes fine-tuning a dynamic area of research, with researchers actively working on developing methods that can effectively utilize human data.

Visual illustration of How SFT and RLHF Works Open foundation and fine-tuned chat models

A recent study from the University of California has introduced a novel technique named SPIN (Self Play fIne tuNing). Drawing inspiration from the success of self-play mechanisms in games, like AlphaGo Zero and AlphaZero. SPIN starts with a supervised fine-tuned model. What makes it stand out is its ability to enable the LLM to engage in self-play. This eliminates the requirement for an expert annotator, be it a human or a more advanced LLM like GPT-4. In simple terms, SPIN involves training a new language model to differentiate between its own generated responses and human-generated responses through a series of iterations. The ultimate goal is to develop a language model that produces responses indistinguishable from those produced by humans.

Table of Contents

What is Self Play?

Self-play is a technique where an algorithm learns by playing against copies of itself. This method increases the challenge and complexity of the learning environment, allowing agents to interact with various versions of themselves. It has gained significant attention in multi-agent reinforcement learning (MARL) due to its effectiveness. A notable example is AlphaGo Zero, a self-play learning scheme that achieved exceptional performance against human players in the game of Go.

How Self-Play Environment Works (Created by Fareed Khan)

Researchers have explored various adaptations and implementations of self-play, including variations in the number of agents, the type of interactions, and the learning algorithms used. The effectiveness of self-play in MARL is well-established, but its application to the enhancement of large language models (LLMs) is a new approach. The application of self-play to LLMs has the potential to further enhance their capabilities and enable them to generate more coherent, informative, and engaging text.

Self-play can be used in both competitive and cooperative settings.

  • In competitive settings, the copies of the algorithm compete against each other to achieve a specific goal.

  • In cooperative settings, the copies of the algorithm work together to achieve a common goal.

It can also be combined with other learning techniques, such as supervised learning and reinforcement learning, to further enhance the performance of the algorithm.

How SPIN Works

SPIN operates like a two-player game. In this game:

  1. Main Player (New LLM) — This player’s role is to learn how to distinguish between responses generated by the Language Model (LLM) and those created by humans. In each iteration (round), the main player is the LLM being actively trained. Its objective is to improve its ability to recognize and differentiate between responses.

  2. Opponent (Old LLM) — The opponent’s task is to generate responses that are indistinguishable from those produced by humans. The opponent, in this case, is the LLM from the previous iteration (round). It uses the self-play mechanism, generating responses based on its past knowledge. The opponent’s goal is to create responses so realistic that the main player (new LLM) has a challenging time deciding whether they are from a human or the LLM.

The dynamics of SPIN involve using a Supervised Fine-Tuning (SFT) dataset, which consists of pairs of input (x) and output (y). These examples are annotated by humans and serve as the basis for training the main player to recognize human-like responses. Some public SFT datasets include Dolly15K, Baize, Ultrachat, and more.

Training Main Player

To train the main player in telling apart language model (LLM) responses from human responses, SPIN uses an objective function. This function measures the expected gap in value between real data and the responses generated by the opponent player. The main player aims to maximize this expected value gap. This involves assigning high values to pairs where a prompt is paired with a response from real data and low values to pairs where the response is generated by the opponent player. This objective function is formulated as a minimization problem.

The main player works to minimize a loss function measuring the disparity between the assigned values for pairs from real data and those from the opponent player’s responses. Throughout training, the main player adjusts its parameters to minimize this loss function. This iterative process continues until the main player becomes adept at effectively telling apart LLM responses from human responses. The choice of function is crucial for the performance of the main player.

Updating Opponent Player

Updating the Opponent Player involves refining the ability of our main player, who has learned to distinguish between real data and the language model responses. With the improved main player and its understanding within a certain function class, let’s see how to update the parameters of the opponent player. When the main player is given two responses to the same prompt, it evaluates their values using its learned discrimination. If one response has a higher value than the other, it assumes that it comes from real data, and the other is from the language model.

The goal of the opponent player is then to enhance the language model so that its responses are indistinguishable from real data according to the main player. To achieve this, a process is set up to adjust the language model’s parameters. The aim is to maximize the main player’s evaluation of language model responses while maintaining stability and avoiding drastic changes. This involves a balancing act, ensuring improvement without straying too far from the original language model.

The process involves finding a new distribution for language model responses that aligns with the main player’s assessments. A regularization term is introduced to prevent excessive deviation from the original model. This ensures a gradual and controlled improvement. Importantly, the obtained distribution may not match the original language model. To ensure alignment, a proportional relationship is solved for that considers both the original and updated models. This leads to a refined model that closely matches the main player’s evaluations. By optimizing this process, a refined model for the opponent player is achieved. This refined model now better matches the main player’s discernment, achieving the goal of improving the language model’s responses in a way that is indistinguishable from real data.

Coding SPIN Algorithm

SPIN algorithm works by first generating synthetic data from the pre-trained model. This synthetic data is then used to fine-tune the model on the new task.

SPIN Algorithm Pseudocode (From Original Paper)

The Spin algorithm’s pseudocode in the original paper might be hard to understand, but by coding it in Python, we can break down each term and better understand how it works.

Initializing the Parameters and SFT Dataset

The original paper uses the Zephyr-7B-SFT-Full as the base model. This model is derived from the pre-trained Mistral-7B. For the dataset, they have used the Ultrachat200k subset of the larger UltraChat corpus, which consists of approximately 1.4 million dialogues produced using OpenAI’s Turbo APIs. From UltraChat200k, they randomly sample 50k prompts and use the base model to generate synthetic responses.

# Import necessary libraries
from datasets import load_dataset
import pandas as pd

# Load the Ultrachat 200k dataset
ultrachat_dataset = load_dataset("HuggingFaceH4/ultrachat_200k")

# Initialize an empty DataFrame
combined_df = pd.DataFrame()

# Loop through all the keys in the Ultrachat dataset
for key in ultrachat_dataset.keys():
    # Convert each dataset key to a pandas DataFrame and concatenate it with the existing DataFrame
    combined_df = pd.concat([combined_df, pd.DataFrame(ultrachat_dataset[key])])

# Shuffle the combined DataFrame and reset the index
combined_df = combined_df.sample(frac=1, random_state=123).reset_index(drop=True)

# Select the first 50,000 rows from the shuffled DataFrame
ultrachat_50k_sample = combined_df.head(50000)

We have coded the same approach of obtaining our dataset by combining all the splits, and then randomly sampling 50k prompts from the original dataframe. As the UltraChat200k dataset contains multi-round conversations, the authors consider the prompting template “### Instruction: {prompt}\n\n### Response:” and only sample the first round as their prompt and ground truth completion pairs.

# for storing each template in a list
templates_data = []

for index, row in ultrachat_50k_sample.iterrows():
    messages = row['messages']
    
    # Check if there are at least two messages (user and assistant)
    if len(messages) >= 2:
        user_message = messages[0]['content']
        assistant_message = messages[1]['content']
        
        # Create the template
        instruction_response_template = f"### Instruction: {user_message}\n\n### Response: {assistant_message}"
        
        # Append the template to the list
        templates_data.append({'Template': instruction_response_template})

# Create a new DataFrame with the generated templates (ground truth)
ground_truth_df = pd.DataFrame(templates_data)

We have transformed our dataframe into a prompt template. This is what the single transformation look like:

Transformation of dataset into prompt template (Created by Fareed Khan)

The prompt template dataset serves as a ground truth dataset for our use, consisting of human responses. Zephyr-7B-SFT-Full will then generate a response for the same prompt, and the SPIN algorithm aims to align it with the ground truth response by iteratively updating the parameters of the language model (LLM). This process continues until it becomes challenging to distinguish between the generated response and the ground truth, achieving a high level of similarity (lowering the loss).

There are two loops in SPIN algorithm. The inner loop runs based on the number of samples we are working with, which is 50k, and the outer loop runs for a total of 3 iterations because the authors found that the model’s performance observe no change after it. Moreover, Alignment Handbook library is used as the codebase for the self-play fine-tuning method, incorporating DeepSpeed module to reduce training costs. They train Zephyr-7B-SFT-Full with the RMSProp optimizer, with no weight decay for all iterations, as commonly used in fine-tuning LLMs for alignment. The global batch size is set to 64, and bfloat16 precision is used. The peak learning rate is set to 5e-7 for iterations 0 and 1, and this peak learning rate is decayed to 1e-7 for iterations 2 and 3 as the loop approaches the end of self-play fine-tuning. Lastly, they choose β = 0.1, and the maximum sequence length is set to be **2048 **tokens.

# Importing the PyTorch library
import torch

# Importing the neural network module from PyTorch
import torch.nn as nn

# Importing the DeepSpeed library for distributed training
import deepspeed

# Importing the AutoTokenizer and AutoModelForCausalLM classes from the transformers library
from transformers import AutoTokenizer, AutoModelForCausalLM

# Loading the zephyr-7b-sft-full model from HuggingFace
tokenizer = AutoTokenizer.from_pretrained("alignment-handbook/zephyr-7b-sft-full")
model = AutoModelForCausalLM.from_pretrained("alignment-handbook/zephyr-7b-sft-full")

# Initializing DeepSpeed Zero with specific configuration settings
deepspeed_config = deepspeed.config.Config(train_batch_size=64, train_micro_batch_size_per_gpu=4)
model, optimizer, _, _ = deepspeed.initialize(model=model, config=deepspeed_config, model_parameters=model.parameters())

# Defining the optimizer and setting the learning rate using RMSprop
optimizer = deepspeed.optim.RMSprop(optimizer, lr=5e-7)

# Setting up a learning rate scheduler using LambdaLR from PyTorch
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.2 ** epoch)

# Setting hyperparameters for training
num_epochs = 3
max_seq_length = 2048
beta = 0.1

As we initialize DeepSpeed, parallel processing is configured to utilize 4 GPUs, and the training configuration is set with a batch size of 64 and a micro-batch size of 4 per GPU. We then loads the zephyr-7b-sft-full model from HuggingFace using the AutoTokenizer and AutoModelForCausalLM classes. Additionally, the optimizer is defined using RMSprop with a learning rate of 5e-7. A learning rate scheduler is implemented using LambdaLR from PyTorch, applying a decay factor of 0.2 at each epoch. The hyperparameters for training, such as the number of epochs (3), maximum sequence length (2048), and a beta value (0.1), are set. The change in the beta value will occur within the outer training loop.

Generating Synthetic Data (Inner Loop of SPIN Algorithm)

Now that we have the ground truth dataset and parameters initialized for our Zephyr-SFT LLM training, we need to code the inner loop of the SPIN algorithm. This inner loop is responsible for generating responses that need to be aligned with the ground truth data.

# zephyr-sft-dataframe (that contains output that will be improved while training)
zephyr_sft_output = pd.DataFrame(columns=['prompt', 'generated_output'])

# Looping through each row in the 'ultrachat_50k_sample' dataframe
for index, row in ultrachat_50k_sample.iterrows():
    # Extracting the 'prompt' column value from the current row
    prompt = row['prompt']
    
    # Generating output for the current prompt using the Zephyr model
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    output = model.generate(input_ids, max_length=200, num_beams=5, no_repeat_ngram_size=2, top_k=50, top_p=0.95)
    
    # Decoding the generated output to human-readable text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Appending the current prompt and its generated output to the new dataframe 'zephyr_sft_output'
    zephyr_sft_output = zephyr_sft_output.append({'prompt': prompt, 'generated_output': generated_text}, ignore_index=True)

This is what the ground truth and synthetic response look like for a single prompt.

Comparison of ground-truth and zephyr-sft answer (Created by Fareed Khan)

Since our dataset contains 50k prompts for the first outer loop iteration, the inner loop will run 50k times, generating responses for each prompt. As a result, a new dataframe will be created zephyr_sft_outputthat contains the prompt and its corresponding generated output through our base model, Zephyr-7B-SFT-Full.

Implementing Update Rule

Before coding the minimization problem, it is crucial to understand how the conditional probability distribution of an LLM-generated output can be calculated. The original paper uses a Markov process, wherein the conditional probability distribution ​(yx)** can be expressed through a decomposition as follows:

Conditional Probability Formula (Markov Process)

This decomposition means that the probability of the output sequence given the input sequence can be calculated by multiplying the probabilities of each output token given the input sequence and the previous output tokens. For example, if the output sequence is “I enjoy reading books” and the input sequence is “I enjoy”, then the conditional probability of the output sequence given the input sequence can be calculated as follows:

Calculating Conditional Probability of small sentence

Markov process conditional probability will be used to calculate the probability distribution of the ground truth LLM response and the Zephyr LLM response, which will then be used to compute the loss function. But first, we need to code the conditional probability function.

# Conditional Probability Function of input text
def compute_conditional_probability(tokenizer, model, input_text):
    # Tokenize the input text and convert it to PyTorch tensors
    inputs = tokenizer([input_text], return_tensors="pt")

    # Generate text using the model, specifying additional parameters
    outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)

    # Assuming 'transition_scores' is the logits for the generated tokens
    transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)

    # Get the length of the input sequence
    input_length = inputs.input_ids.shape[1]

    # Assuming 'transition_scores' is the logits for the generated tokens
    logits = torch.tensor(transition_scores)

    # Apply softmax to obtain probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1)

    # Extract the generated tokens from the output
    generated_tokens = outputs.sequences[:, input_length:]

    # Compute conditional probability
    conditional_probability = 1.0
    for prob in probs[0]:
        token_probability = prob.item()
        conditional_probability *= token_probability

    return conditional_probability

The loss function, which is the sum for each outer loop iteration based on which our training will occur and minimize, contains four important conditional probability variables. Each of these variables depends on either ground truth data or synthetic data created earlier.

L(SPIN) Loss Function Equation (Created by Fareed Khan)

While lambda is a regularization parameter that is used to control the deviation of the opponent player. It is utilized in the KL regularization term to penalize the divergence between the distribution of the opponent player and the target data distribution. The specific value of lambda used in the paper is not explicitly mentioned, as it is likely to be tuned based on the specific task and dataset being used.

def LSPIN_loss(model, updated_model, tokenizer, input_text, lambda_val=0.01):
    # Initialize conditional probability using the original model and input text
    cp = compute_conditional_probability(tokenizer, model, input_text)

    # Update conditional probability using the updated model and input text
    cp_updated = compute_conditional_probability(tokenizer, updated_model, input_text)

    # Calculate conditional probabilities for ground truth data
    p_theta_ground_truth = cp(tokenizer, model, input_text)
    p_theta_t_ground_truth = cp(tokenizer, model, input_text)

    # Calculate conditional probabilities for synthetic data
    p_theta_synthetic = cp_updated(tokenizer, updated_model, input_text)
    p_theta_t_synthetic = cp_updated(tokenizer, updated_model, input_text)

    # Calculate likelihood ratios
    lr_ground_truth = p_theta_ground_truth / p_theta_t_ground_truth
    lr_synthetic = p_theta_synthetic / p_theta_t_synthetic

    # Compute the LSPIN loss
    loss = lambda_val * torch.log(lr_ground_truth) - lambda_val * torch.log(lr_synthetic)

    return loss

A quick rule of thumb is that, If you have a large dataset, you can use a smaller value of lambda or if you have a small dataset, you may need to use a larger value of lambda to prevent overfitting. As we have a smaller dataset with a size of 50k, we can use **0.01 **as the value of lambda.

Training (Outer Loop of SPIN Algorithm)

Coding the outer loop will include all the code we have developed so far. This encompasses generating synthetic data and utilizing the LSPIN loss function to compute the loss. This loss is then used to update our model parameters, resulting in the formation of a new model in the next iteration. Subsequently, this new model generates its output, which is compared to the ground truth, representing human responses.

# Training loop
for epoch in range(num_epochs):
    
    # Model with initial parameters
    initial_model = AutoModelForCausalLM.from_pretrained("alignment-handbook/zephyr-7b-sft-full")
  
    # Update the learning rate
    scheduler.step()

    # Initialize total loss for the epoch
    total_loss = 0.0

    # Generating Synthetic Data (Inner loop)
    for index, row in ultrachat_50k_sample.iterrows():

        # Rest of the code       
        ...

        # Output == prompt response dataframe
        zephyr_sft_output

    # Computing loss using LSPIN function
    for (index1, row1), (index2, row2) in zip(ultrachat_50k_sample.iterrows(), zephyr_sft_output.iterrows()):
        # Assuming 'prompt' and 'generated_output' are the relevant columns in zephyr_sft_output
        prompt = row1['prompt']
        generated_output = row2['generated_output']

        # Compute LSPIN loss
        updated_model = model  # It will be replacing with updated model
        loss = LSPIN_loss(initial_model, updated_model, tokenizer, prompt)

        # Accumulate the loss
        total_loss += loss.item()

    # Backward pass
    loss.backward()

    # Update the parameters
    optimizer.step()

    # Update the value of beta
    if epoch == 2:
        beta = 5.0

When running this training algorithm for epochs set to 3, it will undergo training and generate a finalized Zephyr SFT LLM version. This version will be capable of generating output up to some extent similar to the ground truth or human response, considering that the official implementation is not yet available as open source on GitHub. Let’s visually explore how the training occurs.

Training using SPIN Algorithm (Created by Fareed Khan)

Performance and Results

The results demonstrate that SPIN can significantly enhance the LLM’s performance across various benchmarks and even surpass models trained through direct preference optimization (DPO) supplemented with additional GPT-4 preference data.

Performance comparison after each epoch (From Original Paper)

As we keep training, the improvements become smaller over time. This suggests that the model reaches a point where further iterations don’t lead to significant gains. This is what the response looks like after each iteration for a sample prompt from our training data.

Generation example of Fine Tuned SPIN LLM (From Original Paper)

The generated response, based on the updated parameters after each iteration, aims to closely match the ground truth response. Moreover, they empirically evaluate their method on several benchmark datasets, including the HuggingFace Open LLM Leaderboard, MT-Bench, and datasets from Big-Bench.

Resources

  1. Wolfe, C. R. (2023, September 11). Understanding and Using Supervised Fine-Tuning (SFT) for Language Models. Retrieved from https://cameronrwolfe.substack.com/p/understanding-and-using-supervised

  2. Chen, Z., Deng, Y., Yuan, H., Ji, K., & Gu, Q. (2024, January 2). Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models. Retrieved from https://arxiv.org/pdf/2401.01335.pdf

  3. Touvron, H. (2023, July 13). Llama 2: Open Foundation and Fine-Tuned Chat Models. Retrieved from https://arxiv.org/abs/2307.09288