Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to conduct in8 quantilization and calibration in Python? #3858

Open
yjiangling opened this issue May 13, 2024 · 5 comments
Open

How to conduct in8 quantilization and calibration in Python? #3858

yjiangling opened this issue May 13, 2024 · 5 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@yjiangling
Copy link

Hi, all, I'm tring to convert an onnx model to TensorRT with INT8 quantilization in Python environment, here is the code:

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
from glob import glob
import numpy as np
import librosa
import os

class DataLoader:
	def __init__(self, batch_size, calib_count, calib_files_dir):
		self.index = 0
		self.batch_size = batch_size
		self.calib_count = calib_count
		self.file_list = glob(os.path.join(calib_files_dir, "*.wav"))
		assert (
			len(self.file_list) > self.batch_size * self.calib_count
		), "{} must contains more than {} files for calibration.".format(
			calib_files_dir, self.batch_size * self.calib_count
		)

	def reset(self):
		self.index = 0

	def next_batch(self):
		if self.index < self.calib_count:
			file_path = self.file_list[self.index]
			x, x_len = read_wave(file_path, 160, 512)
			self.index += 1
			calib_data = {'xs': np.ascontiguousarray(x, dtype=np.float32), 'xlen': np.ascontiguousarray(x_len, dtype=np.int32)}
			return calib_data
		else:
			# return np.array([])
			return {'xs': None, 'xlen': None}

	def __len__(self):
		return self.calib_count

def read_wave(filename, hop_length, n_fft):
	pad_len = int(n_fft // 2)
	x, sr = librosa.load(filename, sr=None)
	x = np.pad(x, pad_width=(pad_len, pad_len), mode="reflect")
	nf = np.math.ceil((len(x) - 2*hop_length - n_fft) / hop_length / 4) * 4 + 3
	x = np.pad(x, (0, (nf - 1) * hop_length + n_fft - len(x)))
	data_len = np.array([len(x)], dtype=np.int32)
	data = np.expand_dims(np.array(x, dtype=np.float32), axis=0)
	return data, data_len


class Calibrator(trt.IInt8EntropyCalibrator2):
	def __init__(self, data_loader, cache_file=""):
		trt.IInt8EntropyCalibrator2.__init__(self)
		self.data_loader = data_loader
		# self.d_input = cuda.mem_alloc(self.data_loader.calibration_data.nbytes)
		self.d_input = {'xs': cuda.mem_alloc(np.zeros([1, 480000], np.float32).nbytes), 
						'xlen': cuda.mem_alloc(np.zeros([1], np.int32).nbytes)}
		self.cache_file = cache_file
		data_loader.reset()

	def get_batch_size(self):
		return self.data_loader.batch_size

	def get_batch(self, names):
		batch = self.data_loader.next_batch()
		# if not batch.size:
		if all(value is None for value in batch.values()):
			return None
		# 把校准数据从CPU搬运到GPU中
		# cuda.memcpy_htod(self.d_input, batch)
		cuda.memcpy_htod(self.d_input['xs'], batch['xs'])
		cuda.memcpy_htod(self.d_input['xlen'], batch['xlen'])

		# return [self.d_input]
		return self.d_input

	def read_calibration_cache(self):
		# 如果校准表文件存在则直接从其中读取校准表
		if os.path.exists(self.cache_file):
			with open(self.cache_file, "rb") as f:
				return f.read()

	def write_calibration_cache(self, cache):
		# 如果进行了校准,则把校准表写入文件中以便下次使用
		with open(self.cache_file, "wb") as f:
			f.write(cache)
			f.flush()


def build_engine():
	onnx_file_path = './onnx_model/model.onnx'
	calibration_table_path = 'offline_asr_calib_cache'
	engine_file_path = './trt_model/model.INT8.plan'
	calib_files_path = './data/calibrationt/calib_5000/'  # dataset for calibration
	mode = "INT8"

	TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
	builder = trt.Builder(TRT_LOGGER)
	network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
	config = builder.create_builder_config()
	parser = trt.OnnxParser(network, TRT_LOGGER)
	assert os.path.exists(onnx_file_path), "The onnx file {} is not found".format(onnx_file_path)
	with open(onnx_file_path, "rb") as model:
		if not parser.parse(model.read()):
			print("Failed to parse the ONNX file.")
			for error in range(parser.num_errors):
				print(parser.get_error(error))
			return None

	print("Building an engine from file {}, this may take a while...".format(onnx_file_path))

	# build tensorrt engine
	config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 * (1 << 30))
	profile = builder.create_optimization_profile()
	profile.set_shape("xs", min=(1, 1120), opt=(1, 160000), max=(1, 480000))
	profile.set_shape("xlen", min=(1,), opt=(1,), max=(1,))
	config.add_optimization_profile(profile)

	if mode == "INT8":
		data_loader = DataLoader(1, 10, calib_files_path)
		config.set_flag(trt.BuilderFlag.INT8)
		calibrator = Calibrator(data_loader, calibration_table_path)
		config.int8_calibrator = calibrator
	elif mode == "FP16":
		config.set_flag(trt.BuilderFlag.FP16)

	engine = builder.build_engine(network, config)
	# engine = builder.build_cuda_engine(network, config)
	if engine is None:
		print("Failed to create the engine")
		return None
	with open(engine_file_path, "wb") as f:
		f.write(engine.serialize())

	return engine


if __name__ == '__main__':

	build_engine()

The model have two input tensor("xs" and "xlen"), they have dynamic input shape, when I run this script, it always give the following error:

[05/13/2024-17:39:19] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:367: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
Building an engine from file ./onnx_model/model.onnx, this may take a while...
quantilize.py:173: DeprecationWarning: Use build_serialized_network instead.
engine = builder.build_engine(network, config)
[05/13/2024-17:39:26] [TRT] [W] Calibration Profile is not defined. Running calibration with Profile 0
[05/13/2024-17:39:26] [TRT] [W] Calibration Profile is not defined. Running calibration with Profile 0
[ERROR] Exception caught in get_batch(): Unable to cast Python instance to C++ type (compile in debug mode for details)
[05/13/2024-17:39:44] [TRT] [E] 1: Unexpected exception _Map_base::at
Failed to create the engine

What's wrong? Is there any error in my code? How can I fix this error and successful finish this job? Anyone can give some helps? Thanks a lot in advance!!!

@yjiangling
Copy link
Author

By the way, I use TensorRT 8.4.1, dose the calibration API in Python not work? Hoping someone can give some help! Many thanks.

[05/15/2024-15:50:14] [TRT] [I] Starting Calibration.
[ERROR] Exception caught in get_batch(): Unable to cast Python instance to C++ type (compile in debug mode for details)
[05/15/2024-15:50:16] [TRT] [I] Post Processing Calibration data in 2.194e-06 seconds.

@zerollzeng
Copy link
Collaborator

@zerollzeng zerollzeng self-assigned this May 17, 2024
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label May 17, 2024
@yjiangling
Copy link
Author

yjiangling commented May 21, 2024

You can make use of polygraphy, see https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/examples/cli/convert/01_int8_calibration_in_tensorrt

@zerollzeng Thanks a lot for the support. By the way, can the engine build with EngineFromNetwork API be saved to disk, get the new quantilized TensorRT engine file? I use the following code to generate the TensorRT engine:

from polygraphy.backend.trt import EngineFromNetwork, save_engine

build_engine = EngineFromNetwork(
	NetworkFromOnnxPath(onnx_file_path),
	config=CreateConfig(int8=True, profiles=[calib_profile], calibrator=calibrator),
	)

#save the quantilized engine to TensorRT plan file
save_engine(build_engine, engine_engine_path)

But why the new engine file size is not 1/4 of the old one generated with float32? From 156M to 95M, the onnx file size is 153M, what's wrong? Is the save engine code right?

@yjiangling
Copy link
Author

You can make use of polygraphy, see https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/examples/cli/convert/01_int8_calibration_in_tensorrt

@zerollzeng Thanks a lot for the support. By the way, can the engine build with EngineFromNetwork API be saved to disk, get the new quantilized TensorRT engine file? I use the following code to generate the TensorRT engine:

from polygraphy.backend.trt import EngineFromNetwork, save_engine

build_engine = EngineFromNetwork(
	NetworkFromOnnxPath(onnx_file_path),
	config=CreateConfig(int8=True, profiles=[calib_profile], calibrator=calibrator),
	)

#save the quantilized engine to TensorRT plan file
save_engine(build_engine, engine_engine_path)

But why the new engine file size is not 1/4 of the old one generated with float32? From 156M to 95M, the onnx file size is 153M, what's wrong? Is the save engine code right?

@zerollzeng Sorry for the bother again, I tried to use trtexec tool to generate INT8 quantilize engine without calibration like this:

trtexec --onnx=onnx_model/model.onnx \
	--minShapes=xs:1x1120,xlen:1 \
	--optShapes=xs:1x160000,xlen:1 \
	--maxShapes=xs:1x480000,xlen:1 \
	--workspace=10240 \
	--int8 \
	--saveEngine=trt_model/trt-INT8.plan \
	--verbose \
	--buildOnly \
	> $trt_model/result-INT8.txt

The quantilized TensorRT engine size become 51M, why it is much smaller than the engine generated with polygraphy? Because the latter contains Q/DQ layers? And I test the inference speed of FP32 engine and INT8 engine, it almost the same, what's wrong? I test them on the A100 GPU.

@zerollzeng
Copy link
Collaborator

The quantilized TensorRT engine size become 51M, why it is much smaller than the engine generated with polygraphy?

Many factor affect the final engine size. I don't have a clear conclusion in your case.

And I test the inference speed of FP32 engine and INT8 engine, it almost the same, what's wrong? I test them on the A100 GPU.

I guess sub-optimal Q/DQ placement. You can check the engine layer information in verbose log or check the layer profile(see trtexec -h) to confirm.

You can take PTQ as best perf goal. use model without Q/DQ and build with --best to see how good the perf is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants