/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexPatternFieldRef;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;

public final class CorrelateProjectExtractor
extends RelHomogeneousShuttle {
    private final RelBuilderFactory builderFactory;

    public CorrelateProjectExtractor(RelBuilderFactory factory) {
        this.builderFactory = factory;
    }

    public RelNode visit(LogicalCorrelate correlate) {
        RelNode left = correlate.getLeft().accept((RelShuttle)this);
        RelNode right = correlate.getRight().accept((RelShuttle)this);
        int oldLeft = left.getRowType().getFieldCount();
        Set<RexNode> callsWithCorrelationInRight = CorrelateProjectExtractor.findCorrelationDependentCalls(correlate.getCorrelationId(), right);
        boolean isTrivialCorrelation = callsWithCorrelationInRight.stream().allMatch(exp -> exp instanceof RexFieldAccess);
        if (isTrivialCorrelation) {
            if (correlate.getLeft().equals((Object)left) && correlate.getRight().equals((Object)right)) {
                return correlate;
            }
            return correlate.copy(correlate.getTraitSet(), left, right, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
        }
        RelBuilder builder = this.builderFactory.create(correlate.getCluster(), null);
        builder.push(left);
        ArrayList<RexNode> callsWithCorrelationOverLeft = new ArrayList<RexNode>();
        for (RexNode rexNode : callsWithCorrelationInRight) {
            callsWithCorrelationOverLeft.add(CorrelateProjectExtractor.replaceCorrelationsWithInputRef(rexNode, builder));
        }
        builder.projectPlus(callsWithCorrelationOverLeft);
        HashMap<RexNode, RexNode> transformMapping = new HashMap<RexNode, RexNode>();
        for (RexNode callInRight : callsWithCorrelationInRight) {
            RexBuilder xb = builder.getRexBuilder();
            RexNode v = xb.makeCorrel(builder.peek().getRowType(), correlate.getCorrelationId());
            RexNode flatCorrelationInRight = xb.makeFieldAccess(v, oldLeft + transformMapping.size());
            transformMapping.put(callInRight, flatCorrelationInRight);
        }
        ImmutableList immutableList = builder.fields(ImmutableBitSet.range((int)oldLeft, (int)(oldLeft + callsWithCorrelationOverLeft.size())).asList());
        int newLeft = builder.fields().size();
        right = CorrelateProjectExtractor.replaceExpressionsUsingMap(right, transformMapping);
        builder.push(right);
        builder.correlate(correlate.getJoinType(), correlate.getCorrelationId(), (Iterable)immutableList);
        builder.project((Iterable)builder.fields(switch (correlate.getJoinType()) {
            case JoinRelType.SEMI, JoinRelType.ANTI -> ImmutableBitSet.range((int)0, (int)oldLeft).asList();
            case JoinRelType.LEFT, JoinRelType.INNER -> ImmutableBitSet.builder().set(0, oldLeft).set(newLeft, newLeft + right.getRowType().getFieldCount()).build().asList();
            default -> throw new AssertionError(correlate.getJoinType());
        }));
        return builder.build();
    }

    private static Set<RexNode> findCorrelationDependentCalls(CorrelationId corrId, RelNode plan) {
        final SimpleCorrelationCollector finder = new SimpleCorrelationCollector(corrId);
        plan.accept((RelShuttle)new RelHomogeneousShuttle(){

            public RelNode visit(RelNode other) {
                if (other instanceof Project || other instanceof Filter) {
                    other.accept((RexShuttle)finder);
                }
                return super.visit(other);
            }
        });
        return finder.correlations;
    }

    private static RelNode replaceExpressionsUsingMap(RelNode plan, Map<RexNode, RexNode> mapping) {
        final CallReplacer replacer = new CallReplacer(mapping);
        return plan.accept((RelShuttle)new RelHomogeneousShuttle(){

            public RelNode visit(RelNode other) {
                RelNode mNode = super.visitChildren(other);
                return mNode.accept((RexShuttle)replacer);
            }
        });
    }

    private static boolean isSimpleCorrelatedExpression(RexNode node, CorrelationId id) {
        Boolean r = (Boolean)node.accept((RexVisitor)new SimpleCorrelationDetector(id));
        return r == null ? Boolean.FALSE : r;
    }

    private static RexNode replaceCorrelationsWithInputRef(RexNode exp, final RelBuilder b) {
        return (RexNode)exp.accept((RexVisitor)new RexShuttle(){

            public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
                if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
                    return b.field(fieldAccess.getField().getIndex());
                }
                return super.visitFieldAccess(fieldAccess);
            }
        });
    }

    private static final class SimpleCorrelationCollector
    extends RexShuttle {
        private final CorrelationId correlationId;
        private final Set<RexNode> correlations = new LinkedHashSet<RexNode>();

        SimpleCorrelationCollector(CorrelationId corrId) {
            this.correlationId = corrId;
        }

        public RexNode visitCall(RexCall call) {
            if (CorrelateProjectExtractor.isSimpleCorrelatedExpression((RexNode)call, this.correlationId)) {
                this.correlations.add((RexNode)call);
                return call;
            }
            return super.visitCall(call);
        }

        public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
            if (CorrelateProjectExtractor.isSimpleCorrelatedExpression((RexNode)fieldAccess, this.correlationId)) {
                this.correlations.add((RexNode)fieldAccess);
                return fieldAccess;
            }
            return super.visitFieldAccess(fieldAccess);
        }
    }

    private static final class CallReplacer
    extends RexShuttle {
        private final Map<RexNode, RexNode> mapping;

        CallReplacer(Map<RexNode, RexNode> mapping) {
            this.mapping = mapping;
        }

        public RexNode visitCall(RexCall oldCall) {
            RexNode newCall = this.mapping.get(oldCall);
            if (newCall != null) {
                return newCall;
            }
            return super.visitCall(oldCall);
        }
    }

    private static class SimpleCorrelationDetector
    extends RexVisitorImpl<Boolean> {
        private final CorrelationId corrId;

        private SimpleCorrelationDetector(CorrelationId corrId) {
            super(true);
            this.corrId = corrId;
        }

        public Boolean visitOver(RexOver over) {
            return Boolean.FALSE;
        }

        public Boolean visitSubQuery(RexSubQuery subQuery) {
            return Boolean.FALSE;
        }

        public Boolean visitCall(RexCall call) {
            Boolean hasSimpleCorrelation = null;
            for (RexNode op : call.operands) {
                Boolean b = (Boolean)op.accept((RexVisitor)this);
                if (b == null) continue;
                hasSimpleCorrelation = hasSimpleCorrelation == null ? b : hasSimpleCorrelation != false && b != false;
            }
            return hasSimpleCorrelation == null ? Boolean.FALSE : hasSimpleCorrelation;
        }

        public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
            return (Boolean)fieldAccess.getReferenceExpr().accept((RexVisitor)this);
        }

        public Boolean visitInputRef(RexInputRef inputRef) {
            return Boolean.FALSE;
        }

        public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) {
            return correlVariable.id.equals((Object)this.corrId);
        }

        public Boolean visitTableInputRef(RexTableInputRef ref) {
            return Boolean.FALSE;
        }

        public Boolean visitLocalRef(RexLocalRef localRef) {
            return Boolean.FALSE;
        }

        public Boolean visitPatternFieldRef(RexPatternFieldRef fieldRef) {
            return Boolean.FALSE;
        }
    }
}

