/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibLeftMultBy {
    private static final Log LOG = LogFactory.getLog((String)CLALibLeftMultBy.class.getName());

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        if (m2.isEmpty()) {
            return ret;
        }
        MatrixBlock transposed = new MatrixBlock(m2.getNumColumns(), m2.getNumRows(), false);
        LibMatrixReorg.transpose(m2, transposed);
        ret = CLALibLeftMultBy.leftMultByMatrix(m1, transposed, ret, k);
        ret.recomputeNonZeros();
        return ret;
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock m1, CompressedMatrixBlock m2, MatrixBlock ret, int k) {
        CLALibLeftMultBy.prepareReturnMatrix(m1, m2, ret, true);
        CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(m1.getColGroups(), m2, ret, k, m1.getNumColumns(), m1.getMaxNumValues(), m1.isOverlapping());
        ret.recomputeNonZeros();
        return ret;
    }

    public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        CLALibLeftMultBy.prepareReturnMatrix(m1, m2, ret, false);
        if (m2.isEmpty()) {
            return ret;
        }
        ret = CLALibLeftMultBy.leftMultByMatrix(m1.getColGroups(), m2, ret, k, m1.getNumColumns(), m1.getMaxNumValues(), m1.isOverlapping());
        ret.recomputeNonZeros();
        return ret;
    }

    private static MatrixBlock prepareReturnMatrix(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean doTranspose) {
        int numRowsOutput = doTranspose ? m2.getNumColumns() : m2.getNumRows();
        int numColumnsOutput = m1.getNumColumns();
        if (ret == null) {
            ret = new MatrixBlock(numRowsOutput, numColumnsOutput, false, numRowsOutput * numColumnsOutput);
        } else if (ret.getNumColumns() != numColumnsOutput || ret.getNumRows() != numRowsOutput || !ret.isAllocated()) {
            ret.reset(numRowsOutput, numColumnsOutput, false, numRowsOutput * numColumnsOutput);
        }
        return ret;
    }

    public static void leftMultByTransposeSelf(List<AColGroup> groups, MatrixBlock result, int k, int numColumns, Pair<Integer, int[]> v, boolean overlapping) {
        result.allocateDenseBlock();
        if (overlapping) {
            LOG.warn((Object)"Inefficient TSMM with overlapping matrix could be implemented multi-threaded but is not yet.");
            CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(groups, groups, result);
        } else if (k <= 1) {
            for (int i = 0; i < groups.size(); ++i) {
                CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(groups.get(i), groups, result, i, groups.size());
            }
        } else {
            try {
                ExecutorService pool = CommonThreadPool.get(k);
                ArrayList<LeftMultByCompressedTransposedMatrixTask> tasks = new ArrayList<LeftMultByCompressedTransposedMatrixTask>();
                for (int i = 0; i < groups.size(); ++i) {
                    AColGroup g = groups.get(i);
                    tasks.add(new LeftMultByCompressedTransposedMatrixTask(groups, g, result, i, groups.size()));
                }
                for (Future tret : pool.invokeAll(tasks)) {
                    tret.get();
                }
                pool.shutdown();
            }
            catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        CLALibLeftMultBy.copyToUpperTriangle(result.getDenseBlockValues(), numColumns);
        long nnz = LinearAlgebraUtils.copyUpperToLowerTriangle(result);
        result.setNonZeros(nnz);
        result.examSparsity();
    }

    private static void copyToUpperTriangle(double[] c, int cols) {
        int i = 0;
        int offC = 0;
        while (i < cols) {
            int j = i;
            int offR = i * cols;
            while (j < cols) {
                double prev = c[offC + j];
                if (prev == 0.0) {
                    c[offC + j] = c[i + offR];
                }
                ++j;
                offR += cols;
            }
            ++i;
            offC += cols;
        }
    }

    private static MatrixBlock leftMultByCompressedTransposedMatrix(List<AColGroup> colGroups, CompressedMatrixBlock that, MatrixBlock ret, int k, int numColumns, Pair<Integer, int[]> v, boolean overlapping) {
        ret.allocateDenseBlock();
        List<AColGroup> thatCGs = that.getColGroups();
        if (k <= 1 || overlapping || that.isOverlapping()) {
            if (overlapping || that.isOverlapping()) {
                LOG.warn((Object)"Inefficient Compressed multiplication with overlapping matrix could be implemented multi-threaded but is not yet.");
            }
            CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(colGroups, thatCGs, ret);
        } else {
            try {
                ExecutorService pool = CommonThreadPool.get(k);
                ArrayList<LeftMultByCompressedTransposedMatrixTask> tasks = new ArrayList<LeftMultByCompressedTransposedMatrixTask>();
                for (int i = 0; i < thatCGs.size(); ++i) {
                    tasks.add(new LeftMultByCompressedTransposedMatrixTask(colGroups, thatCGs.get(i), ret));
                }
                for (Future tret : pool.invokeAll(tasks)) {
                    tret.get();
                }
                pool.shutdown();
            }
            catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        ret.recomputeNonZeros();
        return ret;
    }

    private static void leftMultByCompressedTransposedMatrix(List<AColGroup> thisCG, List<AColGroup> thatCG, MatrixBlock ret) {
        for (AColGroup lhs : thatCG) {
            CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(lhs, thisCG, ret, 0, thisCG.size());
        }
    }

    private static void leftMultByCompressedTransposedMatrix(AColGroup lhs, List<AColGroup> thisCG, MatrixBlock ret, int colGroupStart, int colGroupEnd) {
        while (colGroupStart < colGroupEnd) {
            AColGroup rhs = thisCG.get(colGroupStart);
            if (rhs != lhs) {
                rhs.leftMultByAColGroup(lhs, ret);
            } else {
                rhs.tsmm(ret.getDenseBlockValues(), ret.getNumColumns());
            }
            ++colGroupStart;
        }
    }

    private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k, int numColumns, Pair<Integer, int[]> v, boolean overlapping) {
        if (that.isEmpty()) {
            ret.setNonZeros(0L);
            return ret;
        }
        ret.allocateDenseBlock();
        if (k == 1) {
            for (int j = 0; j < colGroups.size(); ++j) {
                colGroups.get(j).leftMultByMatrix(that, ret);
            }
        } else {
            try {
                ExecutorService pool = CommonThreadPool.get(k);
                ArrayList<Callable<Object>> tasks = new ArrayList<Callable<Object>>();
                int rowBlockSize = 1;
                if (overlapping) {
                    for (int blo = 0; blo < that.getNumRows(); blo += rowBlockSize) {
                        tasks.add(new LeftMatrixMatrixMultTask(colGroups, that, ret, blo, Math.min(blo + rowBlockSize, that.getNumRows()), v));
                    }
                } else {
                    for (AColGroup g : colGroups) {
                        for (int blo = 0; blo < that.getNumRows(); blo += rowBlockSize) {
                            tasks.add(new LeftMatrixColGroupMultTask(g, that, ret, blo, Math.min(blo + rowBlockSize, that.getNumRows()), v));
                        }
                    }
                }
                List futures = pool.invokeAll(tasks);
                pool.shutdown();
                for (Future future : futures) {
                    future.get();
                }
            }
            catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        ret.recomputeNonZeros();
        return ret;
    }

    private static class LeftMatrixColGroupMultTask
    implements Callable<Object> {
        private final AColGroup _group;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;
        private final Pair<Integer, int[]> _v;

        protected LeftMatrixColGroupMultTask(AColGroup group, MatrixBlock that, MatrixBlock ret, int rl, int ru, Pair<Integer, int[]> v) {
            this._group = group;
            this._that = that;
            this._ret = ret;
            this._rl = rl;
            this._ru = ru;
            this._v = v;
        }

        @Override
        public Object call() {
            try {
                ColGroupValue.setupThreadLocalMemory((Integer)this._v.getLeft() * (this._ru - this._rl));
                this._group.leftMultByMatrix(this._that, this._ret, this._rl, this._ru);
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        }
    }

    private static class LeftMatrixMatrixMultTask
    implements Callable<Object> {
        private final List<AColGroup> _group;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;
        private final Pair<Integer, int[]> _v;

        protected LeftMatrixMatrixMultTask(List<AColGroup> group, MatrixBlock that, MatrixBlock ret, int rl, int ru, Pair<Integer, int[]> v) {
            this._group = group;
            this._that = that;
            this._ret = ret;
            this._rl = rl;
            this._ru = ru;
            this._v = v;
        }

        @Override
        public Object call() {
            try {
                ColGroupValue.setupThreadLocalMemory((Integer)this._v.getLeft());
                for (int j = 0; j < this._group.size(); ++j) {
                    this._group.get(j).leftMultByMatrix(this._that, this._ret, this._rl, this._ru);
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        }
    }

    private static class LeftMultByCompressedTransposedMatrixTask
    implements Callable<Object> {
        private final List<AColGroup> _groups;
        private final AColGroup _left;
        private final MatrixBlock _ret;
        private final int _start;
        private final int _end;

        protected LeftMultByCompressedTransposedMatrixTask(List<AColGroup> groups, AColGroup left, MatrixBlock ret, int start, int end) {
            this._groups = groups;
            this._left = left;
            this._ret = ret;
            this._start = start;
            this._end = end;
        }

        protected LeftMultByCompressedTransposedMatrixTask(List<AColGroup> groups, AColGroup left, MatrixBlock ret) {
            this._groups = groups;
            this._left = left;
            this._ret = ret;
            this._start = 0;
            this._end = groups.size();
        }

        @Override
        public Object call() {
            try {
                CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(this._left, this._groups, this._ret, this._start, this._end);
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
            return null;
        }
    }
}

