New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can I convert llama3 safetensors to the pth file needed to use with executorch? #3303
Comments
Take a look at some example util functions in torchtune. Let us know if it works |
So I need to use pytorch to save the state dict file? I tried that with a Llama3 fine tune and then tried to compile it for XNNPACK, I got this error:
This is the code I used to save the
What am I doing wrong here? |
@l3utterfly Could you share the command you used in torchtune, as well as the export_llama? |
@iseeyuan sorry, I'm a little new to torch tune, following the documentation here: https://pytorch.org/torchtune/stable/deep_dives/checkpointer.html#understand-checkpointer
This is the command I'm using to convert: |
@l3utterfly I can see two options,
|
@iseeyuan this is the error I'm getting afer converting the safetensor file with torchtune util function you suggested. This error happens after running the compile pt to executorch command:
I tried to fine tune with torchtune, but it appears Torchtune does not support finetuning on top of another finetune? It is still looking for the Any way forward from here? |
@l3utterfly Let me try to convert the safetensor files and let you know if there's a way to workaround.
@kartikayk, are you aware of this? |
@l3utterfly torchtune doesn't really care about how the checkpoint is produced i.e. whether it's a finetuned model or a pre-trained model. All it cares about is that the formats should match up with what the checkpointer expects. When the model changes, you need to update the config to point to the right checkpoint etc. Can you share the exact torchtune config you're using and the command you used to launch training? |
@kartikayk I am trying to convert this model: https://huggingface.co/ResplendentAI/Aura_Uncensored_l3_8B It doesn't contain any original pytorch checkpoint files that torchtune supports, so trying to finetune with torch gets me back to square one: how can I convert the safetensor files to a pytorch checkpoint. |
This should work OOTB if you update the checkpoint files in the config to point to the safetensors. I tried loading the checkpoint into the llama3 8B model in torchtune and the keys loaded successfull: Note that you need to update the checkpointer to point to the HFCheckpointer since the safetensor files are in the HF format. The deep dive you pointed to above has a lot more information about checkpointer formats, but let me know if you have questions. I'm not sure what your config looks like, but just update the Llama3 config to the following:
For examples of how to use safetensors take a look at the Gemma configs |
Thanks for the help, I got the finetuning to run now with this config. First, I tried to load and save the dict right away:
This still gives the same error as before when trying to compile down to executorch.
I am now in the process of doing a "finetune" with |
Glad this worked!
Yes, if you want to just get dummy checkpoints you can set
I don't think executorch supports the HF format though. @iseeyuan can confirm. |
@l3utterfly I took a deeper look into the state dict (
Try exactly like what @kartikayk suggested: sd = checkpointer.load_checkpoint()
print("saving checkpoint")
torch.save(sd['model'], "/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B/checkpoint.pth")
But the converted layer 0 looks like,
I don't know what causes this name differences. Since our flow is from the original llama3 checkpoint, I'd suggest you use the same checkpoint and iterate based on that. |
Update: After chatting with @kartikayk , we need another convert from torchtune to meta's llama3 format. So the code below should work, from torchtune.utils import FullModelHFCheckpointer
from torchtune.models import convert_weights
import torch
checkpointer = FullModelHFCheckpointer(
checkpoint_dir='/Users/myuan/.cache/huggingface/hub/models--ResplendentAI--Aura_Uncensored_l3_8B/snapshots/e7720d40e4d8d3c0fa07a8a579fda4d0644aa731',
checkpoint_files=['model-00001-of-00002.safetensors', 'model-00002-of-00002.safetensors'],
output_dir='/Users/myuan/data/Aura_Uncensored_l3_8B' ,
model_type='LLAMA3'
)
print("loading checkpoint")
sd = checkpointer.load_checkpoint()
sd = convert_weights.tune_to_meta(sd['model'])
print("saving checkpoint")
torch.save(sd, "/Users/myuan/data/Aura_Uncensored_l3_8B/checkpoint.pth") It works well from my side to successfully lower the checkpoint to ExecuTorch. @l3utterfly could you try above conversion and let us know if it works? |
Yes! This works, thank you so much for helping me! I think it may be helpful to put/link this script in the Executorch Llama3 docs? I think this is very beneficial to accelerate adoption of Executorch by the wider community. Now people can load Llama3 finetunes with exeuctorch instead of only working with the base model. It will really encourage the local AI community to build infrastructure around executorch! |
@l3utterfly It's a great idea. Let me put up a PR for this with documentations. |
Summary: As titled. It's pretty common that users download the LLM models in safetensor format. Add instructions and example script to convert them to PyTorch format so that export_llama script can accept. It leverages the utils from TorchTune. Thanks l3utterfly and kartikayk for the discussions and suggestions! More context in #3303 Pull Request resolved: #3523 Reviewed By: mergennachin Differential Revision: D57026658 Pulled By: iseeyuan fbshipit-source-id: 11badf709920ff945cdfdd2b244c52c750943412
Fine-tunes of Llama3 usually only have safetensors uploaded. In order to compile a Llama3 model following the tutorial, I need the original pth checkpoint file.
Is there a way to convert the safetensors to the checkpoint file?
The text was updated successfully, but these errors were encountered: