Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

when "write_predictions_to_file" is true,generate will fail。 #9170

Open
gaojingwei opened this issue May 11, 2024 · 1 comment
Open

when "write_predictions_to_file" is true,generate will fail。 #9170

gaojingwei opened this issue May 11, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@gaojingwei
Copy link

I use images:nvcr.io/nvidia/nemo:24.03.framework,and successfully run /NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py
but when I run /NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py, if set model.data.test_ds.write_predictions_to_file=False,it will run success; if set model.data.test_ds.write_predictions_to_file=True,it will run failed。
this is the error log:

Traceback (most recent call last):
  File "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py", line 170, in <module>
    main()
  File "/opt/NeMo/nemo/core/config/hydra_runner.py", line 129, in wrapper
    _run_hydra(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 389, in _run_hydra
    _run_app(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 452, in _run_app
    run_and_report(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 216, in run_and_report
    raise ex
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 213, in run_and_report
    return func()
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 453, in <lambda>
    lambda: hydra.run(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py", line 164, in main
    trainer.test(model)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 754, in test
    return call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 794, in _test_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1026, in _run_stage
    return self._evaluation_loop.run()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 425, in test_step
    return self.lightning_module.test_step(*args, **kwargs)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 431, in test_step
    return self.inference_step(dataloader_iter, 'test')
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 438, in inference_step
    outputs = self.inference_step_validation_call(batch, batch_idx, data_cfg, dataloader_idx)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 468, in inference_step_validation_call
    output = self.predict_step(batch, batch_idx, dataloader_idx)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 680, in predict_step
    response = generate(self, **inference_config)
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 649, in generate
    output = synced_generate(
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 512, in synced_generate
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 801, in sample_sequence_batch
    output = inference_strategy.forward_step(batch, tensor_shape)
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_strategy.py", line 67, in forward_step
    output_tensor = fwd_bwd_function(
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 1211, in forward_backward_pipelining_without_interleaving
    output_tensor = forward_step(
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 192, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1160, in fwd_output_only_func
    output_tensor = model(tokens, position_ids, attention_mask, **extra_arg)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/models/gpt/gpt_model.py", line 174, in forward
    hidden_states = self.decoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 370, in forward
    hidden_states, context = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_layer.py", line 176, in forward
    attention_output_with_bias = self.self_attention(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/attention.py", line 254, in forward
    query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
  File "/opt/megatron-lm/megatron/core/transformer/attention.py", line 370, in get_query_key_value_tensors
    mixed_qkv, _ = self.linear_qkv(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/custom_layers/transformer_engine.py", line 245, in forward
    out = super().forward(x, is_first_microbatch=_is_first_microbatch)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 1105, in forward
    out = fwd_fn(*args)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 127, in forward
    ln_out, mu, rsigma = _apply_normalization(inputmat,
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/_common.py", line 90, in _apply_normalization
    return normalization_func(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/cpp_extensions/normalization.py", line 179, in rmsnorm_fwd_inf
    return torch.ops.tex_ts.rmsnorm_fwd_inf_ts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 825, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: /opt/TransformerEngine/transformer_engine/common/transformer_engine.cpp:39 in function CheckInputTensor: Assertion failed: t.data.dptr != nullptr. Input x is not allocated!
Error executing job with overrides: ['model.restore_from_path=/workspace/results/checkpoints/megatron_gpt_peft_none_tuning.nemo']

this is my conf:

name: megatron_gpt_peft_${model.peft.peft_scheme}_tuning

trainer:
  devices: 8
  accelerator: gpu
  num_nodes: 2
  precision: 16
  logger: False # logger provided by exp_manager
  enable_checkpointing: False
  use_distributed_sampler: False
  max_epochs: 9999
  max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
  log_every_n_steps: 10 # frequency with which training steps are logged 
  val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
  gradient_clip_val: 1.0

exp_manager:
  explicit_log_dir: null
  exp_dir: null
  name: ${name}
  create_wandb_logger: False
  wandb_logger_kwargs:
    project: null
    name: null
  resume_if_exists: True
  resume_ignore_no_checkpoint: True
  create_checkpoint_callback: True
  checkpoint_callback_params:
    monitor: validation_${model.data.test_ds.metric.name}
    save_top_k: 1
    mode: max
    save_nemo_on_train_end: True
    filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}'
    model_parallel_size: ${model.tensor_model_parallel_size}
    always_save_nemo: True
model:
  seed: 1234
  tensor_model_parallel_size: 8 # intra-layer model parallelism
  pipeline_model_parallel_size: 2 # inter-layer model parallelism

  global_batch_size: 16
  micro_batch_size: 1
  restore_from_path: /workspace/results/checkpoints/megatron_gpt_peft_none_tuning.nemo # Path to an existing .nemo model you wish to add new tasks to or run inference with
  resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
  save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training.
  sync_batch_comm: False
  megatron_amp_O2: False

  ## Sequence Parallelism
  # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
  # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
  sequence_parallel: False

  ## Activation Checkpoint
  activations_checkpoint_granularity: null # 'selective' or 'full'
  activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
  # 'uniform' divides the total number of transformer layers and checkpoints the input activation
  # of each chunk at the specified granularity
  # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
  activations_checkpoint_num_layers: null # not used with 'selective'
  activations_checkpoint_layers_per_pipeline: null
  answer_only_loss: True
  gradient_as_bucket_view: False

  hidden_dropout: 0.0
  attention_dropout: 0.0
  ffn_dropout: 0.0

  peft:
    peft_scheme: "none"  # can be either adapter,ia3, or ptuning
    restore_from_path: null
    restore_from_ckpt:
      checkpoint_dir: null
      checkpoint_name: null

    # Used for adapter peft training
    adapter_tuning:
      type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter'
      adapter_dim: 32
      adapter_dropout: 0.0
      norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used.
      column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used,  options are ['layernorm', 'mixedfusedlayernorm']
      layer_selection: null  # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers
      weight_tying: False
      position_embedding_strategy: null # used only when weight_tying is True

    lora_tuning:
      target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
      adapter_dim: 32
      adapter_dropout: 0.0
      column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      layer_selection:  null  # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
      weight_tying: False
      position_embedding_strategy: null # used only when weight_tying is True

    # Used for p-tuning peft training
    p_tuning:
      virtual_tokens: 10  # The number of virtual tokens the prompt encoder should add at the start of the sequence
      bottleneck_dim: 1024  # the size of the prompt encoder mlp bottleneck
      embedding_dim: 1024  # the size of the prompt encoder embeddings
      init_std: 0.023

    ia3_tuning:
      layer_selection:  null  # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers

  data:
    test_ds:
      file_names: ['/workspace/databricks-dolly-15k/test.jsonl'] # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
      names: ['dolly-15k_test'] # Names of the corresponding datasets used to log metrics.
      global_batch_size: 16
      micro_batch_size: 1
      shuffle: False
      num_workers: 0
      pin_memory: True
      max_seq_length: 2048
      min_seq_length: 1
      drop_last: False
      context_key: 'input'
      label_key: ${data.train_ds.label_key}
      add_eos: ${data.train_ds.add_eos}
      add_sep: ${data.train_ds.add_sep}
      add_bos: ${data.train_ds.add_bos}
      write_predictions_to_file: True
      output_file_path_prefix: /workspace/results/sft_results # Prefix of the file to write predictions to.
      truncation_field: ${data.train_ds.truncation_field} # Options: keys in prompt_template
      index_mapping_dir: null # Path to a directory to write index mapping files.
      prompt_template: ${data.train_ds.prompt_template}
      tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
      truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']

      metric:
        name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
        average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
        num_classes: null
    
inference:
  greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
  top_k: 0  # The number of highest probability vocabulary tokens to keep for top-k-filtering.
  top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
  temperature: 1.0 # sampling temperature
  all_probs: False  # whether return the log prob for all the tokens in vocab
  repetition_penalty: 1.2  # The parameter for repetition penalty. 1.0 means no penalty.
  min_tokens_to_generate: 0  # The minimum length of the sequence to be generated.
  compute_logprob: False  # a flag used to compute logprob of all the input text, a very special case of running inference, default False
  outfile_path: /workspace/results/output.txt
  compute_attention_mask: True

# server-related configs
server: False  # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: True  # whether create a public URL
username: test # user name for web client
password: test2  # password for web client
web_port: 9889 # the port number of the web server 1058
chat: False # use the chat interface
chatbot_config:
  value: False   # whether to inject the value attributes
  attributes:
    - name: Quality
      min: 0
      max: 4
      key: quality
      type: int
      default: 4
    - name: Toxicity
      min: 0
      max: 4
      key: toxcity
      type: int
      default: 0
    - name: Humor
      min: 0
      max: 4
      key: humor
      type: int
      default: 0
    - name: Creativity
      min: 0
      max: 4
      key: creativity
      type: int
      default: 0
    - name: Violence
      min: 0
      max: 4
      key: violence
      type: int
      default: 0
    - name: Helpfulness
      min: 0
      max: 4
      key: helpfulness
      type: int
      default: 4
    - name: Not_Appropriate
      min: 0
      max: 4
      key: not_appropriate
      type: int
      default: 0
    - name: Language
      choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh']
      key: lang
      type: list
      default: en
   
  user: User
  assistant: Assistant
  system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
@gaojingwei gaojingwei added the bug Something isn't working label May 11, 2024
@aditya-malte
Copy link
Contributor

Hi, do you also face this issue in megatron_gpt_finetuning.py if write_predictions_to_file is set to True? I'm also facing the same error that way

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants