/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

public class MatrixAppendMSPInstruction
extends AppendMSPInstruction {
    protected MatrixAppendMSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand offset, CPOperand out, boolean cbind, String opcode, String istr) {
        super(op, in1, in2, offset, out, cbind, opcode, istr);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkBinaryAppendInputCharacteristics(sec, this._cbind, false, false);
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mc2 = sec.getDataCharacteristics(this.input2.getName());
        int blen = mc1.getBlocksize();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(this.input2.getName());
        long off = sec.getScalarInput(this._offset).getLongValue();
        JavaPairRDD out = null;
        out = MatrixAppendMSPInstruction.preservesPartitioning(mc1, mc2, this._cbind) ? in1.mapPartitionsToPair((PairFlatMapFunction)new MapSideAppendPartitionFunction(in2, this._cbind, off, blen), true) : in1.flatMapToPair((PairFlatMapFunction)new MapSideAppendFunction(in2, this._cbind, off, blen));
        this.updateBinaryAppendOutputDataCharacteristics(sec, this._cbind);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        sec.addLineageBroadcast(this.output.getName(), this.input2.getName());
    }

    private static boolean preservesPartitioning(DataCharacteristics mcIn1, DataCharacteristics mcIn2, boolean cbind) {
        long ncblksIn1 = cbind ? mcIn1.getNumColBlocks() : mcIn1.getNumRowBlocks();
        long ncblksOut = cbind ? Math.max((long)Math.ceil(((double)mcIn1.getCols() + (double)mcIn2.getCols()) / (double)mcIn1.getBlocksize()), 1L) : Math.max((long)Math.ceil(((double)mcIn1.getRows() + (double)mcIn2.getRows()) / (double)mcIn1.getBlocksize()), 1L);
        return ncblksIn1 == ncblksOut;
    }

    private static class MapSideAppendPartitionFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 5767240739761027220L;
        private PartitionedBroadcast<MatrixBlock> _pm = null;
        private boolean _cbind = true;
        private long _lastBlockColIndex = -1L;

        public MapSideAppendPartitionFunction(PartitionedBroadcast<MatrixBlock> binput, boolean cbind, long offset, int blen) {
            this._pm = binput;
            this._cbind = cbind;
            this._lastBlockColIndex = Math.max((long)Math.ceil((double)offset / (double)(cbind ? blen : blen)), 1L);
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) throws Exception {
            return new MapAppendPartitionIterator(arg0);
        }

        private class MapAppendPartitionIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapAppendPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
                MatrixIndexes ix = (MatrixIndexes)arg._1();
                MatrixBlock in1 = (MatrixBlock)arg._2();
                if ((MapSideAppendPartitionFunction.this._cbind ? ix.getColumnIndex() : ix.getRowIndex()) != MapSideAppendPartitionFunction.this._lastBlockColIndex) {
                    return arg;
                }
                int rowix = MapSideAppendPartitionFunction.this._cbind ? (int)ix.getRowIndex() : 1;
                int colix = MapSideAppendPartitionFunction.this._cbind ? 1 : (int)ix.getColumnIndex();
                MatrixBlock in2 = MapSideAppendPartitionFunction.this._pm.getBlock(rowix, colix);
                MatrixBlock out = in1.append(in2, new MatrixBlock(), MapSideAppendPartitionFunction.this._cbind);
                return new Tuple2((Object)ix, (Object)out);
            }
        }
    }

    private static class MapSideAppendFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2738541014432173450L;
        private final PartitionedBroadcast<MatrixBlock> _pm;
        private final boolean _cbind;
        private final int _blen;
        private final long _lastBlockColIndex;

        public MapSideAppendFunction(PartitionedBroadcast<MatrixBlock> binput, boolean cbind, long offset, int blen) {
            this._pm = binput;
            this._cbind = cbind;
            this._blen = blen;
            this._lastBlockColIndex = Math.max((long)Math.ceil((double)offset / (double)(cbind ? blen : blen)), 1L);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception {
            ArrayList<Object> ret = new ArrayList<Object>();
            IndexedMatrixValue in1 = SparkUtils.toIndexedMatrixBlock(kv);
            MatrixIndexes ix = in1.getIndexes();
            if ((this._cbind ? ix.getColumnIndex() : ix.getRowIndex()) != this._lastBlockColIndex) {
                ret.add(kv);
            } else if (this._cbind && in1.getValue().getNumColumns() == this._blen || !this._cbind && in1.getValue().getNumRows() == this._blen) {
                ret.add(kv);
                if (this._cbind) {
                    ret.add(new Tuple2((Object)new MatrixIndexes(ix.getRowIndex(), ix.getColumnIndex() + 1L), (Object)this._pm.getBlock((int)ix.getRowIndex(), 1)));
                } else {
                    ret.add(new Tuple2((Object)new MatrixIndexes(ix.getRowIndex() + 1L, ix.getColumnIndex()), (Object)this._pm.getBlock(1, (int)ix.getColumnIndex())));
                }
            } else {
                ArrayList<IndexedMatrixValue> outlist = new ArrayList<IndexedMatrixValue>(2);
                IndexedMatrixValue first = new IndexedMatrixValue(new MatrixIndexes(ix), new MatrixBlock());
                outlist.add(first);
                MatrixBlock value_in2 = null;
                if (this._cbind) {
                    value_in2 = this._pm.getBlock((int)ix.getRowIndex(), 1);
                    if (in1.getValue().getNumColumns() + value_in2.getNumColumns() > this._blen) {
                        IndexedMatrixValue second = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock());
                        second.getIndexes().setIndexes(ix.getRowIndex(), ix.getColumnIndex() + 1L);
                        outlist.add(second);
                    }
                } else {
                    value_in2 = this._pm.getBlock(1, (int)ix.getColumnIndex());
                    if (in1.getValue().getNumRows() + value_in2.getNumRows() > this._blen) {
                        IndexedMatrixValue second = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock());
                        second.getIndexes().setIndexes(ix.getRowIndex() + 1L, ix.getColumnIndex());
                        outlist.add(second);
                    }
                }
                in1.getValue().append(value_in2, outlist, this._blen, this._cbind, true, 0);
                ret.addAll(SparkUtils.fromIndexedMatrixBlock(outlist));
            }
            return ret.iterator();
        }
    }
}

