Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src_seq_length fixed length error when training MoE model #5495

Open
meenakshi-mittal opened this issue May 3, 2024 · 0 comments
Open

src_seq_length fixed length error when training MoE model #5495

meenakshi-mittal opened this issue May 3, 2024 · 0 comments

Comments

@meenakshi-mittal
Copy link

I am trying to train an MoE model through the moe branch on the wikitext-103 dataset given. The example training script for MoE models does not generate "checkpoint_best.pt" or "checkpoint_last.pt" files, which are needed for the evaluation script.

I believe the reason why I was missing these files is because of the "--disable-validation" flag given in the default script, as it seems that the best checkpoint is determined by validation scores.

I removed this flag and attempted to train an MoE model with the following command:

fairseq-train --task language_modeling
data-bin/wikitext-103
--save-dir checkpoints/moe_wikitext-103-best
--tokens-per-sample 512
--ddp-backend fully_sharded --memory-efficient-fp16 --checkpoint-activations
--arch transformer_lm --share-decoder-input-output-embed
--decoder-layers 24 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096
--decoder-attention-heads 16
--moe-expert-count 8 --moe-freq 2
--moe-gating-use-fp32 --moe-second-expert-policy all
--moe-normalize-expert-grad sqrt_world_size
--moe-eval-capacity-token-fraction -1.0
--max-sentences-valid 1 --num-workers-valid 0
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01
--moe-gate-loss-combine-method sum
--optimizer adam --fp16 --adam-betas '(0.9, 0.98)' --clip-norm 0.0
--lr 0.0005 --warmup-updates 750
--sample-break-mode none --dropout 0.2 --attention-dropout 0.2
--batch-size 2 --update-freq 2
--max-update 250 --log-format json --log-interval 10

The model finishes training with no issues, but when it reaches the validation stage at the end I get this error:

Traceback (most recent call last):
File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/data/meenakshi/MoE/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main
main(cfg, **kwargs)
File "/data/meenakshi/MoE/fairseq/fairseq_cli/train.py", line 191, in main
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/meenakshi/MoE/fairseq/fairseq_cli/train.py", line 318, in train
valid_losses, should_stop = validate_and_save(
File "/data/meenakshi/MoE/fairseq/fairseq_cli/train.py", line 403, in validate_and_save
valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
File "/data/meenakshi/MoE/fairseq/fairseq_cli/train.py", line 497, in validate
trainer.valid_step(sample)
File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/meenakshi/MoE/fairseq/fairseq/trainer.py", line 1031, in valid_step
assert sample['net_input']['src_tokens'].shape[1] == fixed_src_seq_length,
AssertionError: got src_seq_length 459, expected 512

Things I tried:

  • This error does not occur when training a base transformer model. I tried to use all of the same parameters/flags used in the example transformer training script, with the additional moe flags, and then the error occurs.
  • I tried a few different values for "tokens-per-sample" like 2048, 1024, and 512. Same error each time.
  • I tried writing a padding script that adds pad tokens to each sample until it reaches the "tokens-per-sample" value. This script was either flawed or this approach simply doesn't work, as I got "exploding loss" errors when I started training.

Environment details:

  • fairseq Version: moe branch
  • PyTorch Version: 2.0.1
  • OS: Linux
  • How you installed fairseq: source
  • Build command you used: pip install --editable ./
  • Python version: 3.9.19
  • CUDA/cuDNN version: 3.7
  • GPU models and configuration: 8 GPUs: Tesla P100-PCIE-16GB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant