/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class RewriteHoistLoopInvariantOperations
extends StatementBlockRewriteRule {
    private final boolean _sideEffectFreeFuns;

    public RewriteHoistLoopInvariantOperations() {
        this(false);
    }

    public RewriteHoistLoopInvariantOperations(boolean noSideEffects) {
        this._sideEffectFreeFuns = noSideEffects;
    }

    @Override
    public boolean createsSplitDag() {
        return true;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
        if (sb == null || !HopRewriteUtils.isLoopStatementBlock(sb)) {
            return Arrays.asList(sb);
        }
        Set<String> candInputs = sb.variablesRead().getVariableNames().stream().filter(v -> !sb.variablesUpdated().containsVariable((String)v)).collect(Collectors.toSet());
        HashMap<String, Hop> invariantOps = new HashMap<String, Hop>();
        this.collectOperations(sb, candInputs, invariantOps);
        return invariantOps.isEmpty() ? Arrays.asList(sb) : Arrays.asList(RewriteHoistLoopInvariantOperations.createStatementBlock(sb, invariantOps), sb);
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
        return sbs;
    }

    private void collectOperations(StatementBlock sb, Set<String> candInputs, Map<String, Hop> invariantOps) {
        if (sb instanceof WhileStatementBlock) {
            WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
            for (StatementBlock csb : wstmt.getBody()) {
                this.collectOperations(csb, candInputs, invariantOps);
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatement fstmt = (ForStatement)sb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                this.collectOperations(csb, candInputs, invariantOps);
            }
        } else if (!(sb instanceof IfStatementBlock) && sb.getHops() != null) {
            Hop.resetVisitStatus(sb.getHops());
            HashSet<Long> memo = new HashSet<Long>();
            for (Hop hop : sb.getHops()) {
                this.rTagLoopInvariantOperations(hop, candInputs, memo);
            }
            Hop.resetVisitStatus(sb.getHops());
            for (Hop hop : sb.getHops()) {
                this.rCollectAndReplaceOperations(hop, candInputs, memo, invariantOps);
            }
            if (!memo.isEmpty()) {
                LOG.debug((Object)("Applied hoistLoopInvariantOperations (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + "): " + memo.size() + "."));
            }
        }
    }

    private void rTagLoopInvariantOperations(Hop hop, Set<String> candInputs, Set<Long> memo) {
        if (hop.isVisited()) {
            return;
        }
        for (Hop c : hop.getInput()) {
            this.rTagLoopInvariantOperations(c, candInputs, memo);
        }
        boolean invariant = !HopRewriteUtils.isDataGenOp(hop, Types.OpOpDG.RAND) && (!(hop instanceof FunctionOp) || this._sideEffectFreeFuns) && !HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) && !HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE);
        for (Hop c : hop.getInput()) {
            invariant &= candInputs.contains(c.getName()) || memo.contains(c.getHopID()) || c instanceof LiteralOp;
        }
        if (invariant) {
            memo.add(hop.getHopID());
        }
        hop.setVisited();
    }

    private void rCollectAndReplaceOperations(Hop hop, Set<String> candInputs, Set<Long> memo, Map<String, Hop> invariantOps) {
        if (hop.isVisited()) {
            return;
        }
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop c = hop.getInput().get(i);
            if (memo.contains(c.getHopID())) {
                String tmpName = RewriteHoistLoopInvariantOperations.createCutVarName(false);
                Hop tmp = Recompiler.deepCopyHopsDag(c);
                tmp.getParent().clear();
                invariantOps.put(tmpName, tmp);
                DataOp tread = HopRewriteUtils.createTransientRead(tmpName, c);
                ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
                for (Hop p : parents) {
                    HopRewriteUtils.replaceChildReference(p, c, tread);
                }
                continue;
            }
            this.rCollectAndReplaceOperations(c, candInputs, memo, invariantOps);
        }
        hop.setVisited();
    }

    private static StatementBlock createStatementBlock(StatementBlock sb, Map<String, Hop> invariantOps) {
        StatementBlock ret = new StatementBlock();
        ret.setDMLProg(sb.getDMLProg());
        ret.setParseInfo(sb);
        ret.setLiveIn(new VariableSet(sb.liveIn()));
        ret.setLiveOut(new VariableSet(sb.liveIn()));
        ArrayList<Hop> hops = new ArrayList<Hop>();
        for (Map.Entry<String, Hop> e : invariantOps.entrySet()) {
            Hop h = e.getValue();
            DataOp twrite = HopRewriteUtils.createTransientWrite(e.getKey(), h);
            hops.add(twrite);
            DataIdentifier diVar = new DataIdentifier(e.getKey());
            diVar.setDimensions(h.getDim1(), h.getDim2());
            diVar.setBlocksize(h.getBlocksize());
            diVar.setDataType(h.getDataType());
            diVar.setValueType(h.getValueType());
            ret.liveOut().addVariable(e.getKey(), diVar);
            sb.liveIn().addVariable(e.getKey(), diVar);
        }
        ret.setHops(hops);
        return ret;
    }
}

