/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.codegen;

import java.util.ArrayList;
import jcuda.Pointer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;

public class SpoofCUDACellwise
extends SpoofCellwise
implements SpoofCUDAOperator {
    private static final long serialVersionUID = -5255791443086948200L;
    private static final Log LOG = LogFactory.getLog((String)SpoofCUDACellwise.class.getName());
    private final int ID;
    private final SpoofCUDAOperator.PrecisionProxy call;
    private Pointer ptr;
    private final SpoofCellwise fallback_java_op;

    public SpoofCUDACellwise(SpoofCellwise.CellType type, boolean sparseSafe, boolean containsSeq, SpoofCellwise.AggOp aggOp, int id, SpoofCUDAOperator.PrecisionProxy ep, SpoofCellwise fallback) {
        super(type, sparseSafe, containsSeq, aggOp);
        this.ID = id;
        this.call = ep;
        this.ptr = null;
        this.fallback_java_op = fallback;
    }

    @Override
    public ScalarObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects) {
        double[] result = new double[1];
        int NT = 256;
        long N = inputs.get(0).getNumRows() * inputs.get(0).getNumColumns();
        long num_blocks = (N + (long)(NT * 2) - 1L) / (long)(NT * 2);
        Pointer ptr = ec.getGPUContext(0).allocate(this.getName(), (long)LibMatrixCUDA.sizeOfDataType * num_blocks);
        long[] out = new long[]{1L, 1L, 1L, 0L, 0L, GPUObject.getPointerAddress(ptr)};
        int offset = 1;
        if (this.call.exec(ec, this, this.ID, this.prepareInputPointers(ec, inputs, offset), this.prepareSideInputPointers(ec, inputs, offset, false), out, scalarObjects, 0L) != 0) {
            LOG.error((Object)("SpoofCUDA " + this.getSpoofType() + " operator failed to execute. Trying Java fallback.\n"));
        }
        LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, result, this.getName(), false);
        return new DoubleObject(result[0]);
    }

    @Override
    public String getName() {
        return this.getSpoofType();
    }

    @Override
    public void setScalarPtr(Pointer _ptr) {
        this.ptr = _ptr;
    }

    @Override
    public Pointer getScalarPtr() {
        return this.ptr;
    }

    @Override
    public void releaseScalarGPUMemory(ExecutionContext ec) {
        if (this.ptr != null) {
            ec.getGPUContext(0).cudaFreeHelper(this.getSpoofType(), this.ptr, DMLScript.EAGER_CUDA_FREE);
            this.ptr = null;
        }
    }

    @Override
    public MatrixObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects, String outputName) {
        long out_rows = ec.getMatrixObject(outputName).getNumRows();
        long out_cols = ec.getMatrixObject(outputName).getNumColumns();
        MatrixObject a = inputs.get(0);
        GPUContext gctx = ec.getGPUContext(0);
        int m = (int)a.getNumRows();
        int n = (int)a.getNumColumns();
        double[] scalars = SpoofCUDACellwise.prepInputScalars(scalarObjects);
        if (this._type == SpoofCellwise.CellType.COL_AGG) {
            out_rows = 1L;
        } else if (this._type == SpoofCellwise.CellType.ROW_AGG) {
            out_cols = 1L;
        }
        boolean sparseSafe = this.isSparseSafe() || inputs.size() < 2 && this.genexec(0.0, new SpoofOperator.SideInput[0], scalars, m, n, 0, 0) == 0.0;
        GPUObject g = a.getGPUObject(gctx);
        boolean sparseOut = this._type == SpoofCellwise.CellType.NO_AGG && sparseSafe && g.isSparse();
        long nnz = g.getNnz("spoofCUDA" + this.getSpoofType(), false);
        if (sparseOut) {
            LOG.warn((Object)"sparse out");
        }
        MatrixObject out_obj = sparseOut ? ec.getSparseMatrixOutputForGPUInstruction(outputName, out_rows, out_cols, this.isSparseSafe() && nnz > 0L ? nnz : out_rows * out_cols).getKey() : ec.getDenseMatrixOutputForGPUInstruction(outputName, out_rows, out_cols).getKey();
        int offset = 1;
        if (!(SpoofCUDACellwise.inputIsEmpty(a.getGPUObject(gctx)) && sparseSafe || this.call.exec(ec, this, this.ID, this.prepareInputPointers(ec, inputs, offset), this.prepareSideInputPointers(ec, inputs, offset, false), this.prepareOutputPointers(ec, out_obj, sparseOut), scalarObjects, 0L) == 0)) {
            LOG.error((Object)("SpoofCUDA " + this.getSpoofType() + " operator failed to execute. Trying Java fallback.(ToDo)\n"));
        }
        return out_obj;
    }

    private static boolean inputIsEmpty(GPUObject g) {
        return g.getDensePointer() == null && g.getSparseMatrixCudaPointer() == null;
    }

    @Override
    protected double genexec(double a, SpoofOperator.SideInput[] b, double[] scalars, int m, int n, long gix, int rix, int cix) {
        return this.fallback_java_op.genexec(a, b, scalars, m, n, 0L, 0, 0);
    }

    @Override
    public int execute_sp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars) {
        return SpoofCUDACellwise.execute_f(ctx, meta, in, sides, out, scalars);
    }

    @Override
    public int execute_dp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars) {
        return SpoofCUDACellwise.execute_d(ctx, meta, in, sides, out, scalars);
    }

    public static native int execute_f(long var0, long[] var2, long[] var3, long[] var4, long[] var5, long var6);

    public static native int execute_d(long var0, long[] var2, long[] var3, long[] var4, long[] var5, long var6);
}

