Skip to content

Commit

Permalink
Merge pull request #3487 from codemayq/main
Browse files Browse the repository at this point in the history
support BAdam in WebUI
  • Loading branch information
hiyouga committed May 1, 2024
2 parents 282b5d5 + dcd53cb commit f1c0eed
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/llmtuner/webui/components/train.py
Expand Up @@ -210,6 +210,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)

with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_block_every = gr.Slider(value=50, minimum=-1, maximum=200, step=1)
badam_update_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)

input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_block_every, badam_update_ratio})
elem_dict.update(
dict(
badam_tab=badam_tab,
use_badam=use_badam,
badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode,
badam_switch_block_every=badam_switch_block_every,
badam_update_ratio=badam_update_ratio,
)
)

with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
Expand Down
109 changes: 109 additions & 0 deletions src/llmtuner/webui/locales.py
Expand Up @@ -891,6 +891,115 @@
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
},
},
"badam_tab": {
"en": {
"label": "BAdam configurations",
},
"ru": {
"label": "Конфигурации BAdam",
},
"zh": {
"label": "BAdam 参数设置",
},
},
"use_badam": {
"en": {
"label": "Use BAdam",
"info": "Enable the block coordinate optimization with Adam.",
},
"ru": {
"label": "Использовать BAdam",
"info": "Включите блочную оптимизацию координат с Adam.",
},
"zh": {
"label": "使用 BAdam",
"info": "使用多Block协同的Adam优化器。",
},
},
"badam_mode": {
"en": {
"label": "BAdam mode",
"info": "Whether to use layer-wise or ratio-wise BAdam optimizer.",
},
"ru": {
"label": "Режим BAdam",
"info": "Использовать оптимизатор BAdam с обработкой слоев или с обработкой коэффициентов.",
},
"zh": {
"label": "BAdam 模式",
"info": "使用layer或者ratio比例模式。",
},
},
"badam_switch_block_every": {
"en": {
"label": "Switch block frequency",
"info": "How often to switch model's block update. Set to -1 to disable the block update.",
},
"ru": {
"label": "Частота переключения",
"info": "Как часто переключать обновление блока модели. Установите -1, чтобы отключить обновление блока.",
},
"zh": {
"label": "切换block的频率",
"info": "控制切换block切换的频率,如果是-1,则不切换。",
},
},
"badam_switch_mode": {
"en": {
"label": "Switch mode",
"info": "The strategy of picking block to update for layer-wise BAdam.",
},
"ru": {
"label": "Переключить режим",
"info": "Стратегия выбора блока для обновления в методе BAdam по слоям.",
},
"zh": {
"label": "Block切换策略",
"info": "如果是layer类型的训练模式,如何切换block。",
},
},
"badam_update_ratio": {
"en": {
"label": "Update ratio",
"info": "The ratio of the update for ratio-wise BAdam.",
},
"ru": {
"label": "Коэффициент обновления",
"info": "Коэффициент обновления для метода BAdam, основанного на коэффициентах.",
},
"zh": {
"label": "Block更新比例",
"info": "如果是比例类型的训练模式,block每次更新的范围比例。",
},
},
"badam_mask_mode": {
"en": {
"label": "Mask mode",
"info": "The mode of the mask for BAdam optimizer.",
},
"ru": {
"label": "Режим маски",
"info": "Режим маски для оптимизатора BAdam.",
},
"zh": {
"label": "Mask模式",
"info": "BAdam优化器内训练参数的mask关系。",
},
},
"badam_verbose": {
"en": {
"label": "Verbosity level",
"info": "0 for no print, 1 for print the block prefix, 2 for print trainable parameters.",
},
"ru": {
"label": "Уровень многословности",
"info": "0 для отсутствия печати, 1 для печати префикса блока, 2 для печати обучаемых параметров.",
},
"zh": {
"label": "输出日志级别",
"info": "0:不输出,1:输出block前缀, 1:输出可训练的参数。",
},
},
"cmd_preview_btn": {
"en": {
"value": "Preview command",
Expand Down
9 changes: 9 additions & 0 deletions src/llmtuner/webui/runner.py
Expand Up @@ -151,6 +151,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
use_badam=get("train.use_badam"),
)
args["disable_tqdm"] = True

Expand Down Expand Up @@ -198,6 +199,14 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")

if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_block_every"] = get("train.badam_switch_block_every")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_update_ratio"] = get("train.badam_update_ratio")
args["badam_mask_mode"] = get("train.badam_mask_mode")
args["badam_verbose"] = get("train.badam_verbose")

return args

def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit f1c0eed

Please sign in to comment.