zink: emit interpolation decorations for ntv outputs
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 unsigned samplers_used : PIPE_MAX_SAMPLERS;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
172 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
335 break;
336
337 case FRAG_RESULT_DEPTH:
338 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
339 break;
340
341 default:
342 spirv_builder_emit_location(&ctx->builder, var_id,
343 var->data.driver_location);
344 }
345 }
346 }
347
348 if (var->data.location_frac)
349 spirv_builder_emit_component(&ctx->builder, var_id,
350 var->data.location_frac);
351
352 switch (var->data.interpolation) {
353 case INTERP_MODE_NONE:
354 case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
355 break;
356 case INTERP_MODE_FLAT:
357 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
358 break;
359 case INTERP_MODE_EXPLICIT:
360 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
361 break;
362 case INTERP_MODE_NOPERSPECTIVE:
363 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
364 break;
365 default:
366 unreachable("unknown interpolation value");
367 }
368
369 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
370
371 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
372 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
373 }
374
375 static SpvDim
376 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
377 {
378 *is_ms = false;
379 switch (gdim) {
380 case GLSL_SAMPLER_DIM_1D:
381 return SpvDim1D;
382 case GLSL_SAMPLER_DIM_2D:
383 return SpvDim2D;
384 case GLSL_SAMPLER_DIM_3D:
385 return SpvDim3D;
386 case GLSL_SAMPLER_DIM_CUBE:
387 return SpvDimCube;
388 case GLSL_SAMPLER_DIM_RECT:
389 return SpvDim2D;
390 case GLSL_SAMPLER_DIM_BUF:
391 return SpvDimBuffer;
392 case GLSL_SAMPLER_DIM_EXTERNAL:
393 return SpvDim2D; /* seems dodgy... */
394 case GLSL_SAMPLER_DIM_MS:
395 *is_ms = true;
396 return SpvDim2D;
397 default:
398 fprintf(stderr, "unknown sampler type %d\n", gdim);
399 break;
400 }
401 return SpvDim2D;
402 }
403
404 uint32_t
405 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
406 {
407 if (stage == MESA_SHADER_NONE ||
408 stage >= MESA_SHADER_COMPUTE) {
409 unreachable("not supported");
410 } else {
411 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
412 PIPE_MAX_SHADER_SAMPLER_VIEWS);
413
414 switch (type) {
415 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
416 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
417 return stage_offset + index;
418
419 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
420 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
421 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
422
423 default:
424 unreachable("unexpected type");
425 }
426 }
427 }
428
429 static void
430 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
431 {
432 const struct glsl_type *type = glsl_without_array(var->type);
433
434 bool is_ms;
435 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
436
437 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
438 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
439 dimension, false,
440 glsl_sampler_type_is_array(type),
441 is_ms, 1,
442 SpvImageFormatUnknown);
443
444 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
445 image_type);
446 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
447 SpvStorageClassUniformConstant,
448 sampled_type);
449
450 if (glsl_type_is_array(var->type)) {
451 for (int i = 0; i < glsl_get_length(var->type); ++i) {
452 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
453 SpvStorageClassUniformConstant);
454
455 if (var->name) {
456 char element_name[100];
457 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
458 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
459 }
460
461 int index = var->data.binding + i;
462 assert(!(ctx->samplers_used & (1 << index)));
463 assert(!ctx->image_types[index]);
464 ctx->image_types[index] = image_type;
465 ctx->samplers[index] = var_id;
466 ctx->samplers_used |= 1 << index;
467
468 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
469 var->data.descriptor_set);
470 int binding = zink_binding(ctx->stage,
471 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
472 var->data.binding + i);
473 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
474 }
475 } else {
476 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
477 SpvStorageClassUniformConstant);
478
479 if (var->name)
480 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
481
482 int index = var->data.binding;
483 assert(!(ctx->samplers_used & (1 << index)));
484 assert(!ctx->image_types[index]);
485 ctx->image_types[index] = image_type;
486 ctx->samplers[index] = var_id;
487 ctx->samplers_used |= 1 << index;
488
489 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
490 var->data.descriptor_set);
491 int binding = zink_binding(ctx->stage,
492 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
493 var->data.binding);
494 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
495 }
496 }
497
498 static void
499 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
500 {
501 uint32_t size = glsl_count_attribute_slots(var->type, false);
502 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
503 SpvId array_length = emit_uint_const(ctx, 32, size);
504 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
505 array_length);
506 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
507
508 // wrap UBO-array in a struct
509 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
510 if (var->name) {
511 char struct_name[100];
512 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
513 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
514 }
515
516 spirv_builder_emit_decoration(&ctx->builder, struct_type,
517 SpvDecorationBlock);
518 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
519
520
521 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
522 SpvStorageClassUniform,
523 struct_type);
524
525 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
526 SpvStorageClassUniform);
527 if (var->name)
528 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
529
530 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
531 ctx->ubos[ctx->num_ubos++] = var_id;
532
533 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
534 var->data.descriptor_set);
535 int binding = zink_binding(ctx->stage,
536 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
537 var->data.binding);
538 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
539 }
540
541 static void
542 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
543 {
544 if (var->data.mode == nir_var_mem_ubo)
545 emit_ubo(ctx, var);
546 else {
547 assert(var->data.mode == nir_var_uniform);
548 if (glsl_type_is_sampler(glsl_without_array(var->type)))
549 emit_sampler(ctx, var);
550 }
551 }
552
553 static SpvId
554 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
555 {
556 assert(ssa->index < ctx->num_defs);
557 assert(ctx->defs[ssa->index] != 0);
558 return ctx->defs[ssa->index];
559 }
560
561 static SpvId
562 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
563 {
564 assert(reg->index < ctx->num_regs);
565 assert(ctx->regs[reg->index] != 0);
566 return ctx->regs[reg->index];
567 }
568
569 static SpvId
570 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
571 {
572 assert(reg->reg);
573 assert(!reg->indirect);
574 assert(!reg->base_offset);
575
576 SpvId var = get_var_from_reg(ctx, reg->reg);
577 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
578 return spirv_builder_emit_load(&ctx->builder, type, var);
579 }
580
581 static SpvId
582 get_src(struct ntv_context *ctx, nir_src *src)
583 {
584 if (src->is_ssa)
585 return get_src_ssa(ctx, src->ssa);
586 else
587 return get_src_reg(ctx, &src->reg);
588 }
589
590 static SpvId
591 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
592 {
593 assert(!alu->src[src].negate);
594 assert(!alu->src[src].abs);
595
596 SpvId def = get_src(ctx, &alu->src[src].src);
597
598 unsigned used_channels = 0;
599 bool need_swizzle = false;
600 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
601 if (!nir_alu_instr_channel_used(alu, src, i))
602 continue;
603
604 used_channels++;
605
606 if (alu->src[src].swizzle[i] != i)
607 need_swizzle = true;
608 }
609 assert(used_channels != 0);
610
611 unsigned live_channels = nir_src_num_components(alu->src[src].src);
612 if (used_channels != live_channels)
613 need_swizzle = true;
614
615 if (!need_swizzle)
616 return def;
617
618 int bit_size = nir_src_bit_size(alu->src[src].src);
619 assert(bit_size == 1 || bit_size == 32);
620
621 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
622 spirv_builder_type_uint(&ctx->builder, bit_size);
623
624 if (used_channels == 1) {
625 uint32_t indices[] = { alu->src[src].swizzle[0] };
626 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
627 def, indices,
628 ARRAY_SIZE(indices));
629 } else if (live_channels == 1) {
630 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
631 raw_type,
632 used_channels);
633
634 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
635 for (unsigned i = 0; i < used_channels; ++i)
636 constituents[i] = def;
637
638 return spirv_builder_emit_composite_construct(&ctx->builder,
639 raw_vec_type,
640 constituents,
641 used_channels);
642 } else {
643 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
644 raw_type,
645 used_channels);
646
647 uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
648 size_t num_components = 0;
649 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
650 if (!nir_alu_instr_channel_used(alu, src, i))
651 continue;
652
653 components[num_components++] = alu->src[src].swizzle[i];
654 }
655
656 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
657 def, def, components,
658 num_components);
659 }
660 }
661
662 static void
663 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
664 {
665 assert(result != 0);
666 assert(ssa->index < ctx->num_defs);
667 ctx->defs[ssa->index] = result;
668 }
669
670 static SpvId
671 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
672 SpvId if_true, SpvId if_false)
673 {
674 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
675 }
676
677 static SpvId
678 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
679 {
680 SpvId type = get_bvec_type(ctx, num_components);
681 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
682 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
683 }
684
685 static SpvId
686 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
687 {
688 return emit_unop(ctx, SpvOpBitcast, type, value);
689 }
690
691 static SpvId
692 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
693 unsigned num_components)
694 {
695 SpvId type = get_uvec_type(ctx, bit_size, num_components);
696 return emit_bitcast(ctx, type, value);
697 }
698
699 static SpvId
700 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
701 unsigned num_components)
702 {
703 SpvId type = get_ivec_type(ctx, bit_size, num_components);
704 return emit_bitcast(ctx, type, value);
705 }
706
707 static SpvId
708 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
709 unsigned num_components)
710 {
711 SpvId type = get_fvec_type(ctx, bit_size, num_components);
712 return emit_bitcast(ctx, type, value);
713 }
714
715 static void
716 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
717 {
718 SpvId var = get_var_from_reg(ctx, reg->reg);
719 assert(var);
720 spirv_builder_emit_store(&ctx->builder, var, result);
721 }
722
723 static void
724 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
725 {
726 if (dest->is_ssa)
727 store_ssa_def(ctx, &dest->ssa, result);
728 else
729 store_reg_def(ctx, &dest->reg, result);
730 }
731
732 static SpvId
733 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
734 {
735 unsigned num_components = nir_dest_num_components(*dest);
736 unsigned bit_size = nir_dest_bit_size(*dest);
737
738 if (bit_size != 1) {
739 switch (nir_alu_type_get_base_type(type)) {
740 case nir_type_bool:
741 assert("bool should have bit-size 1");
742
743 case nir_type_uint:
744 break; /* nothing to do! */
745
746 case nir_type_int:
747 case nir_type_float:
748 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
749 break;
750
751 default:
752 unreachable("unsupported nir_alu_type");
753 }
754 }
755
756 store_dest_raw(ctx, dest, result);
757 return result;
758 }
759
760 static SpvId
761 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
762 {
763 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
764 }
765
766 static SpvId
767 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
768 SpvId src0, SpvId src1)
769 {
770 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
771 }
772
773 static SpvId
774 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
775 SpvId src0, SpvId src1, SpvId src2)
776 {
777 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
778 }
779
780 static SpvId
781 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
782 SpvId src)
783 {
784 SpvId args[] = { src };
785 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
786 op, args, ARRAY_SIZE(args));
787 }
788
789 static SpvId
790 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
791 SpvId src0, SpvId src1)
792 {
793 SpvId args[] = { src0, src1 };
794 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
795 op, args, ARRAY_SIZE(args));
796 }
797
798 static SpvId
799 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
800 SpvId src0, SpvId src1, SpvId src2)
801 {
802 SpvId args[] = { src0, src1, src2 };
803 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
804 op, args, ARRAY_SIZE(args));
805 }
806
807 static SpvId
808 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
809 unsigned num_components, float value)
810 {
811 assert(bit_size == 32);
812
813 SpvId result = emit_float_const(ctx, bit_size, value);
814 if (num_components == 1)
815 return result;
816
817 assert(num_components > 1);
818 SpvId components[num_components];
819 for (int i = 0; i < num_components; i++)
820 components[i] = result;
821
822 SpvId type = get_fvec_type(ctx, bit_size, num_components);
823 return spirv_builder_const_composite(&ctx->builder, type, components,
824 num_components);
825 }
826
827 static SpvId
828 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
829 unsigned num_components, uint32_t value)
830 {
831 assert(bit_size == 32);
832
833 SpvId result = emit_uint_const(ctx, bit_size, value);
834 if (num_components == 1)
835 return result;
836
837 assert(num_components > 1);
838 SpvId components[num_components];
839 for (int i = 0; i < num_components; i++)
840 components[i] = result;
841
842 SpvId type = get_uvec_type(ctx, bit_size, num_components);
843 return spirv_builder_const_composite(&ctx->builder, type, components,
844 num_components);
845 }
846
847 static SpvId
848 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
849 unsigned num_components, int32_t value)
850 {
851 assert(bit_size == 32);
852
853 SpvId result = emit_int_const(ctx, bit_size, value);
854 if (num_components == 1)
855 return result;
856
857 assert(num_components > 1);
858 SpvId components[num_components];
859 for (int i = 0; i < num_components; i++)
860 components[i] = result;
861
862 SpvId type = get_ivec_type(ctx, bit_size, num_components);
863 return spirv_builder_const_composite(&ctx->builder, type, components,
864 num_components);
865 }
866
867 static inline unsigned
868 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
869 {
870 if (nir_op_infos[instr->op].input_sizes[src] > 0)
871 return nir_op_infos[instr->op].input_sizes[src];
872
873 if (instr->dest.dest.is_ssa)
874 return instr->dest.dest.ssa.num_components;
875 else
876 return instr->dest.dest.reg.reg->num_components;
877 }
878
879 static SpvId
880 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
881 {
882 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
883
884 unsigned num_components = alu_instr_src_components(alu, src);
885 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
886 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
887
888 if (bit_size == 1)
889 return raw_value;
890 else {
891 switch (nir_alu_type_get_base_type(type)) {
892 case nir_type_bool:
893 unreachable("bool should have bit-size 1");
894
895 case nir_type_int:
896 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
897
898 case nir_type_uint:
899 return raw_value;
900
901 case nir_type_float:
902 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
903
904 default:
905 unreachable("unknown nir_alu_type");
906 }
907 }
908 }
909
910 static SpvId
911 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
912 {
913 assert(!alu->dest.saturate);
914 return store_dest(ctx, &alu->dest.dest, result,
915 nir_op_infos[alu->op].output_type);
916 }
917
918 static SpvId
919 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
920 {
921 unsigned num_components = nir_dest_num_components(*dest);
922 unsigned bit_size = nir_dest_bit_size(*dest);
923
924 if (bit_size == 1)
925 return get_bvec_type(ctx, num_components);
926
927 switch (nir_alu_type_get_base_type(type)) {
928 case nir_type_bool:
929 unreachable("bool should have bit-size 1");
930
931 case nir_type_int:
932 return get_ivec_type(ctx, bit_size, num_components);
933
934 case nir_type_uint:
935 return get_uvec_type(ctx, bit_size, num_components);
936
937 case nir_type_float:
938 return get_fvec_type(ctx, bit_size, num_components);
939
940 default:
941 unreachable("unsupported nir_alu_type");
942 }
943 }
944
945 static void
946 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
947 {
948 SpvId src[nir_op_infos[alu->op].num_inputs];
949 unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
950 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
951 src[i] = get_alu_src(ctx, alu, i);
952 in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
953 }
954
955 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
956 nir_op_infos[alu->op].output_type);
957 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
958 unsigned num_components = nir_dest_num_components(alu->dest.dest);
959
960 SpvId result = 0;
961 switch (alu->op) {
962 case nir_op_mov:
963 assert(nir_op_infos[alu->op].num_inputs == 1);
964 result = src[0];
965 break;
966
967 #define UNOP(nir_op, spirv_op) \
968 case nir_op: \
969 assert(nir_op_infos[alu->op].num_inputs == 1); \
970 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
971 break;
972
973 UNOP(nir_op_ineg, SpvOpSNegate)
974 UNOP(nir_op_fneg, SpvOpFNegate)
975 UNOP(nir_op_fddx, SpvOpDPdx)
976 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
977 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
978 UNOP(nir_op_fddy, SpvOpDPdy)
979 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
980 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
981 UNOP(nir_op_f2i32, SpvOpConvertFToS)
982 UNOP(nir_op_f2u32, SpvOpConvertFToU)
983 UNOP(nir_op_i2f32, SpvOpConvertSToF)
984 UNOP(nir_op_u2f32, SpvOpConvertUToF)
985 #undef UNOP
986
987 case nir_op_inot:
988 if (bit_size == 1)
989 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
990 else
991 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
992 break;
993
994 case nir_op_b2i32:
995 assert(nir_op_infos[alu->op].num_inputs == 1);
996 result = emit_select(ctx, dest_type, src[0],
997 get_ivec_constant(ctx, 32, num_components, 1),
998 get_ivec_constant(ctx, 32, num_components, 0));
999 break;
1000
1001 case nir_op_b2f32:
1002 assert(nir_op_infos[alu->op].num_inputs == 1);
1003 result = emit_select(ctx, dest_type, src[0],
1004 get_fvec_constant(ctx, 32, num_components, 1),
1005 get_fvec_constant(ctx, 32, num_components, 0));
1006 break;
1007
1008 #define BUILTIN_UNOP(nir_op, spirv_op) \
1009 case nir_op: \
1010 assert(nir_op_infos[alu->op].num_inputs == 1); \
1011 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1012 break;
1013
1014 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1015 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1016 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1017 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1018 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1019 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1020 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1021 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1022 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1023 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1024 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1025 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1026 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1027 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1028 #undef BUILTIN_UNOP
1029
1030 case nir_op_frcp:
1031 assert(nir_op_infos[alu->op].num_inputs == 1);
1032 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1033 get_fvec_constant(ctx, bit_size, num_components, 1),
1034 src[0]);
1035 break;
1036
1037 case nir_op_f2b1:
1038 assert(nir_op_infos[alu->op].num_inputs == 1);
1039 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1040 get_fvec_constant(ctx,
1041 nir_src_bit_size(alu->src[0].src),
1042 num_components, 0));
1043 break;
1044 case nir_op_i2b1:
1045 assert(nir_op_infos[alu->op].num_inputs == 1);
1046 result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1047 get_ivec_constant(ctx,
1048 nir_src_bit_size(alu->src[0].src),
1049 num_components, 0));
1050 break;
1051
1052
1053 #define BINOP(nir_op, spirv_op) \
1054 case nir_op: \
1055 assert(nir_op_infos[alu->op].num_inputs == 2); \
1056 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1057 break;
1058
1059 BINOP(nir_op_iadd, SpvOpIAdd)
1060 BINOP(nir_op_isub, SpvOpISub)
1061 BINOP(nir_op_imul, SpvOpIMul)
1062 BINOP(nir_op_idiv, SpvOpSDiv)
1063 BINOP(nir_op_udiv, SpvOpUDiv)
1064 BINOP(nir_op_umod, SpvOpUMod)
1065 BINOP(nir_op_fadd, SpvOpFAdd)
1066 BINOP(nir_op_fsub, SpvOpFSub)
1067 BINOP(nir_op_fmul, SpvOpFMul)
1068 BINOP(nir_op_fdiv, SpvOpFDiv)
1069 BINOP(nir_op_fmod, SpvOpFMod)
1070 BINOP(nir_op_ilt, SpvOpSLessThan)
1071 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1072 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1073 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1074 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1075 BINOP(nir_op_feq, SpvOpFOrdEqual)
1076 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1077 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1078 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1079 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1080 #undef BINOP
1081
1082 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1083 case nir_op: \
1084 assert(nir_op_infos[alu->op].num_inputs == 2); \
1085 if (nir_src_bit_size(alu->src[0].src) == 1) \
1086 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1087 else \
1088 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1089 break;
1090
1091 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1092 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1093 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1094 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1095 #undef BINOP_LOG
1096
1097 #define BUILTIN_BINOP(nir_op, spirv_op) \
1098 case nir_op: \
1099 assert(nir_op_infos[alu->op].num_inputs == 2); \
1100 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1101 break;
1102
1103 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1104 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1105 #undef BUILTIN_BINOP
1106
1107 case nir_op_fdot2:
1108 case nir_op_fdot3:
1109 case nir_op_fdot4:
1110 assert(nir_op_infos[alu->op].num_inputs == 2);
1111 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1112 break;
1113
1114 case nir_op_fdph:
1115 unreachable("should already be lowered away");
1116
1117 case nir_op_seq:
1118 case nir_op_sne:
1119 case nir_op_slt:
1120 case nir_op_sge: {
1121 assert(nir_op_infos[alu->op].num_inputs == 2);
1122 int num_components = nir_dest_num_components(alu->dest.dest);
1123 SpvId bool_type = get_bvec_type(ctx, num_components);
1124
1125 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1126 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1127 if (num_components > 1) {
1128 SpvId zero_comps[num_components], one_comps[num_components];
1129 for (int i = 0; i < num_components; i++) {
1130 zero_comps[i] = zero;
1131 one_comps[i] = one;
1132 }
1133
1134 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1135 zero_comps, num_components);
1136 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1137 one_comps, num_components);
1138 }
1139
1140 SpvOp op;
1141 switch (alu->op) {
1142 case nir_op_seq: op = SpvOpFOrdEqual; break;
1143 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1144 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1145 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1146 default: unreachable("unexpected op");
1147 }
1148
1149 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1150 result = emit_select(ctx, dest_type, result, one, zero);
1151 }
1152 break;
1153
1154 case nir_op_flrp:
1155 assert(nir_op_infos[alu->op].num_inputs == 3);
1156 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1157 src[0], src[1], src[2]);
1158 break;
1159
1160 case nir_op_fcsel:
1161 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1162 get_bvec_type(ctx, num_components),
1163 src[0],
1164 get_fvec_constant(ctx,
1165 nir_src_bit_size(alu->src[0].src),
1166 num_components, 0));
1167 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1168 break;
1169
1170 case nir_op_bcsel:
1171 assert(nir_op_infos[alu->op].num_inputs == 3);
1172 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1173 break;
1174
1175 case nir_op_bany_fnequal2:
1176 case nir_op_bany_fnequal3:
1177 case nir_op_bany_fnequal4: {
1178 assert(nir_op_infos[alu->op].num_inputs == 2);
1179 assert(alu_instr_src_components(alu, 0) ==
1180 alu_instr_src_components(alu, 1));
1181 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1182 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1183 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1184 result = emit_binop(ctx, op,
1185 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1186 src[0], src[1]);
1187 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1188 break;
1189 }
1190
1191 case nir_op_ball_fequal2:
1192 case nir_op_ball_fequal3:
1193 case nir_op_ball_fequal4: {
1194 assert(nir_op_infos[alu->op].num_inputs == 2);
1195 assert(alu_instr_src_components(alu, 0) ==
1196 alu_instr_src_components(alu, 1));
1197 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1198 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1199 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1200 result = emit_binop(ctx, op,
1201 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1202 src[0], src[1]);
1203 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1204 break;
1205 }
1206
1207 case nir_op_bany_inequal2:
1208 case nir_op_bany_inequal3:
1209 case nir_op_bany_inequal4: {
1210 assert(nir_op_infos[alu->op].num_inputs == 2);
1211 assert(alu_instr_src_components(alu, 0) ==
1212 alu_instr_src_components(alu, 1));
1213 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1214 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1215 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1216 result = emit_binop(ctx, op,
1217 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1218 src[0], src[1]);
1219 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1220 break;
1221 }
1222
1223 case nir_op_ball_iequal2:
1224 case nir_op_ball_iequal3:
1225 case nir_op_ball_iequal4: {
1226 assert(nir_op_infos[alu->op].num_inputs == 2);
1227 assert(alu_instr_src_components(alu, 0) ==
1228 alu_instr_src_components(alu, 1));
1229 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1230 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1231 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1232 result = emit_binop(ctx, op,
1233 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1234 src[0], src[1]);
1235 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1236 break;
1237 }
1238
1239 case nir_op_vec2:
1240 case nir_op_vec3:
1241 case nir_op_vec4: {
1242 int num_inputs = nir_op_infos[alu->op].num_inputs;
1243 assert(2 <= num_inputs && num_inputs <= 4);
1244 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1245 src, num_inputs);
1246 }
1247 break;
1248
1249 default:
1250 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1251 nir_op_infos[alu->op].name);
1252
1253 unreachable("unsupported opcode");
1254 return;
1255 }
1256
1257 store_alu_result(ctx, alu, result);
1258 }
1259
1260 static void
1261 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1262 {
1263 unsigned bit_size = load_const->def.bit_size;
1264 unsigned num_components = load_const->def.num_components;
1265
1266 SpvId constant;
1267 if (num_components > 1) {
1268 SpvId components[num_components];
1269 SpvId type;
1270 if (bit_size == 1) {
1271 for (int i = 0; i < num_components; i++)
1272 components[i] = spirv_builder_const_bool(&ctx->builder,
1273 load_const->value[i].b);
1274
1275 type = get_bvec_type(ctx, num_components);
1276 } else {
1277 for (int i = 0; i < num_components; i++)
1278 components[i] = emit_uint_const(ctx, bit_size,
1279 load_const->value[i].u32);
1280
1281 type = get_uvec_type(ctx, bit_size, num_components);
1282 }
1283 constant = spirv_builder_const_composite(&ctx->builder, type,
1284 components, num_components);
1285 } else {
1286 assert(num_components == 1);
1287 if (bit_size == 1)
1288 constant = spirv_builder_const_bool(&ctx->builder,
1289 load_const->value[0].b);
1290 else
1291 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1292 }
1293
1294 store_ssa_def(ctx, &load_const->def, constant);
1295 }
1296
1297 static void
1298 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1299 {
1300 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1301 assert(const_block_index); // no dynamic indexing for now
1302 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1303
1304 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1305 if (const_offset) {
1306 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1307 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1308 SpvStorageClassUniform,
1309 uvec4_type);
1310
1311 unsigned idx = const_offset->u32;
1312 SpvId member = emit_uint_const(ctx, 32, 0);
1313 SpvId offset = emit_uint_const(ctx, 32, idx);
1314 SpvId offsets[] = { member, offset };
1315 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1316 ctx->ubos[0], offsets,
1317 ARRAY_SIZE(offsets));
1318 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1319
1320 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1321 unsigned num_components = nir_dest_num_components(intr->dest);
1322 if (num_components == 1) {
1323 uint32_t components[] = { 0 };
1324 result = spirv_builder_emit_composite_extract(&ctx->builder,
1325 type,
1326 result, components,
1327 1);
1328 } else if (num_components < 4) {
1329 SpvId constituents[num_components];
1330 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1331 for (uint32_t i = 0; i < num_components; ++i)
1332 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1333 uint_type,
1334 result, &i,
1335 1);
1336
1337 result = spirv_builder_emit_composite_construct(&ctx->builder,
1338 type,
1339 constituents,
1340 num_components);
1341 }
1342
1343 if (nir_dest_bit_size(intr->dest) == 1)
1344 result = uvec_to_bvec(ctx, result, num_components);
1345
1346 store_dest(ctx, &intr->dest, result, nir_type_uint);
1347 } else
1348 unreachable("uniform-addressing not yet supported");
1349 }
1350
1351 static void
1352 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1353 {
1354 assert(ctx->block_started);
1355 spirv_builder_emit_kill(&ctx->builder);
1356 /* discard is weird in NIR, so let's just create an unreachable block after
1357 it and hope that the vulkan driver will DCE any instructinos in it. */
1358 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1359 }
1360
1361 static void
1362 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1363 {
1364 SpvId ptr = get_src(ctx, intr->src);
1365
1366 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1367 SpvId result = spirv_builder_emit_load(&ctx->builder,
1368 get_glsl_type(ctx, var->type),
1369 ptr);
1370 unsigned num_components = nir_dest_num_components(intr->dest);
1371 unsigned bit_size = nir_dest_bit_size(intr->dest);
1372 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1373 store_dest(ctx, &intr->dest, result, nir_type_uint);
1374 }
1375
1376 static void
1377 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1378 {
1379 SpvId ptr = get_src(ctx, &intr->src[0]);
1380 SpvId src = get_src(ctx, &intr->src[1]);
1381
1382 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1383 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1384 SpvId result = emit_bitcast(ctx, type, src);
1385 spirv_builder_emit_store(&ctx->builder, ptr, result);
1386 }
1387
1388 static SpvId
1389 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1390 SpvStorageClass storage_class,
1391 const char *name, SpvBuiltIn builtin)
1392 {
1393 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1394 storage_class,
1395 var_type);
1396 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1397 storage_class);
1398 spirv_builder_emit_name(&ctx->builder, var, name);
1399 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1400
1401 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1402 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1403 return var;
1404 }
1405
1406 static void
1407 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1408 {
1409 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1410 if (!ctx->front_face_var)
1411 ctx->front_face_var = create_builtin_var(ctx, var_type,
1412 SpvStorageClassInput,
1413 "gl_FrontFacing",
1414 SpvBuiltInFrontFacing);
1415
1416 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1417 ctx->front_face_var);
1418 assert(1 == nir_dest_num_components(intr->dest));
1419 store_dest(ctx, &intr->dest, result, nir_type_bool);
1420 }
1421
1422 static void
1423 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1424 {
1425 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1426 if (!ctx->instance_id_var)
1427 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1428 SpvStorageClassInput,
1429 "gl_InstanceId",
1430 SpvBuiltInInstanceIndex);
1431
1432 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1433 ctx->instance_id_var);
1434 assert(1 == nir_dest_num_components(intr->dest));
1435 store_dest(ctx, &intr->dest, result, nir_type_uint);
1436 }
1437
1438 static void
1439 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1440 {
1441 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1442 if (!ctx->vertex_id_var)
1443 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1444 SpvStorageClassInput,
1445 "gl_VertexID",
1446 SpvBuiltInVertexIndex);
1447
1448 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1449 ctx->vertex_id_var);
1450 assert(1 == nir_dest_num_components(intr->dest));
1451 store_dest(ctx, &intr->dest, result, nir_type_uint);
1452 }
1453
1454 static void
1455 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1456 {
1457 switch (intr->intrinsic) {
1458 case nir_intrinsic_load_ubo:
1459 emit_load_ubo(ctx, intr);
1460 break;
1461
1462 case nir_intrinsic_discard:
1463 emit_discard(ctx, intr);
1464 break;
1465
1466 case nir_intrinsic_load_deref:
1467 emit_load_deref(ctx, intr);
1468 break;
1469
1470 case nir_intrinsic_store_deref:
1471 emit_store_deref(ctx, intr);
1472 break;
1473
1474 case nir_intrinsic_load_front_face:
1475 emit_load_front_face(ctx, intr);
1476 break;
1477
1478 case nir_intrinsic_load_instance_id:
1479 emit_load_instance_id(ctx, intr);
1480 break;
1481
1482 case nir_intrinsic_load_vertex_id:
1483 emit_load_vertex_id(ctx, intr);
1484 break;
1485
1486 default:
1487 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1488 nir_intrinsic_infos[intr->intrinsic].name);
1489 unreachable("unsupported intrinsic");
1490 }
1491 }
1492
1493 static void
1494 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1495 {
1496 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1497 undef->def.num_components);
1498
1499 store_ssa_def(ctx, &undef->def,
1500 spirv_builder_emit_undef(&ctx->builder, type));
1501 }
1502
1503 static SpvId
1504 get_src_float(struct ntv_context *ctx, nir_src *src)
1505 {
1506 SpvId def = get_src(ctx, src);
1507 unsigned num_components = nir_src_num_components(*src);
1508 unsigned bit_size = nir_src_bit_size(*src);
1509 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1510 }
1511
1512 static SpvId
1513 get_src_int(struct ntv_context *ctx, nir_src *src)
1514 {
1515 SpvId def = get_src(ctx, src);
1516 unsigned num_components = nir_src_num_components(*src);
1517 unsigned bit_size = nir_src_bit_size(*src);
1518 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1519 }
1520
1521 static void
1522 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1523 {
1524 assert(tex->op == nir_texop_tex ||
1525 tex->op == nir_texop_txb ||
1526 tex->op == nir_texop_txl ||
1527 tex->op == nir_texop_txd ||
1528 tex->op == nir_texop_txf ||
1529 tex->op == nir_texop_txf_ms ||
1530 tex->op == nir_texop_txs);
1531 assert(tex->texture_index == tex->sampler_index);
1532
1533 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1534 offset = 0, sample = 0;
1535 unsigned coord_components = 0;
1536 for (unsigned i = 0; i < tex->num_srcs; i++) {
1537 switch (tex->src[i].src_type) {
1538 case nir_tex_src_coord:
1539 if (tex->op == nir_texop_txf ||
1540 tex->op == nir_texop_txf_ms)
1541 coord = get_src_int(ctx, &tex->src[i].src);
1542 else
1543 coord = get_src_float(ctx, &tex->src[i].src);
1544 coord_components = nir_src_num_components(tex->src[i].src);
1545 break;
1546
1547 case nir_tex_src_projector:
1548 assert(nir_src_num_components(tex->src[i].src) == 1);
1549 proj = get_src_float(ctx, &tex->src[i].src);
1550 assert(proj != 0);
1551 break;
1552
1553 case nir_tex_src_offset:
1554 offset = get_src_int(ctx, &tex->src[i].src);
1555 break;
1556
1557 case nir_tex_src_bias:
1558 assert(tex->op == nir_texop_txb);
1559 bias = get_src_float(ctx, &tex->src[i].src);
1560 assert(bias != 0);
1561 break;
1562
1563 case nir_tex_src_lod:
1564 assert(nir_src_num_components(tex->src[i].src) == 1);
1565 if (tex->op == nir_texop_txf ||
1566 tex->op == nir_texop_txf_ms ||
1567 tex->op == nir_texop_txs)
1568 lod = get_src_int(ctx, &tex->src[i].src);
1569 else
1570 lod = get_src_float(ctx, &tex->src[i].src);
1571 assert(lod != 0);
1572 break;
1573
1574 case nir_tex_src_ms_index:
1575 assert(nir_src_num_components(tex->src[i].src) == 1);
1576 sample = get_src_int(ctx, &tex->src[i].src);
1577 break;
1578
1579 case nir_tex_src_comparator:
1580 assert(nir_src_num_components(tex->src[i].src) == 1);
1581 dref = get_src_float(ctx, &tex->src[i].src);
1582 assert(dref != 0);
1583 break;
1584
1585 case nir_tex_src_ddx:
1586 dx = get_src_float(ctx, &tex->src[i].src);
1587 assert(dx != 0);
1588 break;
1589
1590 case nir_tex_src_ddy:
1591 dy = get_src_float(ctx, &tex->src[i].src);
1592 assert(dy != 0);
1593 break;
1594
1595 default:
1596 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1597 unreachable("unknown texture source");
1598 }
1599 }
1600
1601 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1602 lod = emit_float_const(ctx, 32, 0.0f);
1603 assert(lod != 0);
1604 }
1605
1606 SpvId image_type = ctx->image_types[tex->texture_index];
1607 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1608 image_type);
1609
1610 assert(ctx->samplers_used & (1u << tex->texture_index));
1611 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1612 ctx->samplers[tex->texture_index]);
1613
1614 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1615
1616 if (tex->op == nir_texop_txs) {
1617 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1618 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1619 dest_type, image,
1620 lod);
1621 store_dest(ctx, &tex->dest, result, tex->dest_type);
1622 return;
1623 }
1624
1625 if (proj && coord_components > 0) {
1626 SpvId constituents[coord_components + 1];
1627 if (coord_components == 1)
1628 constituents[0] = coord;
1629 else {
1630 assert(coord_components > 1);
1631 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1632 for (uint32_t i = 0; i < coord_components; ++i)
1633 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1634 float_type,
1635 coord,
1636 &i, 1);
1637 }
1638
1639 constituents[coord_components++] = proj;
1640
1641 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1642 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1643 vec_type,
1644 constituents,
1645 coord_components);
1646 }
1647
1648 SpvId actual_dest_type = dest_type;
1649 if (dref)
1650 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1651
1652 SpvId result;
1653 if (tex->op == nir_texop_txf ||
1654 tex->op == nir_texop_txf_ms) {
1655 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1656 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1657 image, coord, lod, sample);
1658 } else {
1659 result = spirv_builder_emit_image_sample(&ctx->builder,
1660 actual_dest_type, load,
1661 coord,
1662 proj != 0,
1663 lod, bias, dref, dx, dy,
1664 offset);
1665 }
1666
1667 spirv_builder_emit_decoration(&ctx->builder, result,
1668 SpvDecorationRelaxedPrecision);
1669
1670 if (dref && nir_dest_num_components(tex->dest) > 1) {
1671 SpvId components[4] = { result, result, result, result };
1672 result = spirv_builder_emit_composite_construct(&ctx->builder,
1673 dest_type,
1674 components,
1675 4);
1676 }
1677
1678 store_dest(ctx, &tex->dest, result, tex->dest_type);
1679 }
1680
1681 static void
1682 start_block(struct ntv_context *ctx, SpvId label)
1683 {
1684 /* terminate previous block if needed */
1685 if (ctx->block_started)
1686 spirv_builder_emit_branch(&ctx->builder, label);
1687
1688 /* start new block */
1689 spirv_builder_label(&ctx->builder, label);
1690 ctx->block_started = true;
1691 }
1692
1693 static void
1694 branch(struct ntv_context *ctx, SpvId label)
1695 {
1696 assert(ctx->block_started);
1697 spirv_builder_emit_branch(&ctx->builder, label);
1698 ctx->block_started = false;
1699 }
1700
1701 static void
1702 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1703 SpvId else_id)
1704 {
1705 assert(ctx->block_started);
1706 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1707 then_id, else_id);
1708 ctx->block_started = false;
1709 }
1710
1711 static void
1712 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1713 {
1714 switch (jump->type) {
1715 case nir_jump_break:
1716 assert(ctx->loop_break);
1717 branch(ctx, ctx->loop_break);
1718 break;
1719
1720 case nir_jump_continue:
1721 assert(ctx->loop_cont);
1722 branch(ctx, ctx->loop_cont);
1723 break;
1724
1725 default:
1726 unreachable("Unsupported jump type\n");
1727 }
1728 }
1729
1730 static void
1731 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1732 {
1733 assert(deref->deref_type == nir_deref_type_var);
1734
1735 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1736 assert(he);
1737 SpvId result = (SpvId)(intptr_t)he->data;
1738 store_dest_raw(ctx, &deref->dest, result);
1739 }
1740
1741 static void
1742 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1743 {
1744 assert(deref->deref_type == nir_deref_type_array);
1745 nir_variable *var = nir_deref_instr_get_variable(deref);
1746
1747 SpvStorageClass storage_class;
1748 switch (var->data.mode) {
1749 case nir_var_shader_in:
1750 storage_class = SpvStorageClassInput;
1751 break;
1752
1753 case nir_var_shader_out:
1754 storage_class = SpvStorageClassOutput;
1755 break;
1756
1757 default:
1758 unreachable("Unsupported nir_variable_mode\n");
1759 }
1760
1761 SpvId index = get_src(ctx, &deref->arr.index);
1762
1763 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1764 storage_class,
1765 get_glsl_type(ctx, deref->type));
1766
1767 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1768 ptr_type,
1769 get_src(ctx, &deref->parent),
1770 &index, 1);
1771 /* uint is a bit of a lie here, it's really just an opaque type */
1772 store_dest(ctx, &deref->dest, result, nir_type_uint);
1773 }
1774
1775 static void
1776 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1777 {
1778 switch (deref->deref_type) {
1779 case nir_deref_type_var:
1780 emit_deref_var(ctx, deref);
1781 break;
1782
1783 case nir_deref_type_array:
1784 emit_deref_array(ctx, deref);
1785 break;
1786
1787 default:
1788 unreachable("unexpected deref_type");
1789 }
1790 }
1791
1792 static void
1793 emit_block(struct ntv_context *ctx, struct nir_block *block)
1794 {
1795 start_block(ctx, block_label(ctx, block));
1796 nir_foreach_instr(instr, block) {
1797 switch (instr->type) {
1798 case nir_instr_type_alu:
1799 emit_alu(ctx, nir_instr_as_alu(instr));
1800 break;
1801 case nir_instr_type_intrinsic:
1802 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1803 break;
1804 case nir_instr_type_load_const:
1805 emit_load_const(ctx, nir_instr_as_load_const(instr));
1806 break;
1807 case nir_instr_type_ssa_undef:
1808 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1809 break;
1810 case nir_instr_type_tex:
1811 emit_tex(ctx, nir_instr_as_tex(instr));
1812 break;
1813 case nir_instr_type_phi:
1814 unreachable("nir_instr_type_phi not supported");
1815 break;
1816 case nir_instr_type_jump:
1817 emit_jump(ctx, nir_instr_as_jump(instr));
1818 break;
1819 case nir_instr_type_call:
1820 unreachable("nir_instr_type_call not supported");
1821 break;
1822 case nir_instr_type_parallel_copy:
1823 unreachable("nir_instr_type_parallel_copy not supported");
1824 break;
1825 case nir_instr_type_deref:
1826 emit_deref(ctx, nir_instr_as_deref(instr));
1827 break;
1828 }
1829 }
1830 }
1831
1832 static void
1833 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1834
1835 static SpvId
1836 get_src_bool(struct ntv_context *ctx, nir_src *src)
1837 {
1838 assert(nir_src_bit_size(*src) == 1);
1839 return get_src(ctx, src);
1840 }
1841
1842 static void
1843 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1844 {
1845 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1846
1847 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1848 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1849 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1850 SpvId else_id = endif_id;
1851
1852 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1853 if (has_else) {
1854 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1855 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1856 }
1857
1858 /* create a header-block */
1859 start_block(ctx, header_id);
1860 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1861 SpvSelectionControlMaskNone);
1862 branch_conditional(ctx, condition, then_id, else_id);
1863
1864 emit_cf_list(ctx, &if_stmt->then_list);
1865
1866 if (has_else) {
1867 if (ctx->block_started)
1868 branch(ctx, endif_id);
1869
1870 emit_cf_list(ctx, &if_stmt->else_list);
1871 }
1872
1873 start_block(ctx, endif_id);
1874 }
1875
1876 static void
1877 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1878 {
1879 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1880 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1881 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1882 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1883
1884 /* create a header-block */
1885 start_block(ctx, header_id);
1886 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1887 branch(ctx, begin_id);
1888
1889 SpvId save_break = ctx->loop_break;
1890 SpvId save_cont = ctx->loop_cont;
1891 ctx->loop_break = break_id;
1892 ctx->loop_cont = cont_id;
1893
1894 emit_cf_list(ctx, &loop->body);
1895
1896 ctx->loop_break = save_break;
1897 ctx->loop_cont = save_cont;
1898
1899 branch(ctx, cont_id);
1900 start_block(ctx, cont_id);
1901 branch(ctx, header_id);
1902
1903 start_block(ctx, break_id);
1904 }
1905
1906 static void
1907 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1908 {
1909 foreach_list_typed(nir_cf_node, node, node, list) {
1910 switch (node->type) {
1911 case nir_cf_node_block:
1912 emit_block(ctx, nir_cf_node_as_block(node));
1913 break;
1914
1915 case nir_cf_node_if:
1916 emit_if(ctx, nir_cf_node_as_if(node));
1917 break;
1918
1919 case nir_cf_node_loop:
1920 emit_loop(ctx, nir_cf_node_as_loop(node));
1921 break;
1922
1923 case nir_cf_node_function:
1924 unreachable("nir_cf_node_function not supported");
1925 break;
1926 }
1927 }
1928 }
1929
1930 struct spirv_shader *
1931 nir_to_spirv(struct nir_shader *s)
1932 {
1933 struct spirv_shader *ret = NULL;
1934
1935 struct ntv_context ctx = {};
1936
1937 switch (s->info.stage) {
1938 case MESA_SHADER_VERTEX:
1939 case MESA_SHADER_FRAGMENT:
1940 case MESA_SHADER_COMPUTE:
1941 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1942 break;
1943
1944 case MESA_SHADER_TESS_CTRL:
1945 case MESA_SHADER_TESS_EVAL:
1946 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1947 break;
1948
1949 case MESA_SHADER_GEOMETRY:
1950 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1951 break;
1952
1953 default:
1954 unreachable("invalid stage");
1955 }
1956
1957 // TODO: only enable when needed
1958 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1959 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1960 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1961 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
1962 }
1963
1964 ctx.stage = s->info.stage;
1965 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1966 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1967
1968 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1969 SpvMemoryModelGLSL450);
1970
1971 SpvExecutionModel exec_model;
1972 switch (s->info.stage) {
1973 case MESA_SHADER_VERTEX:
1974 exec_model = SpvExecutionModelVertex;
1975 break;
1976 case MESA_SHADER_TESS_CTRL:
1977 exec_model = SpvExecutionModelTessellationControl;
1978 break;
1979 case MESA_SHADER_TESS_EVAL:
1980 exec_model = SpvExecutionModelTessellationEvaluation;
1981 break;
1982 case MESA_SHADER_GEOMETRY:
1983 exec_model = SpvExecutionModelGeometry;
1984 break;
1985 case MESA_SHADER_FRAGMENT:
1986 exec_model = SpvExecutionModelFragment;
1987 break;
1988 case MESA_SHADER_COMPUTE:
1989 exec_model = SpvExecutionModelGLCompute;
1990 break;
1991 default:
1992 unreachable("invalid stage");
1993 }
1994
1995 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1996 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1997 NULL, 0);
1998 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1999 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2000
2001 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
2002 _mesa_key_pointer_equal);
2003
2004 nir_foreach_variable(var, &s->inputs)
2005 emit_input(&ctx, var);
2006
2007 nir_foreach_variable(var, &s->outputs)
2008 emit_output(&ctx, var);
2009
2010 nir_foreach_variable(var, &s->uniforms)
2011 emit_uniform(&ctx, var);
2012
2013 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2014 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2015 SpvExecutionModeOriginUpperLeft);
2016 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2017 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2018 SpvExecutionModeDepthReplacing);
2019 }
2020
2021
2022 spirv_builder_function(&ctx.builder, entry_point, type_void,
2023 SpvFunctionControlMaskNone,
2024 type_main);
2025
2026 nir_function_impl *entry = nir_shader_get_entrypoint(s);
2027 nir_metadata_require(entry, nir_metadata_block_index);
2028
2029 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
2030 if (!ctx.defs)
2031 goto fail;
2032 ctx.num_defs = entry->ssa_alloc;
2033
2034 nir_index_local_regs(entry);
2035 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
2036 if (!ctx.regs)
2037 goto fail;
2038 ctx.num_regs = entry->reg_alloc;
2039
2040 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
2041 if (!block_ids)
2042 goto fail;
2043
2044 for (int i = 0; i < entry->num_blocks; ++i)
2045 block_ids[i] = spirv_builder_new_id(&ctx.builder);
2046
2047 ctx.block_ids = block_ids;
2048 ctx.num_blocks = entry->num_blocks;
2049
2050 /* emit a block only for the variable declarations */
2051 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2052 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2053 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
2054 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2055 SpvStorageClassFunction,
2056 type);
2057 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2058 SpvStorageClassFunction);
2059
2060 ctx.regs[reg->index] = var;
2061 }
2062
2063 emit_cf_list(&ctx, &entry->body);
2064
2065 free(ctx.defs);
2066
2067 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2068 spirv_builder_function_end(&ctx.builder);
2069
2070 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2071 "main", ctx.entry_ifaces,
2072 ctx.num_entry_ifaces);
2073
2074 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2075
2076 ret = CALLOC_STRUCT(spirv_shader);
2077 if (!ret)
2078 goto fail;
2079
2080 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2081 if (!ret->words)
2082 goto fail;
2083
2084 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2085 assert(ret->num_words == num_words);
2086
2087 return ret;
2088
2089 fail:
2090
2091 if (ret)
2092 spirv_shader_delete(ret);
2093
2094 if (ctx.vars)
2095 _mesa_hash_table_destroy(ctx.vars, NULL);
2096
2097 return NULL;
2098 }
2099
2100 void
2101 spirv_shader_delete(struct spirv_shader *s)
2102 {
2103 FREE(s->words);
2104 FREE(s);
2105 }