001 /* 002 * This file is part of the Jikes RVM project (http://jikesrvm.org). 003 * 004 * This file is licensed to You under the Eclipse Public License (EPL); 005 * You may not use this file except in compliance with the License. You 006 * may obtain a copy of the License at 007 * 008 * http://www.opensource.org/licenses/eclipse-1.0.php 009 * 010 * See the COPYRIGHT.txt file distributed with this work for information 011 * regarding copyright ownership. 012 */ 013 package org.jikesrvm.compilers.opt.ir; 014 015 import java.util.Enumeration; 016 017 import org.jikesrvm.VM; 018 import org.jikesrvm.compilers.opt.OptimizingCompilerException; 019 import org.jikesrvm.compilers.opt.ir.operand.BranchProfileOperand; 020 021 /** 022 * Used to iterate over the branch targets (including the fall through edge) 023 * and associated probabilites of a basic block. 024 * Takes into account the ordering of branch instructions when 025 * computing the edge weights such that the total target weight will always 026 * be equal to 1.0 (flow in == flow out). 027 */ 028 public final class WeightedBranchTargets { 029 private BasicBlock[] targets; 030 private float[] weights; 031 private int cur; 032 private int max; 033 034 public void reset() { cur = 0; } 035 036 public boolean hasMoreElements() { return cur < max; } 037 038 public void advance() { cur++; } 039 040 public BasicBlock curBlock() { return targets[cur]; } 041 042 public float curWeight() { return weights[cur]; } 043 044 public WeightedBranchTargets(BasicBlock bb) { 045 targets = new BasicBlock[3]; 046 weights = new float[3]; 047 cur = 0; 048 max = 0; 049 050 float prob = 1f; 051 for (Enumeration<Instruction> ie = bb.enumerateBranchInstructions(); ie.hasMoreElements();) { 052 Instruction s = ie.nextElement(); 053 if (IfCmp.conforms(s)) { 054 BasicBlock target = IfCmp.getTarget(s).target.getBasicBlock(); 055 BranchProfileOperand prof = IfCmp.getBranchProfile(s); 056 float taken = prob * prof.takenProbability; 057 prob = prob * (1f - prof.takenProbability); 058 addEdge(target, taken); 059 } else if (Goto.conforms(s)) { 060 BasicBlock target = Goto.getTarget(s).target.getBasicBlock(); 061 addEdge(target, prob); 062 } else if (InlineGuard.conforms(s)) { 063 BasicBlock target = InlineGuard.getTarget(s).target.getBasicBlock(); 064 BranchProfileOperand prof = InlineGuard.getBranchProfile(s); 065 float taken = prob * prof.takenProbability; 066 prob = prob * (1f - prof.takenProbability); 067 addEdge(target, taken); 068 } else if (IfCmp2.conforms(s)) { 069 BasicBlock target = IfCmp2.getTarget1(s).target.getBasicBlock(); 070 BranchProfileOperand prof = IfCmp2.getBranchProfile1(s); 071 float taken = prob * prof.takenProbability; 072 prob = prob * (1f - prof.takenProbability); 073 addEdge(target, taken); 074 target = IfCmp2.getTarget2(s).target.getBasicBlock(); 075 prof = IfCmp2.getBranchProfile2(s); 076 taken = prob * prof.takenProbability; 077 prob = prob * (1f - prof.takenProbability); 078 addEdge(target, taken); 079 } else if (TableSwitch.conforms(s)) { 080 int lowLimit = TableSwitch.getLow(s).value; 081 int highLimit = TableSwitch.getHigh(s).value; 082 int number = highLimit - lowLimit + 1; 083 float total = 0f; 084 for (int i = 0; i < number; i++) { 085 BasicBlock target = TableSwitch.getTarget(s, i).target.getBasicBlock(); 086 BranchProfileOperand prof = TableSwitch.getBranchProfile(s, i); 087 float taken = prob * prof.takenProbability; 088 total += prof.takenProbability; 089 addEdge(target, taken); 090 } 091 BasicBlock target = TableSwitch.getDefault(s).target.getBasicBlock(); 092 BranchProfileOperand prof = TableSwitch.getDefaultBranchProfile(s); 093 float taken = prob * prof.takenProbability; 094 total += prof.takenProbability; 095 if (VM.VerifyAssertions && !epsilon(total, 1f)) { 096 VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s); 097 } 098 addEdge(target, taken); 099 } else if (LowTableSwitch.conforms(s)) { 100 int number = LowTableSwitch.getNumberOfTargets(s); 101 float total = 0f; 102 for (int i = 0; i < number; i++) { 103 BasicBlock target = LowTableSwitch.getTarget(s, i).target.getBasicBlock(); 104 BranchProfileOperand prof = LowTableSwitch.getBranchProfile(s, i); 105 float taken = prob * prof.takenProbability; 106 total += prof.takenProbability; 107 addEdge(target, taken); 108 } 109 if (VM.VerifyAssertions && !epsilon(total, 1f)) { 110 VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s); 111 } 112 } else if (LookupSwitch.conforms(s)) { 113 int number = LookupSwitch.getNumberOfTargets(s); 114 float total = 0f; 115 for (int i = 0; i < number; i++) { 116 BasicBlock target = LookupSwitch.getTarget(s, i).target.getBasicBlock(); 117 BranchProfileOperand prof = LookupSwitch.getBranchProfile(s, i); 118 float taken = prob * prof.takenProbability; 119 total += prof.takenProbability; 120 addEdge(target, taken); 121 } 122 BasicBlock target = LookupSwitch.getDefault(s).target.getBasicBlock(); 123 BranchProfileOperand prof = LookupSwitch.getDefaultBranchProfile(s); 124 float taken = prob * prof.takenProbability; 125 total += prof.takenProbability; 126 if (VM.VerifyAssertions && !epsilon(total, 1f)) { 127 VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s); 128 } 129 addEdge(target, taken); 130 } else { 131 throw new OptimizingCompilerException("TODO " + s + "\n"); 132 } 133 } 134 BasicBlock ft = bb.getFallThroughBlock(); 135 if (ft != null) addEdge(ft, prob); 136 } 137 138 private void addEdge(BasicBlock target, float weight) { 139 if (max == targets.length) { 140 BasicBlock[] tmp = new BasicBlock[targets.length << 1]; 141 for (int i = 0; i < targets.length; i++) { 142 tmp[i] = targets[i]; 143 } 144 targets = tmp; 145 float[] tmp2 = new float[weights.length << 1]; 146 for (int i = 0; i < weights.length; i++) { 147 tmp2[i] = weights[i]; 148 } 149 weights = tmp2; 150 } 151 targets[max] = target; 152 weights[max] = weight; 153 max++; 154 } 155 156 private boolean epsilon(float a, float b) { 157 return Math.abs(a - b) < 0.1; 158 } 159 }