Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update delegate for derived bias qspec #3511

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def define_nodes_tensor_inputs_outputs(
if input_type_map.node_bias is not None:
bias_node = get_input_node(node, input_type_map.node_bias)
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
bias_node, self._exported_program
)
self.define_tensor(
bias_node,
Expand Down
4 changes: 1 addition & 3 deletions backends/xnnpack/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def define_node(
if node.args[2] is not None:
# If there is a bias
bias_node = get_input_node(node, 2)
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
)
bias_quant_params = QuantParams.from_bias(bias_node, self._exported_program)
self.define_tensor(
get_input_node(node, 2),
xnn_graph,
Expand Down
4 changes: 1 addition & 3 deletions backends/xnnpack/operators/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def define_node(
# bias
if len(node.args) > 2:
bias_node = get_input_node(node, 2)
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
)
bias_quant_params = QuantParams.from_bias(bias_node, self._exported_program)
self.define_tensor(
get_input_node(node, 2),
xnn_graph,
Expand Down
45 changes: 14 additions & 31 deletions backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,41 +306,24 @@ def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:

@classmethod
def from_bias(
cls,
bias: torch.fx.Node,
weight_quantizer: Optional[QuantParams],
input_quantizer: Optional[QuantParams],
cls, tensor_node: torch.fx.Node, ep: ExportedProgram
) -> Optional[QuantParams]:
if weight_quantizer is None or input_quantizer is None:
check_or_raise(
weight_quantizer is None and input_quantizer is None,
"Weight and Input should both be quantized",
)
return None
# source node for quant params
dq = tensor_node

if input_quantizer.is_dynamic:
# No need to quantize bias for dyanamic quantization
if not is_dequant(dq):
return None

check_or_raise(
not input_quantizer.per_channel,
"Input can not be quantized per channel",
)
src = dq

# is input of dq is q?
dq_input = dq.all_input_nodes[0]
if is_quant(dq_input):
src = dq_input

# Only per_tensor quantization is supported for input here
check_or_raise(
isinstance(input_quantizer.scale, float),
f"q_input scale should be float, but got {input_quantizer.scale}",
)
return cls(
per_channel=weight_quantizer.per_channel,
q_input=bias,
scale=weight_quantizer.scale * cast(float, input_quantizer.scale),
zp=weight_quantizer.zp * 0,
axis=0, # not using weight_quantizer.axis because bias is always of shape [out_channels] i.e. 1D
dtype=torch.int32,
qmin=-(2**31),
qmax=(2**31) - 1,
is_output=False,
is_input=False,
src.all_input_nodes[0].op in ["get_attr", "placeholder"],
f"Expected input to quant -> dequant chain from bias to be static tensor, but instead got: {src.all_input_nodes[0]}",
)

return cls.from_q_dq_node(src, ep)