/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hive.druid.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.hive.druid.com.google.common.collect.ImmutableList;
import org.apache.hive.druid.com.google.common.collect.Lists;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptUtil;
import org.apache.hive.druid.org.apache.calcite.rel.RelNode;
import org.apache.hive.druid.org.apache.calcite.rel.core.JoinRelType;
import org.apache.hive.druid.org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.hive.druid.org.apache.calcite.rel.rules.LoptJoinTree;
import org.apache.hive.druid.org.apache.calcite.rel.rules.MultiJoin;
import org.apache.hive.druid.org.apache.calcite.rel.type.RelDataType;
import org.apache.hive.druid.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.hive.druid.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.hive.druid.org.apache.calcite.rex.RexCall;
import org.apache.hive.druid.org.apache.calcite.rex.RexNode;
import org.apache.hive.druid.org.apache.calcite.sql.SqlKind;
import org.apache.hive.druid.org.apache.calcite.util.BitSets;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableBitSet;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableIntList;
import org.checkerframework.checker.initialization.qual.UnderInitialization;
import org.checkerframework.checker.initialization.qual.UnknownInitialization;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.checker.nullness.qual.RequiresNonNull;

public class LoptMultiJoin {
    MultiJoin multiJoin;
    private List<RexNode> joinFilters;
    private List<RexNode> allJoinFilters;
    private final int nJoinFactors;
    private int nTotalFields;
    private final ImmutableList<RelNode> joinFactors;
    private final ImmutableList<JoinRelType> joinTypes;
    private final ImmutableBitSet[] outerJoinFactors;
    private List<@Nullable ImmutableBitSet> projFields;
    private Map<Integer, int[]> joinFieldRefCountsMap;
    private final Map<RexNode, ImmutableBitSet> factorsRefByJoinFilter = new HashMap<RexNode, ImmutableBitSet>();
    private final Map<RexNode, ImmutableBitSet> fieldsRefByJoinFilter = new HashMap<RexNode, ImmutableBitSet>();
    int[] joinStart;
    int[] nFieldsInJoinFactor;
    ImmutableBitSet @MonotonicNonNull [] factorsRefByFactor;
    int @MonotonicNonNull [][] factorWeights;
    final RelDataTypeFactory factory;
    @Nullable Integer[] joinRemovalFactors;
    LogicalJoin[] joinRemovalSemiJoins;
    Set<Integer> removableOuterJoinFactors;
    Map<Integer, RemovableSelfJoin> removableSelfJoinPairs;

    public LoptMultiJoin(MultiJoin multiJoin) {
        this.multiJoin = multiJoin;
        this.joinFactors = ImmutableList.copyOf(multiJoin.getInputs());
        this.nJoinFactors = this.joinFactors.size();
        this.projFields = multiJoin.getProjFields();
        this.joinFieldRefCountsMap = multiJoin.getCopyJoinFieldRefCountsMap();
        this.joinFilters = Lists.newArrayList(RelOptUtil.conjunctions(multiJoin.getJoinFilter()));
        this.allJoinFilters = new ArrayList<RexNode>(this.joinFilters);
        List<@Nullable RexNode> outerJoinFilters = multiJoin.getOuterJoinConditions();
        for (int i = 0; i < this.nJoinFactors; ++i) {
            this.allJoinFilters.addAll(RelOptUtil.conjunctions(outerJoinFilters.get(i)));
        }
        int start = 0;
        this.nTotalFields = multiJoin.getRowType().getFieldCount();
        this.joinStart = new int[this.nJoinFactors];
        this.nFieldsInJoinFactor = new int[this.nJoinFactors];
        for (int i = 0; i < this.nJoinFactors; ++i) {
            this.joinStart[i] = start;
            this.nFieldsInJoinFactor[i] = ((RelNode)this.joinFactors.get(i)).getRowType().getFieldCount();
            start += this.nFieldsInJoinFactor[i];
        }
        this.joinTypes = ImmutableList.copyOf(multiJoin.getJoinTypes());
        List<@Nullable RexNode> outerJoinConds = this.multiJoin.getOuterJoinConditions();
        this.outerJoinFactors = new ImmutableBitSet[this.nJoinFactors];
        for (int i = 0; i < this.nJoinFactors; ++i) {
            RexNode outerJoinCond = outerJoinConds.get(i);
            if (outerJoinCond == null) continue;
            ImmutableBitSet dependentFactors = this.getJoinFilterFactorBitmap(outerJoinCond, false);
            this.outerJoinFactors[i] = dependentFactors = dependentFactors.clear(i);
        }
        this.setJoinFilterRefs();
        this.factory = multiJoin.getCluster().getTypeFactory();
        this.joinRemovalFactors = new Integer[this.nJoinFactors];
        this.joinRemovalSemiJoins = new LogicalJoin[this.nJoinFactors];
        this.removableOuterJoinFactors = new HashSet<Integer>();
        this.removableSelfJoinPairs = new HashMap<Integer, RemovableSelfJoin>();
    }

    public MultiJoin getMultiJoinRel() {
        return this.multiJoin;
    }

    public int getNumJoinFactors() {
        return this.nJoinFactors;
    }

    public RelNode getJoinFactor(int factIdx) {
        return (RelNode)this.joinFactors.get(factIdx);
    }

    public int getNumTotalFields() {
        return this.nTotalFields;
    }

    public int getNumFieldsInJoinFactor(int factIdx) {
        return this.nFieldsInJoinFactor[factIdx];
    }

    public List<RexNode> getJoinFilters() {
        return this.joinFilters;
    }

    public ImmutableBitSet getFactorsRefByJoinFilter(RexNode joinFilter) {
        return Objects.requireNonNull(this.factorsRefByJoinFilter.get(joinFilter), () -> "joinFilter is not found in factorsRefByJoinFilter: " + joinFilter);
    }

    public List<RelDataTypeField> getMultiJoinFields() {
        return this.multiJoin.getRowType().getFieldList();
    }

    public ImmutableBitSet getFieldsRefByJoinFilter(RexNode joinFilter) {
        return Objects.requireNonNull(this.fieldsRefByJoinFilter.get(joinFilter), () -> "joinFilter is not found in fieldsRefByJoinFilter: " + joinFilter);
    }

    public int @Nullable [][] getFactorWeights() {
        return this.factorWeights;
    }

    public ImmutableBitSet getFactorsRefByFactor(int factIdx) {
        return Objects.requireNonNull(this.factorsRefByFactor, "factorsRefByFactor")[factIdx];
    }

    public int getJoinStart(int factIdx) {
        return this.joinStart[factIdx];
    }

    public boolean isNullGenerating(int factIdx) {
        return ((JoinRelType)((Object)this.joinTypes.get(factIdx))).isOuterJoin();
    }

    public ImmutableBitSet getOuterJoinFactors(int factIdx) {
        return this.outerJoinFactors[factIdx];
    }

    public @Nullable RexNode getOuterJoinCond(int factIdx) {
        return this.multiJoin.getOuterJoinConditions().get(factIdx);
    }

    public @Nullable ImmutableBitSet getProjFields(int factIdx) {
        return this.projFields.get(factIdx);
    }

    public int[] getJoinFieldRefCounts(int factIdx) {
        return Objects.requireNonNull(this.joinFieldRefCountsMap.get(factIdx), () -> "no entry in joinFieldRefCountsMap found for " + factIdx);
    }

    public @Nullable Integer getJoinRemovalFactor(int dimIdx) {
        return this.joinRemovalFactors[dimIdx];
    }

    public LogicalJoin getJoinRemovalSemiJoin(int dimIdx) {
        return this.joinRemovalSemiJoins[dimIdx];
    }

    public void setJoinRemovalFactor(int dimIdx, int factIdx) {
        this.joinRemovalFactors[dimIdx] = factIdx;
    }

    public void setJoinRemovalSemiJoin(int dimIdx, LogicalJoin semiJoin) {
        this.joinRemovalSemiJoins[dimIdx] = semiJoin;
    }

    @RequiresNonNull(value={"joinStart", "nFieldsInJoinFactor"})
    ImmutableBitSet getJoinFilterFactorBitmap(@UnderInitialization LoptMultiJoin this, RexNode joinFilter, boolean setFields) {
        ImmutableBitSet fieldRefBitmap = LoptMultiJoin.fieldBitmap(joinFilter);
        if (setFields) {
            this.fieldsRefByJoinFilter.put(joinFilter, fieldRefBitmap);
        }
        return this.factorBitmap(fieldRefBitmap);
    }

    private static ImmutableBitSet fieldBitmap(RexNode joinFilter) {
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder();
        joinFilter.accept(inputFinder);
        return inputFinder.build();
    }

    @RequiresNonNull(value={"allJoinFilters", "joinStart", "nFieldsInJoinFactor"})
    private void setJoinFilterRefs(@UnderInitialization LoptMultiJoin this) {
        ListIterator<RexNode> filterIter = this.allJoinFilters.listIterator();
        while (filterIter.hasNext()) {
            RexNode joinFilter = filterIter.next();
            if (joinFilter.isAlwaysTrue()) {
                filterIter.remove();
            }
            ImmutableBitSet factorRefBitmap = this.getJoinFilterFactorBitmap(joinFilter, true);
            this.factorsRefByJoinFilter.put(joinFilter, factorRefBitmap);
        }
    }

    @RequiresNonNull(value={"joinStart", "nFieldsInJoinFactor"})
    private ImmutableBitSet factorBitmap(@UnknownInitialization LoptMultiJoin this, ImmutableBitSet fieldRefBitmap) {
        ImmutableBitSet.Builder factorRefBitmap = ImmutableBitSet.builder();
        for (int field : fieldRefBitmap) {
            int factor = this.findRef(field);
            factorRefBitmap.set(factor);
        }
        return factorRefBitmap.build();
    }

    @RequiresNonNull(value={"joinStart", "nFieldsInJoinFactor"})
    public int findRef(@UnknownInitialization LoptMultiJoin this, int rexInputRef) {
        for (int i = 0; i < this.nJoinFactors; ++i) {
            if (rexInputRef < this.joinStart[i] || rexInputRef >= this.joinStart[i] + this.nFieldsInJoinFactor[i]) continue;
            return i;
        }
        throw new AssertionError();
    }

    public void setFactorWeights() {
        this.factorWeights = new int[this.nJoinFactors][this.nJoinFactors];
        this.factorsRefByFactor = new ImmutableBitSet[this.nJoinFactors];
        for (int i = 0; i < this.nJoinFactors; ++i) {
            this.factorsRefByFactor[i] = ImmutableBitSet.of();
        }
        for (RexNode joinFilter : this.allJoinFilters) {
            ImmutableBitSet factorRefs = this.factorsRefByJoinFilter.get(joinFilter);
            if (!(joinFilter instanceof RexCall) || !joinFilter.isA(SqlKind.COMPARISON)) continue;
            for (int factor : Objects.requireNonNull(factorRefs, "factorRefs")) {
                this.factorsRefByFactor[factor] = this.factorsRefByFactor[factor].rebuild().addAll(factorRefs).clear(factor).build();
            }
            if (factorRefs.cardinality() == 2) {
                int weight;
                int leftFactor = factorRefs.nextSetBit(0);
                int rightFactor = factorRefs.nextSetBit(leftFactor + 1);
                RexCall call = (RexCall)joinFilter;
                ImmutableBitSet leftFields = LoptMultiJoin.fieldBitmap(call.getOperands().get(0));
                ImmutableBitSet leftBitmap = this.factorBitmap(leftFields);
                if (leftBitmap.cardinality() == 1) {
                    switch (joinFilter.getKind()) {
                        case EQUALS: {
                            weight = 3;
                            break;
                        }
                        default: {
                            weight = 2;
                            break;
                        }
                    }
                } else {
                    weight = 1;
                }
                this.setFactorWeight(weight, leftFactor, rightFactor);
                continue;
            }
            ImmutableIntList list = ImmutableIntList.copyOf(factorRefs);
            Iterator iterator = list.iterator();
            while (iterator.hasNext()) {
                int outer = (Integer)iterator.next();
                Iterator iterator2 = list.iterator();
                while (iterator2.hasNext()) {
                    int inner = (Integer)iterator2.next();
                    if (outer == inner) continue;
                    this.setFactorWeight(1, outer, inner);
                }
            }
        }
    }

    @RequiresNonNull(value={"factorWeights"})
    private void setFactorWeight(int weight, int leftFactor, int rightFactor) {
        if (this.factorWeights[leftFactor][rightFactor] < weight) {
            this.factorWeights[leftFactor][rightFactor] = weight;
            this.factorWeights[rightFactor][leftFactor] = weight;
        }
    }

    public boolean hasAllFactors(LoptJoinTree joinTree, BitSet factorsNeeded) {
        return BitSets.contains(BitSets.of(joinTree.getTreeOrder()), factorsNeeded);
    }

    @Deprecated
    public void getChildFactors(LoptJoinTree joinTree, ImmutableBitSet.Builder childFactors) {
        for (int child : joinTree.getTreeOrder()) {
            childFactors.set(child);
        }
    }

    public List<RelDataTypeField> getJoinFields(LoptJoinTree left, LoptJoinTree right) {
        RelDataType rowType = this.factory.createJoinType(left.getJoinTree().getRowType(), right.getJoinTree().getRowType());
        return rowType.getFieldList();
    }

    public void addRemovableOuterJoinFactor(int factIdx) {
        this.removableOuterJoinFactors.add(factIdx);
    }

    public boolean isRemovableOuterJoinFactor(int factIdx) {
        return this.removableOuterJoinFactors.contains(factIdx);
    }

    public void addRemovableSelfJoinPair(int factor1, int factor2) {
        int rightFactor;
        int leftFactor;
        if (this.getNumFieldsInJoinFactor(factor1) > this.getNumFieldsInJoinFactor(factor2)) {
            leftFactor = factor1;
            rightFactor = factor2;
        } else {
            leftFactor = factor2;
            rightFactor = factor1;
        }
        HashMap<Integer, Integer> columnMapping = new HashMap<Integer, Integer>();
        RelNode left = this.getJoinFactor(leftFactor);
        RelMetadataQuery mq = left.getCluster().getMetadataQuery();
        HashMap<Integer, Integer> leftFactorColMapping = new HashMap<Integer, Integer>();
        for (int i = 0; i < left.getRowType().getFieldCount(); ++i) {
            RelColumnOrigin colOrigin = mq.getColumnOrigin(left, i);
            if (colOrigin == null || !colOrigin.isDerived()) continue;
            leftFactorColMapping.put(colOrigin.getOriginColumnOrdinal(), i);
        }
        RelNode right = this.getJoinFactor(rightFactor);
        for (int i = 0; i < right.getRowType().getFieldCount(); ++i) {
            Integer leftOffset;
            RelColumnOrigin colOrigin = mq.getColumnOrigin(right, i);
            if (colOrigin == null || !colOrigin.isDerived() || (leftOffset = (Integer)leftFactorColMapping.get(colOrigin.getOriginColumnOrdinal())) == null) continue;
            columnMapping.put(i, leftOffset);
        }
        RemovableSelfJoin selfJoin = new RemovableSelfJoin(leftFactor, rightFactor, columnMapping);
        this.removableSelfJoinPairs.put(leftFactor, selfJoin);
        this.removableSelfJoinPairs.put(rightFactor, selfJoin);
    }

    public @Nullable Integer getOtherSelfJoinFactor(int factIdx) {
        RemovableSelfJoin selfJoin = this.removableSelfJoinPairs.get(factIdx);
        if (selfJoin == null) {
            return null;
        }
        if (selfJoin.rightFactor == factIdx) {
            return selfJoin.leftFactor;
        }
        return selfJoin.rightFactor;
    }

    public boolean isLeftFactorInRemovableSelfJoin(int factIdx) {
        RemovableSelfJoin selfJoin = this.removableSelfJoinPairs.get(factIdx);
        if (selfJoin == null) {
            return false;
        }
        return selfJoin.leftFactor == factIdx;
    }

    public boolean isRightFactorInRemovableSelfJoin(int factIdx) {
        RemovableSelfJoin selfJoin = this.removableSelfJoinPairs.get(factIdx);
        if (selfJoin == null) {
            return false;
        }
        return selfJoin.rightFactor == factIdx;
    }

    public @Nullable Integer getRightColumnMapping(int rightFactor, int rightOffset) {
        RemovableSelfJoin selfJoin = Objects.requireNonNull(this.removableSelfJoinPairs.get(rightFactor), () -> "removableSelfJoinPairs.get(rightFactor) is null for " + rightFactor + ", map=" + this.removableSelfJoinPairs);
        assert (selfJoin.rightFactor == rightFactor);
        return (Integer)selfJoin.columnMapping.get(rightOffset);
    }

    public Edge createEdge(RexNode condition) {
        ImmutableBitSet fieldRefBitmap = LoptMultiJoin.fieldBitmap(condition);
        ImmutableBitSet factorRefBitmap = this.factorBitmap(fieldRefBitmap);
        return new Edge(condition, factorRefBitmap, fieldRefBitmap);
    }

    private static class RemovableSelfJoin {
        private final int leftFactor;
        private final int rightFactor;
        private final Map<Integer, Integer> columnMapping;

        RemovableSelfJoin(int leftFactor, int rightFactor, Map<Integer, Integer> columnMapping) {
            this.leftFactor = leftFactor;
            this.rightFactor = rightFactor;
            this.columnMapping = columnMapping;
        }
    }

    static class Edge {
        final ImmutableBitSet factors;
        final ImmutableBitSet columns;
        final RexNode condition;

        Edge(RexNode condition, ImmutableBitSet factors, ImmutableBitSet columns) {
            this.condition = condition;
            this.factors = factors;
            this.columns = columns;
        }

        public String toString() {
            return "Edge(condition: " + this.condition + ", factors: " + this.factors + ", columns: " + this.columns + ")";
        }
    }
}

