Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_edge_transform_and_lower #3483

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 21 additions & 1 deletion exir/backend/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import MappingProxyType
from typing import Dict, List, Mapping, NamedTuple, Union
from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union

import torch

from executorch.exir.backend.backend_details import enforcedmethod
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand Down Expand Up @@ -91,3 +93,21 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
"""
pass

def ops_to_not_decompose(
self,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Returns a list of operator names that should not be decomposed. When these ops are
registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be
guaranteed that the program that the backend receives will not have any of these ops
decomposed.

Returns:
List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide
which will be called for each node in the graph that users can use as a filter for certain
nodes that should be continued to be decomposed even though the op they correspond to is
in the list returned by ops_to_not_decompose.
"""
return ([], None)
16 changes: 10 additions & 6 deletions exir/backend/test/backend_with_compiler_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ def preprocess(
processed_bytes = ""
number_of_instruction = 0
debug_handle_map = {}
match_ops = [
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.add.Tensor,
torch.ops.aten.sin.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]

for node in edge_program.graph.nodes:
if node.op == "call_function":
# TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect.
if (
node.target == exir_ops.edge.aten.sin.default
or node.target == exir_ops.edge.aten.mm.default
or node.target == exir_ops.edge.aten.add.Tensor
or node.target == torch.ops.aten.sin.default
):
if node.target in match_ops:
simple_op = DemoOp(
node.target.__name__,
int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()),
Expand Down
74 changes: 73 additions & 1 deletion exir/backend/test/op_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, final
from typing import Callable, Dict, final, List, Optional, Tuple

import torch
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
Expand Down Expand Up @@ -71,6 +71,7 @@ def _partition_graph_module(
for _, submodule, _ in get_control_flow_submodules(graph_module):
ret_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(ret_partition_tags)

return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down Expand Up @@ -121,3 +122,74 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
)


ops_not_to_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.scaled_dot_product_attention.default,
]

edge_ops_non_decomposed = [
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]


class OpsToNotDecomposeOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in edge_ops_non_decomposed


@final
class NonDecompTestPartitioner(Partitioner):
"""
Partitions all add/mul nodes regardless of order
"""

def __init__(self) -> None:
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[CompileSpec("max_value", bytes([4]))],
)

def ops_to_not_decompose(
self,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
def filter_ops(node: torch.fx.Node) -> bool:
if node.op == "call_function" and node.target in ops_not_to_decompose:
if len(node.args) == 3:
# This means that linear has a bias which is the only linear we support in this
# demo partitioner.
return True
else:
return False

return True

return (ops_not_to_decompose, filter_ops)

def _partition_graph_module(
self,
graph_module: torch.fx.GraphModule,
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

for _, submodule, _ in get_control_flow_submodules(graph_module):
ret_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(ret_partition_tags)
return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
2 changes: 2 additions & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:error",
"//executorch/exir:graph_module",
"//executorch/exir:pass_manager",
"//executorch/exir:print_program",
"//executorch/exir:schema",
Expand All @@ -36,6 +37,7 @@ python_library(
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_aten_with_edge_pass",
"//executorch/exir/passes:replace_view_copy_with_view_pass",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/verification:verifier",
Expand Down
2 changes: 2 additions & 0 deletions exir/program/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from executorch.exir.program._fake_program import get_fake_program
from executorch.exir.program._program import (
_to_edge,
_to_edge_transform_and_lower,
edge_to_executorch_passes,
EdgeProgramManager,
ExecutorchProgram,
Expand All @@ -22,6 +23,7 @@
"ExecutorchProgram",
"_to_edge",
"to_edge",
"_to_edge_transform_and_lower",
"edge_to_executorch_passes",
"EdgeProgramManager",
"ExecutorchProgramManager",
Expand Down