nir/builder: Add support for easily building control-flow
[mesa.git] / src / compiler / nir / nir_opt_if.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_control_flow.h"
26
27 /**
28 * This optimization detects if statements at the tops of loops where the
29 * condition is a phi node of two constants and moves half of the if to above
30 * the loop and the other half of the if to the end of the loop. A simple for
31 * loop "for (int i = 0; i < 4; i++)", when run through the SPIR-V front-end,
32 * ends up looking something like this:
33 *
34 * vec1 32 ssa_0 = load_const (0x00000000)
35 * vec1 32 ssa_1 = load_const (0xffffffff)
36 * loop {
37 * block block_1:
38 * vec1 32 ssa_2 = phi block_0: ssa_0, block_7: ssa_5
39 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
40 * if ssa_2 {
41 * block block_2:
42 * vec1 32 ssa_4 = load_const (0x00000001)
43 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
44 * } else {
45 * block block_3:
46 * }
47 * block block_4:
48 * vec1 32 ssa_6 = load_const (0x00000004)
49 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
50 * if ssa_7 {
51 * block block_5:
52 * } else {
53 * block block_6:
54 * break
55 * }
56 * block block_7:
57 * }
58 *
59 * This turns it into something like this:
60 *
61 * // Stuff from block 1
62 * // Stuff from block 3
63 * loop {
64 * block block_1:
65 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
66 * vec1 32 ssa_6 = load_const (0x00000004)
67 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
68 * if ssa_7 {
69 * block block_5:
70 * } else {
71 * block block_6:
72 * break
73 * }
74 * block block_7:
75 * // Stuff from block 1
76 * // Stuff from block 2
77 * vec1 32 ssa_4 = load_const (0x00000001)
78 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
79 * }
80 */
81 static bool
82 opt_peel_loop_initial_if(nir_loop *loop)
83 {
84 nir_block *header_block = nir_loop_first_block(loop);
85 nir_block *prev_block =
86 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
87
88 /* It would be insane if this were not true */
89 assert(_mesa_set_search(header_block->predecessors, prev_block));
90
91 /* The loop must have exactly one continue block which could be a block
92 * ending in a continue instruction or the "natural" continue from the
93 * last block in the loop back to the top.
94 */
95 if (header_block->predecessors->entries != 2)
96 return false;
97
98 nir_block *continue_block = NULL;
99 struct set_entry *pred_entry;
100 set_foreach(header_block->predecessors, pred_entry) {
101 if (pred_entry->key != prev_block)
102 continue_block = (void *)pred_entry->key;
103 }
104
105 nir_cf_node *if_node = nir_cf_node_next(&header_block->cf_node);
106 if (!if_node || if_node->type != nir_cf_node_if)
107 return false;
108
109 nir_if *nif = nir_cf_node_as_if(if_node);
110 assert(nif->condition.is_ssa);
111
112 nir_ssa_def *cond = nif->condition.ssa;
113 if (cond->parent_instr->type != nir_instr_type_phi)
114 return false;
115
116 nir_phi_instr *cond_phi = nir_instr_as_phi(cond->parent_instr);
117 if (cond->parent_instr->block != header_block)
118 return false;
119
120 /* We already know we have exactly one continue */
121 assert(exec_list_length(&cond_phi->srcs) == 2);
122
123 uint32_t entry_val = 0, continue_val = 0;
124 nir_foreach_phi_src(src, cond_phi) {
125 assert(src->src.is_ssa);
126 nir_const_value *const_src = nir_src_as_const_value(src->src);
127 if (!const_src)
128 return false;
129
130 if (src->pred == continue_block) {
131 continue_val = const_src->u32[0];
132 } else {
133 assert(src->pred == prev_block);
134 entry_val = const_src->u32[0];
135 }
136 }
137
138 /* If they both execute or both don't execute, this is a job for
139 * nir_dead_cf, not this pass.
140 */
141 if ((entry_val && continue_val) || (!entry_val && !continue_val))
142 return false;
143
144 struct exec_list *continue_list, *entry_list;
145 if (continue_val) {
146 continue_list = &nif->then_list;
147 entry_list = &nif->else_list;
148 } else {
149 continue_list = &nif->else_list;
150 entry_list = &nif->then_list;
151 }
152
153 /* We want to be moving the contents of entry_list to above the loop so it
154 * can't contain any break or continue instructions.
155 */
156 foreach_list_typed(nir_cf_node, cf_node, node, entry_list) {
157 nir_foreach_block_in_cf_node(block, cf_node) {
158 nir_instr *last_instr = nir_block_last_instr(block);
159 if (last_instr && last_instr->type == nir_instr_type_jump)
160 return false;
161 }
162 }
163
164 /* Before we do anything, convert the loop to LCSSA. We're about to
165 * replace a bunch of SSA defs with registers and this will prevent any of
166 * it from leaking outside the loop.
167 */
168 nir_convert_loop_to_lcssa(loop);
169
170 nir_block *after_if_block =
171 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
172
173 /* Get rid of phis in the header block since we will be duplicating it */
174 nir_lower_phis_to_regs_block(header_block);
175 /* Get rid of phis after the if since dominance will change */
176 nir_lower_phis_to_regs_block(after_if_block);
177
178 /* Get rid of SSA defs in the pieces we're about to move around */
179 nir_lower_ssa_defs_to_regs_block(header_block);
180 nir_foreach_block_in_cf_node(block, &nif->cf_node)
181 nir_lower_ssa_defs_to_regs_block(block);
182
183 nir_cf_list header, tmp;
184 nir_cf_extract(&header, nir_before_block(header_block),
185 nir_after_block(header_block));
186
187 nir_cf_list_clone(&tmp, &header, &loop->cf_node, NULL);
188 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
189 nir_cf_extract(&tmp, nir_before_cf_list(entry_list),
190 nir_after_cf_list(entry_list));
191 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
192
193 nir_cf_reinsert(&header, nir_after_block_before_jump(continue_block));
194 nir_cf_extract(&tmp, nir_before_cf_list(continue_list),
195 nir_after_cf_list(continue_list));
196 nir_cf_reinsert(&tmp, nir_after_block_before_jump(continue_block));
197
198 nir_cf_node_remove(&nif->cf_node);
199
200 return true;
201 }
202
203 static bool
204 opt_if_cf_list(struct exec_list *cf_list)
205 {
206 bool progress = false;
207 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
208 switch (cf_node->type) {
209 case nir_cf_node_block:
210 break;
211
212 case nir_cf_node_if: {
213 nir_if *nif = nir_cf_node_as_if(cf_node);
214 progress |= opt_if_cf_list(&nif->then_list);
215 progress |= opt_if_cf_list(&nif->else_list);
216 break;
217 }
218
219 case nir_cf_node_loop: {
220 nir_loop *loop = nir_cf_node_as_loop(cf_node);
221 progress |= opt_if_cf_list(&loop->body);
222 progress |= opt_peel_loop_initial_if(loop);
223 break;
224 }
225
226 case nir_cf_node_function:
227 unreachable("Invalid cf type");
228 }
229 }
230
231 return progress;
232 }
233
234 bool
235 nir_opt_if(nir_shader *shader)
236 {
237 bool progress = false;
238
239 nir_foreach_function(function, shader) {
240 if (function->impl == NULL)
241 continue;
242
243 if (opt_if_cf_list(&function->impl->body)) {
244 nir_metadata_preserve(function->impl, nir_metadata_none);
245
246 /* If that made progress, we're no longer really in SSA form. We
247 * need to convert registers back into SSA defs and clean up SSA defs
248 * that don't dominate their uses.
249 */
250 nir_lower_regs_to_ssa_impl(function->impl);
251 progress = true;
252 }
253 }
254
255 return progress;
256 }