You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "train_controlnet_sdxl.py", line 1252, in <module>
main(args)
File "train_controlnet_sdxl.py", line 1013, in main
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 592, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 557, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3093, in map
for rank, done, content in Dataset._map_single(**dataset_kwargs):
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3489, in _map_single
writer.write_batch(batch)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_writer.py", line 557, in write_batch
arrays.append(pa.array(typed_sequence))
File "pyarrow/array.pxi", line 248, in pyarrow.lib.array
File "pyarrow/array.pxi", line 113, in pyarrow.lib._handle_arrow_array_protocol
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_writer.py", line 191, in __arrow_array__
out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/features/features.py", line 447, in cast_to_python_objects
return _cast_to_python_objects(
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/features/features.py", line 324, in _cast_to_python_objects
for x in obj.detach().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16
Steps to reproduce the bug
Here is my train script I use BF16 type,I use diffusers train my model
export MODEL_DIR="/home/mhh/sd_models/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="./control_net"
export VAE_NAME="/home/mhh/sd_models/sdxl-vae-fp16-fix"
accelerate launch train_controlnet_sdxl.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--pretrained_vae_model_name_or_path=$VAE_NAME \
--dataset_name=/home/mhh/sd_datasets/fusing/fill50k \
--mixed_precision="bf16" \
--resolution=1024 \
--learning_rate=1e-5 \
--max_train_steps=200 \
--validation_image "/home/mhh/sd_datasets/controlnet_image/conditioning_image_1.png" "/home/mhh/sd_datasets/controlnet_image/conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=50 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
Expected behavior
When I changed the data type to fp16, it worked.
Environment info
datasets 2.16.1
numpy 1.24.4
The text was updated successfully, but these errors were encountered:
Describe the bug
Steps to reproduce the bug
Here is my train script I use BF16 type,I use diffusers train my model
Expected behavior
When I changed the data type to fp16, it worked.
Environment info
datasets 2.16.1
numpy 1.24.4
The text was updated successfully, but these errors were encountered: