Skip to content

Commit

Permalink
remove exir.capture from test_pass_infra (#3505)
Browse files Browse the repository at this point in the history
Summary:

title

Differential Revision: D56942317
  • Loading branch information
JacobSzwejbka authored and facebook-github-bot committed May 3, 2024
1 parent b9488fe commit 5d47769
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ python_unittest(
],
deps = [
":models",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/passes:lib",
],
Expand Down
51 changes: 27 additions & 24 deletions exir/tests/test_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

import unittest

import executorch.exir as exir

import torch
from executorch.exir import to_edge
from executorch.exir.pass_manager import PassManager
from executorch.exir.passes import ScalarToTensorPass
from executorch.exir.passes.pass_registry import PassRegistry
from torch.export import export
from torch.fx.passes.infra.pass_base import PassBase


Expand Down Expand Up @@ -99,16 +99,16 @@ def replace_mul_with_div(gm: torch.fx.GraphModule) -> None:
if node.op == "call_function" and node.target == torch.mul:
node.target = torch.div

def f(x: torch.Tensor) -> torch.Tensor:
y = torch.add(x, x)
z = torch.add(y, x)
return z
class F(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
y = torch.add(x, x)
z = torch.add(y, x)
return z

f = (
exir.capture(f, (torch.randn(10),), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
)
f = to_edge(export(F(), (torch.randn(10),))).exported_program().graph_module
pm = PassManager(passes=[replace_add_with_mul, replace_mul_with_div])
self.assertEqual(len(pm.passes), 2)
pm(f)
Expand Down Expand Up @@ -144,15 +144,17 @@ def introduce_call_module(gm: torch.fx.GraphModule) -> None:
new_node = gm.graph.call_module("foo", (torch.randn(2),))
node.replace_all_uses_with(new_node)

def f(x: torch.Tensor) -> torch.Tensor:
y = torch.add(x, x)
z = torch.add(y, x)
return z
class F(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
y = torch.add(x, x)
z = torch.add(y, x)
return z

traced_f1 = (
exir.capture(f, (torch.randn(10),), exir.CaptureConfig())
.to_edge()
.exported_program.graph_module
to_edge(export(F(), (torch.randn(10),))).exported_program().graph_module
)
pm1 = PassManager(
passes=[introduce_call_method], run_checks_after_each_pass=True
Expand All @@ -162,13 +164,14 @@ def f(x: torch.Tensor) -> torch.Tensor:
pm1(traced_f1)

def test_pass_metadata(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
class F(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x + y

sample_inputs = (torch.randn(1, 3), torch.randn(1, 3))
gm = exir.capture(
f, sample_inputs, exir.CaptureConfig()
).exported_program.graph_module
gm = export(F(), (torch.ones(10), torch.ones(10))).graph_module

pass_result = ScalarToTensorPass()(gm)
self.assertIsNotNone(pass_result)
Expand Down

0 comments on commit 5d47769

Please sign in to comment.