Skip to content

Commit

Permalink
remove exir.capture from hta_partitioner_demo.py (#3504)
Browse files Browse the repository at this point in the history
Summary:

title

Differential Revision: D56941930
  • Loading branch information
JacobSzwejbka authored and facebook-github-bot committed May 3, 2024
1 parent b9488fe commit 8506bde
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 322 deletions.
97 changes: 36 additions & 61 deletions exir/backend/test/hta_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import final, List

import torch

from executorch import exir
from executorch.exir import to_edge
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
Expand All @@ -20,7 +22,7 @@
)
from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
from executorch.exir.backend.utils import tag_constant_data
from torch.export import ExportedProgram
from torch.export import export, ExportedProgram
from torch.fx.passes.infra.partitioner import Partition


Expand Down Expand Up @@ -64,56 +66,41 @@ def forward(self, x_raw, h, c):
input_c = torch.ones([1, 32])

pattern_lstm_conv_lifted = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(enable_aot=True),
)
.to_edge(
to_edge(
export(
LSTMConvPattern(),
(input_x, input_h, input_c),
),
compile_config=
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
exir.EdgeCompileConfig(_check_ir_validity=False)
exir.EdgeCompileConfig(_check_ir_validity=False),
)
.exported_program.graph_module
)
pattern_lstm_conv = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(),
)
.to_edge(
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
exir.EdgeCompileConfig(_check_ir_validity=False)
)
.exported_program.graph_module
.exported_program()
.graph_module
)

def sub(x, y):
return torch.sub(x, y)
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.sub(x, y)

pattern_sub_lifted = (
exir.capture(
sub,
(input_x, input_h),
exir.CaptureConfig(enable_aot=True, _unlift=False),
to_edge(
export(
Sub(),
(input_x, input_h),
),
compile_config=exir.EdgeCompileConfig(_use_edge_ops=True),
)
.to_edge(exir.EdgeCompileConfig(_use_edge_ops=True))
.exported_program.graph_module
)
pattern_sub = (
exir.capture(
sub,
(input_x, input_h),
exir.CaptureConfig(),
)
.to_edge()
.exported_program.graph_module
.exported_program()
.graph_module
)

self.patterns = [
pattern_lstm_conv_lifted.graph,
pattern_lstm_conv.graph,
pattern_sub_lifted.graph,
pattern_sub.graph,
]

backend_id = QnnBackend.__name__
Expand Down Expand Up @@ -240,32 +227,20 @@ def forward(self, x_raw, h, c):
input_c = torch.ones([1, 32])

pattern_lstm_conv_lifted = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(enable_aot=True),
)
.to_edge(
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
exir.EdgeCompileConfig(_check_ir_validity=False)
)
.exported_program.graph_module
)
pattern_lstm_conv_unlifted = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(),
)
.to_edge(
to_edge(
export(
LSTMConvPattern(),
(input_x, input_h, input_c),
),
compile_config=
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
exir.EdgeCompileConfig(_check_ir_validity=False)
exir.EdgeCompileConfig(_check_ir_validity=False),
)
.exported_program.graph_module
.exported_program()
.graph_module
)
self.patterns = [
pattern_lstm_conv_lifted.graph,
pattern_lstm_conv_unlifted.graph,
]
# Only (lstm + conv) pattern is lowerable

Expand Down

0 comments on commit 8506bde

Please sign in to comment.