Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gguf_q4k_m / gguf_q4k_s #10887

Merged
merged 7 commits into from May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/llm/src/ipex_llm/ggml/quantize.py
Expand Up @@ -47,8 +47,13 @@
"gguf_iq1_m": 25,
"q6_k": 26,
"q4_k": 27,
"q5_k": 28,
"fp6": 29}

# mixed precison from llama.cpp
gguf_mixed_qtype = {"gguf_q4k_s": 101,
"gguf_q4k_m": 102}

_llama_quantize_type = {"q4_0": 2,
"q4_1": 3,
"q5_0": 8,
Expand Down
34 changes: 21 additions & 13 deletions python/llm/src/ipex_llm/transformers/convert.py
Expand Up @@ -42,7 +42,7 @@
import warnings
import transformers
import importlib.util
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from .utils import logger, get_cur_qtype_and_imatrix
from typing import Union
import numpy as np
Expand Down Expand Up @@ -337,15 +337,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if in_features % 64 != 0:
# now our kernel requires in_features is a multiple of 64
continue
new_linear = LowBitLinear(
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
)
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data,
Expand All @@ -355,6 +346,16 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
ggml_tensor_qtype["asym_int4"]]:
cur_qtype = ggml_tensor_qtype["sym_int8"]

new_linear = LowBitLinear(
in_features,
out_features,
cur_qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
)
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
Expand Down Expand Up @@ -766,9 +767,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
embedding_qtype=None,
enable_xetla=False,
mixed_precision=False):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
if qtype in ggml_tensor_qtype.values():
index = list(ggml_tensor_qtype.values()).index(qtype)
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[index]} "
f"format......")
else:
index = list(gguf_mixed_qtype.values()).index(qtype)
logger.info(f"Converting the current model to "
f"{list(gguf_mixed_qtype.keys())[index]} "
f"format......")
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert

# using ipex_llm optimizer before changing to bigdl linear
Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/low_bit_linear.py
Expand Up @@ -79,6 +79,7 @@
IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
Q4_K = ggml_tensor_qtype["q4_k"]
Q6_K = ggml_tensor_qtype["q6_k"]
Q5_K = ggml_tensor_qtype["q5_k"]


# For sym_int4
Expand Down Expand Up @@ -219,7 +220,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K]:
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
else:
if imatrix is not None:
Expand Down
50 changes: 32 additions & 18 deletions python/llm/src/ipex_llm/transformers/model.py
Expand Up @@ -42,7 +42,7 @@
from .utils import extract_local_archive_file, \
load_state_dict, \
get_local_shard_files, load_imatrix_data
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.gguf.api import load_gguf_model
import torch
Expand Down Expand Up @@ -117,12 +117,12 @@ def from_pretrained(cls,
Default to be ``False``.
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
``'nf4'``, ``'fp4'``, ``'fp6'`` ``'fp8'``, ``'fp8_e4m3'``,
``'fp8_e5m2'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``,
``'gguf_iq1_s'``, ``'fp16'``, ``'bf16'``, ``'q4_k'`` or
``'q6_k'``, ``'sym_int4'`` means symmetric int 4,
``'asym_int4'`` means asymmetric int 4,
``'nf4'`` means 4-bit NormalFloat, etc.
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
``'fp6'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``,
``'gguf_iq1_s'``, ``'gguf_q4k_m'``, ``'gguf_q4k_s'``,
``'fp16'``, ``'bf16'``,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be ``True``.
Expand All @@ -139,8 +139,9 @@ def from_pretrained(cls,
added to llama.cpp.
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
specify the model hub. Default to be ``'huggingface'``.
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
Relevant low bit optimizations will be applied to nn.Embedding layer.
:param embedding_qtype: str value, options are ``'q2_k'``, ``'q4_k'`` now.
Default to be None. Relevant low bit optimizations will be applied to
``nn.Embedding`` layer.
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
Default to be False. If set to True, we will use sym_int8 for lm_head when
load_in_low_bit is sym_int4 or asym_int4.
Expand Down Expand Up @@ -321,10 +322,12 @@ def from_pretrained(cls,
"For gguf_iq2 and gguf_iq1 quantization,"
"imatrix is needed.")
cpu_embedding = kwargs.get("cpu_embedding", False)
# for 2bit, default use embedding_quantization
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \
not cpu_embedding and embedding_qtype is None:
embedding_qtype = "q2_k"
# for iq2/k-quants, default use embedding_quantization
if not cpu_embedding and embedding_qtype is None:
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"]:
embedding_qtype = "q2_k"
elif q_k in ["gguf_q4k_s", "gguf_q4k_m"]:
embedding_qtype = "q4_k"
if imatrix_file is not None:
imatrix_data = load_imatrix_data(imatrix_file)
kwargs["imatrix_data"] = imatrix_data
Expand Down Expand Up @@ -376,12 +379,16 @@ def from_gguf(fpath: str, optimize_model: bool = True,
@classmethod
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
from .convert import ggml_convert_low_bit
invalidInputError(q_k in ggml_tensor_qtype,
invalidInputError(q_k in ggml_tensor_qtype or q_k in gguf_mixed_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
f"fp4, fp6, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, "
f"gguf_iq2_xs, gguf_iq1_s, q2_k, q4_k, q6_k, mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k]
f"gguf_iq2_xs, gguf_iq1_s, q2_k, q4_k, q5_k, q6_k, "
f"gguf_q4k_s, gguf_q4k_m, mixed_fp4 or mixed_fp8.")
if q_k in ggml_tensor_qtype:
qtype = ggml_tensor_qtype[q_k]
else:
qtype = gguf_mixed_qtype[q_k]

# In case it needs a second try,
# `from_pretrained`` may pop items out in dict
Expand Down Expand Up @@ -550,17 +557,24 @@ def load_low_bit(cls,
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
" serialize the model using save_low_bit first.")

invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype or
bigdl_transformers_low_bit in gguf_mixed_qtype,
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")

# set default optimize_model=True
optimize_model = kwargs.pop("optimize_model", True)

qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
if bigdl_transformers_low_bit in ggml_tensor_qtype:
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
else:
qtype = gguf_mixed_qtype[bigdl_transformers_low_bit]
if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \
not cpu_embedding:
embedding_qtype = "q2_k"
elif bigdl_transformers_low_bit in ["gguf_q4k_s", "gguf_q4k_m"] and \
not cpu_embedding:
embedding_qtype = "q4_k"
if embedding_qtype is not None:
embedding_qtype = ggml_tensor_qtype[embedding_qtype]

Expand Down
25 changes: 22 additions & 3 deletions python/llm/src/ipex_llm/transformers/utils.py
Expand Up @@ -41,7 +41,7 @@
# SOFTWARE.
import os
from transformers.modeling_utils import _add_variant
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from ..utils.common import invalidInputError
from typing import Union, Optional
import torch
Expand Down Expand Up @@ -271,10 +271,12 @@ def module_name_process(full_module_name):

def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
cur_qtype = qtype
cur_imatrix = None
if model_config is not None:
model_type = getattr(model_config, "model_type", None)
else:
model_dtype = None

if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
ggml_tensor_qtype["gguf_iq1_s"]]:
# For quantization which needs importance matrix
Expand Down Expand Up @@ -306,7 +308,6 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
cur_imatrix = None
if new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int8']
return cur_qtype, cur_imatrix
elif qtype == ggml_tensor_qtype["q2_k"]:
new_module_name, layer, cur_module = module_name_process(full_module_name)
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
Expand All @@ -319,8 +320,26 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
cur_imatrix = None
if new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int8']
elif qtype > 100:
# gguf mixed precision
new_module_name, layer, cur_module = module_name_process(full_module_name)
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['q6_k']
elif qtype == gguf_mixed_qtype["gguf_q4k_m"]:
if int(layer) < int(num_hidden_layers/2) and cur_module in ['v', 'down']:
cur_qtype = ggml_tensor_qtype['q6_k']
else:
cur_qtype = ggml_tensor_qtype['q4_k']
elif qtype == gguf_mixed_qtype["gguf_q4k_s"]:
if int(layer) < int(num_hidden_layers/8) and cur_module in ['v', 'down']:
cur_qtype = ggml_tensor_qtype['q5_k']
else:
cur_qtype = ggml_tensor_qtype['q4_k']
else:
return qtype, None
pass
return cur_qtype, cur_imatrix


def get_modelscope_hf_config(model_id_or_path: str,
Expand Down