/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.cost;

import org.apache.sysds.common.InstructionType;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.cost.CostEstimator;
import org.apache.sysds.hops.cost.VarStats;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class CostEstimatorStaticRuntime
extends CostEstimator {
    private static final long DEFAULT_FLOPS = 0x80000000L;
    private static final double DEFAULT_NFLOP_NOOP = 10.0;
    private static final double DEFAULT_NFLOP_UNKNOWN = 1.0;
    private static final double DEFAULT_NFLOP_CP = 1.0;
    private static final double DEFAULT_NFLOP_TEXT_IO = 350.0;
    private static final double DEFAULT_MBS_FSREAD_BINARYBLOCK_DENSE = 200.0;
    private static final double DEFAULT_MBS_FSREAD_BINARYBLOCK_SPARSE = 100.0;
    private static final double DEFAULT_MBS_HDFSREAD_BINARYBLOCK_DENSE = 150.0;
    public static final double DEFAULT_MBS_HDFSREAD_BINARYBLOCK_SPARSE = 75.0;
    private static final double DEFAULT_MBS_FSWRITE_BINARYBLOCK_DENSE = 150.0;
    private static final double DEFAULT_MBS_FSWRITE_BINARYBLOCK_SPARSE = 75.0;
    private static final double DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_DENSE = 120.0;
    private static final double DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_SPARSE = 60.0;
    private static final double DEFAULT_MBS_HDFSWRITE_TEXT_DENSE = 40.0;
    private static final double DEFAULT_MBS_HDFSWRITE_TEXT_SPARSE = 30.0;

    @Override
    protected double getCPInstTimeEstimate(Instruction inst, VarStats[] vs, String[] args) {
        CPInstruction cpinst = (CPInstruction)inst;
        double ltime = 0.0;
        if (!vs[0]._inmem) {
            ltime += CostEstimatorStaticRuntime.getHDFSReadTime(vs[0].getRows(), vs[0].getCols(), vs[0].getSparsity());
            if (LazyWriteBuffer.getWriteBufferLimit() < MatrixBlock.estimateSizeOnDisk(vs[0].getRows(), vs[0].getCols(), vs[0]._dc.getNonZeros() < 0L ? vs[0].getRows() * vs[0].getCols() : vs[0]._dc.getNonZeros())) {
                ltime += Math.abs(CostEstimatorStaticRuntime.getFSWriteTime(vs[0].getRows(), vs[0].getCols(), vs[0].getSparsity()));
            }
            vs[0]._inmem = true;
        }
        if (!vs[1]._inmem) {
            ltime += CostEstimatorStaticRuntime.getHDFSReadTime(vs[1].getRows(), vs[1].getCols(), vs[1].getSparsity());
            if (LazyWriteBuffer.getWriteBufferLimit() < MatrixBlock.estimateSizeOnDisk(vs[1].getRows(), vs[1].getCols(), vs[1]._dc.getNonZeros() < 0L ? vs[1].getRows() * vs[1].getCols() : vs[1]._dc.getNonZeros())) {
                ltime += Math.abs(CostEstimatorStaticRuntime.getFSWriteTime(vs[1].getRows(), vs[1].getCols(), vs[1].getSparsity()));
            }
            vs[1]._inmem = true;
        }
        if (LOG.isDebugEnabled() && ltime != 0.0) {
            LOG.debug((Object)("Cost[" + cpinst.getOpcode() + " - read] = " + ltime));
        }
        String opcode = cpinst instanceof FunctionCallCPInstruction ? InstructionUtils.getOpCode(cpinst.toString()) : cpinst.getOpcode();
        double etime = CostEstimatorStaticRuntime.getInstTimeEstimate(opcode, vs, args, Types.ExecType.CP);
        double wtime = 0.0;
        if (inst instanceof VariableCPInstruction && ((VariableCPInstruction)inst).getOpcode().equals(Opcodes.WRITE.toString())) {
            wtime += CostEstimatorStaticRuntime.getHDFSWriteTime(vs[2].getRows(), vs[2].getCols(), vs[2].getSparsity(), ((VariableCPInstruction)inst).getInput3().getName());
        }
        if (LOG.isDebugEnabled() && wtime != 0.0) {
            LOG.debug((Object)("Cost[" + cpinst.getOpcode() + " - write] = " + wtime));
        }
        double costs = ltime + etime + wtime;
        return costs;
    }

    private static double getHDFSReadTime(long dm, long dn, double ds) {
        boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double ret = (double)MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn)) / 1048576.0;
        ret = sparse ? (ret /= 75.0) : (ret /= 150.0);
        return ret;
    }

    private static double getHDFSWriteTime(long dm, long dn, double ds) {
        boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double bytes = MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double mbytes = bytes / 1048576.0;
        double ret = -1.0;
        ret = sparse ? mbytes / 60.0 : mbytes / 120.0;
        return ret;
    }

    private static double getHDFSWriteTime(long dm, long dn, double ds, String format) {
        boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double bytes = MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double mbytes = bytes / 1048576.0;
        double ret = -1.0;
        Types.FileFormat fmt = Types.FileFormat.safeValueOf(format);
        if (fmt.isTextFormat()) {
            ret = sparse ? mbytes / 30.0 : mbytes / 40.0;
            ret *= 2.75;
        } else {
            ret = sparse ? mbytes / 60.0 : mbytes / 120.0;
        }
        return ret;
    }

    public static double getFSReadTime(long dm, long dn, double ds) {
        boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double ret = (double)MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn)) / 1048576.0;
        ret = sparse ? (ret /= 100.0) : (ret /= 200.0);
        return ret;
    }

    public static double getFSWriteTime(long dm, long dn, double ds) {
        boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn));
        double ret = (double)MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds * (double)dm * (double)dn)) / 1048576.0;
        ret = sparse ? (ret /= 75.0) : (ret /= 150.0);
        return ret;
    }

    private static double getInstTimeEstimate(String opcode, VarStats[] vs, String[] args, Types.ExecType et) {
        return CostEstimatorStaticRuntime.getInstTimeEstimate(opcode, false, vs[0].getRows(), vs[0].getCols(), !vs[0]._dc.nnzKnown() ? 1.0 : vs[0].getSparsity(), vs[1].getRows(), vs[1].getCols(), !vs[1]._dc.nnzKnown() ? 1.0 : vs[1].getSparsity(), vs[2].getRows(), vs[2].getCols(), !vs[2]._dc.nnzKnown() ? 1.0 : vs[2].getSparsity(), args);
    }

    private static double getInstTimeEstimate(String opcode, boolean inMR, long d1m, long d1n, double d1s, long d2m, long d2n, double d2s, long d3m, long d3n, double d3s, String[] args) {
        double nflops = CostEstimatorStaticRuntime.getNFLOP(opcode, inMR, d1m, d1n, d1s, d2m, d2n, d2s, d3m, d3n, d3s, args);
        double time = nflops / 2.147483648E9;
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Cost[" + opcode + "] = " + time + "s, " + nflops + " flops (" + d1m + "," + d1n + "," + d1s + "," + d2m + "," + d2n + "," + d2s + "," + d3m + "," + d3n + "," + d3s + ")."));
        }
        return time;
    }

    private static double getNFLOP(String optype, boolean inMR, long d1m, long d1n, double d1s, long d2m, long d2n, double d2s, long d3m, long d3n, double d3s, String[] args) {
        boolean leftSparse = MatrixBlock.evalSparseFormatInMemory(d1m, d1n, (long)(d1s * (double)d1m * (double)d1n));
        boolean rightSparse = MatrixBlock.evalSparseFormatInMemory(d2m, d2n, (long)(d2s * (double)d2m * (double)d2n));
        boolean onlyLeft = d1m >= 0L && d1n >= 0L && d2m < 0L && d2n < 0L;
        boolean allExists = d1m >= 0L && d1n >= 0L && d2m >= 0L && d2n >= 0L && d3m >= 0L && d3n >= 0L;
        InstructionType cptype = Opcodes.getTypeByOpcode(optype, Types.ExecType.CP);
        if (cptype != null) {
            switch (cptype) {
                case AggregateBinary: {
                    if (optype.equals(Opcodes.MMULT.toString())) {
                        if (!leftSparse && !rightSparse) {
                            return 2.0 * ((double)(d1m * d1n) * (d2n > 1L ? d1s : 1.0) * (double)d2n) / 2.0;
                        }
                        if (!leftSparse && rightSparse) {
                            return 2.0 * ((double)(d1m * d1n) * d1s * (double)d2n * d2s) / 2.0;
                        }
                        if (leftSparse && !rightSparse) {
                            return 2.0 * ((double)(d1m * d1n) * d1s * (double)d2n) / 2.0;
                        }
                        return 2.0 * ((double)(d1m * d1n) * d1s * (double)d2n * d2s) / 2.0;
                    }
                    if (optype.equals(Opcodes.COV.toString())) {
                        return 23L * d1m;
                    }
                    return 0.0;
                }
                case MMChain: {
                    if (!leftSparse) {
                        return 4L * (d1m * d1n) / 2L;
                    }
                    return 4.0 * ((double)(d1m * d1n) * d1s) / 2.0;
                }
                case AggregateTernary: {
                    return 6L * d1m * d1n;
                }
                case AggregateUnary: {
                    if (optype.equals("nrow") || optype.equals("ncol") || optype.equals("length")) {
                        return 10.0;
                    }
                    if (optype.equals(Opcodes.CM.toString())) {
                        double xcm = 1.0;
                        switch (Integer.parseInt(args[0])) {
                            case 0: {
                                xcm = 1.0;
                                break;
                            }
                            case 1: {
                                xcm = 8.0;
                                break;
                            }
                            case 2: {
                                xcm = 16.0;
                                break;
                            }
                            case 3: {
                                xcm = 31.0;
                                break;
                            }
                            case 4: {
                                xcm = 51.0;
                                break;
                            }
                            case 5: {
                                xcm = 16.0;
                            }
                        }
                        return leftSparse ? xcm * ((double)d1m * d1s + 1.0) : xcm * (double)d1m;
                    }
                    if (optype.equals(Opcodes.UATRACE.toString()) || optype.equals(Opcodes.UAKTRACE.toString())) {
                        return 2L * d1m * d1n;
                    }
                    if (optype.equals(Opcodes.UAP.toString()) || optype.equals(Opcodes.UARP.toString()) || optype.equals(Opcodes.UACP.toString())) {
                        if (!leftSparse) {
                            return d1m * d1n;
                        }
                        return (double)(d1m * d1n) * d1s;
                    }
                    if (optype.equals(Opcodes.UAKP.toString()) || optype.equals(Opcodes.UARKP.toString()) || optype.equals(Opcodes.UACKP.toString())) {
                        return 4L * d1m * d1n;
                    }
                    if (optype.equals(Opcodes.UASQKP.toString()) || optype.equals(Opcodes.UARSQKP.toString()) || optype.equals(Opcodes.UACSQKP.toString())) {
                        return 5L * d1m * d1n;
                    }
                    if (optype.equals(Opcodes.UAMEAN.toString()) || optype.equals(Opcodes.UARMEAN.toString()) || optype.equals(Opcodes.UACMEAN.toString())) {
                        return 7L * d1m * d1n;
                    }
                    if (optype.equals(Opcodes.UAVAR.toString()) || optype.equals(Opcodes.UARVAR.toString()) || optype.equals(Opcodes.UACVAR.toString())) {
                        return 14L * d1m * d1n;
                    }
                    if (optype.equals(Opcodes.UAMAX.toString()) || optype.equals(Opcodes.UARMAX.toString()) || optype.equals(Opcodes.UACMAX.toString()) || optype.equals(Opcodes.UAMIN.toString()) || optype.equals(Opcodes.UARMIN.toString()) || optype.equals(Opcodes.UACMIN.toString()) || optype.equals(Opcodes.UARIMAX.toString()) || optype.equals(Opcodes.UAM.toString())) {
                        return d1m * d1n;
                    }
                    return 0.0;
                }
                case Binary: {
                    if (optype.equals(Opcodes.PLUS.toString()) || optype.equals(Opcodes.MINUS.toString()) && (leftSparse || rightSparse)) {
                        return (double)(d1m * d1n) * d1s + (double)(d2m * d2n) * d2s;
                    }
                    if (optype.equals(Opcodes.SOLVE.toString())) {
                        return d1m * d1n * d1n;
                    }
                    return d3m * d3n;
                }
                case Ternary: {
                    return 2L * d1m * d1n;
                }
                case Ctable: {
                    if (optype.equals(Opcodes.CTABLE.toString())) {
                        if (leftSparse) {
                            return (double)(d1m * d1n) * d1s;
                        }
                        return d1m * d1n;
                    }
                    return 0.0;
                }
                case Builtin: {
                    if (allExists) {
                        return 3L * d3m * d3n;
                    }
                    return d3m * d3n;
                }
                case Unary: {
                    if (optype.equals(Opcodes.PRINT.toString())) {
                        return 1.0;
                    }
                    double xbu = 1.0;
                    if (optype.equals(Opcodes.PLOGP.toString())) {
                        xbu = 2.0;
                    } else if (optype.equals(Opcodes.ROUND.toString())) {
                        xbu = 4.0;
                    }
                    if (optype.equals(Opcodes.SIN.toString()) || optype.equals(Opcodes.TAN.toString()) || optype.equals(Opcodes.ROUND.toString()) || optype.equals(Opcodes.ABS.toString()) || optype.equals(Opcodes.SQRT.toString()) || optype.equals(Opcodes.SPROP.toString()) || optype.equals(Opcodes.SIGMOID.toString()) || optype.equals(Opcodes.SIGN.toString())) {
                        if (leftSparse) {
                            return xbu * (double)d1m * (double)d1n * d1s;
                        }
                        return xbu * (double)d1m * (double)d1n;
                    }
                    return xbu * (double)d1m * (double)d1n;
                }
                case Reorg: 
                case Reshape: {
                    if (leftSparse) {
                        return (double)(d1m * d1n) * d1s;
                    }
                    return d1m * d1n;
                }
                case Append: {
                    return 1.0 * ((leftSparse ? (double)(d1m * d1n) * d1s : (double)(d1m * d1n)) + (rightSparse ? (double)(d2m * d2n) * d2s : (double)(d2m * d2n)));
                }
                case Variable: {
                    if (optype.equals(Opcodes.WRITE.toString())) {
                        double xwrite;
                        Types.FileFormat fmt = Types.FileFormat.safeValueOf(args[0]);
                        boolean text = fmt.isTextFormat();
                        double d = xwrite = text ? 350.0 : 1.0;
                        if (!leftSparse) {
                            return (double)(d1m * d1n) * xwrite;
                        }
                        return (double)(d1m * d1n) * d1s * xwrite;
                    }
                    if (optype.equals("inmem-iqm")) {
                        return (double)(2L * d1m + 5L) + 0.25 * (double)d1m + 4.0 * (double)d1m;
                    }
                    return 10.0;
                }
                case Rand: {
                    if (optype.equals(Opcodes.RANDOM.toString())) {
                        int nflopRand = 32;
                        switch (Integer.parseInt(args[0])) {
                            case 0: {
                                return 10.0;
                            }
                            case 1: {
                                return d3m * d3n * 8L;
                            }
                            case 2: {
                                if (d3s == 1.0) {
                                    return d3m * d3n * (long)nflopRand + d3m * d3n * 8L;
                                }
                                return d3s >= 0.4 ? (double)(2L * d3m * d3n * (long)nflopRand + d3m * d3n * 8L) : (double)(3L * d3m * d3n) * d3s * (double)nflopRand + (double)(d3m * d3n) * d3s * 24.0;
                            }
                        }
                    } else {
                        return (double)(d3m * d3n) * 1.0;
                    }
                }
                case StringInit: {
                    return (double)(d3m * d3n) * 1.0;
                }
                case FCall: {
                    return (double)(d1m * d1n) * d1s * 1.0;
                }
                case MultiReturnBuiltin: {
                    double xf = 2.0;
                    if (optype.equals(Opcodes.EIGEN.toString())) {
                        xf = 32.0;
                    } else if (optype.equals(Opcodes.LU.toString())) {
                        xf = 16.0;
                    } else if (optype.equals(Opcodes.SVD.toString())) {
                        xf = 32.0;
                    }
                    return xf * (double)d1m * (double)d1n * (double)d1n;
                }
                case ParameterizedBuiltin: {
                    if (optype.equals(Opcodes.CDF.toString()) || optype.equals(Opcodes.INVCDF.toString())) {
                        return 1.0;
                    }
                    if (optype.equals(Opcodes.GROUPEDAGG.toString())) {
                        double xga = 1.0;
                        switch (Integer.parseInt(args[0])) {
                            case 0: {
                                xga = 4.0;
                                break;
                            }
                            case 1: {
                                xga = 1.0;
                                break;
                            }
                            case 2: {
                                xga = 8.0;
                                break;
                            }
                            case 3: {
                                xga = 16.0;
                                break;
                            }
                            case 4: {
                                xga = 31.0;
                                break;
                            }
                            case 5: {
                                xga = 51.0;
                                break;
                            }
                            case 6: {
                                xga = 16.0;
                            }
                        }
                        return (double)(2L * d1m) + xga * (double)d1m;
                    }
                    if (optype.equals(Opcodes.RMEMPTY.toString())) {
                        switch (Integer.parseInt(args[0])) {
                            case 0: {
                                return (leftSparse ? (double)d1m : (double)d1m * Math.ceil(1.0 / d1s) / 2.0) + 1.0 * (double)d3m * (double)d2m;
                            }
                            case 1: {
                                return (double)d1n * Math.ceil(1.0 / d1s) / 2.0 + 1.0 * (double)d3m * (double)d2m;
                            }
                        }
                    }
                    return 0.0;
                }
                case QSort: {
                    if (optype.equals("sort")) {
                        double sortCosts = 0.0;
                        sortCosts = onlyLeft ? 1.0 * (double)d1m + (double)d1m : 1.0 * (leftSparse ? (double)d1m * d1s : (double)d1m);
                        return sortCosts + (double)(d1m * (long)((int)(Math.log(d1m) / Math.log(2.0)))) + 1.0 * (double)d1m;
                    }
                    return 0.0;
                }
                case MatrixIndexing: {
                    if (optype.equals(Opcodes.LEFT_INDEX.toString())) {
                        return 1.0 * (leftSparse ? (double)(d1m * d1n) * d1s : (double)(d1m * d1n)) + 2.0 * (rightSparse ? (double)(d2m * d2n) * d2s : (double)(d2m * d2n));
                    }
                    if (optype.equals(Opcodes.RIGHT_INDEX.toString())) {
                        return 1.0 * (leftSparse ? (double)(d2m * d2n) * d2s : (double)(d2m * d2n));
                    }
                    return 0.0;
                }
                case MMTSJ: {
                    if (MMTSJ.MMTSJType.valueOf(args[0]).isLeft()) {
                        if (!rightSparse) {
                            return (double)(d1m * d1n) * d1s * (double)d1n / 2.0;
                        }
                        return (double)(d1m * d1n) * d1s * (double)d1n * d1s / 2.0;
                    }
                    if (onlyLeft) {
                        if (!leftSparse) {
                            return (double)d1m * (double)d1n * (double)d1m / 2.0;
                        }
                        return (double)(d1m * d1n) * d1s + (double)(d1m * d1n) * d1s * (double)d1n * d1s / 2.0;
                    }
                    return 0.0;
                }
                case Partition: {
                    return (double)(d1m * d1n) * d1s + (inMR ? 0.0 : CostEstimatorStaticRuntime.getHDFSWriteTime(d1m, d1n, d1s) * 2.147483648E9);
                }
            }
            throw new DMLRuntimeException("CostEstimator: unsupported instruction type: " + optype);
        }
        throw new DMLRuntimeException("CostEstimator: unsupported instruction type: " + optype);
    }
}

