Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I train controlnet_sdxl in bf16 datatype, got unsupported ERROR in datasets #6566

Closed
HelloWorldBeginner opened this issue Jan 8, 2024 · 1 comment · Fixed by #6607
Closed

Comments

@HelloWorldBeginner
Copy link

Describe the bug

Traceback (most recent call last):
  File "train_controlnet_sdxl.py", line 1252, in <module>
    main(args)
  File "train_controlnet_sdxl.py", line 1013, in main
    train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 592, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 557, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3093, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3489, in _map_single
    writer.write_batch(batch)
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_writer.py", line 557, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 248, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 113, in pyarrow.lib._handle_arrow_array_protocol
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_writer.py", line 191, in __arrow_array__
    out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/features/features.py", line 447, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/features/features.py", line 324, in _cast_to_python_objects
    for x in obj.detach().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16

Steps to reproduce the bug

Here is my train script I use BF16 type,I use diffusers train my model

export MODEL_DIR="/home/mhh/sd_models/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="./control_net"
export VAE_NAME="/home/mhh/sd_models/sdxl-vae-fp16-fix"

accelerate launch train_controlnet_sdxl.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --pretrained_vae_model_name_or_path=$VAE_NAME \
 --dataset_name=/home/mhh/sd_datasets/fusing/fill50k \
 --mixed_precision="bf16" \
 --resolution=1024 \
 --learning_rate=1e-5 \
 --max_train_steps=200 \
 --validation_image "/home/mhh/sd_datasets/controlnet_image/conditioning_image_1.png" "/home/mhh/sd_datasets/controlnet_image/conditioning_image_2.png" \
 --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
 --validation_steps=50 \
 --train_batch_size=1 \
 --gradient_accumulation_steps=4 \
 --report_to="wandb" \
 --seed=42 \

Expected behavior

When I changed the data type to fp16, it worked.

Environment info

datasets 2.16.1
numpy 1.24.4

@skaulintel
Copy link
Contributor

I also see the same error and get passed it by casting that line to float.

so for x in obj.detach().cpu().numpy() becomes for x in obj.detach().to(torch.float).cpu().numpy()

I got the idea from this PR where someone was facing a similar issue (in a different repository). I guess numpy doesn't support bfloat16.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants