spirv: Pass SSA values through functions
[mesa.git] / src / compiler / spirv / vtn_cfg.c
1 /*
2 * Copyright © 2015 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25 #include "nir/nir_vla.h"
26
27 static struct vtn_pointer *
28 vtn_load_param_pointer(struct vtn_builder *b,
29 struct vtn_type *param_type,
30 uint32_t param_idx)
31 {
32 struct vtn_type *ptr_type = param_type;
33 if (param_type->base_type != vtn_base_type_pointer) {
34 assert(param_type->base_type == vtn_base_type_image ||
35 param_type->base_type == vtn_base_type_sampler);
36 ptr_type = rzalloc(b, struct vtn_type);
37 ptr_type->base_type = vtn_base_type_pointer;
38 ptr_type->deref = param_type;
39 ptr_type->storage_class = SpvStorageClassUniformConstant;
40 }
41
42 return vtn_pointer_from_ssa(b, nir_load_param(&b->nb, param_idx), ptr_type);
43 }
44
45 static unsigned
46 vtn_type_count_function_params(struct vtn_type *type)
47 {
48 switch (type->base_type) {
49 case vtn_base_type_array:
50 return type->length * vtn_type_count_function_params(type->array_element);
51
52 case vtn_base_type_struct: {
53 unsigned count = 0;
54 for (unsigned i = 0; i < type->length; i++)
55 count += vtn_type_count_function_params(type->members[i]);
56 return count;
57 }
58
59 case vtn_base_type_sampled_image:
60 return 2;
61
62 default:
63 return 1;
64 }
65 }
66
67 static void
68 vtn_type_add_to_function_params(struct vtn_type *type,
69 nir_function *func,
70 unsigned *param_idx)
71 {
72 static const nir_parameter nir_deref_param = {
73 .num_components = 1,
74 .bit_size = 32,
75 };
76
77 switch (type->base_type) {
78 case vtn_base_type_array:
79 for (unsigned i = 0; i < type->length; i++)
80 vtn_type_add_to_function_params(type->array_element, func, param_idx);
81 break;
82
83 case vtn_base_type_struct:
84 for (unsigned i = 0; i < type->length; i++)
85 vtn_type_add_to_function_params(type->members[i], func, param_idx);
86 break;
87
88 case vtn_base_type_sampled_image:
89 func->params[(*param_idx)++] = nir_deref_param;
90 func->params[(*param_idx)++] = nir_deref_param;
91 break;
92
93 case vtn_base_type_image:
94 case vtn_base_type_sampler:
95 func->params[(*param_idx)++] = nir_deref_param;
96 break;
97
98 case vtn_base_type_pointer:
99 if (type->type) {
100 func->params[(*param_idx)++] = (nir_parameter) {
101 .num_components = glsl_get_vector_elements(type->type),
102 .bit_size = glsl_get_bit_size(type->type),
103 };
104 } else {
105 func->params[(*param_idx)++] = nir_deref_param;
106 }
107 break;
108
109 default:
110 func->params[(*param_idx)++] = (nir_parameter) {
111 .num_components = glsl_get_vector_elements(type->type),
112 .bit_size = glsl_get_bit_size(type->type),
113 };
114 }
115 }
116
117 static void
118 vtn_ssa_value_add_to_call_params(struct vtn_builder *b,
119 struct vtn_ssa_value *value,
120 struct vtn_type *type,
121 nir_call_instr *call,
122 unsigned *param_idx)
123 {
124 switch (type->base_type) {
125 case vtn_base_type_array:
126 for (unsigned i = 0; i < type->length; i++) {
127 vtn_ssa_value_add_to_call_params(b, value->elems[i],
128 type->array_element,
129 call, param_idx);
130 }
131 break;
132
133 case vtn_base_type_struct:
134 for (unsigned i = 0; i < type->length; i++) {
135 vtn_ssa_value_add_to_call_params(b, value->elems[i],
136 type->members[i],
137 call, param_idx);
138 }
139 break;
140
141 default:
142 call->params[(*param_idx)++] = nir_src_for_ssa(value->def);
143 break;
144 }
145 }
146
147 static void
148 vtn_ssa_value_load_function_param(struct vtn_builder *b,
149 struct vtn_ssa_value *value,
150 struct vtn_type *type,
151 unsigned *param_idx)
152 {
153 switch (type->base_type) {
154 case vtn_base_type_array:
155 for (unsigned i = 0; i < type->length; i++) {
156 vtn_ssa_value_load_function_param(b, value->elems[i],
157 type->array_element, param_idx);
158 }
159 break;
160
161 case vtn_base_type_struct:
162 for (unsigned i = 0; i < type->length; i++) {
163 vtn_ssa_value_load_function_param(b, value->elems[i],
164 type->members[i], param_idx);
165 }
166 break;
167
168 default:
169 value->def = nir_load_param(&b->nb, (*param_idx)++);
170 break;
171 }
172 }
173
174 void
175 vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
176 const uint32_t *w, unsigned count)
177 {
178 struct vtn_type *res_type = vtn_value(b, w[1], vtn_value_type_type)->type;
179 struct vtn_function *vtn_callee =
180 vtn_value(b, w[3], vtn_value_type_function)->func;
181 struct nir_function *callee = vtn_callee->impl->function;
182
183 vtn_callee->referenced = true;
184
185 nir_call_instr *call = nir_call_instr_create(b->nb.shader, callee);
186
187 unsigned param_idx = 0;
188
189 nir_deref_instr *ret_deref = NULL;
190 struct vtn_type *ret_type = vtn_callee->type->return_type;
191 if (ret_type->base_type != vtn_base_type_void) {
192 nir_variable *ret_tmp =
193 nir_local_variable_create(b->nb.impl, ret_type->type, "return_tmp");
194 ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
195 call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
196 }
197
198 for (unsigned i = 0; i < vtn_callee->type->length; i++) {
199 struct vtn_type *arg_type = vtn_callee->type->params[i];
200 unsigned arg_id = w[4 + i];
201
202 if (arg_type->base_type == vtn_base_type_sampled_image) {
203 struct vtn_sampled_image *sampled_image =
204 vtn_value(b, arg_id, vtn_value_type_sampled_image)->sampled_image;
205
206 call->params[param_idx++] =
207 nir_src_for_ssa(&sampled_image->image->deref->dest.ssa);
208 call->params[param_idx++] =
209 nir_src_for_ssa(&sampled_image->sampler->deref->dest.ssa);
210 } else if (arg_type->base_type == vtn_base_type_pointer ||
211 arg_type->base_type == vtn_base_type_image ||
212 arg_type->base_type == vtn_base_type_sampler) {
213 struct vtn_pointer *pointer =
214 vtn_value(b, arg_id, vtn_value_type_pointer)->pointer;
215 call->params[param_idx++] =
216 nir_src_for_ssa(vtn_pointer_to_ssa(b, pointer));
217 } else {
218 vtn_ssa_value_add_to_call_params(b, vtn_ssa_value(b, arg_id),
219 arg_type, call, &param_idx);
220 }
221 }
222 assert(param_idx == call->num_params);
223
224 nir_builder_instr_insert(&b->nb, &call->instr);
225
226 if (ret_type->base_type == vtn_base_type_void) {
227 vtn_push_value(b, w[2], vtn_value_type_undef);
228 } else {
229 vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, ret_deref));
230 }
231 }
232
233 static bool
234 vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
235 const uint32_t *w, unsigned count)
236 {
237 switch (opcode) {
238 case SpvOpFunction: {
239 vtn_assert(b->func == NULL);
240 b->func = rzalloc(b, struct vtn_function);
241
242 list_inithead(&b->func->body);
243 b->func->control = w[3];
244
245 MAYBE_UNUSED const struct glsl_type *result_type =
246 vtn_value(b, w[1], vtn_value_type_type)->type->type;
247 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
248 val->func = b->func;
249
250 b->func->type = vtn_value(b, w[4], vtn_value_type_type)->type;
251 const struct vtn_type *func_type = b->func->type;
252
253 vtn_assert(func_type->return_type->type == result_type);
254
255 nir_function *func =
256 nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
257
258 unsigned num_params = 0;
259 for (unsigned i = 0; i < func_type->length; i++)
260 num_params += vtn_type_count_function_params(func_type->params[i]);
261
262 /* Add one parameter for the function return value */
263 if (func_type->return_type->base_type != vtn_base_type_void)
264 num_params++;
265
266 func->num_params = num_params;
267 func->params = ralloc_array(b->shader, nir_parameter, num_params);
268
269 unsigned idx = 0;
270 if (func_type->return_type->base_type != vtn_base_type_void) {
271 /* The return value is a regular pointer */
272 func->params[idx++] = (nir_parameter) {
273 .num_components = 1, .bit_size = 32,
274 };
275 }
276
277 for (unsigned i = 0; i < func_type->length; i++)
278 vtn_type_add_to_function_params(func_type->params[i], func, &idx);
279 assert(idx == num_params);
280
281 b->func->impl = nir_function_impl_create(func);
282 nir_builder_init(&b->nb, func->impl);
283 b->nb.cursor = nir_before_cf_list(&b->func->impl->body);
284
285 b->func_param_idx = 0;
286
287 /* The return value is the first parameter */
288 if (func_type->return_type->base_type != vtn_base_type_void)
289 b->func_param_idx++;
290 break;
291 }
292
293 case SpvOpFunctionEnd:
294 b->func->end = w;
295 b->func = NULL;
296 break;
297
298 case SpvOpFunctionParameter: {
299 struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
300
301 vtn_assert(b->func_param_idx < b->func->impl->function->num_params);
302
303 if (type->base_type == vtn_base_type_sampled_image) {
304 /* Sampled images are actually two parameters. The first is the
305 * image and the second is the sampler.
306 */
307 struct vtn_value *val =
308 vtn_push_value(b, w[2], vtn_value_type_sampled_image);
309
310 val->sampled_image = ralloc(b, struct vtn_sampled_image);
311 val->sampled_image->type = type;
312
313 struct vtn_type *sampler_type = rzalloc(b, struct vtn_type);
314 sampler_type->base_type = vtn_base_type_sampler;
315 sampler_type->type = glsl_bare_sampler_type();
316
317 val->sampled_image->image =
318 vtn_load_param_pointer(b, type, b->func_param_idx++);
319 val->sampled_image->sampler =
320 vtn_load_param_pointer(b, sampler_type, b->func_param_idx++);
321 } else if (type->base_type == vtn_base_type_pointer &&
322 type->type != NULL) {
323 /* This is a pointer with an actual storage type */
324 struct vtn_value *val =
325 vtn_push_value(b, w[2], vtn_value_type_pointer);
326 nir_ssa_def *ssa_ptr = nir_load_param(&b->nb, b->func_param_idx++);
327 val->pointer = vtn_pointer_from_ssa(b, ssa_ptr, type);
328 } else if (type->base_type == vtn_base_type_pointer ||
329 type->base_type == vtn_base_type_image ||
330 type->base_type == vtn_base_type_sampler) {
331 struct vtn_value *val =
332 vtn_push_value(b, w[2], vtn_value_type_pointer);
333 val->pointer =
334 vtn_load_param_pointer(b, type, b->func_param_idx++);
335 } else {
336 /* We're a regular SSA value. */
337 struct vtn_ssa_value *value = vtn_create_ssa_value(b, type->type);
338 vtn_ssa_value_load_function_param(b, value, type, &b->func_param_idx);
339 vtn_push_ssa(b, w[2], type, value);
340 }
341 break;
342 }
343
344 case SpvOpLabel: {
345 vtn_assert(b->block == NULL);
346 b->block = rzalloc(b, struct vtn_block);
347 b->block->node.type = vtn_cf_node_type_block;
348 b->block->label = w;
349 vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
350
351 if (b->func->start_block == NULL) {
352 /* This is the first block encountered for this function. In this
353 * case, we set the start block and add it to the list of
354 * implemented functions that we'll walk later.
355 */
356 b->func->start_block = b->block;
357 exec_list_push_tail(&b->functions, &b->func->node);
358 }
359 break;
360 }
361
362 case SpvOpSelectionMerge:
363 case SpvOpLoopMerge:
364 vtn_assert(b->block && b->block->merge == NULL);
365 b->block->merge = w;
366 break;
367
368 case SpvOpBranch:
369 case SpvOpBranchConditional:
370 case SpvOpSwitch:
371 case SpvOpKill:
372 case SpvOpReturn:
373 case SpvOpReturnValue:
374 case SpvOpUnreachable:
375 vtn_assert(b->block && b->block->branch == NULL);
376 b->block->branch = w;
377 b->block = NULL;
378 break;
379
380 default:
381 /* Continue on as per normal */
382 return true;
383 }
384
385 return true;
386 }
387
388 static void
389 vtn_add_case(struct vtn_builder *b, struct vtn_switch *swtch,
390 struct vtn_block *break_block,
391 uint32_t block_id, uint64_t val, bool is_default)
392 {
393 struct vtn_block *case_block =
394 vtn_value(b, block_id, vtn_value_type_block)->block;
395
396 /* Don't create dummy cases that just break */
397 if (case_block == break_block)
398 return;
399
400 if (case_block->switch_case == NULL) {
401 struct vtn_case *c = ralloc(b, struct vtn_case);
402
403 list_inithead(&c->body);
404 c->start_block = case_block;
405 c->fallthrough = NULL;
406 util_dynarray_init(&c->values, b);
407 c->is_default = false;
408 c->visited = false;
409
410 list_addtail(&c->link, &swtch->cases);
411
412 case_block->switch_case = c;
413 }
414
415 if (is_default) {
416 case_block->switch_case->is_default = true;
417 } else {
418 util_dynarray_append(&case_block->switch_case->values, uint64_t, val);
419 }
420 }
421
422 /* This function performs a depth-first search of the cases and puts them
423 * in fall-through order.
424 */
425 static void
426 vtn_order_case(struct vtn_switch *swtch, struct vtn_case *cse)
427 {
428 if (cse->visited)
429 return;
430
431 cse->visited = true;
432
433 list_del(&cse->link);
434
435 if (cse->fallthrough) {
436 vtn_order_case(swtch, cse->fallthrough);
437
438 /* If we have a fall-through, place this case right before the case it
439 * falls through to. This ensures that fallthroughs come one after
440 * the other. These two can never get separated because that would
441 * imply something else falling through to the same case. Also, this
442 * can't break ordering because the DFS ensures that this case is
443 * visited before anything that falls through to it.
444 */
445 list_addtail(&cse->link, &cse->fallthrough->link);
446 } else {
447 list_add(&cse->link, &swtch->cases);
448 }
449 }
450
451 static enum vtn_branch_type
452 vtn_get_branch_type(struct vtn_builder *b,
453 struct vtn_block *block,
454 struct vtn_case *swcase, struct vtn_block *switch_break,
455 struct vtn_block *loop_break, struct vtn_block *loop_cont)
456 {
457 if (block->switch_case) {
458 /* This branch is actually a fallthrough */
459 vtn_assert(swcase->fallthrough == NULL ||
460 swcase->fallthrough == block->switch_case);
461 swcase->fallthrough = block->switch_case;
462 return vtn_branch_type_switch_fallthrough;
463 } else if (block == loop_break) {
464 return vtn_branch_type_loop_break;
465 } else if (block == loop_cont) {
466 return vtn_branch_type_loop_continue;
467 } else if (block == switch_break) {
468 return vtn_branch_type_switch_break;
469 } else {
470 return vtn_branch_type_none;
471 }
472 }
473
474 static void
475 vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
476 struct vtn_block *start, struct vtn_case *switch_case,
477 struct vtn_block *switch_break,
478 struct vtn_block *loop_break, struct vtn_block *loop_cont,
479 struct vtn_block *end)
480 {
481 struct vtn_block *block = start;
482 while (block != end) {
483 if (block->merge && (*block->merge & SpvOpCodeMask) == SpvOpLoopMerge &&
484 !block->loop) {
485 struct vtn_loop *loop = ralloc(b, struct vtn_loop);
486
487 loop->node.type = vtn_cf_node_type_loop;
488 list_inithead(&loop->body);
489 list_inithead(&loop->cont_body);
490 loop->control = block->merge[3];
491
492 list_addtail(&loop->node.link, cf_list);
493 block->loop = loop;
494
495 struct vtn_block *new_loop_break =
496 vtn_value(b, block->merge[1], vtn_value_type_block)->block;
497 struct vtn_block *new_loop_cont =
498 vtn_value(b, block->merge[2], vtn_value_type_block)->block;
499
500 /* Note: This recursive call will start with the current block as
501 * its start block. If we weren't careful, we would get here
502 * again and end up in infinite recursion. This is why we set
503 * block->loop above and check for it before creating one. This
504 * way, we only create the loop once and the second call that
505 * tries to handle this loop goes to the cases below and gets
506 * handled as a regular block.
507 *
508 * Note: When we make the recursive walk calls, we pass NULL for
509 * the switch break since you have to break out of the loop first.
510 * We do, however, still pass the current switch case because it's
511 * possible that the merge block for the loop is the start of
512 * another case.
513 */
514 vtn_cfg_walk_blocks(b, &loop->body, block, switch_case, NULL,
515 new_loop_break, new_loop_cont, NULL );
516 vtn_cfg_walk_blocks(b, &loop->cont_body, new_loop_cont, NULL, NULL,
517 new_loop_break, NULL, block);
518
519 enum vtn_branch_type branch_type =
520 vtn_get_branch_type(b, new_loop_break, switch_case, switch_break,
521 loop_break, loop_cont);
522
523 if (branch_type != vtn_branch_type_none) {
524 /* Stop walking through the CFG when this inner loop's break block
525 * ends up as the same block as the outer loop's continue block
526 * because we are already going to visit it.
527 */
528 vtn_assert(branch_type == vtn_branch_type_loop_continue);
529 return;
530 }
531
532 block = new_loop_break;
533 continue;
534 }
535
536 vtn_assert(block->node.link.next == NULL);
537 list_addtail(&block->node.link, cf_list);
538
539 switch (*block->branch & SpvOpCodeMask) {
540 case SpvOpBranch: {
541 struct vtn_block *branch_block =
542 vtn_value(b, block->branch[1], vtn_value_type_block)->block;
543
544 block->branch_type = vtn_get_branch_type(b, branch_block,
545 switch_case, switch_break,
546 loop_break, loop_cont);
547
548 if (block->branch_type != vtn_branch_type_none)
549 return;
550
551 block = branch_block;
552 continue;
553 }
554
555 case SpvOpReturn:
556 case SpvOpReturnValue:
557 block->branch_type = vtn_branch_type_return;
558 return;
559
560 case SpvOpKill:
561 block->branch_type = vtn_branch_type_discard;
562 return;
563
564 case SpvOpBranchConditional: {
565 struct vtn_block *then_block =
566 vtn_value(b, block->branch[2], vtn_value_type_block)->block;
567 struct vtn_block *else_block =
568 vtn_value(b, block->branch[3], vtn_value_type_block)->block;
569
570 struct vtn_if *if_stmt = ralloc(b, struct vtn_if);
571
572 if_stmt->node.type = vtn_cf_node_type_if;
573 if_stmt->condition = block->branch[1];
574 list_inithead(&if_stmt->then_body);
575 list_inithead(&if_stmt->else_body);
576
577 list_addtail(&if_stmt->node.link, cf_list);
578
579 if (block->merge &&
580 (*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge) {
581 if_stmt->control = block->merge[2];
582 }
583
584 if_stmt->then_type = vtn_get_branch_type(b, then_block,
585 switch_case, switch_break,
586 loop_break, loop_cont);
587 if_stmt->else_type = vtn_get_branch_type(b, else_block,
588 switch_case, switch_break,
589 loop_break, loop_cont);
590
591 if (then_block == else_block) {
592 block->branch_type = if_stmt->then_type;
593 if (block->branch_type == vtn_branch_type_none) {
594 block = then_block;
595 continue;
596 } else {
597 return;
598 }
599 } else if (if_stmt->then_type == vtn_branch_type_none &&
600 if_stmt->else_type == vtn_branch_type_none) {
601 /* Neither side of the if is something we can short-circuit. */
602 vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
603 struct vtn_block *merge_block =
604 vtn_value(b, block->merge[1], vtn_value_type_block)->block;
605
606 vtn_cfg_walk_blocks(b, &if_stmt->then_body, then_block,
607 switch_case, switch_break,
608 loop_break, loop_cont, merge_block);
609 vtn_cfg_walk_blocks(b, &if_stmt->else_body, else_block,
610 switch_case, switch_break,
611 loop_break, loop_cont, merge_block);
612
613 enum vtn_branch_type merge_type =
614 vtn_get_branch_type(b, merge_block, switch_case, switch_break,
615 loop_break, loop_cont);
616 if (merge_type == vtn_branch_type_none) {
617 block = merge_block;
618 continue;
619 } else {
620 return;
621 }
622 } else if (if_stmt->then_type != vtn_branch_type_none &&
623 if_stmt->else_type != vtn_branch_type_none) {
624 /* Both sides were short-circuited. We're done here. */
625 return;
626 } else {
627 /* Exeactly one side of the branch could be short-circuited.
628 * We set the branch up as a predicated break/continue and we
629 * continue on with the other side as if it were what comes
630 * after the if.
631 */
632 if (if_stmt->then_type == vtn_branch_type_none) {
633 block = then_block;
634 } else {
635 block = else_block;
636 }
637 continue;
638 }
639 vtn_fail("Should have returned or continued");
640 }
641
642 case SpvOpSwitch: {
643 vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
644 struct vtn_block *break_block =
645 vtn_value(b, block->merge[1], vtn_value_type_block)->block;
646
647 struct vtn_switch *swtch = ralloc(b, struct vtn_switch);
648
649 swtch->node.type = vtn_cf_node_type_switch;
650 swtch->selector = block->branch[1];
651 list_inithead(&swtch->cases);
652
653 list_addtail(&swtch->node.link, cf_list);
654
655 /* First, we go through and record all of the cases. */
656 const uint32_t *branch_end =
657 block->branch + (block->branch[0] >> SpvWordCountShift);
658
659 struct vtn_value *cond_val = vtn_untyped_value(b, block->branch[1]);
660 vtn_fail_if(!cond_val->type ||
661 cond_val->type->base_type != vtn_base_type_scalar,
662 "Selector of OpSelect must have a type of OpTypeInt");
663
664 nir_alu_type cond_type =
665 nir_get_nir_type_for_glsl_type(cond_val->type->type);
666 vtn_fail_if(nir_alu_type_get_base_type(cond_type) != nir_type_int &&
667 nir_alu_type_get_base_type(cond_type) != nir_type_uint,
668 "Selector of OpSelect must have a type of OpTypeInt");
669
670 bool is_default = true;
671 const unsigned bitsize = nir_alu_type_get_type_size(cond_type);
672 for (const uint32_t *w = block->branch + 2; w < branch_end;) {
673 uint64_t literal = 0;
674 if (!is_default) {
675 if (bitsize <= 32) {
676 literal = *(w++);
677 } else {
678 assert(bitsize == 64);
679 literal = vtn_u64_literal(w);
680 w += 2;
681 }
682 }
683
684 uint32_t block_id = *(w++);
685
686 vtn_add_case(b, swtch, break_block, block_id, literal, is_default);
687 is_default = false;
688 }
689
690 /* Now, we go through and walk the blocks. While we walk through
691 * the blocks, we also gather the much-needed fall-through
692 * information.
693 */
694 list_for_each_entry(struct vtn_case, cse, &swtch->cases, link) {
695 vtn_assert(cse->start_block != break_block);
696 vtn_cfg_walk_blocks(b, &cse->body, cse->start_block, cse,
697 break_block, loop_break, loop_cont, NULL);
698 }
699
700 /* Finally, we walk over all of the cases one more time and put
701 * them in fall-through order.
702 */
703 for (const uint32_t *w = block->branch + 2; w < branch_end;) {
704 struct vtn_block *case_block =
705 vtn_value(b, *w, vtn_value_type_block)->block;
706
707 if (bitsize <= 32) {
708 w += 2;
709 } else {
710 assert(bitsize == 64);
711 w += 3;
712 }
713
714 if (case_block == break_block)
715 continue;
716
717 vtn_assert(case_block->switch_case);
718
719 vtn_order_case(swtch, case_block->switch_case);
720 }
721
722 enum vtn_branch_type branch_type =
723 vtn_get_branch_type(b, break_block, switch_case, NULL,
724 loop_break, loop_cont);
725
726 if (branch_type != vtn_branch_type_none) {
727 /* It is possible that the break is actually the continue block
728 * for the containing loop. In this case, we need to bail and let
729 * the loop parsing code handle the continue properly.
730 */
731 vtn_assert(branch_type == vtn_branch_type_loop_continue);
732 return;
733 }
734
735 block = break_block;
736 continue;
737 }
738
739 case SpvOpUnreachable:
740 return;
741
742 default:
743 vtn_fail("Unhandled opcode");
744 }
745 }
746 }
747
748 void
749 vtn_build_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
750 {
751 vtn_foreach_instruction(b, words, end,
752 vtn_cfg_handle_prepass_instruction);
753
754 foreach_list_typed(struct vtn_function, func, node, &b->functions) {
755 vtn_cfg_walk_blocks(b, &func->body, func->start_block,
756 NULL, NULL, NULL, NULL, NULL);
757 }
758 }
759
760 static bool
761 vtn_handle_phis_first_pass(struct vtn_builder *b, SpvOp opcode,
762 const uint32_t *w, unsigned count)
763 {
764 if (opcode == SpvOpLabel)
765 return true; /* Nothing to do */
766
767 /* If this isn't a phi node, stop. */
768 if (opcode != SpvOpPhi)
769 return false;
770
771 /* For handling phi nodes, we do a poor-man's out-of-ssa on the spot.
772 * For each phi, we create a variable with the appropreate type and
773 * do a load from that variable. Then, in a second pass, we add
774 * stores to that variable to each of the predecessor blocks.
775 *
776 * We could do something more intelligent here. However, in order to
777 * handle loops and things properly, we really need dominance
778 * information. It would end up basically being the into-SSA
779 * algorithm all over again. It's easier if we just let
780 * lower_vars_to_ssa do that for us instead of repeating it here.
781 */
782 struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
783 nir_variable *phi_var =
784 nir_local_variable_create(b->nb.impl, type->type, "phi");
785 _mesa_hash_table_insert(b->phi_table, w, phi_var);
786
787 vtn_push_ssa(b, w[2], type,
788 vtn_local_load(b, nir_build_deref_var(&b->nb, phi_var)));
789
790 return true;
791 }
792
793 static bool
794 vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
795 const uint32_t *w, unsigned count)
796 {
797 if (opcode != SpvOpPhi)
798 return true;
799
800 struct hash_entry *phi_entry = _mesa_hash_table_search(b->phi_table, w);
801 vtn_assert(phi_entry);
802 nir_variable *phi_var = phi_entry->data;
803
804 for (unsigned i = 3; i < count; i += 2) {
805 struct vtn_block *pred =
806 vtn_value(b, w[i + 1], vtn_value_type_block)->block;
807
808 b->nb.cursor = nir_after_instr(&pred->end_nop->instr);
809
810 struct vtn_ssa_value *src = vtn_ssa_value(b, w[i]);
811
812 vtn_local_store(b, src, nir_build_deref_var(&b->nb, phi_var));
813 }
814
815 return true;
816 }
817
818 static void
819 vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type,
820 nir_variable *switch_fall_var, bool *has_switch_break)
821 {
822 switch (branch_type) {
823 case vtn_branch_type_switch_break:
824 nir_store_var(&b->nb, switch_fall_var, nir_imm_false(&b->nb), 1);
825 *has_switch_break = true;
826 break;
827 case vtn_branch_type_switch_fallthrough:
828 break; /* Nothing to do */
829 case vtn_branch_type_loop_break:
830 nir_jump(&b->nb, nir_jump_break);
831 break;
832 case vtn_branch_type_loop_continue:
833 nir_jump(&b->nb, nir_jump_continue);
834 break;
835 case vtn_branch_type_return:
836 nir_jump(&b->nb, nir_jump_return);
837 break;
838 case vtn_branch_type_discard: {
839 nir_intrinsic_instr *discard =
840 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_discard);
841 nir_builder_instr_insert(&b->nb, &discard->instr);
842 break;
843 }
844 default:
845 vtn_fail("Invalid branch type");
846 }
847 }
848
849 static void
850 vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
851 nir_variable *switch_fall_var, bool *has_switch_break,
852 vtn_instruction_handler handler)
853 {
854 list_for_each_entry(struct vtn_cf_node, node, cf_list, link) {
855 switch (node->type) {
856 case vtn_cf_node_type_block: {
857 struct vtn_block *block = (struct vtn_block *)node;
858
859 const uint32_t *block_start = block->label;
860 const uint32_t *block_end = block->merge ? block->merge :
861 block->branch;
862
863 block_start = vtn_foreach_instruction(b, block_start, block_end,
864 vtn_handle_phis_first_pass);
865
866 vtn_foreach_instruction(b, block_start, block_end, handler);
867
868 block->end_nop = nir_intrinsic_instr_create(b->nb.shader,
869 nir_intrinsic_nop);
870 nir_builder_instr_insert(&b->nb, &block->end_nop->instr);
871
872 if ((*block->branch & SpvOpCodeMask) == SpvOpReturnValue) {
873 vtn_fail_if(b->func->type->return_type->base_type ==
874 vtn_base_type_void,
875 "Return with a value from a function returning void");
876 struct vtn_ssa_value *src = vtn_ssa_value(b, block->branch[1]);
877 nir_deref_instr *ret_deref =
878 nir_build_deref_cast(&b->nb, nir_load_param(&b->nb, 0),
879 nir_var_local, src->type);
880 vtn_local_store(b, src, ret_deref);
881 }
882
883 if (block->branch_type != vtn_branch_type_none) {
884 vtn_emit_branch(b, block->branch_type,
885 switch_fall_var, has_switch_break);
886 }
887
888 break;
889 }
890
891 case vtn_cf_node_type_if: {
892 struct vtn_if *vtn_if = (struct vtn_if *)node;
893 bool sw_break = false;
894
895 nir_if *nif =
896 nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def);
897 if (vtn_if->then_type == vtn_branch_type_none) {
898 vtn_emit_cf_list(b, &vtn_if->then_body,
899 switch_fall_var, &sw_break, handler);
900 } else {
901 vtn_emit_branch(b, vtn_if->then_type, switch_fall_var, &sw_break);
902 }
903
904 nir_push_else(&b->nb, nif);
905 if (vtn_if->else_type == vtn_branch_type_none) {
906 vtn_emit_cf_list(b, &vtn_if->else_body,
907 switch_fall_var, &sw_break, handler);
908 } else {
909 vtn_emit_branch(b, vtn_if->else_type, switch_fall_var, &sw_break);
910 }
911
912 nir_pop_if(&b->nb, nif);
913
914 /* If we encountered a switch break somewhere inside of the if,
915 * then it would have been handled correctly by calling
916 * emit_cf_list or emit_branch for the interrior. However, we
917 * need to predicate everything following on wether or not we're
918 * still going.
919 */
920 if (sw_break) {
921 *has_switch_break = true;
922 nir_push_if(&b->nb, nir_load_var(&b->nb, switch_fall_var));
923 }
924 break;
925 }
926
927 case vtn_cf_node_type_loop: {
928 struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
929
930 nir_loop *loop = nir_push_loop(&b->nb);
931 vtn_emit_cf_list(b, &vtn_loop->body, NULL, NULL, handler);
932
933 if (!list_empty(&vtn_loop->cont_body)) {
934 /* If we have a non-trivial continue body then we need to put
935 * it at the beginning of the loop with a flag to ensure that
936 * it doesn't get executed in the first iteration.
937 */
938 nir_variable *do_cont =
939 nir_local_variable_create(b->nb.impl, glsl_bool_type(), "cont");
940
941 b->nb.cursor = nir_before_cf_node(&loop->cf_node);
942 nir_store_var(&b->nb, do_cont, nir_imm_false(&b->nb), 1);
943
944 b->nb.cursor = nir_before_cf_list(&loop->body);
945
946 nir_if *cont_if =
947 nir_push_if(&b->nb, nir_load_var(&b->nb, do_cont));
948
949 vtn_emit_cf_list(b, &vtn_loop->cont_body, NULL, NULL, handler);
950
951 nir_pop_if(&b->nb, cont_if);
952
953 nir_store_var(&b->nb, do_cont, nir_imm_true(&b->nb), 1);
954
955 b->has_loop_continue = true;
956 }
957
958 nir_pop_loop(&b->nb, loop);
959 break;
960 }
961
962 case vtn_cf_node_type_switch: {
963 struct vtn_switch *vtn_switch = (struct vtn_switch *)node;
964
965 /* First, we create a variable to keep track of whether or not the
966 * switch is still going at any given point. Any switch breaks
967 * will set this variable to false.
968 */
969 nir_variable *fall_var =
970 nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall");
971 nir_store_var(&b->nb, fall_var, nir_imm_false(&b->nb), 1);
972
973 /* Next, we gather up all of the conditions. We have to do this
974 * up-front because we also need to build an "any" condition so
975 * that we can use !any for default.
976 */
977 const int num_cases = list_length(&vtn_switch->cases);
978 NIR_VLA(nir_ssa_def *, conditions, num_cases);
979
980 nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def;
981 /* An accumulation of all conditions. Used for the default */
982 nir_ssa_def *any = NULL;
983
984 int i = 0;
985 list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
986 if (cse->is_default) {
987 conditions[i++] = NULL;
988 continue;
989 }
990
991 nir_ssa_def *cond = NULL;
992 util_dynarray_foreach(&cse->values, uint64_t, val) {
993 nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
994 nir_ssa_def *is_val = nir_ieq(&b->nb, sel, imm);
995
996 cond = cond ? nir_ior(&b->nb, cond, is_val) : is_val;
997 }
998
999 any = any ? nir_ior(&b->nb, any, cond) : cond;
1000 conditions[i++] = cond;
1001 }
1002 vtn_assert(i == num_cases);
1003
1004 /* Now we can walk the list of cases and actually emit code */
1005 i = 0;
1006 list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
1007 /* Figure out the condition */
1008 nir_ssa_def *cond = conditions[i++];
1009 if (cse->is_default) {
1010 vtn_assert(cond == NULL);
1011 cond = nir_inot(&b->nb, any);
1012 }
1013 /* Take fallthrough into account */
1014 cond = nir_ior(&b->nb, cond, nir_load_var(&b->nb, fall_var));
1015
1016 nir_if *case_if = nir_push_if(&b->nb, cond);
1017
1018 bool has_break = false;
1019 nir_store_var(&b->nb, fall_var, nir_imm_true(&b->nb), 1);
1020 vtn_emit_cf_list(b, &cse->body, fall_var, &has_break, handler);
1021 (void)has_break; /* We don't care */
1022
1023 nir_pop_if(&b->nb, case_if);
1024 }
1025 vtn_assert(i == num_cases);
1026
1027 break;
1028 }
1029
1030 default:
1031 vtn_fail("Invalid CF node type");
1032 }
1033 }
1034 }
1035
1036 void
1037 vtn_function_emit(struct vtn_builder *b, struct vtn_function *func,
1038 vtn_instruction_handler instruction_handler)
1039 {
1040 nir_builder_init(&b->nb, func->impl);
1041 b->func = func;
1042 b->nb.cursor = nir_after_cf_list(&func->impl->body);
1043 b->has_loop_continue = false;
1044 b->phi_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
1045 _mesa_key_pointer_equal);
1046
1047 vtn_emit_cf_list(b, &func->body, NULL, NULL, instruction_handler);
1048
1049 vtn_foreach_instruction(b, func->start_block->label, func->end,
1050 vtn_handle_phi_second_pass);
1051
1052 /* Continue blocks for loops get inserted before the body of the loop
1053 * but instructions in the continue may use SSA defs in the loop body.
1054 * Therefore, we need to repair SSA to insert the needed phi nodes.
1055 */
1056 if (b->has_loop_continue)
1057 nir_repair_ssa_impl(func->impl);
1058
1059 func->emitted = true;
1060 }