/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.adaptivebatch.AllToAllBlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BisectionSearchUtils;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismAndInputInfosDecider;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultVertexParallelismAndInputInfosDecider
implements VertexParallelismAndInputInfosDecider {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultVertexParallelismAndInputInfosDecider.class);
    private static final int MAX_NUM_SUBPARTITIONS_PER_TASK_CONSUME = 32768;
    private final int globalMaxParallelism;
    private final int globalMinParallelism;
    private final long dataVolumePerTask;
    private final int globalDefaultSourceParallelism;

    private DefaultVertexParallelismAndInputInfosDecider(int globalMaxParallelism, int globalMinParallelism, MemorySize dataVolumePerTask, int globalDefaultSourceParallelism) {
        Preconditions.checkArgument((globalMinParallelism > 0 ? 1 : 0) != 0, (Object)"The minimum parallelism must be larger than 0.");
        Preconditions.checkArgument((globalMaxParallelism >= globalMinParallelism ? 1 : 0) != 0, (Object)"Maximum parallelism should be greater than or equal to the minimum parallelism.");
        Preconditions.checkArgument((globalDefaultSourceParallelism > 0 ? 1 : 0) != 0, (Object)"The default source parallelism must be larger than 0.");
        Preconditions.checkNotNull((Object)dataVolumePerTask);
        this.globalMaxParallelism = globalMaxParallelism;
        this.globalMinParallelism = globalMinParallelism;
        this.dataVolumePerTask = dataVolumePerTask.getBytes();
        this.globalDefaultSourceParallelism = globalDefaultSourceParallelism;
    }

    @Override
    public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex(JobVertexID jobVertexId, List<BlockingResultInfo> consumedResults, int vertexInitialParallelism, int vertexMinParallelism, int vertexMaxParallelism) {
        Preconditions.checkArgument((vertexInitialParallelism == -1 || vertexInitialParallelism > 0 ? 1 : 0) != 0);
        Preconditions.checkArgument((vertexMinParallelism == -1 || vertexMinParallelism > 0 ? 1 : 0) != 0);
        Preconditions.checkArgument((vertexMaxParallelism > 0 && vertexMaxParallelism >= vertexInitialParallelism && vertexMaxParallelism >= vertexMinParallelism ? 1 : 0) != 0);
        if (consumedResults.isEmpty()) {
            int parallelism = vertexInitialParallelism > 0 ? vertexInitialParallelism : this.computeSourceParallelismUpperBound(jobVertexId, vertexMaxParallelism);
            return new ParallelismAndInputInfos(parallelism, Collections.emptyMap());
        }
        int minParallelism = Math.max(this.globalMinParallelism, vertexMinParallelism);
        int maxParallelism = this.globalMaxParallelism;
        if (vertexInitialParallelism == -1 && vertexMaxParallelism < minParallelism) {
            LOG.info("The vertex maximum parallelism {} is smaller than the minimum parallelism {}. Use {} as the lower bound to decide parallelism of job vertex {}.", new Object[]{vertexMaxParallelism, minParallelism, vertexMaxParallelism, jobVertexId});
            minParallelism = vertexMaxParallelism;
        }
        if (vertexInitialParallelism == -1 && vertexMaxParallelism < maxParallelism) {
            LOG.info("The vertex maximum parallelism {} is smaller than the global maximum parallelism {}. Use {} as the upper bound to decide parallelism of job vertex {}.", new Object[]{vertexMaxParallelism, maxParallelism, vertexMaxParallelism, jobVertexId});
            maxParallelism = vertexMaxParallelism;
        }
        Preconditions.checkState((maxParallelism >= minParallelism ? 1 : 0) != 0);
        if (vertexInitialParallelism == -1 && DefaultVertexParallelismAndInputInfosDecider.areAllInputsAllToAll(consumedResults) && !DefaultVertexParallelismAndInputInfosDecider.areAllInputsBroadcast(consumedResults)) {
            return this.decideParallelismAndEvenlyDistributeData(jobVertexId, consumedResults, vertexInitialParallelism, minParallelism, maxParallelism);
        }
        return this.decideParallelismAndEvenlyDistributeSubpartitions(jobVertexId, consumedResults, vertexInitialParallelism, minParallelism, maxParallelism);
    }

    @Override
    public int computeSourceParallelismUpperBound(JobVertexID jobVertexId, int maxParallelism) {
        if (this.globalDefaultSourceParallelism > maxParallelism) {
            LOG.info("The global default source parallelism {} is larger than the maximum parallelism {}. Use {} as the upper bound parallelism of source job vertex {}.", new Object[]{this.globalDefaultSourceParallelism, maxParallelism, maxParallelism, jobVertexId});
            return maxParallelism;
        }
        return this.globalDefaultSourceParallelism;
    }

    @Override
    public long getDataVolumePerTask() {
        return this.dataVolumePerTask;
    }

    private static boolean areAllInputsAllToAll(List<BlockingResultInfo> consumedResults) {
        return consumedResults.stream().noneMatch(IntermediateResultInfo::isPointwise);
    }

    private static boolean areAllInputsBroadcast(List<BlockingResultInfo> consumedResults) {
        return consumedResults.stream().allMatch(IntermediateResultInfo::isBroadcast);
    }

    private ParallelismAndInputInfos decideParallelismAndEvenlyDistributeSubpartitions(JobVertexID jobVertexId, List<BlockingResultInfo> consumedResults, int initialParallelism, int minParallelism, int maxParallelism) {
        Preconditions.checkArgument((!consumedResults.isEmpty() ? 1 : 0) != 0);
        int parallelism = initialParallelism > 0 ? initialParallelism : this.decideParallelism(jobVertexId, consumedResults, minParallelism, maxParallelism);
        return new ParallelismAndInputInfos(parallelism, VertexInputInfoComputationUtils.computeVertexInputInfos(parallelism, consumedResults, true));
    }

    int decideParallelism(JobVertexID jobVertexId, List<BlockingResultInfo> consumedResults, int minParallelism, int maxParallelism) {
        Preconditions.checkArgument((!consumedResults.isEmpty() ? 1 : 0) != 0);
        List<BlockingResultInfo> nonBroadcastResults = DefaultVertexParallelismAndInputInfosDecider.getNonBroadcastResultInfos(consumedResults);
        if (nonBroadcastResults.isEmpty()) {
            return minParallelism;
        }
        long totalBytes = nonBroadcastResults.stream().mapToLong(BlockingResultInfo::getNumBytesProduced).sum();
        int parallelism = (int)Math.ceil((double)totalBytes / (double)this.dataVolumePerTask);
        int minParallelismLimitedByMaxSubpartitions = (int)Math.ceil((double)DefaultVertexParallelismAndInputInfosDecider.getMaxNumSubpartitions(nonBroadcastResults) / 32768.0);
        parallelism = Math.max(parallelism, minParallelismLimitedByMaxSubpartitions);
        LOG.debug("The total size of non-broadcast data is {}, the initially decided parallelism of job vertex {} is {}.", new Object[]{new MemorySize(totalBytes), jobVertexId, parallelism});
        if (parallelism < minParallelism) {
            LOG.info("The initially decided parallelism {} is smaller than the minimum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", new Object[]{parallelism, minParallelism, minParallelism, jobVertexId});
            parallelism = minParallelism;
        } else if (parallelism > maxParallelism) {
            LOG.info("The initially decided parallelism {} is larger than the maximum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", new Object[]{parallelism, maxParallelism, maxParallelism, jobVertexId});
            parallelism = maxParallelism;
        }
        return parallelism;
    }

    private ParallelismAndInputInfos decideParallelismAndEvenlyDistributeData(JobVertexID jobVertexId, List<BlockingResultInfo> consumedResults, int initialParallelism, int minParallelism, int maxParallelism) {
        Preconditions.checkArgument((initialParallelism == -1 ? 1 : 0) != 0);
        Preconditions.checkArgument((!consumedResults.isEmpty() ? 1 : 0) != 0);
        consumedResults.forEach(resultInfo -> Preconditions.checkState((!resultInfo.isPointwise() ? 1 : 0) != 0));
        List<BlockingResultInfo> nonBroadcastResults = DefaultVertexParallelismAndInputInfosDecider.getNonBroadcastResultInfos(consumedResults);
        int subpartitionNum = DefaultVertexParallelismAndInputInfosDecider.checkAndGetSubpartitionNum(nonBroadcastResults);
        long[] bytesBySubpartition = new long[subpartitionNum];
        Arrays.fill(bytesBySubpartition, 0L);
        for (BlockingResultInfo resultInfo2 : nonBroadcastResults) {
            List<Long> subpartitionBytes = ((AllToAllBlockingResultInfo)resultInfo2).getAggregatedSubpartitionBytes();
            for (int i = 0; i < subpartitionNum; ++i) {
                int n = i;
                bytesBySubpartition[n] = bytesBySubpartition[n] + subpartitionBytes.get(i);
            }
        }
        int maxNumPartitions = DefaultVertexParallelismAndInputInfosDecider.getMaxNumPartitions(nonBroadcastResults);
        int maxRangeSize = 32768 / maxNumPartitions;
        List<IndexRange> subpartitionRanges = DefaultVertexParallelismAndInputInfosDecider.computeSubpartitionRanges(bytesBySubpartition, this.dataVolumePerTask, maxRangeSize);
        if (!DefaultVertexParallelismAndInputInfosDecider.isLegalParallelism(subpartitionRanges.size(), minParallelism, maxParallelism)) {
            Optional<List<IndexRange>> adjustedSubpartitionRanges = DefaultVertexParallelismAndInputInfosDecider.adjustToClosestLegalParallelism(this.dataVolumePerTask, subpartitionRanges.size(), minParallelism, maxParallelism, Arrays.stream(bytesBySubpartition).min().getAsLong(), Arrays.stream(bytesBySubpartition).sum(), limit -> DefaultVertexParallelismAndInputInfosDecider.computeParallelism(bytesBySubpartition, limit, maxRangeSize), limit -> DefaultVertexParallelismAndInputInfosDecider.computeSubpartitionRanges(bytesBySubpartition, limit, maxRangeSize));
            if (!adjustedSubpartitionRanges.isPresent()) {
                LOG.info("Cannot find a legal parallelism to evenly distribute data for job vertex {}. Fall back to compute a parallelism that can evenly distribute subpartitions.", (Object)jobVertexId);
                return this.decideParallelismAndEvenlyDistributeSubpartitions(jobVertexId, consumedResults, initialParallelism, minParallelism, maxParallelism);
            }
            subpartitionRanges = adjustedSubpartitionRanges.get();
        }
        Preconditions.checkState((boolean)DefaultVertexParallelismAndInputInfosDecider.isLegalParallelism(subpartitionRanges.size(), minParallelism, maxParallelism));
        return DefaultVertexParallelismAndInputInfosDecider.createParallelismAndInputInfos(consumedResults, subpartitionRanges);
    }

    private static boolean isLegalParallelism(int parallelism, int minParallelism, int maxParallelism) {
        return parallelism >= minParallelism && parallelism <= maxParallelism;
    }

    private static int checkAndGetSubpartitionNum(List<BlockingResultInfo> consumedResults) {
        Set subpartitionNumSet = consumedResults.stream().flatMap(resultInfo -> IntStream.range(0, resultInfo.getNumPartitions()).boxed().map(resultInfo::getNumSubpartitions)).collect(Collectors.toSet());
        Preconditions.checkState((subpartitionNumSet.size() == 1 ? 1 : 0) != 0);
        return (Integer)subpartitionNumSet.iterator().next();
    }

    private static Optional<List<IndexRange>> adjustToClosestLegalParallelism(long currentDataVolumeLimit, int currentParallelism, int minParallelism, int maxParallelism, long minLimit, long maxLimit, Function<Long, Integer> parallelismComputer, Function<Long, List<IndexRange>> subpartitionRangesComputer) {
        long adjustedDataVolumeLimit = currentDataVolumeLimit;
        if (currentParallelism < minParallelism) {
            adjustedDataVolumeLimit = BisectionSearchUtils.findMaxLegalValue(value -> (Integer)parallelismComputer.apply((Long)value) >= minParallelism, minLimit, currentDataVolumeLimit);
            long minPossibleLegalParallelism = parallelismComputer.apply(adjustedDataVolumeLimit).intValue();
            adjustedDataVolumeLimit = BisectionSearchUtils.findMinLegalValue(value -> (long)((Integer)parallelismComputer.apply((Long)value)).intValue() == minPossibleLegalParallelism, minLimit, adjustedDataVolumeLimit);
        } else if (currentParallelism > maxParallelism) {
            adjustedDataVolumeLimit = BisectionSearchUtils.findMinLegalValue(value -> (Integer)parallelismComputer.apply((Long)value) <= maxParallelism, currentDataVolumeLimit, maxLimit);
        }
        int adjustedParallelism = parallelismComputer.apply(adjustedDataVolumeLimit);
        if (DefaultVertexParallelismAndInputInfosDecider.isLegalParallelism(adjustedParallelism, minParallelism, maxParallelism)) {
            return Optional.of(subpartitionRangesComputer.apply(adjustedDataVolumeLimit));
        }
        return Optional.empty();
    }

    private static ParallelismAndInputInfos createParallelismAndInputInfos(List<BlockingResultInfo> consumedResults, List<IndexRange> subpartitionRanges) {
        HashMap<IntermediateDataSetID, JobVertexInputInfo> vertexInputInfos = new HashMap<IntermediateDataSetID, JobVertexInputInfo>();
        consumedResults.forEach(resultInfo -> {
            int sourceParallelism = resultInfo.getNumPartitions();
            IndexRange partitionRange = new IndexRange(0, sourceParallelism - 1);
            ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
            for (int i = 0; i < subpartitionRanges.size(); ++i) {
                IndexRange subpartitionRange = resultInfo.isBroadcast() ? new IndexRange(0, 0) : (IndexRange)subpartitionRanges.get(i);
                ExecutionVertexInputInfo executionVertexInputInfo = new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange);
                executionVertexInputInfos.add(executionVertexInputInfo);
            }
            vertexInputInfos.put(resultInfo.getResultId(), new JobVertexInputInfo(executionVertexInputInfos));
        });
        return new ParallelismAndInputInfos(subpartitionRanges.size(), vertexInputInfos);
    }

    private static List<IndexRange> computeSubpartitionRanges(long[] nums, long limit, int maxRangeSize) {
        ArrayList<IndexRange> subpartitionRanges = new ArrayList<IndexRange>();
        long tmpSum = 0L;
        int startIndex = 0;
        for (int i = 0; i < nums.length; ++i) {
            long num = nums[i];
            if (i == startIndex || tmpSum + num <= limit && i - startIndex + 1 <= maxRangeSize) {
                tmpSum += num;
                continue;
            }
            subpartitionRanges.add(new IndexRange(startIndex, i - 1));
            startIndex = i;
            tmpSum = num;
        }
        subpartitionRanges.add(new IndexRange(startIndex, nums.length - 1));
        return subpartitionRanges;
    }

    private static int computeParallelism(long[] nums, long limit, int maxRangeSize) {
        long tmpSum = 0L;
        int startIndex = 0;
        int count = 1;
        for (int i = 0; i < nums.length; ++i) {
            long num = nums[i];
            if (i == startIndex || tmpSum + num <= limit && i - startIndex + 1 <= maxRangeSize) {
                tmpSum += num;
                continue;
            }
            startIndex = i;
            tmpSum = num;
            ++count;
        }
        return count;
    }

    private static int getMaxNumPartitions(List<BlockingResultInfo> consumedResults) {
        Preconditions.checkArgument((!consumedResults.isEmpty() ? 1 : 0) != 0);
        return consumedResults.stream().mapToInt(IntermediateResultInfo::getNumPartitions).max().getAsInt();
    }

    private static int getMaxNumSubpartitions(List<BlockingResultInfo> consumedResults) {
        Preconditions.checkArgument((!consumedResults.isEmpty() ? 1 : 0) != 0);
        return consumedResults.stream().mapToInt(resultInfo -> IntStream.range(0, resultInfo.getNumPartitions()).boxed().mapToInt(resultInfo::getNumSubpartitions).sum()).max().getAsInt();
    }

    private static List<BlockingResultInfo> getNonBroadcastResultInfos(List<BlockingResultInfo> consumedResults) {
        return consumedResults.stream().filter(resultInfo -> !resultInfo.isBroadcast()).collect(Collectors.toList());
    }

    static DefaultVertexParallelismAndInputInfosDecider from(int maxParallelism, Configuration configuration) {
        return new DefaultVertexParallelismAndInputInfosDecider(maxParallelism, (Integer)configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MIN_PARALLELISM), (MemorySize)configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK), (Integer)configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, (Object)maxParallelism));
    }
}

