Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48155][SQL] AQEPropagateEmptyRelation for join should check if remain child is just BroadcastQueryStageExec #46523

Closed
wants to merge 7 commits into from

Conversation

AngersZhuuuu
Copy link
Contributor

@AngersZhuuuu AngersZhuuuu commented May 10, 2024

What changes were proposed in this pull request?

It's a new approach to fix SPARK-39551
This situation happened for AQEPropagateEmptyRelation when one side is empty and one side is BroadcastQueryStateExec
This pr avoid do propagate, not to revert all queryStagePreparationRules's result.

Why are the changes needed?

Fix bug

Does this PR introduce any user-facing change?

No

How was this patch tested?

Manuel tested SPARK-39551: Invalid plan check - invalid broadcast query stage, it can work well without origin fix and current pr

For added UT,

  test("SPARK-48155: AQEPropagateEmptyRelation check remained child for join") {
    withSQLConf(
      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
      val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
        """
          |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
          |INNER JOIN (
          |  SELECT * FROM testData2
          |  WHERE b = 0
          |  UNION ALL
          |  SELECT * FROM testData2
          |  WHErE b != 0
          |) t2
          |ON t1.b = t2.b AND t1.a = 0
          |RIGHT OUTER JOIN testData2 t3
          |ON t1.a > t3.a
          |GROUP BY t3.b
        """.stripMargin
      )
      assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
      assert(findTopLevelUnion(adaptivePlan).size == 0)
    }
  }

before this pr the adaptive plan is

*(9) HashAggregate(keys=[b#226], functions=[count(1)], output=[b#226, count(a)#228L])
+- AQEShuffleRead coalesced
   +- ShuffleQueryStage 3
      +- Exchange hashpartitioning(b#226, 5), ENSURE_REQUIREMENTS, [plan_id=356]
         +- *(8) HashAggregate(keys=[b#226], functions=[partial_count(1)], output=[b#226, count#232L])
            +- *(8) Project [b#226]
               +- BroadcastNestedLoopJoin BuildRight, RightOuter, (a#23 > a#225)
                  :- *(7) Project [a#23]
                  :  +- *(7) SortMergeJoin [b#24], [b#220], Inner
                  :     :- *(5) Sort [b#24 ASC NULLS FIRST], false, 0
                  :     :  +- AQEShuffleRead coalesced
                  :     :     +- ShuffleQueryStage 0
                  :     :        +- Exchange hashpartitioning(b#24, 5), ENSURE_REQUIREMENTS, [plan_id=211]
                  :     :           +- *(1) Filter (a#23 = 0)
                  :     :              +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24]
                  :     :                 +- Scan[obj#22]
                  :     +- *(6) Sort [b#220 ASC NULLS FIRST], false, 0
                  :        +- AQEShuffleRead coalesced
                  :           +- ShuffleQueryStage 1
                  :              +- Exchange hashpartitioning(b#220, 5), ENSURE_REQUIREMENTS, [plan_id=233]
                  :                 +- Union
                  :                    :- *(2) Project [b#220]
                  :                    :  +- *(2) Filter (b#220 = 0)
                  :                    :     +- *(2) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#219, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#220]
                  :                    :        +- Scan[obj#218]
                  :                    +- *(3) Project [b#223]
                  :                       +- *(3) Filter NOT (b#223 = 0)
                  :                          +- *(3) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#222, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#223]
                  :                             +- Scan[obj#221]
                  +- BroadcastQueryStage 2
                     +- BroadcastExchange IdentityBroadcastMode, [plan_id=260]
                        +- *(4) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#225, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#226]
                           +- Scan[obj#224]

After this patch

*(6) HashAggregate(keys=[b#226], functions=[count(1)], output=[b#226, count(a)#228L])
+- AQEShuffleRead coalesced
   +- ShuffleQueryStage 3
      +- Exchange hashpartitioning(b#226, 5), ENSURE_REQUIREMENTS, [plan_id=319]
         +- *(5) HashAggregate(keys=[b#226], functions=[partial_count(1)], output=[b#226, count#232L])
            +- *(5) Project [b#226]
               +- BroadcastNestedLoopJoin BuildRight, RightOuter, (a#23 > a#225)
                  :- LocalTableScan <empty>, [a#23]
                  +- BroadcastQueryStage 2
                     +- BroadcastExchange IdentityBroadcastMode, [plan_id=260]
                        +- *(4) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#225, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#226]
                           +- Scan[obj#224]
[info] - xxxx (3 seconds, 136 milliseconds)

Was this patch authored or co-authored using generative AI tooling?

No

… remain child is just BroadcastQueryStageExec
@github-actions github-actions bot added the SQL label May 10, 2024
@AngersZhuuuu
Copy link
Contributor Author

ping @cloud-fan @maryannxue Pls take a look

Copy link
Member

@dongjoon-hyun dongjoon-hyun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think you can provide test cases for this, @AngersZhuuuu ?

@AngersZhuuuu
Copy link
Contributor Author

Do you think you can provide test cases for this, @AngersZhuuuu ?

SPARK-39551: Invalid plan check - invalid broadcast query stage Can cover this, I don't know if we need to remove ValidateSparkPlan rule, it's too weird and rough.

@AngersZhuuuu
Copy link
Contributor Author

Do you think you can provide test cases for this, @AngersZhuuuu ?

Added a new UT to show the difference, pls take a look again @dongjoon-hyun

// Project
// +- LogicalQueryStage(_, BroadcastQueryStage)
// Then after LogicalQueryStageStrategy, will only remain BroadcastQueryStage after project,
// the plan can't execute.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simply say

// A broadcast query stage can't be executed without the join operator.
// TODO: we can return the original query plan before broadcast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how hard is it to return the original query plan? Seems not hard as we just need to add a new def returnSingleJoinSide function in the base class, and unwrap broadcast stage in the AQE rule.

@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in e5ad5e9 May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants