/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.Context;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.TaskFactory;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils;
import org.apache.hadoop.hive.ql.optimizer.physical.GenMRSkewJoinProcessor;
import org.apache.hadoop.hive.ql.optimizer.physical.GenSparkSkewJoinProcessor;
import org.apache.hadoop.hive.ql.optimizer.physical.SkewJoinProcFactory;
import org.apache.hadoop.hive.ql.optimizer.physical.SparkMapJoinResolver;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkSkewJoinResolver;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SMBJoinDesc;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.apache.hadoop.hive.ql.plan.TableDesc;

public class SparkSkewJoinProcFactory {
    private static final Set<JoinOperator> visitedJoinOp = new HashSet<JoinOperator>();

    private SparkSkewJoinProcFactory() {
    }

    public static NodeProcessor getDefaultProc() {
        return SkewJoinProcFactory.getDefaultProc();
    }

    public static NodeProcessor getJoinProc() {
        return new SparkSkewJoinJoinProcessor();
    }

    private static void splitTask(SparkTask currentTask, ReduceWork reduceWork, ParseContext parseContext) throws SemanticException {
        SparkWork currentWork = (SparkWork)currentTask.getWork();
        Set<Operator<?>> reduceSinkSet = SparkMapJoinResolver.getOp(reduceWork, ReduceSinkOperator.class);
        if (currentWork.getChildren(reduceWork).size() == 1 && SparkSkewJoinProcFactory.canSplit(currentWork) && reduceSinkSet.size() == 1) {
            ReduceSinkOperator reduceSink = (ReduceSinkOperator)reduceSinkSet.iterator().next();
            BaseWork childWork = currentWork.getChildren(reduceWork).get(0);
            SparkEdgeProperty originEdge = currentWork.getEdgeProperty(reduceWork, childWork);
            currentWork.disconnect(reduceWork, childWork);
            SparkWork newWork = new SparkWork(parseContext.getConf().getVar(HiveConf.ConfVars.HIVEQUERYID));
            newWork.add(childWork);
            SparkSkewJoinProcFactory.copyWorkGraph(currentWork, newWork, childWork);
            for (BaseWork baseWork : newWork.getAllWorkUnsorted()) {
                currentWork.remove(baseWork);
                currentWork.getCloneToWork().remove(baseWork);
            }
            Context baseCtx = parseContext.getContext();
            Path taskTmpDir = baseCtx.getMRTmpPath();
            Operator<OperatorDesc> rsParent = reduceSink.getParentOperators().get(0);
            TableDesc tableDesc = PlanUtils.getIntermediateFileTableDesc(PlanUtils.getFieldSchemasFromRowSchema(rsParent.getSchema(), "temporarycol"));
            TableScanOperator tableScanOp = GenMapRedUtils.createTemporaryFile(rsParent, reduceSink, taskTmpDir, tableDesc, parseContext);
            MapWork mapWork = PlanUtils.getMapRedWork().getMapWork();
            mapWork.setName("Map " + GenSparkUtils.getUtils().getNextSeqNumber());
            newWork.add(mapWork);
            newWork.connect(mapWork, childWork, originEdge);
            String streamDesc = taskTmpDir.toUri().toString();
            if (GenMapRedUtils.needsTagging((ReduceWork)childWork)) {
                Operator<?> childReducer = ((ReduceWork)childWork).getReducer();
                String id = null;
                if (childReducer instanceof JoinOperator) {
                    if (parseContext.getJoinOps().contains(childReducer)) {
                        id = ((JoinDesc)((JoinOperator)childReducer).getConf()).getId();
                    }
                } else if (childReducer instanceof MapJoinOperator) {
                    if (parseContext.getMapJoinOps().contains(childReducer)) {
                        id = ((MapJoinDesc)((MapJoinOperator)childReducer).getConf()).getId();
                    }
                } else if (childReducer instanceof SMBMapJoinOperator && parseContext.getSmbMapJoinOps().contains(childReducer)) {
                    id = ((SMBJoinDesc)((SMBMapJoinOperator)childReducer).getConf()).getId();
                }
                streamDesc = id != null ? id + ":$INTNAME" : "$INTNAME";
                String origStreamDesc = streamDesc;
                int pos = 0;
                while (mapWork.getAliasToWork().get(streamDesc) != null) {
                    streamDesc = origStreamDesc.concat(String.valueOf(++pos));
                }
            }
            GenMapRedUtils.setTaskPlan(taskTmpDir, streamDesc, tableScanOp, mapWork, false, tableDesc);
            Task<SparkWork> newTask = TaskFactory.get(newWork, parseContext.getConf(), new Task[0]);
            List<Task<Serializable>> childTasks = currentTask.getChildTasks();
            if (childTasks != null && childTasks.size() > 0) {
                Task<Serializable> childTask = childTasks.get(0);
                currentTask.removeDependentTask(childTask);
                newTask.addDependentTask(childTask);
            }
            currentTask.addDependentTask(newTask);
            newTask.setFetchSource(currentTask.isFetchSource());
        }
    }

    private static boolean canSplit(SparkWork sparkWork) {
        for (BaseWork baseWork : sparkWork.getAllWorkUnsorted()) {
            if (sparkWork.getChildren(baseWork).size() <= 1) continue;
            return false;
        }
        return true;
    }

    private static void copyWorkGraph(SparkWork originWork, SparkWork newWork, BaseWork baseWork) {
        SparkEdgeProperty edgeProperty;
        for (BaseWork child : originWork.getChildren(baseWork)) {
            if (newWork.contains(child)) continue;
            newWork.add(child);
            edgeProperty = originWork.getEdgeProperty(baseWork, child);
            newWork.connect(baseWork, child, edgeProperty);
            SparkSkewJoinProcFactory.copyWorkGraph(originWork, newWork, child);
        }
        for (BaseWork parent : originWork.getParents(baseWork)) {
            if (newWork.contains(parent)) continue;
            newWork.add(parent);
            edgeProperty = originWork.getEdgeProperty(parent, baseWork);
            newWork.connect(parent, baseWork, edgeProperty);
            SparkSkewJoinProcFactory.copyWorkGraph(originWork, newWork, parent);
        }
    }

    public static Set<JoinOperator> getVisitedJoinOp() {
        return visitedJoinOp;
    }

    private static boolean supportRuntimeSkewJoin(JoinOperator joinOp, ReduceWork reduceWork, Task<? extends Serializable> currTask, HiveConf hiveConf) {
        if (currTask instanceof SparkTask && GenMRSkewJoinProcessor.skewJoinEnabled(hiveConf, joinOp)) {
            SparkWork sparkWork = (SparkWork)((SparkTask)currTask).getWork();
            List<Task<Serializable>> children = currTask.getChildTasks();
            return !((JoinDesc)joinOp.getConf()).isFixedAsSorted() && sparkWork.contains(reduceWork) && (children == null || children.size() <= 1) && SparkMapJoinResolver.getOp(reduceWork, CommonJoinOperator.class).size() == 1;
        }
        return false;
    }

    public static class SparkSkewJoinJoinProcessor
    implements NodeProcessor {
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            SparkSkewJoinResolver.SparkSkewJoinProcCtx context = (SparkSkewJoinResolver.SparkSkewJoinProcCtx)procCtx;
            Task<? extends Serializable> currentTsk = context.getCurrentTask();
            JoinOperator op = (JoinOperator)nd;
            ReduceWork reduceWork = context.getReducerToReduceWork().get(op);
            ParseContext parseContext = context.getParseCtx();
            if (reduceWork != null && !visitedJoinOp.contains(op) && SparkSkewJoinProcFactory.supportRuntimeSkewJoin(op, reduceWork, currentTsk, parseContext.getConf())) {
                SparkSkewJoinProcFactory.splitTask((SparkTask)currentTsk, reduceWork, parseContext);
                GenSparkSkewJoinProcessor.processSkewJoin(op, currentTsk, reduceWork, parseContext);
                visitedJoinOp.add(op);
            }
            return null;
        }
    }
}

