Skip to content

Commit

Permalink
[PIR] Mark ShareData as an inplace OP and fix inplace pass (#64195)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed May 15, 2024
1 parent 8a612f3 commit 4bedcd0
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 28 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2976,12 +2976,12 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber {
struct ShareBufferOpTranscriber : public OpTranscriber {
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name = dialect::ShareDataOp::name();
std::string target_op_name = dialect::ShareData_Op::name();
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Op share_buffer should have corresponding OpInfo "
"pd_op.share_data"));
"pd_op.share_data_"));
}

return op_info;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
'print',
'number_count',
'assign_value',
'share_data',
'share_data_',
'onednn_to_paddle_layout',
'lrn',
'multi_gru',
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@
func: shadow_feed_tensors
param: [x]

- op : share_data
- op : share_data_
args : (Tensor x)
output : Tensor(out)
infer_meta:
Expand All @@ -1636,6 +1636,7 @@
kernel:
func: share_data
param: [x]
inplace : (x -> out)

- op : shuffle_batch
args : (Tensor x, Tensor seed, int startup_seed=0)
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const std::unordered_set<std::string> LegacyOpList = {
CSplitOp::name(),
PushDenseOp::name(),
SeedOp::name(),
ShareDataOp::name(),
ShareData_Op::name(),
SparseMomentumOp::name(),
GetTensorFromSelectedRowsOp::name(),
RankAttentionOp::name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ class AutoMixedPrecisionPass : public pir::Pass {
return;
}

// Rewrite ShareDataOp
if (op->isa<paddle::dialect::ShareDataOp>() && OpRunLowPrecision(op)) {
// Rewrite ShareData_Op
if (op->isa<paddle::dialect::ShareData_Op>() && OpRunLowPrecision(op)) {
SetResultDataType(op->result(0), precision_mode_, builder.ir_context());
return;
}
Expand Down
54 changes: 40 additions & 14 deletions paddle/fluid/pir/transforms/general/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ bool CanBeDeleted(pir::Value value) {
return !(persist_attr && persist_attr.data());
}

bool HasNotUser(const pir::Value& value,
const std::unordered_map<pir::Value, size_t>& use_count_map,
const std::unordered_map<pir::Value, pir::Value>& inplace_map) {
auto current_value = value;
while (use_count_map.at(current_value) == 0) {
if (inplace_map.count(current_value) == 0) {
return false;
}
current_value = inplace_map.at(current_value);
}
return true;
}

bool CanDoInplace(const std::unordered_set<pir::Value>& eager_dels,
pir::Value input,
pir::Value output,
Expand Down Expand Up @@ -295,15 +308,25 @@ GetEagerDeletionValues(const pir::Block& block) {
std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
const pir::Block& block) {
const auto eager_dels = GetEagerDeletionValues(block);
auto use_count_map = [](const pir::Block& block) {
std::unordered_map<pir::Value, size_t> use_count_map;
for (auto& op : block) {
for (auto value : op.results()) {
use_count_map[value] = value.use_count();
}
}
return use_count_map;
}(block);
std::unordered_map<pir::Value, pir::Value> inplace_map;

std::unordered_map<pir::Operation*, std::string> inplace_ops;

std::unordered_set<pir::Value> visited_values;
std::unordered_set<pir::Value> reused_input_values;
std::unordered_set<pir::Value> reused_output_values;

for (auto& op : block) {
for (size_t i = 0; i < op.num_operands(); ++i) {
visited_values.insert(op.operand_source(i));
use_count_map[op.operand_source(i)]--;
}

if (op.dialect()->name().compare(paddle::dialect::KernelDialect::name()) !=
Expand Down Expand Up @@ -339,11 +362,18 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
if (upper_op_attrs.count("is_inplace") != 0 &&
upper_op_attrs.at("is_inplace").dyn_cast<pir::BoolAttribute>().data()) {
VLOG(6) << upper_op_name << " is already an inplace op.";
for (size_t i = 0; i < op.num_operands(); ++i) {
reused_input_values.insert(op.operand_source(i));
auto op_info =
pir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name);
auto op_yaml_interface =
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
paddle::dialect::OpYamlInfoParser op_info_parser(
op_yaml_interface->get_op_info_(upper_op_name));
for (auto [out_slot, in_slot] : op_info_parser.GetInplaceIdMap()) {
auto out_value = op.result(out_slot);
auto in_value = op.operand_source(in_slot);
inplace_map[out_value] = in_value;
}
for (auto& result : op.results()) {
reused_output_values.insert(result);
visited_values.insert(result);
}
continue;
Expand Down Expand Up @@ -409,8 +439,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
upper_op_name)) ||
(visited_values.count(op.result(out_slot)) > 0) ||
(!CanBeDeleted(op.result(out_slot))) ||
(reused_input_values.count(op.operand_source(in_slot)) > 0) ||
(reused_output_values.count(op.result(out_slot)) > 0) ||
HasNotUser(op.operand_source(in_slot), use_count_map, inplace_map) ||
(std::find(used_external_values.begin(),
used_external_values.end(),
op.operand_source(in_slot)) !=
Expand All @@ -435,19 +464,16 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
<< " -- result " << out_slot
<< " visited: " << (visited_values.count(op.result(out_slot)) > 0);
VLOG_IF(8, in_slot < op.num_operands())
<< " -- operand " << in_slot << " has been reused: "
<< (reused_input_values.count(op.operand_source(in_slot)) > 0);
VLOG_IF(8, out_slot < op.num_results())
<< " -- result " << out_slot << " has been reused: "
<< (reused_output_values.count(op.result(out_slot)) > 0);
<< " -- operand " << in_slot << " has not user: "
<< HasNotUser(
op.operand_source(in_slot), use_count_map, inplace_map);
break;
}
}
if (can_do_inplace) {
inplace_ops[&op] = upper_op_name + "_";
for (auto& kv : inplace_out_2_in) {
reused_input_values.insert(op.operand_source(kv.second));
reused_output_values.insert(op.result(kv.first));
inplace_map[op.result(kv.first)] = op.operand_source(kv.second);
}
VLOG(6) << upper_op_name
<< " will change to inplace version op: " << upper_op_name + "_";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ void BindValue(py::module *m) {
auto share_data_op =
ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::ShareDataOp>(self);
->Build<paddle::dialect::ShareData_Op>(self);
auto out = share_data_op.out();
out.set_attribute(
kAttrStopGradients,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4179,7 +4179,7 @@
data_type : int64_t
tensors_name : StepsTensorList

- op: share_data
- op: share_data_ (share_data)
inputs :
x : X
outputs :
Expand Down
4 changes: 2 additions & 2 deletions test/deprecated/ir/pir/test_special_op_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ def test_program(self):
)
l = pir.translate_to_pir(main_program.desc)
assert (
l.global_block().ops[2].name() == "pd_op.share_data"
), "share_buffer should be translated to share_data"
l.global_block().ops[2].name() == "pd_op.share_data_"
), "share_buffer should be translated to share_data_"


class TestDataOp(unittest.TestCase):
Expand Down
51 changes: 51 additions & 0 deletions test/dygraph_to_static/test_tensor_detach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle


def detach_fn(x, y):
u = x + y
v = u.detach()
o1 = v + 1

return o1, u


class TestDetach(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_detach(self):
static_fn = paddle.jit.to_static(detach_fn)
x = paddle.ones([], 'float32')
y = paddle.ones([], 'float32')
static_res = static_fn(x, y)
dygraph_res = detach_fn(x, y)
np.testing.assert_allclose(
static_res[0].numpy(), dygraph_res[0].numpy()
)
np.testing.assert_allclose(
static_res[1].numpy(), dygraph_res[1].numpy()
)


if __name__ == '__main__':
unittest.main()
35 changes: 32 additions & 3 deletions test/ir/pir/test_pd_inplace_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,38 @@ def test_pd_inplace_pass(self):
exe = paddle.static.Executor()
x_feed = np.ones([2, 2], dtype=np.float32) * 10
(sum_value,) = exe.run(feed={'x': x_feed}, fetch_list=[out])
self.assertEqual(
(sum_value == np.ones([2, 2], dtype="float32") * 10).all(),
True,
np.testing.assert_allclose(
sum_value, np.ones([2, 2], dtype="float32") * 10
)


class TestInputsHasBeenModified(unittest.TestCase):
def test_inputs_has_been_modified(self):
place = paddle.framework.core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.static.data('x', [2, 2], dtype='float32')
y = paddle.static.data('y', [2, 2], dtype='float32')
z = paddle.add(x, y)

detached_z = z.detach()
out = detached_z + 1

exe = paddle.static.Executor()
x_feed = np.ones([2, 2], dtype=np.float32) * 1
y_feed = np.ones([2, 2], dtype=np.float32) * 2
(z_data, out_data) = exe.run(
feed={"x": x_feed, "y": y_feed},
fetch_list=[z, out],
)
np.testing.assert_allclose(
z_data, np.ones([2, 2], dtype="float32") * 3
)
np.testing.assert_allclose(
out_data, np.ones([2, 2], dtype="float32") * 4
)


Expand Down

0 comments on commit 4bedcd0

Please sign in to comment.