/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
import org.apache.flink.table.planner.calcite.FlinkRelFactories;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableCorrelateSortToRankRule;
import org.apache.flink.table.runtime.operators.rank.ConstantRankRange;
import org.apache.flink.table.runtime.operators.rank.RankRange;
import org.apache.flink.table.runtime.operators.rank.RankType;
import org.immutables.value.Value;

@Value.Enclosing
public class CorrelateSortToRankRule
extends RelRule<CorrelateSortToRankRuleConfig> {
    public static final CorrelateSortToRankRule INSTANCE = CorrelateSortToRankRuleConfig.DEFAULT.toRule();

    protected CorrelateSortToRankRule(CorrelateSortToRankRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Correlate correlate = (Correlate)call.rel(0);
        if (correlate.getJoinType() != JoinRelType.INNER) {
            return false;
        }
        Aggregate agg = (Aggregate)call.rel(1);
        if (!agg.getAggCallList().isEmpty() || agg.getGroupSets().size() > 1) {
            return false;
        }
        Project aggInput = (Project)call.rel(2);
        if (!aggInput.isMapping()) {
            return false;
        }
        Sort sort = (Sort)call.rel(3);
        if (sort.offset != null || sort.fetch == null) {
            return false;
        }
        Project sortInput = (Project)call.rel(4);
        if (!sortInput.isMapping()) {
            return false;
        }
        Filter filter = (Filter)call.rel(5);
        List<RexNode> cnfCond = RelOptUtil.conjunctions(filter.getCondition());
        if (cnfCond.stream().anyMatch(c -> !this.isValidCondition((RexNode)c, correlate))) {
            return false;
        }
        return aggInput.getInput().getDigest().equals(filter.getInput().getDigest());
    }

    private boolean isValidCondition(RexNode condition, Correlate correlate) {
        if (condition.getKind() != SqlKind.EQUALS) {
            return false;
        }
        Tuple2<RexInputRef, RexFieldAccess> tuple = this.resolveFilterCondition(condition);
        if (tuple.f0 == null) {
            return false;
        }
        RexCorrelVariable variable = (RexCorrelVariable)((RexFieldAccess)tuple.f1).getReferenceExpr();
        return variable.id.equals(correlate.getCorrelationId());
    }

    private Tuple2<RexInputRef, RexFieldAccess> resolveFilterCondition(RexNode condition) {
        RexCall condCall = (RexCall)condition;
        RexNode operand0 = condCall.getOperands().get(0);
        RexNode operand1 = condCall.getOperands().get(1);
        if (operand0.isA(SqlKind.INPUT_REF) && operand1.isA(SqlKind.FIELD_ACCESS)) {
            return Tuple2.of((Object)((RexInputRef)operand0), (Object)((RexFieldAccess)operand1));
        }
        if (operand0.isA(SqlKind.FIELD_ACCESS) && operand1.isA(SqlKind.INPUT_REF)) {
            return Tuple2.of((Object)((RexInputRef)operand1), (Object)((RexFieldAccess)operand0));
        }
        return Tuple2.of(null, null);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        RelBuilder builder = call.builder();
        Sort sort = (Sort)call.rel(3);
        Project sortInput = (Project)call.rel(4);
        Filter filter = (Filter)call.rel(5);
        List<RexNode> cnfCond = RelOptUtil.conjunctions(filter.getCondition());
        ImmutableBitSet partitionKey = ImmutableBitSet.of(cnfCond.stream().map(c -> ((RexInputRef)this.resolveFilterCondition((RexNode)c).f0).getIndex()).collect(Collectors.toList()));
        RelDataType baseType = sortInput.getInput().getRowType();
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        partitionKey.asList().forEach(k -> projects.add(RexInputRef.of((int)k, baseType)));
        projects.addAll(sortInput.getProjects());
        RelCollation oriCollation = sort.getCollation();
        List<RelFieldCollation> newFieldCollations = oriCollation.getFieldCollations().stream().map(fc -> {
            int newFieldIdx = ((RexInputRef)sortInput.getProjects().get(fc.getFieldIndex())).getIndex();
            return fc.withFieldIndex(newFieldIdx);
        }).collect(Collectors.toList());
        RelCollation newCollation = RelCollations.of(newFieldCollations);
        RelNode newRel = ((FlinkRelBuilder)builder.push(filter.getInput())).rank(partitionKey, newCollation, RankType.ROW_NUMBER, (RankRange)new ConstantRankRange(1L, ((RexLiteral)sort.fetch).getValueAs(Long.class).longValue()), null, false).project(projects).build();
        call.transformTo(newRel);
    }

    @Value.Immutable(singleton=false)
    public static interface CorrelateSortToRankRuleConfig
    extends RelRule.Config {
        public static final CorrelateSortToRankRuleConfig DEFAULT = ImmutableCorrelateSortToRankRule.CorrelateSortToRankRuleConfig.builder().operandSupplier(b0 -> b0.operand(Correlate.class).inputs(b1 -> b1.operand(Aggregate.class).oneInput(b2 -> b2.operand(Project.class).anyInputs()), b2 -> b2.operand(Sort.class).inputs(b3 -> b3.operand(Project.class).inputs(b4 -> b4.operand(Filter.class).anyInputs())))).relBuilderFactory(FlinkRelFactories.FLINK_REL_BUILDER()).description("CorrelateSortToRankRule").build();

        @Override
        default public CorrelateSortToRankRule toRule() {
            return new CorrelateSortToRankRule(this);
        }
    }
}

