nir,spirv: Rework function calls
[mesa.git] / src / compiler / spirv / vtn_cfg.c
1 /*
2 * Copyright © 2015 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25 #include "nir/nir_vla.h"
26
27 static struct vtn_pointer *
28 vtn_load_param_pointer(struct vtn_builder *b,
29 struct vtn_type *param_type,
30 uint32_t param_idx)
31 {
32 struct vtn_type *ptr_type = param_type;
33 if (param_type->base_type != vtn_base_type_pointer) {
34 assert(param_type->base_type == vtn_base_type_image ||
35 param_type->base_type == vtn_base_type_sampler);
36 ptr_type = rzalloc(b, struct vtn_type);
37 ptr_type->base_type = vtn_base_type_pointer;
38 ptr_type->deref = param_type;
39 ptr_type->storage_class = SpvStorageClassUniformConstant;
40 }
41
42 return vtn_pointer_from_ssa(b, nir_load_param(&b->nb, param_idx), ptr_type);
43 }
44
45 static bool
46 vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
47 const uint32_t *w, unsigned count)
48 {
49 switch (opcode) {
50 case SpvOpFunction: {
51 vtn_assert(b->func == NULL);
52 b->func = rzalloc(b, struct vtn_function);
53
54 list_inithead(&b->func->body);
55 b->func->control = w[3];
56
57 MAYBE_UNUSED const struct glsl_type *result_type =
58 vtn_value(b, w[1], vtn_value_type_type)->type->type;
59 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
60 val->func = b->func;
61
62 b->func->type = vtn_value(b, w[4], vtn_value_type_type)->type;
63 const struct vtn_type *func_type = b->func->type;
64
65 vtn_assert(func_type->return_type->type == result_type);
66
67 nir_function *func =
68 nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
69
70 unsigned num_params = func_type->length;
71 for (unsigned i = 0; i < func_type->length; i++) {
72 /* Sampled images are actually two parameters */
73 if (func_type->params[i]->base_type == vtn_base_type_sampled_image)
74 num_params++;
75 }
76
77 /* Add one parameter for the function return value */
78 if (func_type->return_type->base_type != vtn_base_type_void)
79 num_params++;
80
81 func->num_params = num_params;
82 func->params = ralloc_array(b->shader, nir_parameter, num_params);
83
84 unsigned idx = 0;
85 if (func_type->return_type->base_type != vtn_base_type_void) {
86 /* The return value is a regular pointer */
87 func->params[idx++] = (nir_parameter) {
88 .num_components = 1, .bit_size = 32,
89 };
90 }
91
92 for (unsigned i = 0; i < func_type->length; i++) {
93 if (func_type->params[i]->base_type == vtn_base_type_sampled_image) {
94 /* Sampled images are two pointer parameters */
95 func->params[idx++] = (nir_parameter) {
96 .num_components = 1, .bit_size = 32,
97 };
98 func->params[idx++] = (nir_parameter) {
99 .num_components = 1, .bit_size = 32,
100 };
101 } else if (func_type->params[i]->base_type == vtn_base_type_pointer &&
102 func_type->params[i]->type != NULL) {
103 /* Pointers with as storage class get passed by-value */
104 assert(glsl_type_is_vector_or_scalar(func_type->params[i]->type));
105 func->params[idx++] = (nir_parameter) {
106 .num_components =
107 glsl_get_vector_elements(func_type->params[i]->type),
108 .bit_size = glsl_get_bit_size(func_type->params[i]->type),
109 };
110 } else {
111 /* Everything else is a regular pointer */
112 func->params[idx++] = (nir_parameter) {
113 .num_components = 1, .bit_size = 32,
114 };
115 }
116 }
117 assert(idx == num_params);
118
119 b->func->impl = nir_function_impl_create(func);
120 nir_builder_init(&b->nb, func->impl);
121 b->nb.cursor = nir_before_cf_list(&b->func->impl->body);
122
123 b->func_param_idx = 0;
124
125 /* The return value is the first parameter */
126 if (func_type->return_type->base_type != vtn_base_type_void)
127 b->func_param_idx++;
128 break;
129 }
130
131 case SpvOpFunctionEnd:
132 b->func->end = w;
133 b->func = NULL;
134 break;
135
136 case SpvOpFunctionParameter: {
137 struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
138
139 vtn_assert(b->func_param_idx < b->func->impl->function->num_params);
140
141 if (type->base_type == vtn_base_type_sampled_image) {
142 /* Sampled images are actually two parameters. The first is the
143 * image and the second is the sampler.
144 */
145 struct vtn_value *val =
146 vtn_push_value(b, w[2], vtn_value_type_sampled_image);
147
148 val->sampled_image = ralloc(b, struct vtn_sampled_image);
149 val->sampled_image->type = type;
150
151 struct vtn_type *sampler_type = rzalloc(b, struct vtn_type);
152 sampler_type->base_type = vtn_base_type_sampler;
153 sampler_type->type = glsl_bare_sampler_type();
154
155 val->sampled_image->image =
156 vtn_load_param_pointer(b, type, b->func_param_idx++);
157 val->sampled_image->sampler =
158 vtn_load_param_pointer(b, sampler_type, b->func_param_idx++);
159 } else if (type->base_type == vtn_base_type_pointer &&
160 type->type != NULL) {
161 /* This is a pointer with an actual storage type */
162 struct vtn_value *val =
163 vtn_push_value(b, w[2], vtn_value_type_pointer);
164 nir_ssa_def *ssa_ptr = nir_load_param(&b->nb, b->func_param_idx++);
165 val->pointer = vtn_pointer_from_ssa(b, ssa_ptr, type);
166 } else if (type->base_type == vtn_base_type_pointer ||
167 type->base_type == vtn_base_type_image ||
168 type->base_type == vtn_base_type_sampler) {
169 struct vtn_value *val =
170 vtn_push_value(b, w[2], vtn_value_type_pointer);
171 val->pointer =
172 vtn_load_param_pointer(b, type, b->func_param_idx++);
173 } else {
174 /* We're a regular SSA value. */
175 nir_ssa_def *param_val = nir_load_param(&b->nb, b->func_param_idx++);
176 nir_deref_instr *deref =
177 nir_build_deref_cast(&b->nb, param_val, nir_var_local, type->type);
178 vtn_push_ssa(b, w[2], type, vtn_local_load(b, deref));
179 }
180 break;
181 }
182
183 case SpvOpLabel: {
184 vtn_assert(b->block == NULL);
185 b->block = rzalloc(b, struct vtn_block);
186 b->block->node.type = vtn_cf_node_type_block;
187 b->block->label = w;
188 vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
189
190 if (b->func->start_block == NULL) {
191 /* This is the first block encountered for this function. In this
192 * case, we set the start block and add it to the list of
193 * implemented functions that we'll walk later.
194 */
195 b->func->start_block = b->block;
196 exec_list_push_tail(&b->functions, &b->func->node);
197 }
198 break;
199 }
200
201 case SpvOpSelectionMerge:
202 case SpvOpLoopMerge:
203 vtn_assert(b->block && b->block->merge == NULL);
204 b->block->merge = w;
205 break;
206
207 case SpvOpBranch:
208 case SpvOpBranchConditional:
209 case SpvOpSwitch:
210 case SpvOpKill:
211 case SpvOpReturn:
212 case SpvOpReturnValue:
213 case SpvOpUnreachable:
214 vtn_assert(b->block && b->block->branch == NULL);
215 b->block->branch = w;
216 b->block = NULL;
217 break;
218
219 default:
220 /* Continue on as per normal */
221 return true;
222 }
223
224 return true;
225 }
226
227 static void
228 vtn_add_case(struct vtn_builder *b, struct vtn_switch *swtch,
229 struct vtn_block *break_block,
230 uint32_t block_id, uint64_t val, bool is_default)
231 {
232 struct vtn_block *case_block =
233 vtn_value(b, block_id, vtn_value_type_block)->block;
234
235 /* Don't create dummy cases that just break */
236 if (case_block == break_block)
237 return;
238
239 if (case_block->switch_case == NULL) {
240 struct vtn_case *c = ralloc(b, struct vtn_case);
241
242 list_inithead(&c->body);
243 c->start_block = case_block;
244 c->fallthrough = NULL;
245 util_dynarray_init(&c->values, b);
246 c->is_default = false;
247 c->visited = false;
248
249 list_addtail(&c->link, &swtch->cases);
250
251 case_block->switch_case = c;
252 }
253
254 if (is_default) {
255 case_block->switch_case->is_default = true;
256 } else {
257 util_dynarray_append(&case_block->switch_case->values, uint64_t, val);
258 }
259 }
260
261 /* This function performs a depth-first search of the cases and puts them
262 * in fall-through order.
263 */
264 static void
265 vtn_order_case(struct vtn_switch *swtch, struct vtn_case *cse)
266 {
267 if (cse->visited)
268 return;
269
270 cse->visited = true;
271
272 list_del(&cse->link);
273
274 if (cse->fallthrough) {
275 vtn_order_case(swtch, cse->fallthrough);
276
277 /* If we have a fall-through, place this case right before the case it
278 * falls through to. This ensures that fallthroughs come one after
279 * the other. These two can never get separated because that would
280 * imply something else falling through to the same case. Also, this
281 * can't break ordering because the DFS ensures that this case is
282 * visited before anything that falls through to it.
283 */
284 list_addtail(&cse->link, &cse->fallthrough->link);
285 } else {
286 list_add(&cse->link, &swtch->cases);
287 }
288 }
289
290 static enum vtn_branch_type
291 vtn_get_branch_type(struct vtn_builder *b,
292 struct vtn_block *block,
293 struct vtn_case *swcase, struct vtn_block *switch_break,
294 struct vtn_block *loop_break, struct vtn_block *loop_cont)
295 {
296 if (block->switch_case) {
297 /* This branch is actually a fallthrough */
298 vtn_assert(swcase->fallthrough == NULL ||
299 swcase->fallthrough == block->switch_case);
300 swcase->fallthrough = block->switch_case;
301 return vtn_branch_type_switch_fallthrough;
302 } else if (block == loop_break) {
303 return vtn_branch_type_loop_break;
304 } else if (block == loop_cont) {
305 return vtn_branch_type_loop_continue;
306 } else if (block == switch_break) {
307 return vtn_branch_type_switch_break;
308 } else {
309 return vtn_branch_type_none;
310 }
311 }
312
313 static void
314 vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
315 struct vtn_block *start, struct vtn_case *switch_case,
316 struct vtn_block *switch_break,
317 struct vtn_block *loop_break, struct vtn_block *loop_cont,
318 struct vtn_block *end)
319 {
320 struct vtn_block *block = start;
321 while (block != end) {
322 if (block->merge && (*block->merge & SpvOpCodeMask) == SpvOpLoopMerge &&
323 !block->loop) {
324 struct vtn_loop *loop = ralloc(b, struct vtn_loop);
325
326 loop->node.type = vtn_cf_node_type_loop;
327 list_inithead(&loop->body);
328 list_inithead(&loop->cont_body);
329 loop->control = block->merge[3];
330
331 list_addtail(&loop->node.link, cf_list);
332 block->loop = loop;
333
334 struct vtn_block *new_loop_break =
335 vtn_value(b, block->merge[1], vtn_value_type_block)->block;
336 struct vtn_block *new_loop_cont =
337 vtn_value(b, block->merge[2], vtn_value_type_block)->block;
338
339 /* Note: This recursive call will start with the current block as
340 * its start block. If we weren't careful, we would get here
341 * again and end up in infinite recursion. This is why we set
342 * block->loop above and check for it before creating one. This
343 * way, we only create the loop once and the second call that
344 * tries to handle this loop goes to the cases below and gets
345 * handled as a regular block.
346 *
347 * Note: When we make the recursive walk calls, we pass NULL for
348 * the switch break since you have to break out of the loop first.
349 * We do, however, still pass the current switch case because it's
350 * possible that the merge block for the loop is the start of
351 * another case.
352 */
353 vtn_cfg_walk_blocks(b, &loop->body, block, switch_case, NULL,
354 new_loop_break, new_loop_cont, NULL );
355 vtn_cfg_walk_blocks(b, &loop->cont_body, new_loop_cont, NULL, NULL,
356 new_loop_break, NULL, block);
357
358 enum vtn_branch_type branch_type =
359 vtn_get_branch_type(b, new_loop_break, switch_case, switch_break,
360 loop_break, loop_cont);
361
362 if (branch_type != vtn_branch_type_none) {
363 /* Stop walking through the CFG when this inner loop's break block
364 * ends up as the same block as the outer loop's continue block
365 * because we are already going to visit it.
366 */
367 vtn_assert(branch_type == vtn_branch_type_loop_continue);
368 return;
369 }
370
371 block = new_loop_break;
372 continue;
373 }
374
375 vtn_assert(block->node.link.next == NULL);
376 list_addtail(&block->node.link, cf_list);
377
378 switch (*block->branch & SpvOpCodeMask) {
379 case SpvOpBranch: {
380 struct vtn_block *branch_block =
381 vtn_value(b, block->branch[1], vtn_value_type_block)->block;
382
383 block->branch_type = vtn_get_branch_type(b, branch_block,
384 switch_case, switch_break,
385 loop_break, loop_cont);
386
387 if (block->branch_type != vtn_branch_type_none)
388 return;
389
390 block = branch_block;
391 continue;
392 }
393
394 case SpvOpReturn:
395 case SpvOpReturnValue:
396 block->branch_type = vtn_branch_type_return;
397 return;
398
399 case SpvOpKill:
400 block->branch_type = vtn_branch_type_discard;
401 return;
402
403 case SpvOpBranchConditional: {
404 struct vtn_block *then_block =
405 vtn_value(b, block->branch[2], vtn_value_type_block)->block;
406 struct vtn_block *else_block =
407 vtn_value(b, block->branch[3], vtn_value_type_block)->block;
408
409 struct vtn_if *if_stmt = ralloc(b, struct vtn_if);
410
411 if_stmt->node.type = vtn_cf_node_type_if;
412 if_stmt->condition = block->branch[1];
413 list_inithead(&if_stmt->then_body);
414 list_inithead(&if_stmt->else_body);
415
416 list_addtail(&if_stmt->node.link, cf_list);
417
418 if (block->merge &&
419 (*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge) {
420 if_stmt->control = block->merge[2];
421 }
422
423 if_stmt->then_type = vtn_get_branch_type(b, then_block,
424 switch_case, switch_break,
425 loop_break, loop_cont);
426 if_stmt->else_type = vtn_get_branch_type(b, else_block,
427 switch_case, switch_break,
428 loop_break, loop_cont);
429
430 if (then_block == else_block) {
431 block->branch_type = if_stmt->then_type;
432 if (block->branch_type == vtn_branch_type_none) {
433 block = then_block;
434 continue;
435 } else {
436 return;
437 }
438 } else if (if_stmt->then_type == vtn_branch_type_none &&
439 if_stmt->else_type == vtn_branch_type_none) {
440 /* Neither side of the if is something we can short-circuit. */
441 vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
442 struct vtn_block *merge_block =
443 vtn_value(b, block->merge[1], vtn_value_type_block)->block;
444
445 vtn_cfg_walk_blocks(b, &if_stmt->then_body, then_block,
446 switch_case, switch_break,
447 loop_break, loop_cont, merge_block);
448 vtn_cfg_walk_blocks(b, &if_stmt->else_body, else_block,
449 switch_case, switch_break,
450 loop_break, loop_cont, merge_block);
451
452 enum vtn_branch_type merge_type =
453 vtn_get_branch_type(b, merge_block, switch_case, switch_break,
454 loop_break, loop_cont);
455 if (merge_type == vtn_branch_type_none) {
456 block = merge_block;
457 continue;
458 } else {
459 return;
460 }
461 } else if (if_stmt->then_type != vtn_branch_type_none &&
462 if_stmt->else_type != vtn_branch_type_none) {
463 /* Both sides were short-circuited. We're done here. */
464 return;
465 } else {
466 /* Exeactly one side of the branch could be short-circuited.
467 * We set the branch up as a predicated break/continue and we
468 * continue on with the other side as if it were what comes
469 * after the if.
470 */
471 if (if_stmt->then_type == vtn_branch_type_none) {
472 block = then_block;
473 } else {
474 block = else_block;
475 }
476 continue;
477 }
478 vtn_fail("Should have returned or continued");
479 }
480
481 case SpvOpSwitch: {
482 vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
483 struct vtn_block *break_block =
484 vtn_value(b, block->merge[1], vtn_value_type_block)->block;
485
486 struct vtn_switch *swtch = ralloc(b, struct vtn_switch);
487
488 swtch->node.type = vtn_cf_node_type_switch;
489 swtch->selector = block->branch[1];
490 list_inithead(&swtch->cases);
491
492 list_addtail(&swtch->node.link, cf_list);
493
494 /* First, we go through and record all of the cases. */
495 const uint32_t *branch_end =
496 block->branch + (block->branch[0] >> SpvWordCountShift);
497
498 struct vtn_value *cond_val = vtn_untyped_value(b, block->branch[1]);
499 vtn_fail_if(!cond_val->type ||
500 cond_val->type->base_type != vtn_base_type_scalar,
501 "Selector of OpSelect must have a type of OpTypeInt");
502
503 nir_alu_type cond_type =
504 nir_get_nir_type_for_glsl_type(cond_val->type->type);
505 vtn_fail_if(nir_alu_type_get_base_type(cond_type) != nir_type_int &&
506 nir_alu_type_get_base_type(cond_type) != nir_type_uint,
507 "Selector of OpSelect must have a type of OpTypeInt");
508
509 bool is_default = true;
510 const unsigned bitsize = nir_alu_type_get_type_size(cond_type);
511 for (const uint32_t *w = block->branch + 2; w < branch_end;) {
512 uint64_t literal = 0;
513 if (!is_default) {
514 if (bitsize <= 32) {
515 literal = *(w++);
516 } else {
517 assert(bitsize == 64);
518 literal = vtn_u64_literal(w);
519 w += 2;
520 }
521 }
522
523 uint32_t block_id = *(w++);
524
525 vtn_add_case(b, swtch, break_block, block_id, literal, is_default);
526 is_default = false;
527 }
528
529 /* Now, we go through and walk the blocks. While we walk through
530 * the blocks, we also gather the much-needed fall-through
531 * information.
532 */
533 list_for_each_entry(struct vtn_case, cse, &swtch->cases, link) {
534 vtn_assert(cse->start_block != break_block);
535 vtn_cfg_walk_blocks(b, &cse->body, cse->start_block, cse,
536 break_block, loop_break, loop_cont, NULL);
537 }
538
539 /* Finally, we walk over all of the cases one more time and put
540 * them in fall-through order.
541 */
542 for (const uint32_t *w = block->branch + 2; w < branch_end;) {
543 struct vtn_block *case_block =
544 vtn_value(b, *w, vtn_value_type_block)->block;
545
546 if (bitsize <= 32) {
547 w += 2;
548 } else {
549 assert(bitsize == 64);
550 w += 3;
551 }
552
553 if (case_block == break_block)
554 continue;
555
556 vtn_assert(case_block->switch_case);
557
558 vtn_order_case(swtch, case_block->switch_case);
559 }
560
561 enum vtn_branch_type branch_type =
562 vtn_get_branch_type(b, break_block, switch_case, NULL,
563 loop_break, loop_cont);
564
565 if (branch_type != vtn_branch_type_none) {
566 /* It is possible that the break is actually the continue block
567 * for the containing loop. In this case, we need to bail and let
568 * the loop parsing code handle the continue properly.
569 */
570 vtn_assert(branch_type == vtn_branch_type_loop_continue);
571 return;
572 }
573
574 block = break_block;
575 continue;
576 }
577
578 case SpvOpUnreachable:
579 return;
580
581 default:
582 vtn_fail("Unhandled opcode");
583 }
584 }
585 }
586
587 void
588 vtn_build_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
589 {
590 vtn_foreach_instruction(b, words, end,
591 vtn_cfg_handle_prepass_instruction);
592
593 foreach_list_typed(struct vtn_function, func, node, &b->functions) {
594 vtn_cfg_walk_blocks(b, &func->body, func->start_block,
595 NULL, NULL, NULL, NULL, NULL);
596 }
597 }
598
599 static bool
600 vtn_handle_phis_first_pass(struct vtn_builder *b, SpvOp opcode,
601 const uint32_t *w, unsigned count)
602 {
603 if (opcode == SpvOpLabel)
604 return true; /* Nothing to do */
605
606 /* If this isn't a phi node, stop. */
607 if (opcode != SpvOpPhi)
608 return false;
609
610 /* For handling phi nodes, we do a poor-man's out-of-ssa on the spot.
611 * For each phi, we create a variable with the appropreate type and
612 * do a load from that variable. Then, in a second pass, we add
613 * stores to that variable to each of the predecessor blocks.
614 *
615 * We could do something more intelligent here. However, in order to
616 * handle loops and things properly, we really need dominance
617 * information. It would end up basically being the into-SSA
618 * algorithm all over again. It's easier if we just let
619 * lower_vars_to_ssa do that for us instead of repeating it here.
620 */
621 struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
622 nir_variable *phi_var =
623 nir_local_variable_create(b->nb.impl, type->type, "phi");
624 _mesa_hash_table_insert(b->phi_table, w, phi_var);
625
626 vtn_push_ssa(b, w[2], type,
627 vtn_local_load(b, nir_build_deref_var(&b->nb, phi_var)));
628
629 return true;
630 }
631
632 static bool
633 vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
634 const uint32_t *w, unsigned count)
635 {
636 if (opcode != SpvOpPhi)
637 return true;
638
639 struct hash_entry *phi_entry = _mesa_hash_table_search(b->phi_table, w);
640 vtn_assert(phi_entry);
641 nir_variable *phi_var = phi_entry->data;
642
643 for (unsigned i = 3; i < count; i += 2) {
644 struct vtn_block *pred =
645 vtn_value(b, w[i + 1], vtn_value_type_block)->block;
646
647 b->nb.cursor = nir_after_instr(&pred->end_nop->instr);
648
649 struct vtn_ssa_value *src = vtn_ssa_value(b, w[i]);
650
651 vtn_local_store(b, src, nir_build_deref_var(&b->nb, phi_var));
652 }
653
654 return true;
655 }
656
657 static void
658 vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type,
659 nir_variable *switch_fall_var, bool *has_switch_break)
660 {
661 switch (branch_type) {
662 case vtn_branch_type_switch_break:
663 nir_store_var(&b->nb, switch_fall_var, nir_imm_int(&b->nb, NIR_FALSE), 1);
664 *has_switch_break = true;
665 break;
666 case vtn_branch_type_switch_fallthrough:
667 break; /* Nothing to do */
668 case vtn_branch_type_loop_break:
669 nir_jump(&b->nb, nir_jump_break);
670 break;
671 case vtn_branch_type_loop_continue:
672 nir_jump(&b->nb, nir_jump_continue);
673 break;
674 case vtn_branch_type_return:
675 nir_jump(&b->nb, nir_jump_return);
676 break;
677 case vtn_branch_type_discard: {
678 nir_intrinsic_instr *discard =
679 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_discard);
680 nir_builder_instr_insert(&b->nb, &discard->instr);
681 break;
682 }
683 default:
684 vtn_fail("Invalid branch type");
685 }
686 }
687
688 static void
689 vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
690 nir_variable *switch_fall_var, bool *has_switch_break,
691 vtn_instruction_handler handler)
692 {
693 list_for_each_entry(struct vtn_cf_node, node, cf_list, link) {
694 switch (node->type) {
695 case vtn_cf_node_type_block: {
696 struct vtn_block *block = (struct vtn_block *)node;
697
698 const uint32_t *block_start = block->label;
699 const uint32_t *block_end = block->merge ? block->merge :
700 block->branch;
701
702 block_start = vtn_foreach_instruction(b, block_start, block_end,
703 vtn_handle_phis_first_pass);
704
705 vtn_foreach_instruction(b, block_start, block_end, handler);
706
707 block->end_nop = nir_intrinsic_instr_create(b->nb.shader,
708 nir_intrinsic_nop);
709 nir_builder_instr_insert(&b->nb, &block->end_nop->instr);
710
711 if ((*block->branch & SpvOpCodeMask) == SpvOpReturnValue) {
712 vtn_fail_if(b->func->type->return_type->base_type ==
713 vtn_base_type_void,
714 "Return with a value from a function returning void");
715 struct vtn_ssa_value *src = vtn_ssa_value(b, block->branch[1]);
716 nir_deref_instr *ret_deref =
717 nir_build_deref_cast(&b->nb, nir_load_param(&b->nb, 0),
718 nir_var_local, src->type);
719 vtn_local_store(b, src, ret_deref);
720 }
721
722 if (block->branch_type != vtn_branch_type_none) {
723 vtn_emit_branch(b, block->branch_type,
724 switch_fall_var, has_switch_break);
725 }
726
727 break;
728 }
729
730 case vtn_cf_node_type_if: {
731 struct vtn_if *vtn_if = (struct vtn_if *)node;
732 bool sw_break = false;
733
734 nir_if *nif =
735 nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def);
736 if (vtn_if->then_type == vtn_branch_type_none) {
737 vtn_emit_cf_list(b, &vtn_if->then_body,
738 switch_fall_var, &sw_break, handler);
739 } else {
740 vtn_emit_branch(b, vtn_if->then_type, switch_fall_var, &sw_break);
741 }
742
743 nir_push_else(&b->nb, nif);
744 if (vtn_if->else_type == vtn_branch_type_none) {
745 vtn_emit_cf_list(b, &vtn_if->else_body,
746 switch_fall_var, &sw_break, handler);
747 } else {
748 vtn_emit_branch(b, vtn_if->else_type, switch_fall_var, &sw_break);
749 }
750
751 nir_pop_if(&b->nb, nif);
752
753 /* If we encountered a switch break somewhere inside of the if,
754 * then it would have been handled correctly by calling
755 * emit_cf_list or emit_branch for the interrior. However, we
756 * need to predicate everything following on wether or not we're
757 * still going.
758 */
759 if (sw_break) {
760 *has_switch_break = true;
761 nir_push_if(&b->nb, nir_load_var(&b->nb, switch_fall_var));
762 }
763 break;
764 }
765
766 case vtn_cf_node_type_loop: {
767 struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
768
769 nir_loop *loop = nir_push_loop(&b->nb);
770 vtn_emit_cf_list(b, &vtn_loop->body, NULL, NULL, handler);
771
772 if (!list_empty(&vtn_loop->cont_body)) {
773 /* If we have a non-trivial continue body then we need to put
774 * it at the beginning of the loop with a flag to ensure that
775 * it doesn't get executed in the first iteration.
776 */
777 nir_variable *do_cont =
778 nir_local_variable_create(b->nb.impl, glsl_bool_type(), "cont");
779
780 b->nb.cursor = nir_before_cf_node(&loop->cf_node);
781 nir_store_var(&b->nb, do_cont, nir_imm_int(&b->nb, NIR_FALSE), 1);
782
783 b->nb.cursor = nir_before_cf_list(&loop->body);
784
785 nir_if *cont_if =
786 nir_push_if(&b->nb, nir_load_var(&b->nb, do_cont));
787
788 vtn_emit_cf_list(b, &vtn_loop->cont_body, NULL, NULL, handler);
789
790 nir_pop_if(&b->nb, cont_if);
791
792 nir_store_var(&b->nb, do_cont, nir_imm_int(&b->nb, NIR_TRUE), 1);
793
794 b->has_loop_continue = true;
795 }
796
797 nir_pop_loop(&b->nb, loop);
798 break;
799 }
800
801 case vtn_cf_node_type_switch: {
802 struct vtn_switch *vtn_switch = (struct vtn_switch *)node;
803
804 /* First, we create a variable to keep track of whether or not the
805 * switch is still going at any given point. Any switch breaks
806 * will set this variable to false.
807 */
808 nir_variable *fall_var =
809 nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall");
810 nir_store_var(&b->nb, fall_var, nir_imm_int(&b->nb, NIR_FALSE), 1);
811
812 /* Next, we gather up all of the conditions. We have to do this
813 * up-front because we also need to build an "any" condition so
814 * that we can use !any for default.
815 */
816 const int num_cases = list_length(&vtn_switch->cases);
817 NIR_VLA(nir_ssa_def *, conditions, num_cases);
818
819 nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def;
820 /* An accumulation of all conditions. Used for the default */
821 nir_ssa_def *any = NULL;
822
823 int i = 0;
824 list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
825 if (cse->is_default) {
826 conditions[i++] = NULL;
827 continue;
828 }
829
830 nir_ssa_def *cond = NULL;
831 util_dynarray_foreach(&cse->values, uint64_t, val) {
832 nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
833 nir_ssa_def *is_val = nir_ieq(&b->nb, sel, imm);
834
835 cond = cond ? nir_ior(&b->nb, cond, is_val) : is_val;
836 }
837
838 any = any ? nir_ior(&b->nb, any, cond) : cond;
839 conditions[i++] = cond;
840 }
841 vtn_assert(i == num_cases);
842
843 /* Now we can walk the list of cases and actually emit code */
844 i = 0;
845 list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
846 /* Figure out the condition */
847 nir_ssa_def *cond = conditions[i++];
848 if (cse->is_default) {
849 vtn_assert(cond == NULL);
850 cond = nir_inot(&b->nb, any);
851 }
852 /* Take fallthrough into account */
853 cond = nir_ior(&b->nb, cond, nir_load_var(&b->nb, fall_var));
854
855 nir_if *case_if = nir_push_if(&b->nb, cond);
856
857 bool has_break = false;
858 nir_store_var(&b->nb, fall_var, nir_imm_int(&b->nb, NIR_TRUE), 1);
859 vtn_emit_cf_list(b, &cse->body, fall_var, &has_break, handler);
860 (void)has_break; /* We don't care */
861
862 nir_pop_if(&b->nb, case_if);
863 }
864 vtn_assert(i == num_cases);
865
866 break;
867 }
868
869 default:
870 vtn_fail("Invalid CF node type");
871 }
872 }
873 }
874
875 void
876 vtn_function_emit(struct vtn_builder *b, struct vtn_function *func,
877 vtn_instruction_handler instruction_handler)
878 {
879 nir_builder_init(&b->nb, func->impl);
880 b->func = func;
881 b->nb.cursor = nir_after_cf_list(&func->impl->body);
882 b->has_loop_continue = false;
883 b->phi_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
884 _mesa_key_pointer_equal);
885
886 vtn_emit_cf_list(b, &func->body, NULL, NULL, instruction_handler);
887
888 vtn_foreach_instruction(b, func->start_block->label, func->end,
889 vtn_handle_phi_second_pass);
890
891 /* Continue blocks for loops get inserted before the body of the loop
892 * but instructions in the continue may use SSA defs in the loop body.
893 * Therefore, we need to repair SSA to insert the needed phi nodes.
894 */
895 if (b->has_loop_continue)
896 nir_repair_ssa_impl(func->impl);
897
898 func->emitted = true;
899 }