P4C
The P4 Compiler
Loading...
Searching...
No Matches
parser_extract_balance_score.h
1
19#ifndef BF_P4C_PHV_PARSER_EXTRACT_BALANCE_SCORE_H_
20#define BF_P4C_PHV_PARSER_EXTRACT_BALANCE_SCORE_H_
21
22/* Note: this is an incorrect header. It contains function implementations, hence it can
23 * be included just in one translation unit. Otherwise, linking conflicts will
24 * happen. */
25
26#include <algorithm>
27#include <iostream>
28#include <map>
29#include <set>
30#include <vector>
31
32#include "bf-p4c/common/table_printer.h"
33#include "bf-p4c/phv/phv.h"
34
36 static const std::vector<PHV::Size> extractor_sizes;
37
38 // TODO would be nice to get this from PardeSpec. JBay has one size,
39 // 16-bit extractors, so the extractor balance may not even be relevant.
40
41 std::map<PHV::Size, unsigned> use = {
42 {PHV::Size::b8, 0}, {PHV::Size::b16, 0}, {PHV::Size::b32, 0}};
43
45
46 explicit StateExtractUsage(const std::set<PHV::Container> &containers) {
47 for (auto c : containers) use[c.type().size()]++;
48 }
49
50 // Given two state extract usages of same number of extracted bytes, which is
51 // better in terms of bandwidth utilization?
52 //
53 // This is to be used before PHV allocation, where we assume each state has
54 // infinite bandwidth; After PHV allocation, big states are split in order to
55 // satisfy each state's capacity constraint (4xB, 4xH, 4xW).
56 //
57 // The objective is to avoid spilling extracts into a next state whilst they
58 // can be extracted upfront. This requires a balanced usage of the three extractor
59 // sizes.
60 //
61 // We implemented the tie break based on the following heuristics:
62 // 1. Compare the delta of max used and min used extractor size
63 // 2. Compare the number of extract sizes that exceed 4
64 // 3. Compare the total number of extracts
65 // 4. Compare the number of extracts of each szie in descending order
66 //
67 // e.g. below is all possible combinations of 20 bytes of extraction, and associated
68 // score (with 0 being the best score, and the more negative the worse).
69 //
70 // -----------------------
71 // | B| H| W| Score|
72 // -----------------------
73 // | 2| 3| 3| 0|
74 // | 4| 2| 3| -1|
75 // | 4| 4| 2| -2|
76 // | 2| 1| 4| -3|
77 // | 2| 5| 2| -4|
78 // | 0| 2| 4| -5|
79 // | 0| 4| 3| -6|
80 // | 4| 0| 4| -7|
81 // | 6| 3| 2| -8|
82 // | 0| 0| 5| -9|
83 // | 6| 1| 3| -10|
84 // | 4| 6| 1| -11|
85 // | 6| 5| 1| -12|
86 // | 0| 6| 2| -13|
87 // | 2| 7| 1| -14|
88 // | 8| 2| 2| -15|
89 // | 8| 4| 1| -16|
90 // | 6| 7| 0| -17|
91 // | 0| 8| 1| -18|
92 // | 8| 0| 3| -19|
93 // | 4| 8| 0| -20|
94 // | 8| 6| 0| -21|
95 // | 2| 9| 0| -22|
96 // | 10| 1| 2| -23|
97 // | 10| 3| 1| -24|
98 // | 0| 10| 0| -25|
99 // | 10| 5| 0| -26|
100 // | 12| 2| 1| -27|
101 // | 12| 0| 2| -28|
102 // | 12| 4| 0| -29|
103 // | 14| 1| 1| -30|
104 // | 14| 3| 0| -31|
105 // | 16| 0| 1| -32|
106 // | 16| 2| 0| -33|
107 // | 18| 1| 0| -34|
108 // | 20| 0| 0| -35|
109 // -----------------------
110 //
111 bool operator<(StateExtractUsage b) const {
112 BUG_CHECK(use.size() == 3 && b.use.size() == 3, "malformed extractor use");
113
114 if (total_bytes() < b.total_bytes()) return true;
115 if (total_bytes() > b.total_bytes()) return false;
116
117 auto a_sorted = sorted();
118 auto b_sorted = b.sorted();
119
120 unsigned a_delta = a_sorted[2] - a_sorted[0];
121 unsigned b_delta = b_sorted[2] - b_sorted[0];
122
123 if (a_delta < b_delta) return true;
124 if (a_delta > b_delta) return false;
125
126 unsigned a_over_four = 0;
127 unsigned b_over_four = 0;
128
129 for (auto u : use) {
130 if (u.second > 4) a_over_four++;
131 }
132
133 for (auto u : b.use) {
134 if (u.second > 4) b_over_four++;
135 }
136
137 if (a_over_four < b_over_four) return true;
138 if (a_over_four > b_over_four) return false;
139
140 if (total_extracts() < b.total_extracts()) return true;
141 if (total_extracts() > b.total_extracts()) return false;
142
143 for (int i = 2; i >= 0; i--) {
144 if (a_sorted[i] > b_sorted[i]) return true;
145 if (a_sorted[i] < b_sorted[i]) return false;
146 }
147
148 return false;
149 }
150
151 bool operator==(const StateExtractUsage &c) const { return use == c.use; }
152
153 unsigned total_extracts() const {
154 unsigned total = 0;
155 for (auto sz : extractor_sizes) total += use.at(sz);
156
157 return total;
158 }
159
160 unsigned total_bytes() const {
161 unsigned total = 0;
162 for (auto sz : extractor_sizes) total += (unsigned)sz / 8 * use.at(sz);
163
164 return total;
165 }
166
167 void print() const {
168 for (auto sz : extractor_sizes) std::cout << use.at(sz) << " ";
169
170 std::cout << std::endl;
171 }
172
173 std::vector<unsigned> sorted() const {
174 std::vector<unsigned> rv;
175
176 for (auto sz : extractor_sizes) rv.push_back(use.at(sz));
177
178 std::sort(rv.begin(), rv.end());
179 return rv;
180 }
181};
182
183namespace ParserExtractScore {
184
185void verify(const StateExtractUsage &use, unsigned num_bytes) {
186 BUG_CHECK(use.total_bytes() == num_bytes, "number of bytes don't add up");
187}
188
189std::string print_scoreboard(unsigned num_bytes, const std::set<StateExtractUsage> &combos) {
190 std::stringstream ss;
191 ss << "Scoreboard for " << num_bytes << " bytes:" << std::endl;
192
193 TablePrinter tp(ss, {"B", "H", "W", "Score"});
194
195 int score = 0;
196 for (auto &use : combos) {
197 tp.addRow({std::to_string(use.use.at(PHV::Size::b8)),
198 std::to_string(use.use.at(PHV::Size::b16)),
199 std::to_string(use.use.at(PHV::Size::b32)), std::to_string(score--)});
200 }
201
202 tp.print();
203 return ss.str();
204}
205
206// What are all possible combinations of extracts (B, H, W) that add up to N bytes?
207// This is a textbook dynamic programming problem (memoization + optimal substructure).
208std::set<StateExtractUsage> enumerate_extract_combinations(
209 unsigned num_bytes, std::map<unsigned, std::set<StateExtractUsage>> &all_usages) {
210 std::set<StateExtractUsage> usages;
211
212 if (num_bytes == 0) {
214 usages.insert(use);
215 return usages;
216 }
217
218 if (all_usages.count(num_bytes)) return all_usages.at(num_bytes);
219
220 for (auto sz : {PHV::Size::b8, PHV::Size::b16, PHV::Size::b32}) {
221 unsigned bytes = (unsigned)sz / 8;
222
223 if (num_bytes >= bytes) {
224 auto opt_sub = enumerate_extract_combinations(num_bytes - bytes, all_usages);
225
226 for (auto &u : opt_sub) {
227 auto use = u;
228 use.use[sz]++;
229 usages.insert(use);
230 }
231 }
232 }
233
234 for (auto use : usages) verify(use, num_bytes);
235
236 all_usages[num_bytes] = usages; // memoize
237
238 LOG4(print_scoreboard(num_bytes, usages));
239
240 return usages;
241}
242
243std::set<StateExtractUsage> enumerate_extract_combinations(unsigned num_bytes) {
244 static std::map<unsigned, std::set<StateExtractUsage>> all_usages;
245
246 if (all_usages.count(num_bytes)) return all_usages.at(num_bytes);
247
248 auto res = enumerate_extract_combinations(num_bytes, all_usages);
249 return res;
250}
251
252int get_score(const StateExtractUsage &use) {
253 unsigned num_bytes = use.total_bytes();
254 auto combos = enumerate_extract_combinations(num_bytes);
255
256 auto it = combos.find(use);
257
258 BUG_CHECK(it != combos.end(), "invalid extractor use?");
259
260 unsigned dis = std::distance(combos.begin(), it);
261
262 // We use negative number for the score, with 0 being the best score,
263 // and the more negative the worse. Essentially, the score indicates
264 // the penalty incurred because of bandwidth imbalance.
265 return -dis;
266}
267
268} // namespace ParserExtractScore
269
270#endif /* BF_P4C_PHV_PARSER_EXTRACT_BALANCE_SCORE_H_ */
Definition table_printer.h:40
Definition parser_extract_balance_score.h:35
static const std::vector< PHV::Size > extractor_sizes
Definition parser_extract_balance_score.h:57