zink: always use logical eq ops in ntv with 1bit inputs
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 unsigned samplers_used : PIPE_MAX_SAMPLERS;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
172 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
335 break;
336
337 case FRAG_RESULT_DEPTH:
338 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
339 break;
340
341 default:
342 spirv_builder_emit_location(&ctx->builder, var_id,
343 var->data.driver_location);
344 }
345 }
346 }
347
348 if (var->data.location_frac)
349 spirv_builder_emit_component(&ctx->builder, var_id,
350 var->data.location_frac);
351
352 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
353
354 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
355 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
356 }
357
358 static SpvDim
359 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
360 {
361 *is_ms = false;
362 switch (gdim) {
363 case GLSL_SAMPLER_DIM_1D:
364 return SpvDim1D;
365 case GLSL_SAMPLER_DIM_2D:
366 return SpvDim2D;
367 case GLSL_SAMPLER_DIM_3D:
368 return SpvDim3D;
369 case GLSL_SAMPLER_DIM_CUBE:
370 return SpvDimCube;
371 case GLSL_SAMPLER_DIM_RECT:
372 return SpvDim2D;
373 case GLSL_SAMPLER_DIM_BUF:
374 return SpvDimBuffer;
375 case GLSL_SAMPLER_DIM_EXTERNAL:
376 return SpvDim2D; /* seems dodgy... */
377 case GLSL_SAMPLER_DIM_MS:
378 *is_ms = true;
379 return SpvDim2D;
380 default:
381 fprintf(stderr, "unknown sampler type %d\n", gdim);
382 break;
383 }
384 return SpvDim2D;
385 }
386
387 uint32_t
388 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
389 {
390 if (stage == MESA_SHADER_NONE ||
391 stage >= MESA_SHADER_COMPUTE) {
392 unreachable("not supported");
393 } else {
394 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
395 PIPE_MAX_SHADER_SAMPLER_VIEWS);
396
397 switch (type) {
398 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
399 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
400 return stage_offset + index;
401
402 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
403 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
404 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
405
406 default:
407 unreachable("unexpected type");
408 }
409 }
410 }
411
412 static void
413 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
414 {
415 const struct glsl_type *type = glsl_without_array(var->type);
416
417 bool is_ms;
418 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
419
420 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
421 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
422 dimension, false,
423 glsl_sampler_type_is_array(type),
424 is_ms, 1,
425 SpvImageFormatUnknown);
426
427 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
428 image_type);
429 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
430 SpvStorageClassUniformConstant,
431 sampled_type);
432
433 if (glsl_type_is_array(var->type)) {
434 for (int i = 0; i < glsl_get_length(var->type); ++i) {
435 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
436 SpvStorageClassUniformConstant);
437
438 if (var->name) {
439 char element_name[100];
440 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
441 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
442 }
443
444 int index = var->data.binding + i;
445 assert(!(ctx->samplers_used & (1 << index)));
446 assert(!ctx->image_types[index]);
447 ctx->image_types[index] = image_type;
448 ctx->samplers[index] = var_id;
449 ctx->samplers_used |= 1 << index;
450
451 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
452 var->data.descriptor_set);
453 int binding = zink_binding(ctx->stage,
454 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
455 var->data.binding + i);
456 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
457 }
458 } else {
459 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
460 SpvStorageClassUniformConstant);
461
462 if (var->name)
463 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
464
465 int index = var->data.binding;
466 assert(!(ctx->samplers_used & (1 << index)));
467 assert(!ctx->image_types[index]);
468 ctx->image_types[index] = image_type;
469 ctx->samplers[index] = var_id;
470 ctx->samplers_used |= 1 << index;
471
472 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
473 var->data.descriptor_set);
474 int binding = zink_binding(ctx->stage,
475 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
476 var->data.binding);
477 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
478 }
479 }
480
481 static void
482 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
483 {
484 uint32_t size = glsl_count_attribute_slots(var->type, false);
485 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
486 SpvId array_length = emit_uint_const(ctx, 32, size);
487 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
488 array_length);
489 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
490
491 // wrap UBO-array in a struct
492 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
493 if (var->name) {
494 char struct_name[100];
495 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
496 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
497 }
498
499 spirv_builder_emit_decoration(&ctx->builder, struct_type,
500 SpvDecorationBlock);
501 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
502
503
504 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
505 SpvStorageClassUniform,
506 struct_type);
507
508 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
509 SpvStorageClassUniform);
510 if (var->name)
511 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
512
513 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
514 ctx->ubos[ctx->num_ubos++] = var_id;
515
516 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
517 var->data.descriptor_set);
518 int binding = zink_binding(ctx->stage,
519 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
520 var->data.binding);
521 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
522 }
523
524 static void
525 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
526 {
527 if (var->data.mode == nir_var_mem_ubo)
528 emit_ubo(ctx, var);
529 else {
530 assert(var->data.mode == nir_var_uniform);
531 if (glsl_type_is_sampler(glsl_without_array(var->type)))
532 emit_sampler(ctx, var);
533 }
534 }
535
536 static SpvId
537 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
538 {
539 assert(ssa->index < ctx->num_defs);
540 assert(ctx->defs[ssa->index] != 0);
541 return ctx->defs[ssa->index];
542 }
543
544 static SpvId
545 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
546 {
547 assert(reg->index < ctx->num_regs);
548 assert(ctx->regs[reg->index] != 0);
549 return ctx->regs[reg->index];
550 }
551
552 static SpvId
553 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
554 {
555 assert(reg->reg);
556 assert(!reg->indirect);
557 assert(!reg->base_offset);
558
559 SpvId var = get_var_from_reg(ctx, reg->reg);
560 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
561 return spirv_builder_emit_load(&ctx->builder, type, var);
562 }
563
564 static SpvId
565 get_src(struct ntv_context *ctx, nir_src *src)
566 {
567 if (src->is_ssa)
568 return get_src_ssa(ctx, src->ssa);
569 else
570 return get_src_reg(ctx, &src->reg);
571 }
572
573 static SpvId
574 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
575 {
576 assert(!alu->src[src].negate);
577 assert(!alu->src[src].abs);
578
579 SpvId def = get_src(ctx, &alu->src[src].src);
580
581 unsigned used_channels = 0;
582 bool need_swizzle = false;
583 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
584 if (!nir_alu_instr_channel_used(alu, src, i))
585 continue;
586
587 used_channels++;
588
589 if (alu->src[src].swizzle[i] != i)
590 need_swizzle = true;
591 }
592 assert(used_channels != 0);
593
594 unsigned live_channels = nir_src_num_components(alu->src[src].src);
595 if (used_channels != live_channels)
596 need_swizzle = true;
597
598 if (!need_swizzle)
599 return def;
600
601 int bit_size = nir_src_bit_size(alu->src[src].src);
602 assert(bit_size == 1 || bit_size == 32);
603
604 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
605 spirv_builder_type_uint(&ctx->builder, bit_size);
606
607 if (used_channels == 1) {
608 uint32_t indices[] = { alu->src[src].swizzle[0] };
609 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
610 def, indices,
611 ARRAY_SIZE(indices));
612 } else if (live_channels == 1) {
613 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
614 raw_type,
615 used_channels);
616
617 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
618 for (unsigned i = 0; i < used_channels; ++i)
619 constituents[i] = def;
620
621 return spirv_builder_emit_composite_construct(&ctx->builder,
622 raw_vec_type,
623 constituents,
624 used_channels);
625 } else {
626 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
627 raw_type,
628 used_channels);
629
630 uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
631 size_t num_components = 0;
632 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
633 if (!nir_alu_instr_channel_used(alu, src, i))
634 continue;
635
636 components[num_components++] = alu->src[src].swizzle[i];
637 }
638
639 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
640 def, def, components,
641 num_components);
642 }
643 }
644
645 static void
646 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
647 {
648 assert(result != 0);
649 assert(ssa->index < ctx->num_defs);
650 ctx->defs[ssa->index] = result;
651 }
652
653 static SpvId
654 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
655 SpvId if_true, SpvId if_false)
656 {
657 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
658 }
659
660 static SpvId
661 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
662 {
663 SpvId type = get_bvec_type(ctx, num_components);
664 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
665 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
666 }
667
668 static SpvId
669 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
670 {
671 return emit_unop(ctx, SpvOpBitcast, type, value);
672 }
673
674 static SpvId
675 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
676 unsigned num_components)
677 {
678 SpvId type = get_uvec_type(ctx, bit_size, num_components);
679 return emit_bitcast(ctx, type, value);
680 }
681
682 static SpvId
683 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
684 unsigned num_components)
685 {
686 SpvId type = get_ivec_type(ctx, bit_size, num_components);
687 return emit_bitcast(ctx, type, value);
688 }
689
690 static SpvId
691 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
692 unsigned num_components)
693 {
694 SpvId type = get_fvec_type(ctx, bit_size, num_components);
695 return emit_bitcast(ctx, type, value);
696 }
697
698 static void
699 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
700 {
701 SpvId var = get_var_from_reg(ctx, reg->reg);
702 assert(var);
703 spirv_builder_emit_store(&ctx->builder, var, result);
704 }
705
706 static void
707 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
708 {
709 if (dest->is_ssa)
710 store_ssa_def(ctx, &dest->ssa, result);
711 else
712 store_reg_def(ctx, &dest->reg, result);
713 }
714
715 static SpvId
716 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
717 {
718 unsigned num_components = nir_dest_num_components(*dest);
719 unsigned bit_size = nir_dest_bit_size(*dest);
720
721 if (bit_size != 1) {
722 switch (nir_alu_type_get_base_type(type)) {
723 case nir_type_bool:
724 assert("bool should have bit-size 1");
725
726 case nir_type_uint:
727 break; /* nothing to do! */
728
729 case nir_type_int:
730 case nir_type_float:
731 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
732 break;
733
734 default:
735 unreachable("unsupported nir_alu_type");
736 }
737 }
738
739 store_dest_raw(ctx, dest, result);
740 return result;
741 }
742
743 static SpvId
744 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
745 {
746 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
747 }
748
749 static SpvId
750 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
751 SpvId src0, SpvId src1)
752 {
753 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
754 }
755
756 static SpvId
757 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
758 SpvId src0, SpvId src1, SpvId src2)
759 {
760 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
761 }
762
763 static SpvId
764 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
765 SpvId src)
766 {
767 SpvId args[] = { src };
768 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
769 op, args, ARRAY_SIZE(args));
770 }
771
772 static SpvId
773 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
774 SpvId src0, SpvId src1)
775 {
776 SpvId args[] = { src0, src1 };
777 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
778 op, args, ARRAY_SIZE(args));
779 }
780
781 static SpvId
782 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
783 SpvId src0, SpvId src1, SpvId src2)
784 {
785 SpvId args[] = { src0, src1, src2 };
786 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
787 op, args, ARRAY_SIZE(args));
788 }
789
790 static SpvId
791 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
792 unsigned num_components, float value)
793 {
794 assert(bit_size == 32);
795
796 SpvId result = emit_float_const(ctx, bit_size, value);
797 if (num_components == 1)
798 return result;
799
800 assert(num_components > 1);
801 SpvId components[num_components];
802 for (int i = 0; i < num_components; i++)
803 components[i] = result;
804
805 SpvId type = get_fvec_type(ctx, bit_size, num_components);
806 return spirv_builder_const_composite(&ctx->builder, type, components,
807 num_components);
808 }
809
810 static SpvId
811 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
812 unsigned num_components, uint32_t value)
813 {
814 assert(bit_size == 32);
815
816 SpvId result = emit_uint_const(ctx, bit_size, value);
817 if (num_components == 1)
818 return result;
819
820 assert(num_components > 1);
821 SpvId components[num_components];
822 for (int i = 0; i < num_components; i++)
823 components[i] = result;
824
825 SpvId type = get_uvec_type(ctx, bit_size, num_components);
826 return spirv_builder_const_composite(&ctx->builder, type, components,
827 num_components);
828 }
829
830 static SpvId
831 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
832 unsigned num_components, int32_t value)
833 {
834 assert(bit_size == 32);
835
836 SpvId result = emit_int_const(ctx, bit_size, value);
837 if (num_components == 1)
838 return result;
839
840 assert(num_components > 1);
841 SpvId components[num_components];
842 for (int i = 0; i < num_components; i++)
843 components[i] = result;
844
845 SpvId type = get_ivec_type(ctx, bit_size, num_components);
846 return spirv_builder_const_composite(&ctx->builder, type, components,
847 num_components);
848 }
849
850 static inline unsigned
851 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
852 {
853 if (nir_op_infos[instr->op].input_sizes[src] > 0)
854 return nir_op_infos[instr->op].input_sizes[src];
855
856 if (instr->dest.dest.is_ssa)
857 return instr->dest.dest.ssa.num_components;
858 else
859 return instr->dest.dest.reg.reg->num_components;
860 }
861
862 static SpvId
863 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
864 {
865 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
866
867 unsigned num_components = alu_instr_src_components(alu, src);
868 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
869 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
870
871 if (bit_size == 1)
872 return raw_value;
873 else {
874 switch (nir_alu_type_get_base_type(type)) {
875 case nir_type_bool:
876 unreachable("bool should have bit-size 1");
877
878 case nir_type_int:
879 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
880
881 case nir_type_uint:
882 return raw_value;
883
884 case nir_type_float:
885 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
886
887 default:
888 unreachable("unknown nir_alu_type");
889 }
890 }
891 }
892
893 static SpvId
894 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
895 {
896 assert(!alu->dest.saturate);
897 return store_dest(ctx, &alu->dest.dest, result,
898 nir_op_infos[alu->op].output_type);
899 }
900
901 static SpvId
902 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
903 {
904 unsigned num_components = nir_dest_num_components(*dest);
905 unsigned bit_size = nir_dest_bit_size(*dest);
906
907 if (bit_size == 1)
908 return get_bvec_type(ctx, num_components);
909
910 switch (nir_alu_type_get_base_type(type)) {
911 case nir_type_bool:
912 unreachable("bool should have bit-size 1");
913
914 case nir_type_int:
915 return get_ivec_type(ctx, bit_size, num_components);
916
917 case nir_type_uint:
918 return get_uvec_type(ctx, bit_size, num_components);
919
920 case nir_type_float:
921 return get_fvec_type(ctx, bit_size, num_components);
922
923 default:
924 unreachable("unsupported nir_alu_type");
925 }
926 }
927
928 static void
929 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
930 {
931 SpvId src[nir_op_infos[alu->op].num_inputs];
932 unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
933 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
934 src[i] = get_alu_src(ctx, alu, i);
935 in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
936 }
937
938 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
939 nir_op_infos[alu->op].output_type);
940 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
941 unsigned num_components = nir_dest_num_components(alu->dest.dest);
942
943 SpvId result = 0;
944 switch (alu->op) {
945 case nir_op_mov:
946 assert(nir_op_infos[alu->op].num_inputs == 1);
947 result = src[0];
948 break;
949
950 #define UNOP(nir_op, spirv_op) \
951 case nir_op: \
952 assert(nir_op_infos[alu->op].num_inputs == 1); \
953 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
954 break;
955
956 UNOP(nir_op_ineg, SpvOpSNegate)
957 UNOP(nir_op_fneg, SpvOpFNegate)
958 UNOP(nir_op_fddx, SpvOpDPdx)
959 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
960 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
961 UNOP(nir_op_fddy, SpvOpDPdy)
962 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
963 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
964 UNOP(nir_op_f2i32, SpvOpConvertFToS)
965 UNOP(nir_op_f2u32, SpvOpConvertFToU)
966 UNOP(nir_op_i2f32, SpvOpConvertSToF)
967 UNOP(nir_op_u2f32, SpvOpConvertUToF)
968 #undef UNOP
969
970 case nir_op_inot:
971 if (bit_size == 1)
972 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
973 else
974 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
975 break;
976
977 case nir_op_b2i32:
978 assert(nir_op_infos[alu->op].num_inputs == 1);
979 result = emit_select(ctx, dest_type, src[0],
980 get_ivec_constant(ctx, 32, num_components, 1),
981 get_ivec_constant(ctx, 32, num_components, 0));
982 break;
983
984 case nir_op_b2f32:
985 assert(nir_op_infos[alu->op].num_inputs == 1);
986 result = emit_select(ctx, dest_type, src[0],
987 get_fvec_constant(ctx, 32, num_components, 1),
988 get_fvec_constant(ctx, 32, num_components, 0));
989 break;
990
991 #define BUILTIN_UNOP(nir_op, spirv_op) \
992 case nir_op: \
993 assert(nir_op_infos[alu->op].num_inputs == 1); \
994 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
995 break;
996
997 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
998 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
999 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1000 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1001 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1002 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1003 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1004 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1005 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1006 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1007 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1008 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1009 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1010 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1011 #undef BUILTIN_UNOP
1012
1013 case nir_op_frcp:
1014 assert(nir_op_infos[alu->op].num_inputs == 1);
1015 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1016 get_fvec_constant(ctx, bit_size, num_components, 1),
1017 src[0]);
1018 break;
1019
1020 case nir_op_f2b1:
1021 assert(nir_op_infos[alu->op].num_inputs == 1);
1022 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1023 get_fvec_constant(ctx,
1024 nir_src_bit_size(alu->src[0].src),
1025 num_components, 0));
1026 break;
1027 case nir_op_i2b1:
1028 assert(nir_op_infos[alu->op].num_inputs == 1);
1029 result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1030 get_ivec_constant(ctx,
1031 nir_src_bit_size(alu->src[0].src),
1032 num_components, 0));
1033 break;
1034
1035
1036 #define BINOP(nir_op, spirv_op) \
1037 case nir_op: \
1038 assert(nir_op_infos[alu->op].num_inputs == 2); \
1039 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1040 break;
1041
1042 BINOP(nir_op_iadd, SpvOpIAdd)
1043 BINOP(nir_op_isub, SpvOpISub)
1044 BINOP(nir_op_imul, SpvOpIMul)
1045 BINOP(nir_op_idiv, SpvOpSDiv)
1046 BINOP(nir_op_udiv, SpvOpUDiv)
1047 BINOP(nir_op_umod, SpvOpUMod)
1048 BINOP(nir_op_fadd, SpvOpFAdd)
1049 BINOP(nir_op_fsub, SpvOpFSub)
1050 BINOP(nir_op_fmul, SpvOpFMul)
1051 BINOP(nir_op_fdiv, SpvOpFDiv)
1052 BINOP(nir_op_fmod, SpvOpFMod)
1053 BINOP(nir_op_ilt, SpvOpSLessThan)
1054 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1055 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1056 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1057 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1058 BINOP(nir_op_feq, SpvOpFOrdEqual)
1059 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1060 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1061 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1062 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1063 #undef BINOP
1064
1065 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1066 case nir_op: \
1067 assert(nir_op_infos[alu->op].num_inputs == 2); \
1068 if (nir_src_bit_size(alu->src[0].src) == 1) \
1069 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1070 else \
1071 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1072 break;
1073
1074 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1075 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1076 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1077 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1078 #undef BINOP_LOG
1079
1080 #define BUILTIN_BINOP(nir_op, spirv_op) \
1081 case nir_op: \
1082 assert(nir_op_infos[alu->op].num_inputs == 2); \
1083 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1084 break;
1085
1086 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1087 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1088 #undef BUILTIN_BINOP
1089
1090 case nir_op_fdot2:
1091 case nir_op_fdot3:
1092 case nir_op_fdot4:
1093 assert(nir_op_infos[alu->op].num_inputs == 2);
1094 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1095 break;
1096
1097 case nir_op_fdph:
1098 unreachable("should already be lowered away");
1099
1100 case nir_op_seq:
1101 case nir_op_sne:
1102 case nir_op_slt:
1103 case nir_op_sge: {
1104 assert(nir_op_infos[alu->op].num_inputs == 2);
1105 int num_components = nir_dest_num_components(alu->dest.dest);
1106 SpvId bool_type = get_bvec_type(ctx, num_components);
1107
1108 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1109 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1110 if (num_components > 1) {
1111 SpvId zero_comps[num_components], one_comps[num_components];
1112 for (int i = 0; i < num_components; i++) {
1113 zero_comps[i] = zero;
1114 one_comps[i] = one;
1115 }
1116
1117 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1118 zero_comps, num_components);
1119 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1120 one_comps, num_components);
1121 }
1122
1123 SpvOp op;
1124 switch (alu->op) {
1125 case nir_op_seq: op = SpvOpFOrdEqual; break;
1126 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1127 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1128 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1129 default: unreachable("unexpected op");
1130 }
1131
1132 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1133 result = emit_select(ctx, dest_type, result, one, zero);
1134 }
1135 break;
1136
1137 case nir_op_flrp:
1138 assert(nir_op_infos[alu->op].num_inputs == 3);
1139 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1140 src[0], src[1], src[2]);
1141 break;
1142
1143 case nir_op_fcsel:
1144 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1145 get_bvec_type(ctx, num_components),
1146 src[0],
1147 get_fvec_constant(ctx,
1148 nir_src_bit_size(alu->src[0].src),
1149 num_components, 0));
1150 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1151 break;
1152
1153 case nir_op_bcsel:
1154 assert(nir_op_infos[alu->op].num_inputs == 3);
1155 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1156 break;
1157
1158 case nir_op_bany_fnequal2:
1159 case nir_op_bany_fnequal3:
1160 case nir_op_bany_fnequal4: {
1161 assert(nir_op_infos[alu->op].num_inputs == 2);
1162 assert(alu_instr_src_components(alu, 0) ==
1163 alu_instr_src_components(alu, 1));
1164 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1165 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1166 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1167 result = emit_binop(ctx, op,
1168 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1169 src[0], src[1]);
1170 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1171 break;
1172 }
1173
1174 case nir_op_ball_fequal2:
1175 case nir_op_ball_fequal3:
1176 case nir_op_ball_fequal4: {
1177 assert(nir_op_infos[alu->op].num_inputs == 2);
1178 assert(alu_instr_src_components(alu, 0) ==
1179 alu_instr_src_components(alu, 1));
1180 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1181 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1182 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1183 result = emit_binop(ctx, op,
1184 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1185 src[0], src[1]);
1186 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1187 break;
1188 }
1189
1190 case nir_op_bany_inequal2:
1191 case nir_op_bany_inequal3:
1192 case nir_op_bany_inequal4: {
1193 assert(nir_op_infos[alu->op].num_inputs == 2);
1194 assert(alu_instr_src_components(alu, 0) ==
1195 alu_instr_src_components(alu, 1));
1196 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1197 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1198 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1199 result = emit_binop(ctx, op,
1200 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1201 src[0], src[1]);
1202 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1203 break;
1204 }
1205
1206 case nir_op_ball_iequal2:
1207 case nir_op_ball_iequal3:
1208 case nir_op_ball_iequal4: {
1209 assert(nir_op_infos[alu->op].num_inputs == 2);
1210 assert(alu_instr_src_components(alu, 0) ==
1211 alu_instr_src_components(alu, 1));
1212 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1213 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1214 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1215 result = emit_binop(ctx, op,
1216 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1217 src[0], src[1]);
1218 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1219 break;
1220 }
1221
1222 case nir_op_vec2:
1223 case nir_op_vec3:
1224 case nir_op_vec4: {
1225 int num_inputs = nir_op_infos[alu->op].num_inputs;
1226 assert(2 <= num_inputs && num_inputs <= 4);
1227 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1228 src, num_inputs);
1229 }
1230 break;
1231
1232 default:
1233 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1234 nir_op_infos[alu->op].name);
1235
1236 unreachable("unsupported opcode");
1237 return;
1238 }
1239
1240 store_alu_result(ctx, alu, result);
1241 }
1242
1243 static void
1244 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1245 {
1246 unsigned bit_size = load_const->def.bit_size;
1247 unsigned num_components = load_const->def.num_components;
1248
1249 SpvId constant;
1250 if (num_components > 1) {
1251 SpvId components[num_components];
1252 SpvId type;
1253 if (bit_size == 1) {
1254 for (int i = 0; i < num_components; i++)
1255 components[i] = spirv_builder_const_bool(&ctx->builder,
1256 load_const->value[i].b);
1257
1258 type = get_bvec_type(ctx, num_components);
1259 } else {
1260 for (int i = 0; i < num_components; i++)
1261 components[i] = emit_uint_const(ctx, bit_size,
1262 load_const->value[i].u32);
1263
1264 type = get_uvec_type(ctx, bit_size, num_components);
1265 }
1266 constant = spirv_builder_const_composite(&ctx->builder, type,
1267 components, num_components);
1268 } else {
1269 assert(num_components == 1);
1270 if (bit_size == 1)
1271 constant = spirv_builder_const_bool(&ctx->builder,
1272 load_const->value[0].b);
1273 else
1274 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1275 }
1276
1277 store_ssa_def(ctx, &load_const->def, constant);
1278 }
1279
1280 static void
1281 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1282 {
1283 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1284 assert(const_block_index); // no dynamic indexing for now
1285 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1286
1287 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1288 if (const_offset) {
1289 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1290 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1291 SpvStorageClassUniform,
1292 uvec4_type);
1293
1294 unsigned idx = const_offset->u32;
1295 SpvId member = emit_uint_const(ctx, 32, 0);
1296 SpvId offset = emit_uint_const(ctx, 32, idx);
1297 SpvId offsets[] = { member, offset };
1298 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1299 ctx->ubos[0], offsets,
1300 ARRAY_SIZE(offsets));
1301 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1302
1303 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1304 unsigned num_components = nir_dest_num_components(intr->dest);
1305 if (num_components == 1) {
1306 uint32_t components[] = { 0 };
1307 result = spirv_builder_emit_composite_extract(&ctx->builder,
1308 type,
1309 result, components,
1310 1);
1311 } else if (num_components < 4) {
1312 SpvId constituents[num_components];
1313 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1314 for (uint32_t i = 0; i < num_components; ++i)
1315 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1316 uint_type,
1317 result, &i,
1318 1);
1319
1320 result = spirv_builder_emit_composite_construct(&ctx->builder,
1321 type,
1322 constituents,
1323 num_components);
1324 }
1325
1326 if (nir_dest_bit_size(intr->dest) == 1)
1327 result = uvec_to_bvec(ctx, result, num_components);
1328
1329 store_dest(ctx, &intr->dest, result, nir_type_uint);
1330 } else
1331 unreachable("uniform-addressing not yet supported");
1332 }
1333
1334 static void
1335 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1336 {
1337 assert(ctx->block_started);
1338 spirv_builder_emit_kill(&ctx->builder);
1339 /* discard is weird in NIR, so let's just create an unreachable block after
1340 it and hope that the vulkan driver will DCE any instructinos in it. */
1341 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1342 }
1343
1344 static void
1345 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1346 {
1347 SpvId ptr = get_src(ctx, intr->src);
1348
1349 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1350 SpvId result = spirv_builder_emit_load(&ctx->builder,
1351 get_glsl_type(ctx, var->type),
1352 ptr);
1353 unsigned num_components = nir_dest_num_components(intr->dest);
1354 unsigned bit_size = nir_dest_bit_size(intr->dest);
1355 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1356 store_dest(ctx, &intr->dest, result, nir_type_uint);
1357 }
1358
1359 static void
1360 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1361 {
1362 SpvId ptr = get_src(ctx, &intr->src[0]);
1363 SpvId src = get_src(ctx, &intr->src[1]);
1364
1365 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1366 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1367 SpvId result = emit_bitcast(ctx, type, src);
1368 spirv_builder_emit_store(&ctx->builder, ptr, result);
1369 }
1370
1371 static SpvId
1372 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1373 SpvStorageClass storage_class,
1374 const char *name, SpvBuiltIn builtin)
1375 {
1376 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1377 storage_class,
1378 var_type);
1379 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1380 storage_class);
1381 spirv_builder_emit_name(&ctx->builder, var, name);
1382 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1383
1384 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1385 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1386 return var;
1387 }
1388
1389 static void
1390 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1391 {
1392 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1393 if (!ctx->front_face_var)
1394 ctx->front_face_var = create_builtin_var(ctx, var_type,
1395 SpvStorageClassInput,
1396 "gl_FrontFacing",
1397 SpvBuiltInFrontFacing);
1398
1399 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1400 ctx->front_face_var);
1401 assert(1 == nir_dest_num_components(intr->dest));
1402 store_dest(ctx, &intr->dest, result, nir_type_bool);
1403 }
1404
1405 static void
1406 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1407 {
1408 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1409 if (!ctx->instance_id_var)
1410 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1411 SpvStorageClassInput,
1412 "gl_InstanceId",
1413 SpvBuiltInInstanceIndex);
1414
1415 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1416 ctx->instance_id_var);
1417 assert(1 == nir_dest_num_components(intr->dest));
1418 store_dest(ctx, &intr->dest, result, nir_type_uint);
1419 }
1420
1421 static void
1422 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1423 {
1424 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1425 if (!ctx->vertex_id_var)
1426 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1427 SpvStorageClassInput,
1428 "gl_VertexID",
1429 SpvBuiltInVertexIndex);
1430
1431 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1432 ctx->vertex_id_var);
1433 assert(1 == nir_dest_num_components(intr->dest));
1434 store_dest(ctx, &intr->dest, result, nir_type_uint);
1435 }
1436
1437 static void
1438 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1439 {
1440 switch (intr->intrinsic) {
1441 case nir_intrinsic_load_ubo:
1442 emit_load_ubo(ctx, intr);
1443 break;
1444
1445 case nir_intrinsic_discard:
1446 emit_discard(ctx, intr);
1447 break;
1448
1449 case nir_intrinsic_load_deref:
1450 emit_load_deref(ctx, intr);
1451 break;
1452
1453 case nir_intrinsic_store_deref:
1454 emit_store_deref(ctx, intr);
1455 break;
1456
1457 case nir_intrinsic_load_front_face:
1458 emit_load_front_face(ctx, intr);
1459 break;
1460
1461 case nir_intrinsic_load_instance_id:
1462 emit_load_instance_id(ctx, intr);
1463 break;
1464
1465 case nir_intrinsic_load_vertex_id:
1466 emit_load_vertex_id(ctx, intr);
1467 break;
1468
1469 default:
1470 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1471 nir_intrinsic_infos[intr->intrinsic].name);
1472 unreachable("unsupported intrinsic");
1473 }
1474 }
1475
1476 static void
1477 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1478 {
1479 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1480 undef->def.num_components);
1481
1482 store_ssa_def(ctx, &undef->def,
1483 spirv_builder_emit_undef(&ctx->builder, type));
1484 }
1485
1486 static SpvId
1487 get_src_float(struct ntv_context *ctx, nir_src *src)
1488 {
1489 SpvId def = get_src(ctx, src);
1490 unsigned num_components = nir_src_num_components(*src);
1491 unsigned bit_size = nir_src_bit_size(*src);
1492 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1493 }
1494
1495 static SpvId
1496 get_src_int(struct ntv_context *ctx, nir_src *src)
1497 {
1498 SpvId def = get_src(ctx, src);
1499 unsigned num_components = nir_src_num_components(*src);
1500 unsigned bit_size = nir_src_bit_size(*src);
1501 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1502 }
1503
1504 static void
1505 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1506 {
1507 assert(tex->op == nir_texop_tex ||
1508 tex->op == nir_texop_txb ||
1509 tex->op == nir_texop_txl ||
1510 tex->op == nir_texop_txd ||
1511 tex->op == nir_texop_txf ||
1512 tex->op == nir_texop_txf_ms ||
1513 tex->op == nir_texop_txs);
1514 assert(tex->texture_index == tex->sampler_index);
1515
1516 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1517 offset = 0, sample = 0;
1518 unsigned coord_components = 0;
1519 for (unsigned i = 0; i < tex->num_srcs; i++) {
1520 switch (tex->src[i].src_type) {
1521 case nir_tex_src_coord:
1522 if (tex->op == nir_texop_txf ||
1523 tex->op == nir_texop_txf_ms)
1524 coord = get_src_int(ctx, &tex->src[i].src);
1525 else
1526 coord = get_src_float(ctx, &tex->src[i].src);
1527 coord_components = nir_src_num_components(tex->src[i].src);
1528 break;
1529
1530 case nir_tex_src_projector:
1531 assert(nir_src_num_components(tex->src[i].src) == 1);
1532 proj = get_src_float(ctx, &tex->src[i].src);
1533 assert(proj != 0);
1534 break;
1535
1536 case nir_tex_src_offset:
1537 offset = get_src_int(ctx, &tex->src[i].src);
1538 break;
1539
1540 case nir_tex_src_bias:
1541 assert(tex->op == nir_texop_txb);
1542 bias = get_src_float(ctx, &tex->src[i].src);
1543 assert(bias != 0);
1544 break;
1545
1546 case nir_tex_src_lod:
1547 assert(nir_src_num_components(tex->src[i].src) == 1);
1548 if (tex->op == nir_texop_txf ||
1549 tex->op == nir_texop_txf_ms ||
1550 tex->op == nir_texop_txs)
1551 lod = get_src_int(ctx, &tex->src[i].src);
1552 else
1553 lod = get_src_float(ctx, &tex->src[i].src);
1554 assert(lod != 0);
1555 break;
1556
1557 case nir_tex_src_ms_index:
1558 assert(nir_src_num_components(tex->src[i].src) == 1);
1559 sample = get_src_int(ctx, &tex->src[i].src);
1560 break;
1561
1562 case nir_tex_src_comparator:
1563 assert(nir_src_num_components(tex->src[i].src) == 1);
1564 dref = get_src_float(ctx, &tex->src[i].src);
1565 assert(dref != 0);
1566 break;
1567
1568 case nir_tex_src_ddx:
1569 dx = get_src_float(ctx, &tex->src[i].src);
1570 assert(dx != 0);
1571 break;
1572
1573 case nir_tex_src_ddy:
1574 dy = get_src_float(ctx, &tex->src[i].src);
1575 assert(dy != 0);
1576 break;
1577
1578 default:
1579 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1580 unreachable("unknown texture source");
1581 }
1582 }
1583
1584 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1585 lod = emit_float_const(ctx, 32, 0.0f);
1586 assert(lod != 0);
1587 }
1588
1589 SpvId image_type = ctx->image_types[tex->texture_index];
1590 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1591 image_type);
1592
1593 assert(ctx->samplers_used & (1u << tex->texture_index));
1594 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1595 ctx->samplers[tex->texture_index]);
1596
1597 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1598
1599 if (tex->op == nir_texop_txs) {
1600 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1601 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1602 dest_type, image,
1603 lod);
1604 store_dest(ctx, &tex->dest, result, tex->dest_type);
1605 return;
1606 }
1607
1608 if (proj && coord_components > 0) {
1609 SpvId constituents[coord_components + 1];
1610 if (coord_components == 1)
1611 constituents[0] = coord;
1612 else {
1613 assert(coord_components > 1);
1614 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1615 for (uint32_t i = 0; i < coord_components; ++i)
1616 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1617 float_type,
1618 coord,
1619 &i, 1);
1620 }
1621
1622 constituents[coord_components++] = proj;
1623
1624 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1625 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1626 vec_type,
1627 constituents,
1628 coord_components);
1629 }
1630
1631 SpvId actual_dest_type = dest_type;
1632 if (dref)
1633 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1634
1635 SpvId result;
1636 if (tex->op == nir_texop_txf ||
1637 tex->op == nir_texop_txf_ms) {
1638 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1639 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1640 image, coord, lod, sample);
1641 } else {
1642 result = spirv_builder_emit_image_sample(&ctx->builder,
1643 actual_dest_type, load,
1644 coord,
1645 proj != 0,
1646 lod, bias, dref, dx, dy,
1647 offset);
1648 }
1649
1650 spirv_builder_emit_decoration(&ctx->builder, result,
1651 SpvDecorationRelaxedPrecision);
1652
1653 if (dref && nir_dest_num_components(tex->dest) > 1) {
1654 SpvId components[4] = { result, result, result, result };
1655 result = spirv_builder_emit_composite_construct(&ctx->builder,
1656 dest_type,
1657 components,
1658 4);
1659 }
1660
1661 store_dest(ctx, &tex->dest, result, tex->dest_type);
1662 }
1663
1664 static void
1665 start_block(struct ntv_context *ctx, SpvId label)
1666 {
1667 /* terminate previous block if needed */
1668 if (ctx->block_started)
1669 spirv_builder_emit_branch(&ctx->builder, label);
1670
1671 /* start new block */
1672 spirv_builder_label(&ctx->builder, label);
1673 ctx->block_started = true;
1674 }
1675
1676 static void
1677 branch(struct ntv_context *ctx, SpvId label)
1678 {
1679 assert(ctx->block_started);
1680 spirv_builder_emit_branch(&ctx->builder, label);
1681 ctx->block_started = false;
1682 }
1683
1684 static void
1685 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1686 SpvId else_id)
1687 {
1688 assert(ctx->block_started);
1689 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1690 then_id, else_id);
1691 ctx->block_started = false;
1692 }
1693
1694 static void
1695 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1696 {
1697 switch (jump->type) {
1698 case nir_jump_break:
1699 assert(ctx->loop_break);
1700 branch(ctx, ctx->loop_break);
1701 break;
1702
1703 case nir_jump_continue:
1704 assert(ctx->loop_cont);
1705 branch(ctx, ctx->loop_cont);
1706 break;
1707
1708 default:
1709 unreachable("Unsupported jump type\n");
1710 }
1711 }
1712
1713 static void
1714 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1715 {
1716 assert(deref->deref_type == nir_deref_type_var);
1717
1718 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1719 assert(he);
1720 SpvId result = (SpvId)(intptr_t)he->data;
1721 /* uint is a bit of a lie here, it's really just an opaque type */
1722 store_dest(ctx, &deref->dest, result, nir_type_uint);
1723 }
1724
1725 static void
1726 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1727 {
1728 assert(deref->deref_type == nir_deref_type_array);
1729 nir_variable *var = nir_deref_instr_get_variable(deref);
1730
1731 SpvStorageClass storage_class;
1732 switch (var->data.mode) {
1733 case nir_var_shader_in:
1734 storage_class = SpvStorageClassInput;
1735 break;
1736
1737 case nir_var_shader_out:
1738 storage_class = SpvStorageClassOutput;
1739 break;
1740
1741 default:
1742 unreachable("Unsupported nir_variable_mode\n");
1743 }
1744
1745 SpvId index = get_src(ctx, &deref->arr.index);
1746
1747 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1748 storage_class,
1749 get_glsl_type(ctx, deref->type));
1750
1751 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1752 ptr_type,
1753 get_src(ctx, &deref->parent),
1754 &index, 1);
1755 /* uint is a bit of a lie here, it's really just an opaque type */
1756 store_dest(ctx, &deref->dest, result, nir_type_uint);
1757 }
1758
1759 static void
1760 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1761 {
1762 switch (deref->deref_type) {
1763 case nir_deref_type_var:
1764 emit_deref_var(ctx, deref);
1765 break;
1766
1767 case nir_deref_type_array:
1768 emit_deref_array(ctx, deref);
1769 break;
1770
1771 default:
1772 unreachable("unexpected deref_type");
1773 }
1774 }
1775
1776 static void
1777 emit_block(struct ntv_context *ctx, struct nir_block *block)
1778 {
1779 start_block(ctx, block_label(ctx, block));
1780 nir_foreach_instr(instr, block) {
1781 switch (instr->type) {
1782 case nir_instr_type_alu:
1783 emit_alu(ctx, nir_instr_as_alu(instr));
1784 break;
1785 case nir_instr_type_intrinsic:
1786 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1787 break;
1788 case nir_instr_type_load_const:
1789 emit_load_const(ctx, nir_instr_as_load_const(instr));
1790 break;
1791 case nir_instr_type_ssa_undef:
1792 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1793 break;
1794 case nir_instr_type_tex:
1795 emit_tex(ctx, nir_instr_as_tex(instr));
1796 break;
1797 case nir_instr_type_phi:
1798 unreachable("nir_instr_type_phi not supported");
1799 break;
1800 case nir_instr_type_jump:
1801 emit_jump(ctx, nir_instr_as_jump(instr));
1802 break;
1803 case nir_instr_type_call:
1804 unreachable("nir_instr_type_call not supported");
1805 break;
1806 case nir_instr_type_parallel_copy:
1807 unreachable("nir_instr_type_parallel_copy not supported");
1808 break;
1809 case nir_instr_type_deref:
1810 emit_deref(ctx, nir_instr_as_deref(instr));
1811 break;
1812 }
1813 }
1814 }
1815
1816 static void
1817 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1818
1819 static SpvId
1820 get_src_bool(struct ntv_context *ctx, nir_src *src)
1821 {
1822 assert(nir_src_bit_size(*src) == 1);
1823 return get_src(ctx, src);
1824 }
1825
1826 static void
1827 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1828 {
1829 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1830
1831 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1832 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1833 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1834 SpvId else_id = endif_id;
1835
1836 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1837 if (has_else) {
1838 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1839 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1840 }
1841
1842 /* create a header-block */
1843 start_block(ctx, header_id);
1844 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1845 SpvSelectionControlMaskNone);
1846 branch_conditional(ctx, condition, then_id, else_id);
1847
1848 emit_cf_list(ctx, &if_stmt->then_list);
1849
1850 if (has_else) {
1851 if (ctx->block_started)
1852 branch(ctx, endif_id);
1853
1854 emit_cf_list(ctx, &if_stmt->else_list);
1855 }
1856
1857 start_block(ctx, endif_id);
1858 }
1859
1860 static void
1861 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1862 {
1863 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1864 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1865 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1866 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1867
1868 /* create a header-block */
1869 start_block(ctx, header_id);
1870 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1871 branch(ctx, begin_id);
1872
1873 SpvId save_break = ctx->loop_break;
1874 SpvId save_cont = ctx->loop_cont;
1875 ctx->loop_break = break_id;
1876 ctx->loop_cont = cont_id;
1877
1878 emit_cf_list(ctx, &loop->body);
1879
1880 ctx->loop_break = save_break;
1881 ctx->loop_cont = save_cont;
1882
1883 branch(ctx, cont_id);
1884 start_block(ctx, cont_id);
1885 branch(ctx, header_id);
1886
1887 start_block(ctx, break_id);
1888 }
1889
1890 static void
1891 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1892 {
1893 foreach_list_typed(nir_cf_node, node, node, list) {
1894 switch (node->type) {
1895 case nir_cf_node_block:
1896 emit_block(ctx, nir_cf_node_as_block(node));
1897 break;
1898
1899 case nir_cf_node_if:
1900 emit_if(ctx, nir_cf_node_as_if(node));
1901 break;
1902
1903 case nir_cf_node_loop:
1904 emit_loop(ctx, nir_cf_node_as_loop(node));
1905 break;
1906
1907 case nir_cf_node_function:
1908 unreachable("nir_cf_node_function not supported");
1909 break;
1910 }
1911 }
1912 }
1913
1914 struct spirv_shader *
1915 nir_to_spirv(struct nir_shader *s)
1916 {
1917 struct spirv_shader *ret = NULL;
1918
1919 struct ntv_context ctx = {};
1920
1921 switch (s->info.stage) {
1922 case MESA_SHADER_VERTEX:
1923 case MESA_SHADER_FRAGMENT:
1924 case MESA_SHADER_COMPUTE:
1925 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1926 break;
1927
1928 case MESA_SHADER_TESS_CTRL:
1929 case MESA_SHADER_TESS_EVAL:
1930 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1931 break;
1932
1933 case MESA_SHADER_GEOMETRY:
1934 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1935 break;
1936
1937 default:
1938 unreachable("invalid stage");
1939 }
1940
1941 // TODO: only enable when needed
1942 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1943 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1944 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1945 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
1946 }
1947
1948 ctx.stage = s->info.stage;
1949 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1950 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1951
1952 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1953 SpvMemoryModelGLSL450);
1954
1955 SpvExecutionModel exec_model;
1956 switch (s->info.stage) {
1957 case MESA_SHADER_VERTEX:
1958 exec_model = SpvExecutionModelVertex;
1959 break;
1960 case MESA_SHADER_TESS_CTRL:
1961 exec_model = SpvExecutionModelTessellationControl;
1962 break;
1963 case MESA_SHADER_TESS_EVAL:
1964 exec_model = SpvExecutionModelTessellationEvaluation;
1965 break;
1966 case MESA_SHADER_GEOMETRY:
1967 exec_model = SpvExecutionModelGeometry;
1968 break;
1969 case MESA_SHADER_FRAGMENT:
1970 exec_model = SpvExecutionModelFragment;
1971 break;
1972 case MESA_SHADER_COMPUTE:
1973 exec_model = SpvExecutionModelGLCompute;
1974 break;
1975 default:
1976 unreachable("invalid stage");
1977 }
1978
1979 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1980 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1981 NULL, 0);
1982 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1983 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1984
1985 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1986 _mesa_key_pointer_equal);
1987
1988 nir_foreach_variable(var, &s->inputs)
1989 emit_input(&ctx, var);
1990
1991 nir_foreach_variable(var, &s->outputs)
1992 emit_output(&ctx, var);
1993
1994 nir_foreach_variable(var, &s->uniforms)
1995 emit_uniform(&ctx, var);
1996
1997 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1998 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1999 SpvExecutionModeOriginUpperLeft);
2000 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2001 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2002 SpvExecutionModeDepthReplacing);
2003 }
2004
2005
2006 spirv_builder_function(&ctx.builder, entry_point, type_void,
2007 SpvFunctionControlMaskNone,
2008 type_main);
2009
2010 nir_function_impl *entry = nir_shader_get_entrypoint(s);
2011 nir_metadata_require(entry, nir_metadata_block_index);
2012
2013 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
2014 if (!ctx.defs)
2015 goto fail;
2016 ctx.num_defs = entry->ssa_alloc;
2017
2018 nir_index_local_regs(entry);
2019 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
2020 if (!ctx.regs)
2021 goto fail;
2022 ctx.num_regs = entry->reg_alloc;
2023
2024 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
2025 if (!block_ids)
2026 goto fail;
2027
2028 for (int i = 0; i < entry->num_blocks; ++i)
2029 block_ids[i] = spirv_builder_new_id(&ctx.builder);
2030
2031 ctx.block_ids = block_ids;
2032 ctx.num_blocks = entry->num_blocks;
2033
2034 /* emit a block only for the variable declarations */
2035 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2036 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2037 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
2038 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2039 SpvStorageClassFunction,
2040 type);
2041 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2042 SpvStorageClassFunction);
2043
2044 ctx.regs[reg->index] = var;
2045 }
2046
2047 emit_cf_list(&ctx, &entry->body);
2048
2049 free(ctx.defs);
2050
2051 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2052 spirv_builder_function_end(&ctx.builder);
2053
2054 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2055 "main", ctx.entry_ifaces,
2056 ctx.num_entry_ifaces);
2057
2058 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2059
2060 ret = CALLOC_STRUCT(spirv_shader);
2061 if (!ret)
2062 goto fail;
2063
2064 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2065 if (!ret->words)
2066 goto fail;
2067
2068 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2069 assert(ret->num_words == num_words);
2070
2071 return ret;
2072
2073 fail:
2074
2075 if (ret)
2076 spirv_shader_delete(ret);
2077
2078 if (ctx.vars)
2079 _mesa_hash_table_destroy(ctx.vars, NULL);
2080
2081 return NULL;
2082 }
2083
2084 void
2085 spirv_shader_delete(struct spirv_shader *s)
2086 {
2087 FREE(s->words);
2088 FREE(s);
2089 }