aco: improve jump threading with wave32
[mesa.git] / src / amd / compiler / aco_ssa_elimination.cpp
1 /*
2 * Copyright © 2018 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25
26 #include "aco_ir.h"
27
28 #include <map>
29
30 namespace aco {
31 namespace {
32
33 /* map: block-id -> pair (dest, src) to store phi information */
34 typedef std::map<uint32_t, std::vector<std::pair<Definition, Operand>>> phi_info;
35
36 struct ssa_elimination_ctx {
37 phi_info logical_phi_info;
38 phi_info linear_phi_info;
39 std::vector<bool> empty_blocks;
40 Program* program;
41
42 ssa_elimination_ctx(Program* program) : empty_blocks(program->blocks.size(), true), program(program) {}
43 };
44
45 void collect_phi_info(ssa_elimination_ctx& ctx)
46 {
47 for (Block& block : ctx.program->blocks) {
48 for (aco_ptr<Instruction>& phi : block.instructions) {
49 if (phi->opcode != aco_opcode::p_phi && phi->opcode != aco_opcode::p_linear_phi)
50 break;
51
52 for (unsigned i = 0; i < phi->operands.size(); i++) {
53 if (phi->operands[i].isUndefined())
54 continue;
55 if (phi->operands[i].isTemp() && phi->operands[i].physReg() == phi->definitions[0].physReg())
56 continue;
57
58 std::vector<unsigned>& preds = phi->opcode == aco_opcode::p_phi ? block.logical_preds : block.linear_preds;
59 phi_info& info = phi->opcode == aco_opcode::p_phi ? ctx.logical_phi_info : ctx.linear_phi_info;
60 const auto result = info.emplace(preds[i], std::vector<std::pair<Definition, Operand>>());
61 assert(phi->definitions[0].size() == phi->operands[i].size());
62 result.first->second.emplace_back(phi->definitions[0], phi->operands[i]);
63 ctx.empty_blocks[preds[i]] = false;
64 }
65 }
66 }
67 }
68
69 void insert_parallelcopies(ssa_elimination_ctx& ctx)
70 {
71 /* insert the parallelcopies from logical phis before p_logical_end */
72 for (auto&& entry : ctx.logical_phi_info) {
73 Block& block = ctx.program->blocks[entry.first];
74 unsigned idx = block.instructions.size() - 1;
75 while (block.instructions[idx]->opcode != aco_opcode::p_logical_end) {
76 assert(idx > 0);
77 idx--;
78 }
79
80 std::vector<aco_ptr<Instruction>>::iterator it = std::next(block.instructions.begin(), idx);
81 aco_ptr<Pseudo_instruction> pc{create_instruction<Pseudo_instruction>(aco_opcode::p_parallelcopy, Format::PSEUDO, entry.second.size(), entry.second.size())};
82 unsigned i = 0;
83 for (std::pair<Definition, Operand>& pair : entry.second)
84 {
85 pc->definitions[i] = pair.first;
86 pc->operands[i] = pair.second;
87 i++;
88 }
89 /* this shouldn't be needed since we're only copying vgprs */
90 pc->tmp_in_scc = false;
91 block.instructions.insert(it, std::move(pc));
92 }
93
94 /* insert parallelcopies for the linear phis at the end of blocks just before the branch */
95 for (auto&& entry : ctx.linear_phi_info) {
96 Block& block = ctx.program->blocks[entry.first];
97 std::vector<aco_ptr<Instruction>>::iterator it = block.instructions.end();
98 --it;
99 assert((*it)->format == Format::PSEUDO_BRANCH);
100 aco_ptr<Pseudo_instruction> pc{create_instruction<Pseudo_instruction>(aco_opcode::p_parallelcopy, Format::PSEUDO, entry.second.size(), entry.second.size())};
101 unsigned i = 0;
102 for (std::pair<Definition, Operand>& pair : entry.second)
103 {
104 pc->definitions[i] = pair.first;
105 pc->operands[i] = pair.second;
106 i++;
107 }
108 pc->tmp_in_scc = block.scc_live_out;
109 pc->scratch_sgpr = block.scratch_sgpr;
110 block.instructions.insert(it, std::move(pc));
111 }
112 }
113
114
115 void try_remove_merge_block(ssa_elimination_ctx& ctx, Block* block)
116 {
117 /* check if the successor is another merge block which restores exec */
118 // TODO: divergent loops also restore exec
119 if (block->linear_succs.size() != 1 ||
120 !(ctx.program->blocks[block->linear_succs[0]].kind & block_kind_merge))
121 return;
122
123 /* check if this block is empty and the exec mask is not needed */
124 for (aco_ptr<Instruction>& instr : block->instructions) {
125 if (instr->opcode == aco_opcode::p_parallelcopy) {
126 if (instr->definitions[0].physReg() == exec)
127 continue;
128 else
129 return;
130 }
131
132 if (instr->opcode != aco_opcode::p_linear_phi &&
133 instr->opcode != aco_opcode::p_phi &&
134 instr->opcode != aco_opcode::p_logical_start &&
135 instr->opcode != aco_opcode::p_logical_end &&
136 instr->opcode != aco_opcode::p_branch)
137 return;
138 }
139
140 /* keep the branch instruction and remove the rest */
141 aco_ptr<Instruction> branch = std::move(block->instructions.back());
142 block->instructions.clear();
143 block->instructions.emplace_back(std::move(branch));
144 }
145
146 void try_remove_invert_block(ssa_elimination_ctx& ctx, Block* block)
147 {
148 assert(block->linear_succs.size() == 2);
149 if (block->linear_succs[0] != block->linear_succs[1])
150 return;
151
152 /* check if we can remove this block */
153 for (aco_ptr<Instruction>& instr : block->instructions) {
154 if (instr->opcode != aco_opcode::p_linear_phi &&
155 instr->opcode != aco_opcode::p_phi &&
156 (instr->opcode != aco_opcode::s_andn2_b64 || ctx.program->wave_size != 64) &&
157 (instr->opcode != aco_opcode::s_andn2_b32 || ctx.program->wave_size != 32) &&
158 instr->opcode != aco_opcode::p_branch)
159 return;
160 }
161
162 unsigned succ_idx = block->linear_succs[0];
163 assert(block->linear_preds.size() == 2);
164 for (unsigned i = 0; i < 2; i++) {
165 Block *pred = &ctx.program->blocks[block->linear_preds[i]];
166 pred->linear_succs[0] = succ_idx;
167 ctx.program->blocks[succ_idx].linear_preds[i] = pred->index;
168
169 Pseudo_branch_instruction *branch = static_cast<Pseudo_branch_instruction*>(pred->instructions.back().get());
170 assert(branch->format == Format::PSEUDO_BRANCH);
171 branch->target[0] = succ_idx;
172 branch->target[1] = succ_idx;
173 }
174
175 block->instructions.clear();
176 block->linear_preds.clear();
177 block->linear_succs.clear();
178 }
179
180 void try_remove_simple_block(ssa_elimination_ctx& ctx, Block* block)
181 {
182 for (aco_ptr<Instruction>& instr : block->instructions) {
183 if (instr->opcode != aco_opcode::p_logical_start &&
184 instr->opcode != aco_opcode::p_logical_end &&
185 instr->opcode != aco_opcode::p_branch)
186 return;
187 }
188
189 Block& pred = ctx.program->blocks[block->linear_preds[0]];
190 Block& succ = ctx.program->blocks[block->linear_succs[0]];
191 Pseudo_branch_instruction* branch = static_cast<Pseudo_branch_instruction*>(pred.instructions.back().get());
192 if (branch->opcode == aco_opcode::p_branch) {
193 branch->target[0] = succ.index;
194 branch->target[1] = succ.index;
195 } else if (branch->target[0] == block->index) {
196 branch->target[0] = succ.index;
197 } else if (branch->target[0] == succ.index) {
198 assert(branch->target[1] == block->index);
199 branch->target[1] = succ.index;
200 branch->opcode = aco_opcode::p_branch;
201 } else if (branch->target[1] == block->index) {
202 /* check if there is a fall-through path from block to succ */
203 bool falls_through = true;
204 for (unsigned j = block->index + 1; falls_through && j < succ.index; j++) {
205 assert(ctx.program->blocks[j].index == j);
206 if (!ctx.program->blocks[j].instructions.empty())
207 falls_through = false;
208 }
209 if (falls_through) {
210 branch->target[1] = succ.index;
211 } else {
212 /* check if there is a fall-through path for the alternative target */
213 for (unsigned j = block->index + 1; j < branch->target[0]; j++) {
214 if (!ctx.program->blocks[j].instructions.empty())
215 return;
216 }
217
218 /* This is a (uniform) break or continue block. The branch condition has to be inverted. */
219 if (branch->opcode == aco_opcode::p_cbranch_z)
220 branch->opcode = aco_opcode::p_cbranch_nz;
221 else if (branch->opcode == aco_opcode::p_cbranch_nz)
222 branch->opcode = aco_opcode::p_cbranch_z;
223 else
224 assert(false);
225 /* also invert the linear successors */
226 pred.linear_succs[0] = pred.linear_succs[1];
227 pred.linear_succs[1] = succ.index;
228 branch->target[1] = branch->target[0];
229 branch->target[0] = succ.index;
230 }
231 } else {
232 assert(false);
233 }
234
235 if (branch->target[0] == branch->target[1])
236 branch->opcode = aco_opcode::p_branch;
237
238 for (unsigned i = 0; i < pred.linear_succs.size(); i++)
239 if (pred.linear_succs[i] == block->index)
240 pred.linear_succs[i] = succ.index;
241
242 for (unsigned i = 0; i < succ.linear_preds.size(); i++)
243 if (succ.linear_preds[i] == block->index)
244 succ.linear_preds[i] = pred.index;
245
246 block->instructions.clear();
247 block->linear_preds.clear();
248 block->linear_succs.clear();
249 }
250
251 void jump_threading(ssa_elimination_ctx& ctx)
252 {
253 for (int i = ctx.program->blocks.size() - 1; i >= 0; i--) {
254 Block* block = &ctx.program->blocks[i];
255
256 if (!ctx.empty_blocks[i])
257 continue;
258
259 if (block->kind & block_kind_invert) {
260 try_remove_invert_block(ctx, block);
261 continue;
262 }
263
264 if (block->linear_succs.size() > 1)
265 continue;
266
267 if (block->kind & block_kind_merge ||
268 block->kind & block_kind_loop_exit)
269 try_remove_merge_block(ctx, block);
270
271 if (block->linear_preds.size() == 1)
272 try_remove_simple_block(ctx, block);
273 }
274 }
275
276 } /* end namespace */
277
278
279 void ssa_elimination(Program* program)
280 {
281 ssa_elimination_ctx ctx(program);
282
283 /* Collect information about every phi-instruction */
284 collect_phi_info(ctx);
285
286 /* eliminate empty blocks */
287 jump_threading(ctx);
288
289 /* insert parallelcopies from SSA elimination */
290 insert_parallelcopies(ctx);
291
292 }
293 }