Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Fine-tuning method: AdaLoRA #844

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/llmtuner/hparams/finetuning_args.py
Expand Up @@ -8,7 +8,7 @@ class FinetuningArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
finetuning_type: Optional[Literal["lora", "adalora", "freeze", "full", "none"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
Expand Down Expand Up @@ -37,6 +37,14 @@ class FinetuningArguments:
Qwen choices: [\"mlp\", \"attn\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
target_r: Optional[int] = field(
default=8,
metadata={"help": "The target average rank of incremental matrix for AdaLoRA fine-tuning."}
)
init_r: Optional[int] = field(
default=12,
metadata={"help": "The initial rank for each incremental matrix for AdaLoRA fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
Expand Down Expand Up @@ -82,7 +90,7 @@ def __post_init__(self):

self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]

assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "adalora", "freeze", "full", "none"], "Invalid fine-tuning method."

def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
Expand Down
36 changes: 26 additions & 10 deletions src/llmtuner/tuner/core/adapter.py
Expand Up @@ -6,6 +6,7 @@
PeftModel,
TaskType,
LoraConfig,
AdaLoraConfig,
get_peft_model
)
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
Expand Down Expand Up @@ -55,8 +56,11 @@ def init_adapter(
if model_args.checkpoint_dir is not None:
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."

if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
if finetuning_args.finetuning_type == "lora" or finetuning_args.finetuning_type == "adalora":
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
if finetuning_args.finetuning_type == "adalora":
logger.info("Fine-tuning method: AdaLoRA")
latest_checkpoint = None

if model_args.checkpoint_dir is not None:
Expand All @@ -81,14 +85,26 @@ def init_adapter(
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)

if is_trainable and latest_checkpoint is None: # create new lora weights while training
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
if finetuning_args.finetuning_type == "lora":
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
if finetuning_args.finetuning_type == "adalora":
lora_config = AdaLoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
target_r=finetuning_args.target_r,
init_r=finetuning_args.init_r,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
model = get_peft_model(model, lora_config)

if model_args.checkpoint_dir is not None:
Expand Down