001    /*
002     *  This file is part of the Jikes RVM project (http://jikesrvm.org).
003     *
004     *  This file is licensed to You under the Eclipse Public License (EPL);
005     *  You may not use this file except in compliance with the License. You
006     *  may obtain a copy of the License at
007     *
008     *      http://www.opensource.org/licenses/eclipse-1.0.php
009     *
010     *  See the COPYRIGHT.txt file distributed with this work for information
011     *  regarding copyright ownership.
012     */
013    package org.jikesrvm.compilers.opt.ir;
014    
015    import java.util.Enumeration;
016    
017    import org.jikesrvm.VM;
018    import org.jikesrvm.compilers.opt.OptimizingCompilerException;
019    import org.jikesrvm.compilers.opt.ir.operand.BranchProfileOperand;
020    
021    /**
022     * Used to iterate over the branch targets (including the fall through edge)
023     * and associated probabilites of a basic block.
024     * Takes into account the ordering of branch instructions when
025     * computing the edge weights such that the total target weight will always
026     * be equal to 1.0 (flow in == flow out).
027     */
028    public final class WeightedBranchTargets {
029      private BasicBlock[] targets;
030      private float[] weights;
031      private int cur;
032      private int max;
033    
034      public void reset() { cur = 0; }
035    
036      public boolean hasMoreElements() { return cur < max; }
037    
038      public void advance() { cur++; }
039    
040      public BasicBlock curBlock() { return targets[cur]; }
041    
042      public float curWeight() { return weights[cur]; }
043    
044      public WeightedBranchTargets(BasicBlock bb) {
045        targets = new BasicBlock[3];
046        weights = new float[3];
047        cur = 0;
048        max = 0;
049    
050        float prob = 1f;
051        for (Enumeration<Instruction> ie = bb.enumerateBranchInstructions(); ie.hasMoreElements();) {
052          Instruction s = ie.nextElement();
053          if (IfCmp.conforms(s)) {
054            BasicBlock target = IfCmp.getTarget(s).target.getBasicBlock();
055            BranchProfileOperand prof = IfCmp.getBranchProfile(s);
056            float taken = prob * prof.takenProbability;
057            prob = prob * (1f - prof.takenProbability);
058            addEdge(target, taken);
059          } else if (Goto.conforms(s)) {
060            BasicBlock target = Goto.getTarget(s).target.getBasicBlock();
061            addEdge(target, prob);
062          } else if (InlineGuard.conforms(s)) {
063            BasicBlock target = InlineGuard.getTarget(s).target.getBasicBlock();
064            BranchProfileOperand prof = InlineGuard.getBranchProfile(s);
065            float taken = prob * prof.takenProbability;
066            prob = prob * (1f - prof.takenProbability);
067            addEdge(target, taken);
068          } else if (IfCmp2.conforms(s)) {
069            BasicBlock target = IfCmp2.getTarget1(s).target.getBasicBlock();
070            BranchProfileOperand prof = IfCmp2.getBranchProfile1(s);
071            float taken = prob * prof.takenProbability;
072            prob = prob * (1f - prof.takenProbability);
073            addEdge(target, taken);
074            target = IfCmp2.getTarget2(s).target.getBasicBlock();
075            prof = IfCmp2.getBranchProfile2(s);
076            taken = prob * prof.takenProbability;
077            prob = prob * (1f - prof.takenProbability);
078            addEdge(target, taken);
079          } else if (TableSwitch.conforms(s)) {
080            int lowLimit = TableSwitch.getLow(s).value;
081            int highLimit = TableSwitch.getHigh(s).value;
082            int number = highLimit - lowLimit + 1;
083            float total = 0f;
084            for (int i = 0; i < number; i++) {
085              BasicBlock target = TableSwitch.getTarget(s, i).target.getBasicBlock();
086              BranchProfileOperand prof = TableSwitch.getBranchProfile(s, i);
087              float taken = prob * prof.takenProbability;
088              total += prof.takenProbability;
089              addEdge(target, taken);
090            }
091            BasicBlock target = TableSwitch.getDefault(s).target.getBasicBlock();
092            BranchProfileOperand prof = TableSwitch.getDefaultBranchProfile(s);
093            float taken = prob * prof.takenProbability;
094            total += prof.takenProbability;
095            if (VM.VerifyAssertions && !epsilon(total, 1f)) {
096              VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
097            }
098            addEdge(target, taken);
099          } else if (LowTableSwitch.conforms(s)) {
100            int number = LowTableSwitch.getNumberOfTargets(s);
101            float total = 0f;
102            for (int i = 0; i < number; i++) {
103              BasicBlock target = LowTableSwitch.getTarget(s, i).target.getBasicBlock();
104              BranchProfileOperand prof = LowTableSwitch.getBranchProfile(s, i);
105              float taken = prob * prof.takenProbability;
106              total += prof.takenProbability;
107              addEdge(target, taken);
108            }
109            if (VM.VerifyAssertions && !epsilon(total, 1f)) {
110              VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
111            }
112          } else if (LookupSwitch.conforms(s)) {
113            int number = LookupSwitch.getNumberOfTargets(s);
114            float total = 0f;
115            for (int i = 0; i < number; i++) {
116              BasicBlock target = LookupSwitch.getTarget(s, i).target.getBasicBlock();
117              BranchProfileOperand prof = LookupSwitch.getBranchProfile(s, i);
118              float taken = prob * prof.takenProbability;
119              total += prof.takenProbability;
120              addEdge(target, taken);
121            }
122            BasicBlock target = LookupSwitch.getDefault(s).target.getBasicBlock();
123            BranchProfileOperand prof = LookupSwitch.getDefaultBranchProfile(s);
124            float taken = prob * prof.takenProbability;
125            total += prof.takenProbability;
126            if (VM.VerifyAssertions && !epsilon(total, 1f)) {
127              VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
128            }
129            addEdge(target, taken);
130          } else {
131            throw new OptimizingCompilerException("TODO " + s + "\n");
132          }
133        }
134        BasicBlock ft = bb.getFallThroughBlock();
135        if (ft != null) addEdge(ft, prob);
136      }
137    
138      private void addEdge(BasicBlock target, float weight) {
139        if (max == targets.length) {
140          BasicBlock[] tmp = new BasicBlock[targets.length << 1];
141          for (int i = 0; i < targets.length; i++) {
142            tmp[i] = targets[i];
143          }
144          targets = tmp;
145          float[] tmp2 = new float[weights.length << 1];
146          for (int i = 0; i < weights.length; i++) {
147            tmp2[i] = weights[i];
148          }
149          weights = tmp2;
150        }
151        targets[max] = target;
152        weights[max] = weight;
153        max++;
154      }
155    
156      private boolean epsilon(float a, float b) {
157        return Math.abs(a - b) < 0.1;
158      }
159    }