/*
 * Decompiled with CFR 0.152.
 */
package org.apache.helix.controller.stages;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.helix.HelixDefinedState;
import org.apache.helix.api.config.StateTransitionThrottleConfig;
import org.apache.helix.controller.common.PartitionStateMap;
import org.apache.helix.controller.pipeline.AbstractBaseStage;
import org.apache.helix.controller.pipeline.StageException;
import org.apache.helix.controller.stages.AttributeName;
import org.apache.helix.controller.stages.BestPossibleStateOutput;
import org.apache.helix.controller.stages.ClusterDataCache;
import org.apache.helix.controller.stages.ClusterEvent;
import org.apache.helix.controller.stages.CurrentStateOutput;
import org.apache.helix.controller.stages.IntermediateStateOutput;
import org.apache.helix.controller.stages.StateTransitionThrottleController;
import org.apache.helix.model.IdealState;
import org.apache.helix.model.Partition;
import org.apache.helix.model.Resource;
import org.apache.helix.model.StateModelDefinition;
import org.apache.log4j.Logger;

public class IntermediateStateCalcStage
extends AbstractBaseStage {
    private static final Logger logger = Logger.getLogger((String)IntermediateStateCalcStage.class.getName());

    @Override
    public void process(ClusterEvent event) throws Exception {
        long startTime = System.currentTimeMillis();
        logger.info((Object)"START Intermediate.process()");
        CurrentStateOutput currentStateOutput = (CurrentStateOutput)event.getAttribute(AttributeName.CURRENT_STATE.name());
        BestPossibleStateOutput bestPossibleStateOutput = (BestPossibleStateOutput)event.getAttribute(AttributeName.BEST_POSSIBLE_STATE.name());
        Map resourceMap = (Map)event.getAttribute(AttributeName.RESOURCES.name());
        ClusterDataCache cache = (ClusterDataCache)event.getAttribute("ClusterDataCache");
        if (currentStateOutput == null || bestPossibleStateOutput == null || resourceMap == null || cache == null) {
            throw new StageException("Missing attributes in event:" + event + ". Requires CURRENT_STATE|BEST_POSSIBLE_STATE|RESOURCES|DataCache");
        }
        IntermediateStateOutput immediateStateOutput = this.compute(cache, resourceMap, currentStateOutput, bestPossibleStateOutput);
        event.addAttribute(AttributeName.INTERMEDIATE_STATE.name(), immediateStateOutput);
        long endTime = System.currentTimeMillis();
        logger.info((Object)("END ImmediateStateCalcStage.process(). took: " + (endTime - startTime) + " ms"));
    }

    private IntermediateStateOutput compute(ClusterDataCache dataCache, Map<String, Resource> resourceMap, CurrentStateOutput currentStateOutput, BestPossibleStateOutput bestPossibleStateOutput) {
        IntermediateStateOutput output = new IntermediateStateOutput();
        StateTransitionThrottleController throttleController = new StateTransitionThrottleController(resourceMap.keySet(), dataCache.getClusterConfig(), dataCache.getLiveInstances().keySet());
        for (String resourceName : resourceMap.keySet()) {
            PartitionStateMap intermediatePartitionStateMap = this.computeIntermediatePartitionState(dataCache, dataCache.getIdealState(resourceName), resourceMap.get(resourceName), currentStateOutput, bestPossibleStateOutput.getPartitionStateMap(resourceName), throttleController);
            output.setState(resourceName, intermediatePartitionStateMap);
        }
        return output;
    }

    public PartitionStateMap computeIntermediatePartitionState(ClusterDataCache cache, IdealState idealState, Resource resource, CurrentStateOutput currentStateOutput, PartitionStateMap bestPossiblePartitionStateMap, StateTransitionThrottleController throttleController) {
        HashMap<String, String> intermediateMap;
        Map<String, String> bestPossibleMap;
        Map<String, String> currentStateMap;
        String resourceName = resource.getResourceName();
        logger.info((Object)("Processing resource:" + resourceName));
        if (!throttleController.isThrottleEnabled()) {
            logger.info((Object)("None of any type of transition throttling is set for resource " + resourceName + " skip computing intermediate partition state."));
            return bestPossiblePartitionStateMap;
        }
        String stateModelDefName = idealState.getStateModelDefRef();
        StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
        boolean pendingRecoveryRebalance = false;
        for (Partition partition : resource.getPartitions()) {
            StateTransitionThrottleConfig.RebalanceType rebalanceType;
            Map<String, String> currentStateMap2 = currentStateOutput.getCurrentStateMap(resourceName, partition);
            Map<String, String> pendingMap = currentStateOutput.getPendingStateMap(resourceName, partition);
            Map<String, String> bestPossibleMap2 = bestPossiblePartitionStateMap.getPartitionMap(partition);
            if (this.needRecoveryRebalance(bestPossibleMap2, stateModelDef, currentStateMap2)) {
                rebalanceType = StateTransitionThrottleConfig.RebalanceType.RECOVERY_BALANCE;
                pendingRecoveryRebalance = true;
            } else {
                rebalanceType = StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE;
            }
            if (pendingMap.size() > 0) {
                throttleController.chargeCluster(rebalanceType);
                throttleController.chargeResource(rebalanceType, resourceName);
            }
            HashSet<String> allInstances = new HashSet<String>(currentStateMap2.keySet());
            allInstances.addAll(pendingMap.keySet());
            for (String ins : allInstances) {
                String currentState = currentStateMap2.get(ins);
                String pendingState = pendingMap.get(ins);
                if (pendingState == null || pendingState.equals(currentState)) continue;
                throttleController.chargeInstance(rebalanceType, ins);
            }
        }
        PartitionStateMap output = new PartitionStateMap(resourceName);
        int recoveryNeededCount = 0;
        int recoveryThrottledCount = 0;
        int loadbalanceNeededCount = 0;
        int loadbalanceThrottledCount = 0;
        HashSet<Partition> partitionsNeedRecovery = new HashSet<Partition>();
        HashSet<Partition> partitionsNeedLoadbalance = new HashSet<Partition>();
        HashSet<Partition> partitionsRecoveryThrotted = new HashSet<Partition>();
        HashSet<Partition> partitionsLoadbalanceThrottled = new HashSet<Partition>();
        for (Partition partition : resource.getPartitions()) {
            currentStateMap = currentStateOutput.getCurrentStateMap(resourceName, partition);
            bestPossibleMap = bestPossiblePartitionStateMap.getPartitionMap(partition);
            intermediateMap = new HashMap<String, String>();
            if (currentStateMap.equals(bestPossibleMap)) {
                intermediateMap.putAll(bestPossibleMap);
            } else if (this.needRecoveryRebalance(bestPossibleMap, stateModelDef, currentStateMap)) {
                ++recoveryNeededCount;
                intermediateMap.putAll(bestPossibleMap);
                pendingRecoveryRebalance = true;
                partitionsNeedRecovery.add(partition);
            } else {
                partitionsNeedLoadbalance.add(partition);
            }
            output.setState(partition, intermediateMap);
        }
        loadbalanceNeededCount = partitionsNeedLoadbalance.size();
        if (!pendingRecoveryRebalance) {
            for (Partition partition : partitionsNeedLoadbalance) {
                String bestPossibleState;
                String currentState;
                currentStateMap = currentStateOutput.getCurrentStateMap(resourceName, partition);
                bestPossibleMap = bestPossiblePartitionStateMap.getPartitionMap(partition);
                intermediateMap = new HashMap();
                HashSet<String> allInstances = new HashSet<String>(currentStateMap.keySet());
                allInstances.addAll(bestPossibleMap.keySet());
                boolean throttled = false;
                if (throttleController.throttleforResource(StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE, resourceName)) {
                    throttled = true;
                    logger.debug((Object)("Load balance throttled on resource for " + resourceName + " " + partition.getPartitionName()));
                } else {
                    for (String ins : allInstances) {
                        currentState = currentStateMap.get(ins);
                        bestPossibleState = bestPossibleMap.get(ins);
                        if (bestPossibleState == null || bestPossibleState.equals(currentState) || !throttleController.throttleForInstance(StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE, ins)) continue;
                        throttled = true;
                        logger.debug((Object)("Load balance throttled because instance " + ins + " for " + resourceName + " " + partition.getPartitionName()));
                    }
                }
                if (!throttled) {
                    intermediateMap.putAll(bestPossibleMap);
                    for (String ins : allInstances) {
                        currentState = currentStateMap.get(ins);
                        bestPossibleState = bestPossibleMap.get(ins);
                        if (bestPossibleState == null || bestPossibleState.equals(currentState)) continue;
                        throttleController.chargeInstance(StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE, ins);
                    }
                    throttleController.chargeCluster(StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE);
                    throttleController.chargeResource(StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE, resourceName);
                } else {
                    intermediateMap.putAll(currentStateMap);
                    ++loadbalanceThrottledCount;
                    partitionsLoadbalanceThrottled.add(partition);
                }
                output.setState(partition, intermediateMap);
            }
        }
        logger.info((Object)String.format("RecoveryNeeded: %d, RecoveryThrottled: %d, loadbalanceNeeded: %d, loadbalanceThrottled: %d", recoveryNeededCount, recoveryThrottledCount, loadbalanceNeededCount, loadbalanceThrottledCount));
        if (logger.isDebugEnabled()) {
            this.logParitionMapState(resourceName, new HashSet<Partition>(resource.getPartitions()), partitionsNeedRecovery, partitionsRecoveryThrotted, partitionsNeedLoadbalance, partitionsLoadbalanceThrottled, currentStateOutput, bestPossiblePartitionStateMap, output);
        }
        logger.info((Object)("End processing resource:" + resourceName));
        return output;
    }

    private void logParitionMapState(String resource, Set<Partition> allPartitions, Set<Partition> recoveryPartitions, Set<Partition> recoveryThrottledPartitions, Set<Partition> loadbalancePartitions, Set<Partition> loadbalanceThrottledPartitions, CurrentStateOutput currentStateOutput, PartitionStateMap bestPossibleStateMap, PartitionStateMap intermediateStateMap) {
        logger.debug((Object)("Partitions need recovery: " + recoveryPartitions + "\nPartitions get throttled on recovery: " + recoveryThrottledPartitions));
        logger.debug((Object)("Partitions need loadbalance: " + loadbalancePartitions + "\nPartitions get throttled on load-balance: " + loadbalanceThrottledPartitions));
        for (Partition partition : allPartitions) {
            if (recoveryPartitions.contains(partition)) {
                logger.debug((Object)("recovery balance needed for " + resource + " " + partition.getPartitionName()));
                if (recoveryThrottledPartitions.contains(partition)) {
                    logger.debug((Object)("Recovery balance throttled on resource for " + resource + " " + partition.getPartitionName()));
                }
            } else if (loadbalancePartitions.contains(partition)) {
                logger.debug((Object)("load balance needed for " + resource + " " + partition.getPartitionName()));
                if (loadbalanceThrottledPartitions.contains(partition)) {
                    logger.debug((Object)("Load balance throttled on resource for " + resource + " " + partition.getPartitionName()));
                }
            } else {
                logger.debug((Object)("no balance needed for " + resource + " " + partition.getPartitionName()));
            }
            logger.debug((Object)(partition + ": Best possible map: " + bestPossibleStateMap.getPartitionMap(partition)));
            logger.debug((Object)(partition + ": Current State: " + currentStateOutput.getCurrentStateMap(resource, partition)));
            logger.debug((Object)(partition + ": Pending state: " + currentStateOutput.getPendingMessageMap(resource, partition)));
            logger.debug((Object)(partition + ": Intermediate state: " + intermediateStateMap.getPartitionMap(partition)));
        }
    }

    private boolean needRecoveryRebalance(Map<String, String> bestPossibleMap, StateModelDefinition stateModelDef, Map<String, String> currentStateMap) {
        boolean recoveryBalanceNeeded = false;
        List<String> states = stateModelDef.getStatesPriorityList();
        Map<String, Long> bestPossibleStateCounts = this.getStateCounts(bestPossibleMap);
        Map<String, Long> currentStateCounts = this.getStateCounts(currentStateMap);
        for (String state : states) {
            Long bestPossibleCount = bestPossibleStateCounts.get(state);
            Long currentCount = currentStateCounts.get(state);
            if (bestPossibleCount == null && currentCount == null || bestPossibleCount != null && currentCount != null && bestPossibleCount.equals(currentCount) || state.equals(HelixDefinedState.DROPPED.name()) || state.equals(HelixDefinedState.ERROR.name()) || state.equals(stateModelDef.getInitialState())) continue;
            recoveryBalanceNeeded = true;
            break;
        }
        return recoveryBalanceNeeded;
    }

    private Map<String, Long> getStateCounts(Map<String, String> stateMap) {
        HashMap<String, Long> stateCounts = new HashMap<String, Long>();
        for (String state : stateMap.values()) {
            if (!stateCounts.containsKey(state)) {
                stateCounts.put(state, 0L);
            }
            stateCounts.put(state, (Long)stateCounts.get(state) + 1L);
        }
        return stateCounts;
    }
}

