package harpoon.Analysis.Tree;

import harpoon.Analysis.BasicBlock;
import harpoon.Analysis.BasicBlockInterf;
import harpoon.Analysis.DataFlow.ForwardDataFlowBasicBlockVisitor;
import harpoon.Analysis.DataFlow.ReversePostOrderIterator;
import harpoon.Analysis.DataFlow.TreeSolver;
import harpoon.Analysis.EdgesIterator;
import harpoon.IR.Properties.CFGrapher;
import harpoon.IR.Properties.UseDefer;
import harpoon.IR.Tree.Code;
import harpoon.IR.Tree.Exp;
import harpoon.IR.Tree.ExpList;
import harpoon.IR.Tree.MOVE;
import harpoon.IR.Tree.SEQ;
import harpoon.IR.Tree.Stm;
import harpoon.IR.Tree.TEMP;
import harpoon.IR.Tree.Tree;
import harpoon.Temp.Temp;
import harpoon.Util.Tuple;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import net.cscott.jutil.BitString;

/* loaded from: input_file:harpoon/Analysis/Tree/TreeFolding.class */
public class TreeFolding extends ForwardDataFlowBasicBlockVisitor {
    private boolean initialized;
    private final int maxTreeID;
    private final Code code;
    private final Map bb2tfi;
    private final Map DUChains;
    private final Map UDChains;
    private final Map tempsToPrsvs;
    private final Stm root;
    private final BasicBlock.Factory bbfactory;
    private CFGrapher grapher;
    private UseDefer usedefer;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:harpoon/Analysis/Tree/TreeFolding$TreeFoldingInfo.class */
    public class TreeFoldingInfo {
        final BitString[] genSet = new BitString[2];
        final BitString[] prsvSet = new BitString[2];
        final BitString[] inSet = new BitString[2];
        final BitString[] outSet = new BitString[2];

        TreeFoldingInfo(BasicBlock basicBlock) {
            this.genSet[0] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.genSet[1] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.inSet[0] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.inSet[1] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.outSet[0] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.outSet[1] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.prsvSet[0] = new BitString(TreeFolding.this.maxTreeID + 1);
            this.prsvSet[1] = new BitString(TreeFolding.this.maxTreeID + 1);
            computeGenPrsvSets(basicBlock);
        }

        private void computeGenPrsvSets(BasicBlock basicBlock) {
            this.prsvSet[0].setAll();
            this.prsvSet[1].setAll();
            this.genSet[1].setAll();
            ListIterator listIterator = basicBlock.statements().listIterator();
            while (listIterator.hasNext()) {
                Stm stm = (Stm) listIterator.next();
                if (TreeFolding.mayWriteMem(stm)) {
                    this.prsvSet[1].setAll();
                } else if (stm.kind() == 12) {
                    this.prsvSet[0].and((BitString) TreeFolding.this.tempsToPrsvs.get(TreeFolding.this.usedefer.def(stm)[0]));
                    this.genSet[0].set(stm.getID());
                    this.genSet[1].clear(stm.getID());
                }
            }
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append("tGen  set: " + this.genSet[0].toString());
            stringBuffer.append("\n\tPrsv set: " + this.prsvSet[0].toString());
            stringBuffer.append("\n\tIn   set: " + this.inSet[0].toString());
            stringBuffer.append("\n\tOut  set: " + this.outSet[0].toString() + "\n");
            return stringBuffer.toString();
        }
    }

    public TreeFolding(Code code) {
        this.initialized = false;
        HashMap hashMap = new HashMap();
        if (!$assertionsDisabled && !code.isCanonical()) {
            throw new AssertionError();
        }
        this.code = code;
        this.root = (Stm) this.code.getRootElement2();
        this.grapher = code.getGrapher();
        this.usedefer = code.getUseDefer();
        this.maxTreeID = TreeSolver.getMaxID(RS(this.root));
        this.bb2tfi = new HashMap();
        this.DUChains = new HashMap();
        this.UDChains = new HashMap();
        this.tempsToPrsvs = new HashMap();
        this.bbfactory = new BasicBlock.Factory(code, this.grapher);
        BasicBlock root = this.bbfactory.getRoot();
        initTempsToPrsvs(this.tempsToPrsvs);
        initTempsToDefs(hashMap);
        ReversePostOrderIterator reversePostOrderIterator = new ReversePostOrderIterator(root);
        while (reversePostOrderIterator.hasNext()) {
            BasicBlock basicBlock = (BasicBlock) reversePostOrderIterator.next();
            this.bb2tfi.put(basicBlock, new TreeFoldingInfo(basicBlock));
        }
        TreeSolver.forward_rpo_solver(root, this);
        computeUseDef(this.bbfactory, hashMap);
        this.initialized = true;
    }

    public Code fold() {
        HashMap hashMap = new HashMap();
        initIDsToTrees(hashMap);
        fold(this.bbfactory.getRoot(), hashMap, this.DUChains, this.UDChains);
        return this.code;
    }

    @Override // harpoon.Analysis.DataFlow.DataFlowBasicBlockVisitor, harpoon.Analysis.BasicBlockInterfVisitor
    public void visit(BasicBlock basicBlock) {
        if (!$assertionsDisabled && basicBlock == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.initialized) {
            throw new AssertionError();
        }
        TreeFoldingInfo treeFoldingInfo = (TreeFoldingInfo) this.bb2tfi.get(basicBlock);
        if (!$assertionsDisabled && treeFoldingInfo == null) {
            throw new AssertionError();
        }
        treeFoldingInfo.outSet[0].clearUpTo(this.maxTreeID);
        treeFoldingInfo.outSet[0].or(treeFoldingInfo.prsvSet[0]);
        treeFoldingInfo.outSet[0].and(treeFoldingInfo.inSet[0]);
        treeFoldingInfo.outSet[0].or(treeFoldingInfo.genSet[0]);
        treeFoldingInfo.outSet[1].clearUpTo(this.maxTreeID);
        treeFoldingInfo.outSet[1].or(treeFoldingInfo.prsvSet[1]);
        treeFoldingInfo.outSet[1].or(treeFoldingInfo.inSet[1]);
        treeFoldingInfo.outSet[1].and(treeFoldingInfo.genSet[1]);
    }

    @Override // harpoon.Analysis.DataFlow.DataFlowBasicBlockVisitor
    public boolean merge(BasicBlockInterf basicBlockInterf, BasicBlockInterf basicBlockInterf2) {
        if (!$assertionsDisabled && this.initialized) {
            throw new AssertionError();
        }
        TreeFoldingInfo treeFoldingInfo = (TreeFoldingInfo) this.bb2tfi.get(basicBlockInterf);
        TreeFoldingInfo treeFoldingInfo2 = (TreeFoldingInfo) this.bb2tfi.get(basicBlockInterf2);
        return treeFoldingInfo2.inSet[0].or_upTo(treeFoldingInfo.outSet[0], this.maxTreeID) || treeFoldingInfo2.inSet[1].or_upTo(treeFoldingInfo.outSet[1], this.maxTreeID);
    }

    private void initIDsToTrees(Map map) {
        EdgesIterator edgesIterator = new EdgesIterator(RS(this.root), this.grapher);
        while (edgesIterator.hasNext()) {
            Stm stm = (Stm) edgesIterator.next();
            map.put(new Integer(stm.getID()), stm);
        }
    }

    private void initTempsToPrsvs(Map map) {
        Iterator<Tree> elementsI = ((Code.TreeFactory) this.root.getFactory()).getParent().getElementsI();
        while (elementsI.hasNext()) {
            Tree next = elementsI.next();
            Temp[] def = next instanceof Stm ? this.usedefer.def(next) : this.usedefer.use(next);
            for (int i = 0; i < def.length; i++) {
                BitString bitString = (BitString) map.get(def[i]);
                if (bitString == null) {
                    Temp temp = def[i];
                    BitString bitString2 = new BitString(this.maxTreeID + 1);
                    bitString = bitString2;
                    map.put(temp, bitString2);
                    bitString.setAll();
                }
                bitString.clear(next.getID());
            }
        }
    }

    private void initTempsToDefs(Map map) {
        EdgesIterator edgesIterator = new EdgesIterator(RS(this.root), this.grapher);
        while (edgesIterator.hasNext()) {
            Stm stm = (Stm) edgesIterator.next();
            if (stm.kind() != 2 && stm.kind() != 14) {
                Temp[] def = this.usedefer.def(stm);
                if (!$assertionsDisabled && def.length > 1) {
                    throw new AssertionError();
                }
                if (def.length == 1) {
                    MAP_TO_SET(def[0], new Integer(stm.getID()), map);
                }
            }
        }
    }

    private void computeUseDef(BasicBlock.Factory factory, Map map) {
        Iterator blocksIterator = factory.blocksIterator();
        while (blocksIterator.hasNext()) {
            BasicBlock basicBlock = (BasicBlock) blocksIterator.next();
            BitString clone = ((TreeFoldingInfo) this.bb2tfi.get(basicBlock)).inSet[0].clone();
            ListIterator listIterator = basicBlock.statements().listIterator();
            while (listIterator.hasNext()) {
                Stm stm = (Stm) listIterator.next();
                Integer num = new Integer(stm.getID());
                for (Temp temp : this.usedefer.use(stm)) {
                    Set<Integer> set = (Set) map.get(temp);
                    if (map.containsKey(temp)) {
                        for (Integer num2 : set) {
                            if (clone.get(num2.intValue())) {
                                Tuple tuple = new Tuple(new Object[]{num2, temp});
                                Tuple tuple2 = new Tuple(new Object[]{num, temp});
                                MAP_TO_SET(tuple2, tuple, this.UDChains);
                                MAP_TO_SET(tuple, tuple2, this.DUChains);
                            }
                        }
                    }
                }
                Temp[] def = this.usedefer.def(stm);
                if (!$assertionsDisabled && def.length > 1 && stm.kind() != 2 && stm.kind() != 14) {
                    throw new AssertionError();
                }
                if (def.length == 1) {
                    clone.and((BitString) this.tempsToPrsvs.get(def[0]));
                    clone.set(stm.getID());
                }
            }
        }
    }

    private void fold(BasicBlock basicBlock, Map map, Map map2, Map map3) {
        Map hashMap = new HashMap();
        ReversePostOrderIterator reversePostOrderIterator = new ReversePostOrderIterator(basicBlock);
        while (reversePostOrderIterator.hasNext()) {
            BasicBlock basicBlock2 = (BasicBlock) reversePostOrderIterator.next();
            TreeFoldingInfo treeFoldingInfo = (TreeFoldingInfo) this.bb2tfi.get(basicBlock2);
            BitString clone = treeFoldingInfo.inSet[0].clone();
            BitString clone2 = treeFoldingInfo.inSet[1].clone();
            ListIterator listIterator = basicBlock2.statements().listIterator();
            while (listIterator.hasNext()) {
                Stm stm = (Stm) listIterator.next();
                Temp[] use = this.usedefer.use(stm);
                for (int i = 0; i < use.length; i++) {
                    Set set = (Set) map3.get(new Tuple(new Object[]{new Integer(stm.getID()), use[i]}));
                    if (set != null && set.size() == 1) {
                        Tuple tuple = (Tuple) set.iterator().next();
                        if (((Set) map2.get(tuple)).size() != 1) {
                            continue;
                        } else {
                            if (!$assertionsDisabled && !map.containsKey(tuple.proj(0))) {
                                throw new AssertionError();
                            }
                            if (!clone2.get(((Integer) tuple.proj(0)).intValue()) && clone.get(((Integer) tuple.proj(0)).intValue())) {
                                Stm stm2 = (Stm) GET_TREE(hashMap, (Stm) map.get(tuple.proj(0)));
                                this.code.remove(stm2);
                                Tree build = stm.build(replace(((MOVE) stm2).getSrc(), GET_TREE(hashMap, stm).kids(), use[i]));
                                GET_TREE(hashMap, stm).replace(build);
                                MAP_TREE(hashMap, GET_TREE(hashMap, stm), build);
                            }
                        }
                    }
                }
                if (mayWriteMem(stm)) {
                    clone2.setAll();
                } else if (stm.kind() == 12) {
                    clone.and((BitString) this.tempsToPrsvs.get(this.usedefer.def(stm)[0]));
                    clone.set(stm.getID());
                    clone2.clear(stm.getID());
                }
            }
        }
    }

    private void MAP_TREE(Map map, Tree tree, Tree tree2) {
        map.put(tree, tree2);
    }

    private Tree GET_TREE(Map map, Tree tree) {
        while (map.containsKey(tree)) {
            tree = (Tree) map.get(tree);
        }
        return tree;
    }

    private void MAP_TO_SET(Object obj, Object obj2, Map map) {
        Set set;
        if (map.containsKey(obj)) {
            set = (Set) map.get(obj);
        } else {
            set = new HashSet();
            map.put(obj, set);
        }
        set.add(obj2);
    }

    private Stm RS(Stm stm) {
        while (true) {
            try {
                stm = ((SEQ) stm).getLeft();
            } catch (ClassCastException e) {
                return stm;
            }
        }
    }

    private ExpList replace(Exp exp, ExpList expList, Temp temp) {
        if (expList == null || expList.head == null) {
            return null;
        }
        if (expList.head.kind() != 18) {
            return new ExpList(expList.head.build(replace(exp, expList.head.kids(), temp)), replace(exp, expList.tail, temp));
        }
        TEMP temp2 = (TEMP) expList.head;
        return temp2.temp == temp ? new ExpList(exp, expList.tail) : new ExpList(temp2, replace(exp, expList.tail, temp));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean mayWriteMem(Stm stm) {
        int kind = stm.kind();
        return kind == 2 || kind == 14 || (kind == 12 && ((MOVE) stm).getDst().kind() == 10);
    }

    static {
        $assertionsDisabled = !TreeFolding.class.desiredAssertionStatus();
    }
}
