You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import argparse
import os
from collections import defaultdict
import torch
from safetensors.torch import load_file, save_file
def shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for ptr, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
"""
)
def convert_file(
pt_filename: str,
sf_filename: str,
):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
check_file_size(sf_filename, pt_filename)
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
It is PyTorch exclusive for now.
It converts them locally.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"file",
type=str,
help="Weights file",
)
args = parser.parse_args()
file = args.file
name = file.split('.')
name.pop()
name = ''.join(name)
dist = f'{name}.safetensors'
print(dist)
if os.path.exists(dist):
print(f'Error: {dist} already exists')
else:
convert_file(file, dist)
print('Success')
你好,请教下,bash train_full.sh -m train微调后的last.ckpt模型,有办法转成类似https://huggingface.co/Qwen/Qwen-1_8B-Chat/tree/main 下面的safetensors文件吗
以方便之后用llama.cpp做进一步转换
The text was updated successfully, but these errors were encountered: