Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE: ADD LISA ALGORITHM] #3103

Closed
wants to merge 16 commits into from
35 changes: 35 additions & 0 deletions examples/lisa_single_gpu/sft.sh
@@ -0,0 +1,35 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type full \
--use_lisa \
--lisa_activated_layers 2 \
--lisa_interval_steps 5 \
--output_dir ../../saves/LLaMA2-7B/lisa/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

83 changes: 78 additions & 5 deletions src/llmtuner/extras/callbacks.py
Expand Up @@ -2,23 +2,96 @@
import os
import time
from datetime import timedelta
from functools import reduce
from typing import TYPE_CHECKING

import numpy as np
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length

from .constants import LOG_FILE_NAME
from .logging import get_logger
from .misc import fix_valuehead_checkpoint

from .misc import fix_valuehead_checkpoint, count_parameters
from ..hparams import FinetuningArguments

if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments


logger = get_logger(__name__)


class LisaTrainCallback(TrainerCallback):
def __init__(self, finetuning_args: "FinetuningArguments", trainer: None):
super().__init__()
self.trainer = trainer
self.layers_attribute = self.attention_layer_auto_detect(finetuning_args.lisa_attention_name)
self.step_interval = finetuning_args.lisa_interval_steps
self.lisa_activated_layers = finetuning_args.lisa_activated_layers
self.total_layers = len(self.get_layers())
self.lisa_verbose = finetuning_args.lisa_verbose
self.trained_layers = set()
if self.lisa_activated_layers > self.total_layers:
raise ValueError(
f'lisa_activated_layers>({self.lisa_activated_layers})>total_layers({self.total_layers}), '
f'please check your arguments.')
logger.info(
f"LISA will activate {self.lisa_activated_layers}/{self.total_layers} layers "
f"({self.lisa_activated_layers * 100 / self.total_layers}%) every {self.step_interval} steps"
)

def attention_layer_auto_detect(self, lisa_attention_name):
class_to_layers_map = {
'LlamaForCausalLM': 'model.layers',
'Qwen2ForCausalLM': 'model.layers',
'MistralForCausalLM': 'model.layers',
'MixtralForCausalLM': 'model.layers',
'GemmaForCausalLM': 'model.layers',
'GPT2LMHeadModel': 'transformer.h',
}
_atten_val = lisa_attention_name
model_class_name = self.trainer.model.__class__.__name__
if _atten_val is None:
# Determine the way to access layers based on the model type
if model_class_name in class_to_layers_map:
_atten_val = class_to_layers_map[model_class_name]

return _atten_val

def on_step_begin(self, args, state, control, **kwargs):
if state.global_step % self.step_interval == 0:
self.switch_active_layers()

def freeze_all_layers(self):
layers = self.get_layers()
for layer in layers:
for param in layer.parameters():
param.requires_grad = False

def get_layers(self):
return reduce(getattr, self.layers_attribute.split("."), self.trainer.model)

def switch_active_layers(self):
# disable gradients for all layers
self.freeze_all_layers()
layers = self.get_layers()
active_layers_indices = np.random.choice(range(self.total_layers), self.lisa_activated_layers,
replace=False)
self.trained_layers.update(active_layers_indices)
for idx in active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True
if self.lisa_verbose:
trainable_params, all_param = count_parameters(self.trainer.model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
logger.info(
f"LISA will activate layers {','.join(map(str, sorted(active_layers_indices)))} for the next steps. "
f"{len(self.trained_layers)}/{self.total_layers} layers "
f"({len(self.trained_layers) * 100 / self.total_layers}%) "
f"are trained: {','.join(map(str, sorted(self.trained_layers)))}")


class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Expand Down Expand Up @@ -107,7 +180,7 @@ def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control:
self.max_steps = 0

def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
r"""
Event called after a successful prediction.
Expand Down Expand Up @@ -153,7 +226,7 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra
f.write(json.dumps(logs) + "\n")

def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
Expand Down
46 changes: 45 additions & 1 deletion src/llmtuner/hparams/finetuning_args.py
Expand Up @@ -204,7 +204,45 @@ class GaloreArguments:


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class LisaArguments:
r"""
paper: https://arxiv.org/abs/2403.17919
ref: https://github.com/OptimalScale/LMFlow
Arguments pertaining to the Lisa algorithm.
- 始终更新底层 embedding 和顶层 linear head;
- 随机更新少数中间的 self-attention 层,比如 2-4 层。
"""
use_lisa: bool = field(
default=False,
metadata={
"help": "the number of activated layers in LISA."
}
)
lisa_activated_layers: int = field(
default=None,
metadata={
"help": "the number of activated layers in LISA."
}
)
lisa_interval_steps: int = field(
default=None,
metadata={
"help": "the number of steps in each freezing interval of LISA, i.e. "
"the selected unfrozen layers are randomly switched every {lisa_interval_steps} steps."
}
)
lisa_attention_name: str = field(
default="model.layers",
metadata={"help": "suffix name of attention names"}
)
lisa_verbose: bool = field(
default=False,
metadata={"help": "output more for lisa"},
)


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, LisaArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
Expand Down Expand Up @@ -261,6 +299,12 @@ def split_arg(arg):
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")

if self.use_lisa:
if self.finetuning_type != 'full':
raise ValueError("`use_lisa` requires `finetuning_type` is `full`")
if self.lisa_interval_steps is None or self.lisa_activated_layers is None:
raise ValueError("`use_lisa` requires `lisa_interval_steps` and `lisa_activated_layers`")

def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
Expand Down
6 changes: 5 additions & 1 deletion src/llmtuner/train/sft/workflow.py
Expand Up @@ -5,6 +5,7 @@
from transformers import DataCollatorForSeq2Seq

from ...data import get_dataset, split_dataset
from ...extras.callbacks import LisaTrainCallback
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
Expand All @@ -13,7 +14,6 @@
from .metric import ComputeMetrics
from .trainer import CustomSeq2SeqTrainer


if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback

Expand Down Expand Up @@ -60,6 +60,10 @@ def run_sft(
**split_dataset(dataset, data_args, training_args),
)

# post callbacks
if finetuning_args.use_lisa:
trainer.add_callback(LisaTrainCallback(finetuning_args, trainer))

# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
Expand Down