/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableFlinkBushyJoinReorderRule;
import org.immutables.value.Value;

@Value.Enclosing
public class FlinkBushyJoinReorderRule
extends RelRule<Config>
implements TransformationRule {
    protected FlinkBushyJoinReorderRule(Config config) {
        super(config);
    }

    @Deprecated
    public FlinkBushyJoinReorderRule(RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class));
    }

    @Deprecated
    public FlinkBushyJoinReorderRule(RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, RelFactories.FilterFactory filterFactory) {
        this(RelBuilder.proto(joinFactory, projectFactory, filterFactory));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        RelBuilder relBuilder = call.builder();
        MultiJoin multiJoinRel = (MultiJoin)call.rel(0);
        LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
        RelNode bestOrder = FlinkBushyJoinReorderRule.findBestOrder(relBuilder, multiJoin);
        call.transformTo(bestOrder);
    }

    private static RelNode findBestOrder(RelBuilder relBuilder, LoptMultiJoin multiJoin) {
        List<Map<Set<Integer>, JoinPlan>> foundPlansForInnerJoin = FlinkBushyJoinReorderRule.reorderInnerJoin(relBuilder, multiJoin);
        Map<Set<Integer>, JoinPlan> lastLevelOfInnerJoin = foundPlansForInnerJoin.get(foundPlansForInnerJoin.size() - 1);
        JoinPlan bestPlanForInnerJoin = FlinkBushyJoinReorderRule.getBestPlan(lastLevelOfInnerJoin);
        JoinPlan containOuterJoinPlan = multiJoin.getMultiJoinRel().isFullOuterJoin() || FlinkBushyJoinReorderRule.outerJoinConditionExists(multiJoin) ? FlinkBushyJoinReorderRule.addOuterJoinToTop(bestPlanForInnerJoin, multiJoin, relBuilder) : bestPlanForInnerJoin;
        JoinPlan finalPlan = containOuterJoinPlan.factorIds.size() != multiJoin.getNumJoinFactors() ? FlinkBushyJoinReorderRule.addCrossJoinToTop(containOuterJoinPlan, multiJoin, relBuilder) : containOuterJoinPlan;
        List<String> fieldNames = multiJoin.getMultiJoinRel().getRowType().getFieldNames();
        return FlinkBushyJoinReorderRule.createTopProject(relBuilder, multiJoin, finalPlan, fieldNames);
    }

    private static List<Map<Set<Integer>, JoinPlan>> reorderInnerJoin(RelBuilder relBuilder, LoptMultiJoin multiJoin) {
        Map<Set<Integer>, JoinPlan> nextLevelJoinPlanMap;
        int numJoinFactors = multiJoin.getNumJoinFactors();
        ArrayList<Map<Set<Integer>, JoinPlan>> foundPlans = new ArrayList<Map<Set<Integer>, JoinPlan>>();
        LinkedHashMap firstLevelJoinPlanMap = new LinkedHashMap();
        for (int i = 0; i < numJoinFactors; ++i) {
            if (multiJoin.isNullGenerating(i)) continue;
            HashSet<Integer> set1 = new HashSet<Integer>();
            LinkedHashSet<Integer> set2 = new LinkedHashSet<Integer>();
            set1.add(i);
            set2.add(i);
            RelNode joinFactor = multiJoin.getJoinFactor(i);
            firstLevelJoinPlanMap.put(set1, new JoinPlan(set2, joinFactor));
        }
        foundPlans.add(firstLevelJoinPlanMap);
        if (multiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return foundPlans;
        }
        while (foundPlans.size() < numJoinFactors && (nextLevelJoinPlanMap = FlinkBushyJoinReorderRule.foundNextLevel(relBuilder, new ArrayList<Map<Set<Integer>, JoinPlan>>(foundPlans), multiJoin)).size() != 0) {
            foundPlans.add(nextLevelJoinPlanMap);
        }
        return foundPlans;
    }

    private static boolean outerJoinConditionExists(LoptMultiJoin multiJoin) {
        for (int i = 0; i < multiJoin.getNumJoinFactors(); ++i) {
            if (multiJoin.getOuterJoinCond(i) == null || RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i)).size() == 0) continue;
            return true;
        }
        return false;
    }

    private static JoinPlan getBestPlan(Map<Set<Integer>, JoinPlan> levelPlan) {
        JoinPlan bestPlan = null;
        for (Map.Entry<Set<Integer>, JoinPlan> entry : levelPlan.entrySet()) {
            if (bestPlan != null && !entry.getValue().betterThan(bestPlan)) continue;
            bestPlan = entry.getValue();
        }
        return bestPlan;
    }

    private static JoinPlan addOuterJoinToTop(JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) {
        List<Integer> remainIndexes = FlinkBushyJoinReorderRule.getRemainIndexes(multiJoin.getNumJoinFactors(), bestPlan.factorIds);
        RelNode leftNode = bestPlan.relNode;
        LinkedHashSet<Integer> set = new LinkedHashSet<Integer>(bestPlan.factorIds);
        for (int index : remainIndexes) {
            RelNode rightNode = multiJoin.getJoinFactor(index);
            Optional<List<RexCall>> joinConditions = FlinkBushyJoinReorderRule.getJoinConditions(bestPlan.factorIds, Collections.singleton(index), multiJoin, true);
            if (!joinConditions.isPresent()) continue;
            List<RexCall> conditions = joinConditions.get();
            List<RexCall> newCondition = FlinkBushyJoinReorderRule.convertToNewCondition(new ArrayList<Integer>(set), Collections.singletonList(index), conditions, multiJoin);
            JoinRelType joinType = JoinRelType.LEFT;
            if (multiJoin.getMultiJoinRel().isFullOuterJoin()) {
                assert (remainIndexes.size() == 1);
                joinType = JoinRelType.FULL;
            }
            relBuilder.clear();
            leftNode = relBuilder.push(leftNode).push(rightNode).join(joinType, newCondition).build();
            set.add(index);
        }
        return new JoinPlan(set, leftNode);
    }

    private static JoinPlan addCrossJoinToTop(JoinPlan bestPlan, LoptMultiJoin multiJoin, RelBuilder relBuilder) {
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        List<Integer> remainIndexes = FlinkBushyJoinReorderRule.getRemainIndexes(multiJoin.getNumJoinFactors(), bestPlan.factorIds);
        RelNode leftNode = bestPlan.relNode;
        LinkedHashSet<Integer> set = new LinkedHashSet<Integer>(bestPlan.factorIds);
        for (int index : remainIndexes) {
            relBuilder.clear();
            set.add(index);
            RelNode rightNode = multiJoin.getJoinFactor(index);
            leftNode = relBuilder.push(leftNode).push(rightNode).join(multiJoin.getMultiJoinRel().getJoinTypes().get(index), rexBuilder.makeLiteral(true)).build();
        }
        return new JoinPlan(set, leftNode);
    }

    private static RelNode createTopProject(RelBuilder relBuilder, LoptMultiJoin multiJoin, JoinPlan finalPlan, List<String> fieldNames) {
        ArrayList<RexInputRef> newProjExprs = new ArrayList<RexInputRef>();
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        ArrayList<Integer> newJoinOrder = new ArrayList<Integer>(finalPlan.factorIds);
        int nJoinFactors = multiJoin.getNumJoinFactors();
        List<RelDataTypeField> fields = multiJoin.getMultiJoinFields();
        HashMap<Integer, Integer> factorToOffsetMap = new HashMap<Integer, Integer>();
        int fieldStart = 0;
        for (int pos = 0; pos < nJoinFactors; ++pos) {
            factorToOffsetMap.put((Integer)newJoinOrder.get(pos), fieldStart);
            fieldStart += multiJoin.getNumFieldsInJoinFactor((Integer)newJoinOrder.get(pos));
        }
        for (int currFactor = 0; currFactor < nJoinFactors; ++currFactor) {
            Integer leftFactor = null;
            if (multiJoin.isRightFactorInRemovableSelfJoin(currFactor)) {
                leftFactor = multiJoin.getOtherSelfJoinFactor(currFactor);
            }
            for (int fieldPos = 0; fieldPos < multiJoin.getNumFieldsInJoinFactor(currFactor); ++fieldPos) {
                Integer leftOffset;
                int newOffset = Objects.requireNonNull((Integer)factorToOffsetMap.get(currFactor), () -> "factorToOffsetMap.get(currFactor)") + fieldPos;
                if (leftFactor != null && (leftOffset = multiJoin.getRightColumnMapping(currFactor, fieldPos)) != null) {
                    newOffset = Objects.requireNonNull((Integer)factorToOffsetMap.get(leftFactor), "factorToOffsetMap.get(leftFactor)") + leftOffset;
                }
                newProjExprs.add(rexBuilder.makeInputRef(fields.get(newProjExprs.size()).getType(), newOffset));
            }
        }
        relBuilder.clear();
        relBuilder.push(finalPlan.relNode);
        relBuilder.project(newProjExprs, fieldNames);
        RexNode postJoinFilter = multiJoin.getMultiJoinRel().getPostJoinFilter();
        if (postJoinFilter != null) {
            relBuilder.filter(postJoinFilter);
        }
        return relBuilder.build();
    }

    private static Map<Set<Integer>, JoinPlan> foundNextLevel(RelBuilder relBuilder, List<Map<Set<Integer>, JoinPlan>> foundPlans, LoptMultiJoin multiJoin) {
        LinkedHashMap<Set<Integer>, JoinPlan> currentLevelJoinPlanMap = new LinkedHashMap<Set<Integer>, JoinPlan>();
        int foundPlansLevel = foundPlans.size() - 1;
        int joinLeftSideLevel = 0;
        for (int joinRightSideLevel = foundPlansLevel; joinLeftSideLevel <= joinRightSideLevel; ++joinLeftSideLevel, --joinRightSideLevel) {
            ArrayList<JoinPlan> joinLeftSidePlans = new ArrayList<JoinPlan>(foundPlans.get(joinLeftSideLevel).values());
            int planSize = joinLeftSidePlans.size();
            for (int i = 0; i < planSize; ++i) {
                ArrayList<JoinPlan> joinRightSidePlans;
                JoinPlan joinLeftSidePlan = (JoinPlan)joinLeftSidePlans.get(i);
                if (joinLeftSideLevel == joinRightSideLevel) {
                    joinRightSidePlans = new ArrayList<JoinPlan>(joinLeftSidePlans);
                    if (i > 0) {
                        joinRightSidePlans.subList(0, i).clear();
                    }
                } else {
                    joinRightSidePlans = new ArrayList<JoinPlan>(foundPlans.get(joinRightSideLevel).values());
                }
                for (JoinPlan joinRightSidePlan : joinRightSidePlans) {
                    JoinPlan existingPlanInCurrentLevel;
                    Optional<JoinPlan> newJoinPlan = FlinkBushyJoinReorderRule.buildInnerJoin(relBuilder, joinLeftSidePlan, joinRightSidePlan, multiJoin);
                    if (!newJoinPlan.isPresent() || (existingPlanInCurrentLevel = (JoinPlan)currentLevelJoinPlanMap.get(newJoinPlan.get().factorIds)) != null && !newJoinPlan.get().betterThan(existingPlanInCurrentLevel)) continue;
                    currentLevelJoinPlanMap.put(newJoinPlan.get().factorIds, newJoinPlan.get());
                }
            }
        }
        return currentLevelJoinPlanMap;
    }

    private static Optional<JoinPlan> buildInnerJoin(RelBuilder relBuilder, JoinPlan leftSidePlan, JoinPlan rightSidePlan, LoptMultiJoin multiJoin) {
        JoinPlan newRightSidePlan;
        JoinPlan newLeftSidePlan;
        HashSet<Integer> resSet = new HashSet<Integer>(leftSidePlan.factorIds);
        resSet.retainAll(rightSidePlan.factorIds);
        if (!resSet.isEmpty()) {
            return Optional.empty();
        }
        Optional<List<RexCall>> joinConditions = FlinkBushyJoinReorderRule.getJoinConditions(leftSidePlan.factorIds, rightSidePlan.factorIds, multiJoin, false);
        if (!joinConditions.isPresent()) {
            return Optional.empty();
        }
        List<RexCall> conditions = joinConditions.get();
        LinkedHashSet<Integer> newFactorIds = new LinkedHashSet<Integer>();
        if (leftSidePlan.factorIds.size() >= rightSidePlan.factorIds.size()) {
            newLeftSidePlan = leftSidePlan;
            newRightSidePlan = rightSidePlan;
        } else {
            newLeftSidePlan = rightSidePlan;
            newRightSidePlan = leftSidePlan;
        }
        newFactorIds.addAll(newLeftSidePlan.factorIds);
        newFactorIds.addAll(newRightSidePlan.factorIds);
        List<RexCall> newCondition = FlinkBushyJoinReorderRule.convertToNewCondition(new ArrayList<Integer>(newLeftSidePlan.factorIds), new ArrayList<Integer>(newRightSidePlan.factorIds), conditions, multiJoin);
        relBuilder.clear();
        Join newJoin = (Join)relBuilder.push(newLeftSidePlan.relNode).push(newRightSidePlan.relNode).join(JoinRelType.INNER, newCondition).build();
        return Optional.of(new JoinPlan(newFactorIds, newJoin));
    }

    private static List<RexCall> convertToNewCondition(List<Integer> leftFactorIds, List<Integer> rightFactorIds, List<RexCall> rexNodes, LoptMultiJoin multiJoin) {
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        ArrayList<RexCall> newCondition = new ArrayList<RexCall>();
        for (RexCall rexCond : rexNodes) {
            ArrayList<RexNode> resultRexNode = new ArrayList<RexNode>();
            for (RexNode rexNode : rexCond.getOperands()) {
                rexNode = rexNode.accept(new JoinConditionShuttle(multiJoin, leftFactorIds, rightFactorIds));
                resultRexNode.add(rexNode);
            }
            RexNode resultRex = rexBuilder.makeCall(rexCond.op, resultRexNode);
            newCondition.add((RexCall)resultRex);
        }
        return newCondition;
    }

    private static Optional<List<RexCall>> getJoinConditions(Set<Integer> leftSideFactorIds, Set<Integer> rightSideFactorIds, LoptMultiJoin multiJoin, boolean isOuterJoin) {
        ArrayList<RexCall> resultRexCall = new ArrayList<RexCall>();
        List<Object> joinConditions = new ArrayList();
        if (isOuterJoin && !multiJoin.getMultiJoinRel().isFullOuterJoin()) {
            for (int i = 0; i < multiJoin.getNumJoinFactors(); ++i) {
                joinConditions.addAll(RelOptUtil.conjunctions(multiJoin.getOuterJoinCond(i)));
            }
        } else {
            joinConditions = multiJoin.getJoinFilters();
        }
        for (RexNode rexNode : joinConditions) {
            if (rexNode instanceof RexCall) {
                RexCall callCondition = (RexCall)rexNode;
                ImmutableBitSet factorsRefByJoinFilter = multiJoin.getFactorsRefByJoinFilter(callCondition);
                int leftSideFactorNumbers = 0;
                int rightSideFactorNumbers = 0;
                for (int leftSideFactorId : leftSideFactorIds) {
                    if (!factorsRefByJoinFilter.get(leftSideFactorId)) continue;
                    ++leftSideFactorNumbers;
                }
                for (int rightSideFactorId : rightSideFactorIds) {
                    if (!factorsRefByJoinFilter.get(rightSideFactorId)) continue;
                    ++rightSideFactorNumbers;
                }
                if (leftSideFactorNumbers <= 0 || rightSideFactorNumbers <= 0 || leftSideFactorNumbers + rightSideFactorNumbers != factorsRefByJoinFilter.asSet().size()) continue;
                resultRexCall.add(callCondition);
                continue;
            }
            return Optional.empty();
        }
        if (resultRexCall.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of(resultRexCall);
    }

    private static List<Integer> getRemainIndexes(int totalNumOfJoinFactors, Set<Integer> factorIds) {
        ArrayList<Integer> remainIndexes = new ArrayList<Integer>();
        for (int i = 0; i < totalNumOfJoinFactors; ++i) {
            if (factorIds.contains(i)) continue;
            remainIndexes.add(i);
        }
        return remainIndexes;
    }

    @Value.Immutable(singleton=false)
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableFlinkBushyJoinReorderRule.Config.builder().build().withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs()).as(Config.class);

        @Override
        default public FlinkBushyJoinReorderRule toRule() {
            return new FlinkBushyJoinReorderRule(this);
        }
    }

    private static class JoinConditionShuttle
    extends RexShuttle {
        private final LoptMultiJoin multiJoin;
        private final List<Integer> leftFactorIds;
        private final List<Integer> rightFactorIds;

        public JoinConditionShuttle(LoptMultiJoin multiJoin, List<Integer> leftFactorIds, List<Integer> rightFactorIds) {
            this.multiJoin = multiJoin;
            this.leftFactorIds = leftFactorIds;
            this.rightFactorIds = rightFactorIds;
        }

        @Override
        public RexNode visitInputRef(RexInputRef var) {
            int index = var.getIndex();
            int currentIndex = 0;
            int factorRef = this.multiJoin.findRef(index);
            if (this.leftFactorIds.contains(factorRef)) {
                for (Integer leftFactorId : this.leftFactorIds) {
                    if (leftFactorId == factorRef) {
                        return new RexInputRef(currentIndex += JoinConditionShuttle.findFactorIndex(index, this.multiJoin), var.getType());
                    }
                    currentIndex += this.multiJoin.getNumFieldsInJoinFactor(leftFactorId);
                }
            } else {
                for (int leftFactor : this.leftFactorIds) {
                    currentIndex += this.multiJoin.getNumFieldsInJoinFactor(leftFactor);
                }
                for (Integer rightFactorId : this.rightFactorIds) {
                    if (rightFactorId == factorRef) {
                        return new RexInputRef(currentIndex += JoinConditionShuttle.findFactorIndex(index, this.multiJoin), var.getType());
                    }
                    currentIndex += this.multiJoin.getNumFieldsInJoinFactor(rightFactorId);
                }
            }
            return var;
        }

        private static int findFactorIndex(int index, LoptMultiJoin multiJoin) {
            int factorId = multiJoin.findRef(index);
            int resultIndex = 0;
            for (int i = 0; i < factorId; ++i) {
                resultIndex += multiJoin.getNumFieldsInJoinFactor(i);
            }
            return index - resultIndex;
        }
    }

    private static class JoinPlan {
        final Set<Integer> factorIds;
        final RelNode relNode;

        JoinPlan(Set<Integer> factorIds, RelNode relNode) {
            this.factorIds = factorIds;
            this.relNode = relNode;
        }

        private boolean betterThan(JoinPlan otherPlan) {
            RelMetadataQuery mq = this.relNode.getCluster().getMetadataQuery();
            RelOptCost thisCost = mq.getCumulativeCost(this.relNode);
            RelOptCost otherCost = mq.getCumulativeCost(otherPlan.relNode);
            if (thisCost == null || otherCost == null) {
                return false;
            }
            return thisCost.isLt(otherCost);
        }
    }
}

