/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderBin;
import org.apache.sysds.runtime.transform.encode.EncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderDummycode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderFeatureHash;
import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
import org.apache.sysds.runtime.transform.encode.EncoderRecode;
import org.apache.sysds.runtime.util.IndexRange;

public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFEDInstruction {
    protected final ArrayList<CPOperand> _outputs;

    private MultiReturnParameterizedBuiltinFEDInstruction(Operator op, CPOperand input1, CPOperand input2, ArrayList<CPOperand> outputs, String opcode, String istr) {
        super(FEDInstruction.FEDType.MultiReturnParameterizedBuiltin, op, input1, input2, null, opcode, istr);
        this._outputs = outputs;
    }

    public CPOperand getOutput(int i) {
        return this._outputs.get(i);
    }

    public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList<CPOperand> outputs = new ArrayList<CPOperand>();
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("transformencode")) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            outputs.add(new CPOperand(parts[3], Types.ValueType.FP64, Types.DataType.MATRIX));
            outputs.add(new CPOperand(parts[4], Types.ValueType.STRING, Types.DataType.FRAME));
            return new MultiReturnParameterizedBuiltinFEDInstruction(null, in1, in2, outputs, opcode, str);
        }
        throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        FrameObject fin = ec.getFrameObject(this.input1.getName());
        String spec = ec.getScalarInput(this.input2).getStringValue();
        Object[] colNames = new String[(int)fin.getNumColumns()];
        Arrays.fill(colNames, "");
        EncoderComposite globalEncoder = new EncoderComposite(Arrays.asList(new EncoderRecode(), new EncoderFeatureHash(), new EncoderPassThrough(), new EncoderBin(), new EncoderDummycode(), new EncoderOmit(true), new EncoderMVImpute()));
        FederationMap fedMapping = fin.getFedMapping();
        fedMapping.forEachParallel((arg_0, arg_1) -> MultiReturnParameterizedBuiltinFEDInstruction.lambda$processInstruction$0(spec, globalEncoder, (String[])colNames, arg_0, arg_1));
        FrameBlock meta = new FrameBlock((int)fin.getNumColumns(), Types.ValueType.STRING);
        meta.setColumnNames((String[])colNames);
        globalEncoder.getMetaData(meta);
        globalEncoder.initMetaData(meta);
        MultiReturnParameterizedBuiltinFEDInstruction.encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(this.getOutput(0)));
        ec.setFrameOutput(this.getOutput(1).getName(), meta);
    }

    public static void encodeFederatedFrames(FederationMap fedMapping, Encoder globalEncoder, MatrixObject transformedMat) {
        long varID = FederationUtils.getNextFedDataID();
        FederationMap transformedFedMapping = fedMapping.mapParallel(varID, (range, data) -> {
            long[] beginDims = range.getBeginDims();
            long[] endDims = range.getEndDims();
            IndexRange ixRange = new IndexRange(beginDims[0], endDims[0], beginDims[1], endDims[1]).add(1);
            globalEncoder.updateIndexRanges(beginDims, endDims);
            Encoder encoder = globalEncoder.subRangeEncoder(ixRange);
            try {
                FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        });
        transformedMat.getDataCharacteristics().setDimension(transformedFedMapping.getMaxIndexInRange(0), transformedFedMapping.getMaxIndexInRange(1));
        transformedMat.setFedMapping(transformedFedMapping);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static /* synthetic */ Void lambda$processInstruction$0(String spec, EncoderComposite globalEncoder, String[] colNames, FederatedRange range, FederatedData data) {
        int columnOffset = (int)range.getBeginDims()[1] + 1;
        Future<FederatedResponse> responseFuture = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new CreateFrameEncoder(data.getVarID(), spec, columnOffset)));
        try {
            FederatedResponse response = responseFuture.get();
            Encoder encoder = (Encoder)response.getData()[0];
            EncoderComposite encoderComposite = globalEncoder;
            synchronized (encoderComposite) {
                globalEncoder.mergeAt(encoder, (int)(range.getBeginDims()[0] + 1L), columnOffset);
            }
            String[] subRangeColNames = (String[])response.getData()[1];
            System.arraycopy(subRangeColNames, 0, colNames, (int)range.getBeginDims()[1], subRangeColNames.length);
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Federated encoder creation failed: " + e.getMessage());
        }
        return null;
    }

    public static class ExecuteFrameEncoder
    extends FederatedUDF {
        private static final long serialVersionUID = 6034440964680578276L;
        private final long _outputID;
        private final Encoder _encoder;

        public ExecuteFrameEncoder(long input, long output, Encoder encoder) {
            super(new long[]{input});
            this._outputID = output;
            this._encoder = encoder;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FrameBlock fb = (FrameBlock)((FrameObject)data[0]).acquireReadAndRelease();
            MatrixBlock mbout = this._encoder.apply(fb, new MatrixBlock(fb.getNumRows(), fb.getNumColumns(), false));
            MatrixObject mo = ExecutionContext.createMatrixObject(mbout);
            ec.setVariable(String.valueOf(this._outputID), mo);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
        }
    }

    public static class CreateFrameEncoder
    extends FederatedUDF {
        private static final long serialVersionUID = 2376756757742169692L;
        private final String _spec;
        private final int _offset;

        public CreateFrameEncoder(long input, String spec, int offset) {
            super(new long[]{input});
            this._spec = spec;
            this._offset = offset;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FrameObject fo = (FrameObject)data[0];
            FrameBlock fb = (FrameBlock)fo.acquireRead();
            String[] colNames = fb.getColumnNames();
            Encoder encoder = EncoderFactory.createEncoder(this._spec, colNames, fb.getNumColumns(), null, this._offset, this._offset + fb.getNumColumns());
            encoder.build(fb);
            fo.release();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{encoder, fb.getColumnNames()});
        }
    }
}

