Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support BAdam in WebUI #3487

Merged
merged 3 commits into from May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/llmtuner/webui/components/train.py
Expand Up @@ -210,6 +210,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)

with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_block_every = gr.Slider(value=50, minimum=-1, maximum=200, step=1)
badam_update_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)

input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_block_every, badam_update_ratio})
elem_dict.update(
dict(
badam_tab=badam_tab,
use_badam=use_badam,
badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode,
badam_switch_block_every=badam_switch_block_every,
badam_update_ratio=badam_update_ratio,
)
)

with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
Expand Down
109 changes: 109 additions & 0 deletions src/llmtuner/webui/locales.py
Expand Up @@ -891,6 +891,115 @@
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
},
},
"badam_tab": {
"en": {
"label": "BAdam configurations",
},
"ru": {
"label": "Конфигурации BAdam",
},
"zh": {
"label": "BAdam 参数设置",
},
},
"use_badam": {
"en": {
"label": "Use BAdam",
"info": "Enable the block coordinate optimization with Adam.",
},
"ru": {
"label": "Использовать BAdam",
"info": "Включите блочную оптимизацию координат с Adam.",
},
"zh": {
"label": "使用 BAdam",
"info": "使用多Block协同的Adam优化器。",
},
},
"badam_mode": {
"en": {
"label": "BAdam mode",
"info": "Whether to use layer-wise or ratio-wise BAdam optimizer.",
},
"ru": {
"label": "Режим BAdam",
"info": "Использовать оптимизатор BAdam с обработкой слоев или с обработкой коэффициентов.",
},
"zh": {
"label": "BAdam 模式",
"info": "使用layer或者ratio比例模式。",
},
},
"badam_switch_block_every": {
"en": {
"label": "Switch block frequency",
"info": "How often to switch model's block update. Set to -1 to disable the block update.",
},
"ru": {
"label": "Частота переключения",
"info": "Как часто переключать обновление блока модели. Установите -1, чтобы отключить обновление блока.",
},
"zh": {
"label": "切换block的频率",
"info": "控制切换block切换的频率,如果是-1,则不切换。",
},
},
"badam_switch_mode": {
"en": {
"label": "Switch mode",
"info": "The strategy of picking block to update for layer-wise BAdam.",
},
"ru": {
"label": "Переключить режим",
"info": "Стратегия выбора блока для обновления в методе BAdam по слоям.",
},
"zh": {
"label": "Block切换策略",
"info": "如果是layer类型的训练模式,如何切换block。",
},
},
"badam_update_ratio": {
"en": {
"label": "Update ratio",
"info": "The ratio of the update for ratio-wise BAdam.",
},
"ru": {
"label": "Коэффициент обновления",
"info": "Коэффициент обновления для метода BAdam, основанного на коэффициентах.",
},
"zh": {
"label": "Block更新比例",
"info": "如果是比例类型的训练模式,block每次更新的范围比例。",
},
},
"badam_mask_mode": {
"en": {
"label": "Mask mode",
"info": "The mode of the mask for BAdam optimizer.",
},
"ru": {
"label": "Режим маски",
"info": "Режим маски для оптимизатора BAdam.",
},
"zh": {
"label": "Mask模式",
"info": "BAdam优化器内训练参数的mask关系。",
},
},
"badam_verbose": {
"en": {
"label": "Verbosity level",
"info": "0 for no print, 1 for print the block prefix, 2 for print trainable parameters.",
},
"ru": {
"label": "Уровень многословности",
"info": "0 для отсутствия печати, 1 для печати префикса блока, 2 для печати обучаемых параметров.",
},
"zh": {
"label": "输出日志级别",
"info": "0:不输出,1:输出block前缀, 1:输出可训练的参数。",
},
},
"cmd_preview_btn": {
"en": {
"value": "Preview command",
Expand Down
9 changes: 9 additions & 0 deletions src/llmtuner/webui/runner.py
Expand Up @@ -151,6 +151,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
use_badam=get("train.use_badam"),
)
args["disable_tqdm"] = True

Expand Down Expand Up @@ -198,6 +199,14 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")

if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_block_every"] = get("train.badam_switch_block_every")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_update_ratio"] = get("train.badam_update_ratio")
args["badam_mask_mode"] = get("train.badam_mask_mode")
args["badam_verbose"] = get("train.badam_verbose")

return args

def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
Expand Down