Skip to content

Commit

Permalink
to_edge_transform_and_lower
Browse files Browse the repository at this point in the history
Summary:
This diff introduces the to_edge_transform_and_lower API. The changes introduces are:
- Adding support to the Parititioner class to register ops that it doesn't want to be composed
- Changes to _program.py to add the implementation of to_edge_transform_and_lower()
- Added a basic test case to test that Linear, SDPA & Linear + SDPA are not decomposed when asked and the corresponding backend consumes them.

Differential Revision: D56401086
  • Loading branch information
tarun292 authored and facebook-github-bot committed May 2, 2024
1 parent 0a916d4 commit e2312fb
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 11 deletions.
15 changes: 15 additions & 0 deletions exir/backend/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from types import MappingProxyType
from typing import Dict, List, Mapping, NamedTuple, Union

import torch

from executorch.exir.backend.backend_details import enforcedmethod
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export import ExportedProgram
Expand Down Expand Up @@ -91,3 +93,16 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
"""
pass

@abstractmethod
def ops_to_not_decompose(self) -> List[torch._ops.OpOverload]:
"""
Returns a list of operator names that should not be decomposed. When these ops are
registered and the backend is invoked through to_edge_transform_and_lower it will be
guaranteed that the program that the backend receives will not have any of these ops
decomposed.
Returns:
List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
"""
pass
16 changes: 10 additions & 6 deletions exir/backend/test/backend_with_compiler_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ def preprocess(
processed_bytes = ""
number_of_instruction = 0
debug_handle_map = {}
match_ops = [
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.add.Tensor,
torch.ops.aten.sin.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]

for node in edge_program.graph.nodes:
if node.op == "call_function":
# TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect.
if (
node.target == exir_ops.edge.aten.sin.default
or node.target == exir_ops.edge.aten.mm.default
or node.target == exir_ops.edge.aten.add.Tensor
or node.target == torch.ops.aten.sin.default
):
if node.target in match_ops:
simple_op = DemoOp(
node.target.__name__,
int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()),
Expand Down
60 changes: 59 additions & 1 deletion exir/backend/test/op_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, final
from typing import Dict, final, List

import torch
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
Expand Down Expand Up @@ -121,3 +121,61 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
)


ops_not_to_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.scaled_dot_product_attention.default,
]

edge_ops_non_decomposed = [
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]


class OpsToNotDecomposeOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in edge_ops_non_decomposed


@final
class NonDecompTestPartitioner(Partitioner):
"""
Partitions all add/mul nodes regardless of order
"""

def __init__(self) -> None:
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[CompileSpec("max_value", bytes([4]))],
)

def ops_to_not_decompose(self) -> List[torch._ops.OpOverload]:
return ops_not_to_decompose

def _partition_graph_module(
self,
graph_module: torch.fx.GraphModule,
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

for _, submodule, _ in get_control_flow_submodules(graph_module):
ret_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(ret_partition_tags)
return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
1 change: 1 addition & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:error",
"//executorch/exir:graph_module",
"//executorch/exir:pass_manager",
"//executorch/exir:print_program",
"//executorch/exir:schema",
Expand Down
2 changes: 2 additions & 0 deletions exir/program/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
ExecutorchProgramManager,
ExirExportedProgram,
to_edge,
to_edge_transform_and_lower,
)

__all__ = [
"ExirExportedProgram",
"ExecutorchProgram",
"_to_edge",
"to_edge",
"to_edge_transform_and_lower",
"edge_to_executorch_passes",
"EdgeProgramManager",
"ExecutorchProgramManager",
Expand Down

0 comments on commit e2312fb

Please sign in to comment.