Skip to content

TinyLLaVA/TinyLLaVA_Factory

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

hf_space arXiv License Doc Demo

🎉 News

🔥 Takeaways

  • Our best model, TinyLLaVA-Phi-2-SigLIP-3.1B, achieves better overall performance against existing 7B models such as LLaVA-1.5 and Qwen-VL.

  • TinyLLaVA Factory is an open-source modular codebase for small-scale large multimodal models (LMMs), implemented in PyTorch and HuggingFace, with a focus on simplicity of code implementations, extensibility of new features, and reproducibility of training results.

  • TinyLLaVA Factory integrates a suite of cutting-edge models and methods.

    • LLM currently supports OpenELM, TinyLlama, StableLM, Qwen, Gemma, and Phi.

    • Vision tower currently supports CLIP, SigLIP, Dino, and combination of CLIP and Dino.

    • Connector currently supports MLP, Qformer, and Resampler.

Contents

Installation and Requirements

Please note that our environment requirements are different from LLaVA's environment requirements. We strongly recommend you create the environment from scratch as follows.

  1. Clone this repository and navigate to the folder
git clone https://github.com/TinyLLaVA/TinyLLaVA_Factory.git
cd TinyLLaVA_Factory
  1. Create a conda environment, activate it and install Packages
conda create -n tinyllava_factory python=3.10 -y
conda activate tinyllava_factory
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
  1. Install additional packages
pip install flash-attn --no-build-isolation

Upgrade to the latest code base

git pull
pip install -e .

Get Started

1. Data Preparation

Please refer to the Data Preparation section in our Documenation.

2. Train

Here's an example for training a LMM using Phi-2.

  • Replace data paths with yours in scripts/train/train_phi.sh
  • Replace output_dir with yours in scripts/train/pretrain.sh
  • Replace pretrained_model_path and output_dir with yours in scripts/train/finetune.sh
  • Adjust your GPU ids (localhost) and per_device_train_batch_size in scripts/train/pretrain.sh and scripts/train/finetune.sh
bash scripts/train/train_phi.sh

Important hyperparameters used in pretraining and finetuning are provided below.

Training Stage Global Batch Size Learning rate conv_version
Pretraining 256 1e-3 pretrain
Finetuning 128 2e-5 phi

Tips:

Global Batch Size = num of GPUs * per_device_train_batch_size * gradient_accumulation_steps, we recommand you always keep global batch size and learning rate as above except for lora tuning your model.

conv_version is a hyperparameter used for choosing different chat templates for different LLMs. In the pretraining stage, conv_version is the same for all LLMs, using pretrain. In the finetuning stage, we use

phi for Phi-2, StableLM, Qwen-1.5

llama for TinyLlama, OpenELM

gemma for Gemma

3. Evaluation

Please refer to the Evaluation section in our Documenation.

Model Zoo

Trained Models

which are trained using TinyLLaVA Factory.

Model Performance

VT (HF Path) LLM (HF Path) Recipe VQA-v2 GQA SQA-image TextVQA MM-Vet POPE MME MMMU
openai/clip-vit-large-patch14-336 apple/OpenELM-450M-Instruct base 69.5 52.1 50.6 40.4 20.0 83.6 1052.9 23.9
google/siglip-so400m-patch14-384 apple/OpenELM-450M-Instruct base 71.7 53.9 54.1 44.0 20.0 85.4 1118.8 24.0
openai/clip-vit-large-patch14-336 TinyLlama/TinyLlama-1.1B-Chat-v1.0 base 73.7 58.0 59.9 46.3 23.2 85.5 1284.6 27.9
google/siglip-so400m-patch14-384 TinyLlama/TinyLlama-1.1B-Chat-v1.0 base 75.5 58.6 64.0 49.6 23.5 86.3 1256.5 28.3
openai/clip-vit-large-patch14-336 stabilityai/stablelm-2-zephyr-1_6b base 75.9 59.5 64.6 50.5 27.3 86.1 1368.1 31.8
google/siglip-so400m-patch14-384 stabilityai/stablelm-2-zephyr-1_6b base 78.2 60.7 66.7 56.0 29.4 86.3 1319.3 32.6
google/siglip-so400m-patch14-384 google/gemma-2b-it base 78.4 61.6 64.4 53.6 26.9 86.4 1339.0 31.7
openai/clip-vit-large-patch14-336 microsoft/phi-2 base 76.8 59.4 71.2 53.4 31.7 86.8 1448.6 36.3
google/siglip-so400m-patch14-384 microsoft/phi-2 base 79.2 61.6 71.9 57.4 35.0 87.2 1462.4 38.2
google/siglip-so400m-patch14-384 microsoft/phi-2 base&lora 77.6 59.7 71.6 53.8 33.3 87.9 1413.2 35.6
google/siglip-so400m-patch14-384 microsoft/phi-2 share 80.1 62.1 73.0 60.3 37.5 87.2 1466.4 38.4

Legacy Models

which are trained using the old codebase TinyLLaVABench.

If you have models trained by our old codebase TinyLLaVABench and you still want to use them, we provide an example of TinyLLaVA-3.1B for how to use legacy models.

Example of using legacy models
from tinyllava.eval.run_tiny_llava import eval_model
from tinyllava.model.convert_legecy_weights_to_tinyllavafactory import *

model = convert_legecy_weights_to_tinyllavafactory('bczhou/TinyLLaVA-3.1B')

prompt = "What are the things I should be cautious about when I visit here?"
image_file = "https://llava-vl.github.io/static/images/view.jpg"

args = type('Args', (), {
    "model_path": None,
    "model": model,
    "query": prompt,
    "conv_mode": "phi", # the same as conv_version in the training stage. Different LLMs have different conv_mode/conv_version, please replace it
    "image_file": image_file,
    "sep": ",",
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 512
})()

eval_model(args)

"""
Output: 
When visiting this serene lakeside location with a wooden dock, there are a few things to be cautious about. First, ensure that the dock is stable and secure before stepping onto it, as it might be slippery or wet, especially if it's a wooden structure. Second, be mindful of the surrounding water, as it can be deep or have hidden obstacles, such as rocks or debris, that could pose a risk. Additionally, be aware of the weather conditions, as sudden changes in weather can make the area more dangerous. Lastly, respect the natural environment and wildlife, and avoid littering or disturbing the ecosystem.
"""

Launch Demo Locally

If you want to launch the model trained by yourself or us locally, here's an example.

Run inference with the model trained by yourself
from tinyllava.eval.run_tiny_llava import eval_model

model_path = "/absolute/path/to/your/model/"
prompt = "What are the things I should be cautious about when I visit here?"
image_file = "https://llava-vl.github.io/static/images/view.jpg"

args = type('Args', (), {
    "model_path": model_path,
    "model_base": None,
    "query": prompt,
    "conv_mode": "phi",
    "image_file": image_file,
    "sep": ",",
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 512
})()

eval_model(args)

"""
Output: 
XXXXXXXXXXXXXXXXX
"""
Run inference with the model trained by us using huggingface transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

hf_path = 'tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B'
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
model.cuda()
config = model.config
tokenizer = AutoTokenizer.from_pretrained(hf_path, use_fast=False, model_max_length = config.tokenizer_model_max_length,padding_side = config.tokenizer_padding_side)
prompt="What are these?"
image_url="http://images.cocodataset.org/val2017/000000039769.jpg"
output_text, genertaion_time = model.chat(prompt=prompt, image=image_url, tokenizer=tokenizer)

print('model output:', output_text)
print('runing time:', genertaion_time)

Customize Your Own Multimodel Models

LLM

If you want to add a new LLM by yourself, you need to create two files: one for chat template and the other for language model, under the folders tinyllava/data/template/ and tinyllava/model/llm/.

Here is an example of adding the Gemma model.

Firstly, create tinyllava/data/template/gemma_template.py, which will be used for the finetuning stage.

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from packaging import version

from .formatter import EmptyFormatter, StringFormatter
from .base import Template
from .formatter import Formatter
from . import register_template
from ...utils.constants import *

from transformers import PreTrainedTokenizer
import torch
import tokenizers

    
system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."

@register_template('gemma')
@dataclass
class GemmaTemplate(Template):
    format_image_token: "Formatter" = StringFormatter(slot="<image>\n{{content}}")
    format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ")
    format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "<eos>") # to be modified
    system: "Formatter" = EmptyFormatter(slot=system+" ")
    separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '<eos>']) # to be modified

    def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds):
        # your code here
        return labels, cur_len

Tips:

Please ensure that the labels (returned by the _make_masks function) follows this format: answers and the eos token id are not masked, the image token is masked with -200, and the other tokens are masked with -100.

Secondly, create tinyllava/model/llm/gemma.py.

from transformers import GemmaForCausalLM, AutoTokenizer
# The LLM you want to add along with its corresponding tokenizer.

from . import register_llm

# Add GemmaForCausalLM along with its corresponding tokenizer and handle special tokens.
@register_llm('gemma') # Enable the LLMFactory to obtain the added LLM by this string ('gemma').
def return_gemmaclass(): 
    def tokenizer_and_post_load(tokenizer):
        tokenizer.pad_token = tokenizer.unk_token
        return tokenizer
    return (GemmaForCausalLM, (AutoTokenizer, tokenizer_and_post_load))

Finally, create scripts/train/train_gemma.sh with the corresponding LLM_VERSION and CONV_VERSION.

Vision Tower

If you want to add a new vision tower, you need to implement a new vision tower class that should be inherited from the base class VisionTower. Here's an example of the MoF vision tower.

First, create tinyllava/model/vision_tower/mof.py

@register_vision_tower('mof')      
class MoFVisionTower(VisionTower):
    def __init__(self, cfg):
        super().__init__(cfg)

        self._vision_tower = MoF(cfg)
        self._image_processor = # your image processor
  
    def _load_model(self, vision_tower_name, **kwargs):
        # your code here, make sure your model can be correctly loaded from pretrained parameters either by huggingface or pytorch loading

    def forward(self, x, **kwargs):
        # your code here

Then, modify your training scripts with the corresponding CT_VERSION.

Connector

If you want to add a new connector, you need to implement a new connector class that should be inherited from the base class Connector. Here's an example of the Linear connector.

First, create tinyllava/model/connector/linear.py

import torch.nn as nn

from . import register_connector
from .base import Connector
    
@register_connector('linear') #Enable the ConnectorMFactory to obtain the added connector by this string ('linear').     
class LinearConnector(Connector):
    def __init__(self, config):
        super().__init__()
        self._connector =  nn.Linear(config.vision_hidden_size, config.hidden_size) # define your connector model

Then, modify your training scripts with the corresponding CN_VERSION.

For QA

If you have any questions about TinyLLaVA Factory, feel free to contact the WeChat account: TinyLLaVA or YingHuCS.

✏ Citation

If you find our paper and code useful in your research, please consider giving a star ⭐ and citation 📝.

@misc{zhou2024tinyllava,
      title={TinyLLaVA: A Framework of Small-scale Large Multimodal Models}, 
      author={Baichuan Zhou and Ying Hu and Xi Weng and Junlong Jia and Jie Luo and Xien Liu and Ji Wu and Lei Huang},
      year={2024},
      eprint={2402.14289},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

❤️ Community efforts

  • Our codebase is built upon the LLaVA project. Great work!
  • Our project uses data from the ShareGPT4V project. Great work!