Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

你好,请教下,bash train_full.sh -m train微调后的last.ckpt模型,有办法转成类似https://huggingface.co/Qwen/Qwen-1_8B-Chat/tree/main下面的safetensors文件吗 #29

Open
dgo2dance opened this issue Jan 22, 2024 · 1 comment

Comments

@dgo2dance
Copy link

dgo2dance commented Jan 22, 2024

你好,请教下,bash train_full.sh -m train微调后的last.ckpt模型,有办法转成类似https://huggingface.co/Qwen/Qwen-1_8B-Chat/tree/main 下面的safetensors文件吗

以方便之后用llama.cpp做进一步转换

@ssbuild
Copy link
Owner

ssbuild commented Jan 23, 2024

import argparse
import os
from collections import defaultdict

import torch

from safetensors.torch import load_file, save_file

def shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for ptr, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def check_file_size(sf_filename: str, pt_filename: str):
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size

    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(
            f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """
        )


def convert_file(
    pt_filename: str,
    sf_filename: str,
):
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})
    check_file_size(sf_filename, pt_filename)
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


if __name__ == "__main__":
    DESCRIPTION = """
    Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
    It is PyTorch exclusive for now.
    It converts them locally.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument(
        "file",
        type=str,
        help="Weights file",
    )
    args = parser.parse_args()
    file = args.file
    name = file.split('.')
    name.pop()
    name = ''.join(name)
    dist = f'{name}.safetensors'
    print(dist)
    if os.path.exists(dist):
        print(f'Error: {dist} already exists')
    else:
        convert_file(file, dist)
        print('Success')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants