/*
 * Decompiled with CFR 0.152.
 */
package fr.inria.tapenade.representation;

import fr.inria.tapenade.representation.Block;
import fr.inria.tapenade.representation.CallGraph;
import fr.inria.tapenade.representation.ILUtils;
import fr.inria.tapenade.representation.TapEnv;
import fr.inria.tapenade.representation.TapList;
import fr.inria.tapenade.representation.TypeSpec;
import fr.inria.tapenade.representation.VariableDecl;
import fr.inria.tapenade.representation.WrapperTypeSpec;
import fr.inria.tapenade.utils.TapIntList;
import fr.inria.tapenade.utils.TapPair;
import fr.inria.tapenade.utils.Tree;

public final class MPIcallInfo {
    private final String funcName;
    private boolean onDifferentiableType;
    private final Tree callTree;
    private final Block block;
    private Tree channel;
    private final Tree orig;
    private final Tree dest;
    private final Tree tag;
    private final Tree comm;
    private final TapIntList readBufferRks;
    private final TapIntList writtenBufferRks;
    private final TapIntList bufferTypeRks;
    private final int reduceOpRk;

    public String funcName() {
        return this.funcName;
    }

    public boolean isOnDifferentiableType() {
        return this.onDifferentiableType;
    }

    private MPIcallInfo(String funcName, Tree callTree, Tree channel, Block block, Tree orig, Tree dest, Tree tag, Tree comm, TapIntList readBufferRks, TapIntList writtenBufferRks, TapIntList bufferTypeRks, int reduceOpRk) {
        this.funcName = funcName;
        this.callTree = callTree;
        this.block = block;
        this.channel = channel;
        this.orig = orig;
        this.dest = dest;
        this.tag = tag;
        this.comm = comm;
        this.readBufferRks = readBufferRks;
        this.writtenBufferRks = writtenBufferRks;
        this.bufferTypeRks = bufferTypeRks;
        this.reduceOpRk = reduceOpRk;
        this.onDifferentiableType = true;
        TapIntList bufferRks = TapIntList.quickUnion(readBufferRks, writtenBufferRks);
        while (bufferRks != null && this.onDifferentiableType) {
            Tree bufferArg = ILUtils.getArguments(callTree).down(bufferRks.head);
            WrapperTypeSpec argType = block.symbolTable.typeOf(bufferArg);
            if (argType != null && !TypeSpec.isDifferentiableType(argType, TapEnv.diffKind())) {
                this.onDifferentiableType = false;
            }
            bufferRks = bufferRks.tail;
        }
        while (bufferTypeRks != null && this.onDifferentiableType) {
            String typeName;
            Tree bufferTypeArg = ILUtils.getArguments(callTree).down(bufferTypeRks.head);
            if (bufferTypeArg.opCode() == 94 && ((typeName = bufferTypeArg.stringValue().toLowerCase()).startsWith("mpi_char") || typeName.equals("mpi_byte") || typeName.equals("mpi_logical") || typeName.equals("mpi_short") || typeName.startsWith("mpi_int") || typeName.equals("mpi_long") || typeName.equals("mpi_long_long_int") || typeName.startsWith("mpi_unsigned"))) {
                this.onDifferentiableType = false;
            }
            bufferTypeRks = bufferTypeRks.tail;
        }
    }

    private static MPIcallInfo buildMessagePassingMPIcallInfo(String funcName, Tree callTree, Tree channel, Block block) {
        String lcFuncName = funcName.toLowerCase();
        Tree args = ILUtils.getArguments(callTree);
        if (lcFuncName.equals("mpi_gather") || lcFuncName.equals("ampi_gather")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, args.down(7), null, args.down(8), new TapIntList(1, null), new TapIntList(4, null), new TapIntList(3, new TapIntList(6, null)), -1);
        }
        if (lcFuncName.equals("mpi_scatter") || lcFuncName.equals("ampi_scatter")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, args.down(7), null, null, args.down(8), new TapIntList(1, null), new TapIntList(4, null), new TapIntList(3, new TapIntList(6, null)), -1);
        }
        if (lcFuncName.equals("mpi_allgather") || lcFuncName.equals("ampi_allgather")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, null, null, args.down(7), new TapIntList(1, null), new TapIntList(4, null), new TapIntList(3, new TapIntList(6, null)), -1);
        }
        if (lcFuncName.equals("mpi_gatherv") || lcFuncName.equals("ampi_gatherv")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, args.down(8), null, args.down(9), new TapIntList(1, null), new TapIntList(4, null), new TapIntList(3, new TapIntList(7, null)), -1);
        }
        if (lcFuncName.equals("mpi_scatterv") || lcFuncName.equals("ampi_scatterv")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, args.down(8), null, null, args.down(9), new TapIntList(1, null), new TapIntList(5, null), new TapIntList(4, new TapIntList(7, null)), -1);
        }
        if (lcFuncName.equals("mpi_allgatherv") || lcFuncName.equals("ampi_allgatherv")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, null, null, args.down(8), new TapIntList(1, null), new TapIntList(4, null), new TapIntList(3, new TapIntList(7, null)), -1);
        }
        if (lcFuncName.equals("mpi_bcast") || lcFuncName.equals("ampi_bcast")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, args.down(4), null, null, args.down(5), new TapIntList(1, null), new TapIntList(1, null), new TapIntList(3, null), -1);
        }
        if (lcFuncName.equals("mpi_reduce") || lcFuncName.equals("ampi_reduce")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, args.down(6), null, args.down(7), new TapIntList(1, null), new TapIntList(2, null), new TapIntList(4, null), 5);
        }
        if (lcFuncName.equals("mpi_allreduce") || lcFuncName.equals("ampi_allreduce")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, null, null, args.down(6), new TapIntList(1, null), new TapIntList(2, null), new TapIntList(4, null), 5);
        }
        if (lcFuncName.equals("mpi_wait") || lcFuncName.equals("ampi_wait")) {
            return new MPIcallInfo(funcName, callTree, ILUtils.build(136), block, null, null, null, null, null, null, null, -1);
        }
        if (lcFuncName.equals("mpi_send") || lcFuncName.equals("mpi_isend")) {
            return new MPIcallInfo(funcName, callTree, channel, block, null, args.down(4), args.down(5), args.down(6), new TapIntList(1, null), null, new TapIntList(3, null), -1);
        }
        if (lcFuncName.equals("ampi_send") || lcFuncName.equals("ampi_isend")) {
            return new MPIcallInfo(funcName, callTree, channel, block, null, args.down(4), args.down(5), args.down(7), new TapIntList(1, null), null, new TapIntList(3, null), -1);
        }
        if (lcFuncName.equals("mpi_recv") || lcFuncName.equals("mpi_irecv")) {
            return new MPIcallInfo(funcName, callTree, channel, block, args.down(4), null, args.down(5), args.down(6), null, new TapIntList(1, null), new TapIntList(3, null), -1);
        }
        if (lcFuncName.equals("ampi_recv") || lcFuncName.equals("ampi_irecv")) {
            return new MPIcallInfo(funcName, callTree, channel, block, args.down(4), null, args.down(5), args.down(7), null, new TapIntList(1, null), new TapIntList(3, null), -1);
        }
        if (lcFuncName.equals("mpi_comm_dup") || lcFuncName.equals("ampi_comm_dup")) {
            return new MPIcallInfo(funcName, callTree, channel, block, null, null, null, args.down(1), null, null, null, -1);
        }
        if (lcFuncName.equals("mpi_comm_split") || lcFuncName.equals("ampi_comm_split")) {
            return new MPIcallInfo(funcName, callTree, channel, block, null, null, null, args.down(1), null, null, null, -1);
        }
        if (lcFuncName.equals("mpi_comm_create") || lcFuncName.equals("ampi_comm_create")) {
            return new MPIcallInfo(funcName, callTree, channel, block, null, null, null, args.down(1), null, null, null, -1);
        }
        if (lcFuncName.equals("mpi_comm_free") || lcFuncName.equals("ampi_comm_free")) {
            return new MPIcallInfo(funcName, callTree, channel, block, null, null, null, args.down(1), null, null, null, -1);
        }
        if (lcFuncName.equals("mpi_init") || lcFuncName.equals("ampi_init_nt") || lcFuncName.equals("mpi_comm_group") || lcFuncName.equals("ampi_comm_group") || lcFuncName.equals("mpi_comm_spawn") || lcFuncName.equals("ampi_comm_spawn") || lcFuncName.equals("mpi_comm_rank") || lcFuncName.equals("ampi_comm_rank") || lcFuncName.equals("mpi_comm_size") || lcFuncName.equals("ampi_comm_size") || lcFuncName.equals("mpi_finalize") || lcFuncName.equals("ampi_finalize_nt") || lcFuncName.equals("mpi_barrier")) {
            return null;
        }
        TapEnv.toolError("(Build MP call info) Unexpected MPI primitive " + funcName + " in " + callTree);
        return null;
    }

    public static MPIcallInfo getMessagePassingMPIcallInfo(String funcName, Tree callTree, int language, Block block) {
        MPIcallInfo result = (MPIcallInfo)callTree.getAnnotation("MPIcallInfo");
        Tree channelTree = (Tree)callTree.getAnnotation("channel");
        if (result == null && MPIcallInfo.isMessagePassingFunction(funcName, language) && (result = MPIcallInfo.buildMessagePassingMPIcallInfo(funcName, callTree, channelTree, block)) != null) {
            callTree.setAnnotation("MPIcallInfo", result);
        }
        if (result != null && result.channel == null && channelTree != null) {
            result.channel = channelTree;
        }
        return result;
    }

    public static boolean isMessagePassingFunction(String funcName, int language) {
        boolean caseIndep;
        boolean bl = caseIndep = TapEnv.isFortran(language) || language == -1;
        if (caseIndep) {
            return (funcName = funcName.toLowerCase()).startsWith("mpi_") || funcName.startsWith("ampi_");
        }
        return funcName.startsWith("MPI_") || funcName.startsWith("AMPI_");
    }

    protected static void checkMessagePassingCalls(CallGraph callGraph) {
        TapList<TapPair<MPIcallInfo, TapIntList>> allMPChannels = callGraph.getAllMessagePassingChannels();
        while (allMPChannels != null) {
            MPIcallInfo channelInfo = (MPIcallInfo)((TapPair)allMPChannels.head).first;
            channelInfo.channel = (Tree)channelInfo.callTree.getAnnotation("channel");
            allMPChannels = allMPChannels.tail;
        }
    }

    protected static int checkNumberOfChannels(CallGraph callGraph) {
        TapList<Object> hdCleanedChannels;
        int result = 0;
        TapList<String> channels = null;
        TapList<TapPair<MPIcallInfo, TapIntList>> allChannels = callGraph.getAllMessagePassingChannels();
        TapList<Object> toCleanedChannels = hdCleanedChannels = new TapList<Object>(null, null);
        while (allChannels != null) {
            Tree directive = ((MPIcallInfo)((TapPair)allChannels.head).first).channel;
            boolean keep = true;
            if (directive != null) {
                String channelName = directive.stringValue();
                if (!TapList.containsEquals(channels, channelName)) {
                    channels = new TapList<String>(channelName, channels);
                } else {
                    keep = false;
                }
            }
            if (keep) {
                toCleanedChannels = toCleanedChannels.placdl(allChannels.head);
                ++result;
            }
            allChannels = allChannels.tail;
        }
        callGraph.setAllMessagePassingChannels(hdCleanedChannels.tail);
        return result;
    }

    public static Tree replaceMPIconstantsIfPossible(Tree preprocessedTree) {
        return preprocessedTree;
    }

    public boolean isPointToPoint() {
        return this.funcName != null && (this.funcName.toLowerCase().endsWith("send") || this.funcName.toLowerCase().endsWith("recv"));
    }

    public boolean argumentIsRead(int argumentRank) {
        return TapIntList.contains(this.readBufferRks, argumentRank);
    }

    public Tree sentExprToChannel() {
        return this.readBufferRks != null && this.isPointToPoint() ? ILUtils.getArguments(this.callTree).down(this.readBufferRks.head) : null;
    }

    public Tree receivedExprFromChannel() {
        return this.writtenBufferRks != null && this.isPointToPoint() ? ILUtils.getArguments(this.callTree).down(this.writtenBufferRks.head) : null;
    }

    public boolean argumentIsABuffer(int argumentRank) {
        return TapIntList.contains(this.readBufferRks, argumentRank) || TapIntList.contains(this.writtenBufferRks, argumentRank);
    }

    public boolean argumentIsAType(int argumentRank) {
        return TapIntList.contains(this.bufferTypeRks, argumentRank);
    }

    public boolean argumentIsAReduceOp(int argumentRank) {
        return this.reduceOpRk == argumentRank;
    }

    public boolean isReduceCall() {
        return this.reduceOpRk != -1;
    }

    public boolean isNonBlocking() {
        if (this.funcName == null) {
            return false;
        }
        String lcFuncName = this.funcName.toLowerCase();
        return lcFuncName.endsWith("mpi_isend") || lcFuncName.endsWith("mpi_irecv") || lcFuncName.endsWith("mpi_wait");
    }

    public boolean isCommFree() {
        return this.funcName != null && this.funcName.toLowerCase().endsWith("mpi_comm_free");
    }

    public boolean isCommCreation() {
        if (this.funcName == null) {
            return false;
        }
        String lcFuncName = this.funcName.toLowerCase();
        return lcFuncName.endsWith("mpi_comm_dup") || lcFuncName.endsWith("mpi_comm_split") || lcFuncName.endsWith("mpi_comm_create");
    }

    public void registerMessagePassingChannel(CallGraph callGraph) {
        TapPair<MPIcallInfo, TapIntList> curChannel = this.findMessagePassingChannel(callGraph);
        if (curChannel == null && this.isPointToPoint()) {
            curChannel = new TapPair<MPIcallInfo, Object>(this, null);
            callGraph.setAllMessagePassingChannels(new TapList<TapPair<MPIcallInfo, TapIntList>>(curChannel, callGraph.getAllMessagePassingChannels()));
        }
    }

    public TapPair<MPIcallInfo, TapIntList> findMessagePassingChannel(CallGraph callGraph) {
        TapList<TapPair<MPIcallInfo, TapIntList>> inAllChannels = callGraph.getAllMessagePassingChannels();
        TapPair result = null;
        while (result == null && inAllChannels != null) {
            if (this.equalMessagePassingChannel((MPIcallInfo)((TapPair)inAllChannels.head).first)) {
                result = (TapPair)inAllChannels.head;
            }
            inAllChannels = inAllChannels.tail;
        }
        return result;
    }

    public TapIntList findMessagePassingChannelZones(CallGraph callGraph) {
        TapPair<MPIcallInfo, TapIntList> channelTapPair;
        TapIntList result = null;
        if (this.isPointToPoint() && (channelTapPair = this.findMessagePassingChannel(callGraph)) != null) {
            result = (TapIntList)channelTapPair.second;
        }
        return result;
    }

    private boolean equalMessagePassingChannel(MPIcallInfo channelInfo) {
        boolean result = true;
        if (this.comm != null && channelInfo.comm != null) {
            result = this.equalDirectiveOrCommunicatorOrTag(this.comm, channelInfo.comm);
        }
        if (result) {
            boolean equalDirective = true;
            boolean equalConstanteTag = true;
            if (this.channel != null && channelInfo.channel != null) {
                equalDirective = this.equalDirectiveOrCommunicatorOrTag(this.channel, channelInfo.channel);
            }
            if (this.block.symbolTable == channelInfo.block.symbolTable) {
                Integer tag1 = null;
                Integer tag2 = null;
                if (this.tag != null) {
                    if (this.tag.opCode() == 101) {
                        tag1 = this.tag.intValue();
                    } else if (ILUtils.baseName(this.tag) != null) {
                        VariableDecl varTag1 = this.block.symbolTable.getConstantDecl(ILUtils.baseName(this.tag));
                        if (varTag1 != null) {
                            tag1 = varTag1.constantValue;
                        }
                    } else {
                        TapEnv.toolWarning(-1, "Analysis cannot finely compare tag " + this.tag);
                    }
                }
                if (channelInfo.tag != null) {
                    if (channelInfo.tag.opCode() == 101) {
                        tag2 = channelInfo.tag.intValue();
                    } else if (ILUtils.baseName(channelInfo.tag) != null) {
                        VariableDecl varTag2 = this.block.symbolTable.getConstantDecl(ILUtils.baseName(channelInfo.tag));
                        if (varTag2 != null) {
                            tag2 = varTag2.constantValue;
                        }
                    } else {
                        TapEnv.toolWarning(-1, "Analysis cannot finely compare tag " + channelInfo.tag);
                    }
                }
                if (tag1 != null && tag2 != null) {
                    boolean bl = equalConstanteTag = tag1.intValue() == tag2.intValue();
                    if (this.channel != null && channelInfo.channel != null && equalDirective != equalConstanteTag) {
                        TapEnv.fileWarning(15, channelInfo.callTree, "(MPI01) channel directives " + ILUtils.toString(this.channel) + " " + ILUtils.toString(channelInfo.channel) + " incompatible with tags " + ILUtils.toString(this.tag) + " " + ILUtils.toString(channelInfo.tag));
                    }
                }
            }
            result = equalDirective && equalConstanteTag;
        }
        return result;
    }

    private boolean equalDirectiveOrCommunicatorOrTag(Tree tagOrComm1, Tree tagOrComm2) {
        return tagOrComm1.equalsTree(tagOrComm2);
    }

    public String toString() {
        String lcFuncName = this.funcName.toLowerCase();
        String centralString = lcFuncName.endsWith("send") ? "sends " + this.readBufferRks + " to " + this.dest + " through channel " + this.channel : (lcFuncName.endsWith("recv") ? "receives " + this.writtenBufferRks + " from " + this.orig + " through channel " + this.channel : (lcFuncName.startsWith("mpi_gather") || lcFuncName.startsWith("ampi_gather") || lcFuncName.startsWith("mpi_allgather") || lcFuncName.startsWith("ampi_allgather") || lcFuncName.startsWith("mpi_scatter") || lcFuncName.startsWith("ampi_scatter") || lcFuncName.endsWith("bcast") || lcFuncName.endsWith("reduce") ? "global, reads " + this.readBufferRks + " of " + (this.orig == null ? "all procs" : "proc " + this.orig) + ", writes " + this.writtenBufferRks + " of " + (this.dest == null ? "all procs" : "proc " + this.dest) : "ancillary"));
        return "<MPIcallInfo " + this.funcName + ": " + centralString + (this.onDifferentiableType ? ">" : " (non-diff Type)>");
    }
}

