zink/spirv: do not use bitwise operations on booleans
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 unsigned samplers_used : PIPE_MAX_SAMPLERS;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
172 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
335 break;
336
337 case FRAG_RESULT_DEPTH:
338 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
339 break;
340
341 default:
342 spirv_builder_emit_location(&ctx->builder, var_id,
343 var->data.driver_location);
344 }
345 }
346 }
347
348 if (var->data.location_frac)
349 spirv_builder_emit_component(&ctx->builder, var_id,
350 var->data.location_frac);
351
352 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
353
354 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
355 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
356 }
357
358 static SpvDim
359 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
360 {
361 *is_ms = false;
362 switch (gdim) {
363 case GLSL_SAMPLER_DIM_1D:
364 return SpvDim1D;
365 case GLSL_SAMPLER_DIM_2D:
366 return SpvDim2D;
367 case GLSL_SAMPLER_DIM_3D:
368 return SpvDim3D;
369 case GLSL_SAMPLER_DIM_CUBE:
370 return SpvDimCube;
371 case GLSL_SAMPLER_DIM_RECT:
372 return SpvDim2D;
373 case GLSL_SAMPLER_DIM_BUF:
374 return SpvDimBuffer;
375 case GLSL_SAMPLER_DIM_EXTERNAL:
376 return SpvDim2D; /* seems dodgy... */
377 case GLSL_SAMPLER_DIM_MS:
378 *is_ms = true;
379 return SpvDim2D;
380 default:
381 fprintf(stderr, "unknown sampler type %d\n", gdim);
382 break;
383 }
384 return SpvDim2D;
385 }
386
387 uint32_t
388 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
389 {
390 if (stage == MESA_SHADER_NONE ||
391 stage >= MESA_SHADER_COMPUTE) {
392 unreachable("not supported");
393 } else {
394 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
395 PIPE_MAX_SHADER_SAMPLER_VIEWS);
396
397 switch (type) {
398 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
399 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
400 return stage_offset + index;
401
402 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
403 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
404 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
405
406 default:
407 unreachable("unexpected type");
408 }
409 }
410 }
411
412 static void
413 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
414 {
415 const struct glsl_type *type = glsl_without_array(var->type);
416
417 bool is_ms;
418 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
419
420 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
421 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
422 dimension, false,
423 glsl_sampler_type_is_array(type),
424 is_ms, 1,
425 SpvImageFormatUnknown);
426
427 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
428 image_type);
429 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
430 SpvStorageClassUniformConstant,
431 sampled_type);
432
433 if (glsl_type_is_array(var->type)) {
434 for (int i = 0; i < glsl_get_length(var->type); ++i) {
435 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
436 SpvStorageClassUniformConstant);
437
438 if (var->name) {
439 char element_name[100];
440 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
441 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
442 }
443
444 int index = var->data.binding + i;
445 assert(!(ctx->samplers_used & (1 << index)));
446 assert(!ctx->image_types[index]);
447 ctx->image_types[index] = image_type;
448 ctx->samplers[index] = var_id;
449 ctx->samplers_used |= 1 << index;
450
451 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
452 var->data.descriptor_set);
453 int binding = zink_binding(ctx->stage,
454 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
455 var->data.binding + i);
456 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
457 }
458 } else {
459 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
460 SpvStorageClassUniformConstant);
461
462 if (var->name)
463 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
464
465 int index = var->data.binding;
466 assert(!(ctx->samplers_used & (1 << index)));
467 assert(!ctx->image_types[index]);
468 ctx->image_types[index] = image_type;
469 ctx->samplers[index] = var_id;
470 ctx->samplers_used |= 1 << index;
471
472 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
473 var->data.descriptor_set);
474 int binding = zink_binding(ctx->stage,
475 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
476 var->data.binding);
477 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
478 }
479 }
480
481 static void
482 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
483 {
484 uint32_t size = glsl_count_attribute_slots(var->type, false);
485 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
486 SpvId array_length = emit_uint_const(ctx, 32, size);
487 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
488 array_length);
489 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
490
491 // wrap UBO-array in a struct
492 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
493 if (var->name) {
494 char struct_name[100];
495 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
496 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
497 }
498
499 spirv_builder_emit_decoration(&ctx->builder, struct_type,
500 SpvDecorationBlock);
501 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
502
503
504 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
505 SpvStorageClassUniform,
506 struct_type);
507
508 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
509 SpvStorageClassUniform);
510 if (var->name)
511 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
512
513 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
514 ctx->ubos[ctx->num_ubos++] = var_id;
515
516 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
517 var->data.descriptor_set);
518 int binding = zink_binding(ctx->stage,
519 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
520 var->data.binding);
521 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
522 }
523
524 static void
525 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
526 {
527 if (var->data.mode == nir_var_mem_ubo)
528 emit_ubo(ctx, var);
529 else {
530 assert(var->data.mode == nir_var_uniform);
531 if (glsl_type_is_sampler(glsl_without_array(var->type)))
532 emit_sampler(ctx, var);
533 }
534 }
535
536 static SpvId
537 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
538 {
539 assert(ssa->index < ctx->num_defs);
540 assert(ctx->defs[ssa->index] != 0);
541 return ctx->defs[ssa->index];
542 }
543
544 static SpvId
545 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
546 {
547 assert(reg->index < ctx->num_regs);
548 assert(ctx->regs[reg->index] != 0);
549 return ctx->regs[reg->index];
550 }
551
552 static SpvId
553 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
554 {
555 assert(reg->reg);
556 assert(!reg->indirect);
557 assert(!reg->base_offset);
558
559 SpvId var = get_var_from_reg(ctx, reg->reg);
560 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
561 return spirv_builder_emit_load(&ctx->builder, type, var);
562 }
563
564 static SpvId
565 get_src(struct ntv_context *ctx, nir_src *src)
566 {
567 if (src->is_ssa)
568 return get_src_ssa(ctx, src->ssa);
569 else
570 return get_src_reg(ctx, &src->reg);
571 }
572
573 static SpvId
574 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
575 {
576 assert(!alu->src[src].negate);
577 assert(!alu->src[src].abs);
578
579 SpvId def = get_src(ctx, &alu->src[src].src);
580
581 unsigned used_channels = 0;
582 bool need_swizzle = false;
583 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
584 if (!nir_alu_instr_channel_used(alu, src, i))
585 continue;
586
587 used_channels++;
588
589 if (alu->src[src].swizzle[i] != i)
590 need_swizzle = true;
591 }
592 assert(used_channels != 0);
593
594 unsigned live_channels = nir_src_num_components(alu->src[src].src);
595 if (used_channels != live_channels)
596 need_swizzle = true;
597
598 if (!need_swizzle)
599 return def;
600
601 int bit_size = nir_src_bit_size(alu->src[src].src);
602 assert(bit_size == 1 || bit_size == 32);
603
604 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
605 spirv_builder_type_uint(&ctx->builder, bit_size);
606
607 if (used_channels == 1) {
608 uint32_t indices[] = { alu->src[src].swizzle[0] };
609 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
610 def, indices,
611 ARRAY_SIZE(indices));
612 } else if (live_channels == 1) {
613 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
614 raw_type,
615 used_channels);
616
617 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
618 for (unsigned i = 0; i < used_channels; ++i)
619 constituents[i] = def;
620
621 return spirv_builder_emit_composite_construct(&ctx->builder,
622 raw_vec_type,
623 constituents,
624 used_channels);
625 } else {
626 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
627 raw_type,
628 used_channels);
629
630 uint32_t components[NIR_MAX_VEC_COMPONENTS];
631 size_t num_components = 0;
632 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
633 if (!nir_alu_instr_channel_used(alu, src, i))
634 continue;
635
636 components[num_components++] = alu->src[src].swizzle[i];
637 }
638
639 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
640 def, def, components,
641 num_components);
642 }
643 }
644
645 static void
646 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
647 {
648 assert(result != 0);
649 assert(ssa->index < ctx->num_defs);
650 ctx->defs[ssa->index] = result;
651 }
652
653 static SpvId
654 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
655 SpvId if_true, SpvId if_false)
656 {
657 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
658 }
659
660 static SpvId
661 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
662 {
663 SpvId type = get_bvec_type(ctx, num_components);
664 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
665 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
666 }
667
668 static SpvId
669 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
670 {
671 return emit_unop(ctx, SpvOpBitcast, type, value);
672 }
673
674 static SpvId
675 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
676 unsigned num_components)
677 {
678 SpvId type = get_uvec_type(ctx, bit_size, num_components);
679 return emit_bitcast(ctx, type, value);
680 }
681
682 static SpvId
683 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
684 unsigned num_components)
685 {
686 SpvId type = get_ivec_type(ctx, bit_size, num_components);
687 return emit_bitcast(ctx, type, value);
688 }
689
690 static SpvId
691 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
692 unsigned num_components)
693 {
694 SpvId type = get_fvec_type(ctx, bit_size, num_components);
695 return emit_bitcast(ctx, type, value);
696 }
697
698 static void
699 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
700 {
701 SpvId var = get_var_from_reg(ctx, reg->reg);
702 assert(var);
703 spirv_builder_emit_store(&ctx->builder, var, result);
704 }
705
706 static void
707 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
708 {
709 if (dest->is_ssa)
710 store_ssa_def(ctx, &dest->ssa, result);
711 else
712 store_reg_def(ctx, &dest->reg, result);
713 }
714
715 static void
716 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
717 {
718 unsigned num_components = nir_dest_num_components(*dest);
719 unsigned bit_size = nir_dest_bit_size(*dest);
720
721 if (bit_size != 1) {
722 switch (nir_alu_type_get_base_type(type)) {
723 case nir_type_bool:
724 assert("bool should have bit-size 1");
725
726 case nir_type_uint:
727 break; /* nothing to do! */
728
729 case nir_type_int:
730 case nir_type_float:
731 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
732 break;
733
734 default:
735 unreachable("unsupported nir_alu_type");
736 }
737 }
738
739 store_dest_raw(ctx, dest, result);
740 }
741
742 static SpvId
743 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
744 {
745 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
746 }
747
748 static SpvId
749 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
750 SpvId src0, SpvId src1)
751 {
752 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
753 }
754
755 static SpvId
756 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
757 SpvId src0, SpvId src1, SpvId src2)
758 {
759 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
760 }
761
762 static SpvId
763 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
764 SpvId src)
765 {
766 SpvId args[] = { src };
767 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
768 op, args, ARRAY_SIZE(args));
769 }
770
771 static SpvId
772 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
773 SpvId src0, SpvId src1)
774 {
775 SpvId args[] = { src0, src1 };
776 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
777 op, args, ARRAY_SIZE(args));
778 }
779
780 static SpvId
781 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
782 SpvId src0, SpvId src1, SpvId src2)
783 {
784 SpvId args[] = { src0, src1, src2 };
785 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
786 op, args, ARRAY_SIZE(args));
787 }
788
789 static SpvId
790 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
791 unsigned num_components, float value)
792 {
793 assert(bit_size == 32);
794
795 SpvId result = emit_float_const(ctx, bit_size, value);
796 if (num_components == 1)
797 return result;
798
799 assert(num_components > 1);
800 SpvId components[num_components];
801 for (int i = 0; i < num_components; i++)
802 components[i] = result;
803
804 SpvId type = get_fvec_type(ctx, bit_size, num_components);
805 return spirv_builder_const_composite(&ctx->builder, type, components,
806 num_components);
807 }
808
809 static SpvId
810 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
811 unsigned num_components, uint32_t value)
812 {
813 assert(bit_size == 32);
814
815 SpvId result = emit_uint_const(ctx, bit_size, value);
816 if (num_components == 1)
817 return result;
818
819 assert(num_components > 1);
820 SpvId components[num_components];
821 for (int i = 0; i < num_components; i++)
822 components[i] = result;
823
824 SpvId type = get_uvec_type(ctx, bit_size, num_components);
825 return spirv_builder_const_composite(&ctx->builder, type, components,
826 num_components);
827 }
828
829 static SpvId
830 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
831 unsigned num_components, int32_t value)
832 {
833 assert(bit_size == 32);
834
835 SpvId result = emit_int_const(ctx, bit_size, value);
836 if (num_components == 1)
837 return result;
838
839 assert(num_components > 1);
840 SpvId components[num_components];
841 for (int i = 0; i < num_components; i++)
842 components[i] = result;
843
844 SpvId type = get_ivec_type(ctx, bit_size, num_components);
845 return spirv_builder_const_composite(&ctx->builder, type, components,
846 num_components);
847 }
848
849 static inline unsigned
850 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
851 {
852 if (nir_op_infos[instr->op].input_sizes[src] > 0)
853 return nir_op_infos[instr->op].input_sizes[src];
854
855 if (instr->dest.dest.is_ssa)
856 return instr->dest.dest.ssa.num_components;
857 else
858 return instr->dest.dest.reg.reg->num_components;
859 }
860
861 static SpvId
862 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
863 {
864 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
865
866 unsigned num_components = alu_instr_src_components(alu, src);
867 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
868 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
869
870 if (bit_size == 1)
871 return raw_value;
872 else {
873 switch (nir_alu_type_get_base_type(type)) {
874 case nir_type_bool:
875 unreachable("bool should have bit-size 1");
876
877 case nir_type_int:
878 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
879
880 case nir_type_uint:
881 return raw_value;
882
883 case nir_type_float:
884 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
885
886 default:
887 unreachable("unknown nir_alu_type");
888 }
889 }
890 }
891
892 static void
893 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
894 {
895 assert(!alu->dest.saturate);
896 return store_dest(ctx, &alu->dest.dest, result,
897 nir_op_infos[alu->op].output_type);
898 }
899
900 static SpvId
901 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
902 {
903 unsigned num_components = nir_dest_num_components(*dest);
904 unsigned bit_size = nir_dest_bit_size(*dest);
905
906 if (bit_size == 1)
907 return get_bvec_type(ctx, num_components);
908
909 switch (nir_alu_type_get_base_type(type)) {
910 case nir_type_bool:
911 unreachable("bool should have bit-size 1");
912
913 case nir_type_int:
914 return get_ivec_type(ctx, bit_size, num_components);
915
916 case nir_type_uint:
917 return get_uvec_type(ctx, bit_size, num_components);
918
919 case nir_type_float:
920 return get_fvec_type(ctx, bit_size, num_components);
921
922 default:
923 unreachable("unsupported nir_alu_type");
924 }
925 }
926
927 static void
928 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
929 {
930 SpvId src[nir_op_infos[alu->op].num_inputs];
931 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
932 src[i] = get_alu_src(ctx, alu, i);
933
934 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
935 nir_op_infos[alu->op].output_type);
936 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
937 unsigned num_components = nir_dest_num_components(alu->dest.dest);
938
939 SpvId result = 0;
940 switch (alu->op) {
941 case nir_op_mov:
942 assert(nir_op_infos[alu->op].num_inputs == 1);
943 result = src[0];
944 break;
945
946 #define UNOP(nir_op, spirv_op) \
947 case nir_op: \
948 assert(nir_op_infos[alu->op].num_inputs == 1); \
949 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
950 break;
951
952 UNOP(nir_op_ineg, SpvOpSNegate)
953 UNOP(nir_op_fneg, SpvOpFNegate)
954 UNOP(nir_op_fddx, SpvOpDPdx)
955 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
956 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
957 UNOP(nir_op_fddy, SpvOpDPdy)
958 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
959 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
960 UNOP(nir_op_f2i32, SpvOpConvertFToS)
961 UNOP(nir_op_f2u32, SpvOpConvertFToU)
962 UNOP(nir_op_i2f32, SpvOpConvertSToF)
963 UNOP(nir_op_u2f32, SpvOpConvertUToF)
964 #undef UNOP
965
966 case nir_op_inot:
967 if (bit_size == 1)
968 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
969 else
970 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
971 break;
972
973 case nir_op_b2i32:
974 assert(nir_op_infos[alu->op].num_inputs == 1);
975 result = emit_select(ctx, dest_type, src[0],
976 get_ivec_constant(ctx, 32, num_components, 1),
977 get_ivec_constant(ctx, 32, num_components, 0));
978 break;
979
980 case nir_op_b2f32:
981 assert(nir_op_infos[alu->op].num_inputs == 1);
982 result = emit_select(ctx, dest_type, src[0],
983 get_fvec_constant(ctx, 32, num_components, 1),
984 get_fvec_constant(ctx, 32, num_components, 0));
985 break;
986
987 #define BUILTIN_UNOP(nir_op, spirv_op) \
988 case nir_op: \
989 assert(nir_op_infos[alu->op].num_inputs == 1); \
990 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
991 break;
992
993 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
994 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
995 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
996 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
997 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
998 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
999 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1000 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1001 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1002 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1003 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1004 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1005 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1006 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1007 #undef BUILTIN_UNOP
1008
1009 case nir_op_frcp:
1010 assert(nir_op_infos[alu->op].num_inputs == 1);
1011 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1012 get_fvec_constant(ctx, bit_size, num_components, 1),
1013 src[0]);
1014 break;
1015
1016 case nir_op_f2b1:
1017 assert(nir_op_infos[alu->op].num_inputs == 1);
1018 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1019 get_fvec_constant(ctx,
1020 nir_src_bit_size(alu->src[0].src),
1021 num_components, 0));
1022 break;
1023
1024
1025 #define BINOP(nir_op, spirv_op) \
1026 case nir_op: \
1027 assert(nir_op_infos[alu->op].num_inputs == 2); \
1028 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1029 break;
1030
1031 BINOP(nir_op_iadd, SpvOpIAdd)
1032 BINOP(nir_op_isub, SpvOpISub)
1033 BINOP(nir_op_imul, SpvOpIMul)
1034 BINOP(nir_op_idiv, SpvOpSDiv)
1035 BINOP(nir_op_udiv, SpvOpUDiv)
1036 BINOP(nir_op_umod, SpvOpUMod)
1037 BINOP(nir_op_fadd, SpvOpFAdd)
1038 BINOP(nir_op_fsub, SpvOpFSub)
1039 BINOP(nir_op_fmul, SpvOpFMul)
1040 BINOP(nir_op_fdiv, SpvOpFDiv)
1041 BINOP(nir_op_fmod, SpvOpFMod)
1042 BINOP(nir_op_ilt, SpvOpSLessThan)
1043 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1044 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1045 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1046 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1047 BINOP(nir_op_feq, SpvOpFOrdEqual)
1048 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1049 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1050 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1051 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1052 #undef BINOP
1053
1054 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1055 case nir_op: \
1056 assert(nir_op_infos[alu->op].num_inputs == 2); \
1057 if (nir_src_bit_size(alu->src[0].src) == 1) \
1058 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1059 else \
1060 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1061 break;
1062
1063 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1064 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1065 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1066 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1067 #undef BINOP_LOG
1068
1069 #define BUILTIN_BINOP(nir_op, spirv_op) \
1070 case nir_op: \
1071 assert(nir_op_infos[alu->op].num_inputs == 2); \
1072 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1073 break;
1074
1075 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1076 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1077 #undef BUILTIN_BINOP
1078
1079 case nir_op_fdot2:
1080 case nir_op_fdot3:
1081 case nir_op_fdot4:
1082 assert(nir_op_infos[alu->op].num_inputs == 2);
1083 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1084 break;
1085
1086 case nir_op_fdph:
1087 unreachable("should already be lowered away");
1088
1089 case nir_op_seq:
1090 case nir_op_sne:
1091 case nir_op_slt:
1092 case nir_op_sge: {
1093 assert(nir_op_infos[alu->op].num_inputs == 2);
1094 int num_components = nir_dest_num_components(alu->dest.dest);
1095 SpvId bool_type = get_bvec_type(ctx, num_components);
1096
1097 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1098 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1099 if (num_components > 1) {
1100 SpvId zero_comps[num_components], one_comps[num_components];
1101 for (int i = 0; i < num_components; i++) {
1102 zero_comps[i] = zero;
1103 one_comps[i] = one;
1104 }
1105
1106 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1107 zero_comps, num_components);
1108 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1109 one_comps, num_components);
1110 }
1111
1112 SpvOp op;
1113 switch (alu->op) {
1114 case nir_op_seq: op = SpvOpFOrdEqual; break;
1115 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1116 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1117 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1118 default: unreachable("unexpected op");
1119 }
1120
1121 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1122 result = emit_select(ctx, dest_type, result, one, zero);
1123 }
1124 break;
1125
1126 case nir_op_flrp:
1127 assert(nir_op_infos[alu->op].num_inputs == 3);
1128 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1129 src[0], src[1], src[2]);
1130 break;
1131
1132 case nir_op_fcsel:
1133 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1134 get_bvec_type(ctx, num_components),
1135 src[0],
1136 get_fvec_constant(ctx,
1137 nir_src_bit_size(alu->src[0].src),
1138 num_components, 0));
1139 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1140 break;
1141
1142 case nir_op_bcsel:
1143 assert(nir_op_infos[alu->op].num_inputs == 3);
1144 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1145 break;
1146
1147 case nir_op_bany_fnequal2:
1148 case nir_op_bany_fnequal3:
1149 case nir_op_bany_fnequal4:
1150 assert(nir_op_infos[alu->op].num_inputs == 2);
1151 assert(alu_instr_src_components(alu, 0) ==
1152 alu_instr_src_components(alu, 1));
1153 result = emit_binop(ctx, SpvOpFOrdNotEqual,
1154 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1155 src[0], src[1]);
1156 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1157 break;
1158
1159 case nir_op_ball_fequal2:
1160 case nir_op_ball_fequal3:
1161 case nir_op_ball_fequal4:
1162 assert(nir_op_infos[alu->op].num_inputs == 2);
1163 assert(alu_instr_src_components(alu, 0) ==
1164 alu_instr_src_components(alu, 1));
1165 result = emit_binop(ctx, SpvOpFOrdEqual,
1166 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1167 src[0], src[1]);
1168 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1169 break;
1170
1171 case nir_op_bany_inequal2:
1172 case nir_op_bany_inequal3:
1173 case nir_op_bany_inequal4:
1174 assert(nir_op_infos[alu->op].num_inputs == 2);
1175 assert(alu_instr_src_components(alu, 0) ==
1176 alu_instr_src_components(alu, 1));
1177 result = emit_binop(ctx, SpvOpINotEqual,
1178 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1179 src[0], src[1]);
1180 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1181 break;
1182
1183 case nir_op_ball_iequal2:
1184 case nir_op_ball_iequal3:
1185 case nir_op_ball_iequal4:
1186 assert(nir_op_infos[alu->op].num_inputs == 2);
1187 assert(alu_instr_src_components(alu, 0) ==
1188 alu_instr_src_components(alu, 1));
1189 result = emit_binop(ctx, SpvOpIEqual,
1190 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1191 src[0], src[1]);
1192 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1193 break;
1194
1195 case nir_op_vec2:
1196 case nir_op_vec3:
1197 case nir_op_vec4: {
1198 int num_inputs = nir_op_infos[alu->op].num_inputs;
1199 assert(2 <= num_inputs && num_inputs <= 4);
1200 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1201 src, num_inputs);
1202 }
1203 break;
1204
1205 default:
1206 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1207 nir_op_infos[alu->op].name);
1208
1209 unreachable("unsupported opcode");
1210 return;
1211 }
1212
1213 store_alu_result(ctx, alu, result);
1214 }
1215
1216 static void
1217 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1218 {
1219 unsigned bit_size = load_const->def.bit_size;
1220 unsigned num_components = load_const->def.num_components;
1221
1222 SpvId constant;
1223 if (num_components > 1) {
1224 SpvId components[num_components];
1225 SpvId type;
1226 if (bit_size == 1) {
1227 for (int i = 0; i < num_components; i++)
1228 components[i] = spirv_builder_const_bool(&ctx->builder,
1229 load_const->value[i].b);
1230
1231 type = get_bvec_type(ctx, num_components);
1232 } else {
1233 for (int i = 0; i < num_components; i++)
1234 components[i] = emit_uint_const(ctx, bit_size,
1235 load_const->value[i].u32);
1236
1237 type = get_uvec_type(ctx, bit_size, num_components);
1238 }
1239 constant = spirv_builder_const_composite(&ctx->builder, type,
1240 components, num_components);
1241 } else {
1242 assert(num_components == 1);
1243 if (bit_size == 1)
1244 constant = spirv_builder_const_bool(&ctx->builder,
1245 load_const->value[0].b);
1246 else
1247 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1248 }
1249
1250 store_ssa_def(ctx, &load_const->def, constant);
1251 }
1252
1253 static void
1254 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1255 {
1256 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1257 assert(const_block_index); // no dynamic indexing for now
1258 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1259
1260 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1261 if (const_offset) {
1262 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1263 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1264 SpvStorageClassUniform,
1265 uvec4_type);
1266
1267 unsigned idx = const_offset->u32;
1268 SpvId member = emit_uint_const(ctx, 32, 0);
1269 SpvId offset = emit_uint_const(ctx, 32, idx);
1270 SpvId offsets[] = { member, offset };
1271 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1272 ctx->ubos[0], offsets,
1273 ARRAY_SIZE(offsets));
1274 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1275
1276 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1277 unsigned num_components = nir_dest_num_components(intr->dest);
1278 if (num_components == 1) {
1279 uint32_t components[] = { 0 };
1280 result = spirv_builder_emit_composite_extract(&ctx->builder,
1281 type,
1282 result, components,
1283 1);
1284 } else if (num_components < 4) {
1285 SpvId constituents[num_components];
1286 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1287 for (uint32_t i = 0; i < num_components; ++i)
1288 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1289 uint_type,
1290 result, &i,
1291 1);
1292
1293 result = spirv_builder_emit_composite_construct(&ctx->builder,
1294 type,
1295 constituents,
1296 num_components);
1297 }
1298
1299 if (nir_dest_bit_size(intr->dest) == 1)
1300 result = uvec_to_bvec(ctx, result, num_components);
1301
1302 store_dest(ctx, &intr->dest, result, nir_type_uint);
1303 } else
1304 unreachable("uniform-addressing not yet supported");
1305 }
1306
1307 static void
1308 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1309 {
1310 assert(ctx->block_started);
1311 spirv_builder_emit_kill(&ctx->builder);
1312 /* discard is weird in NIR, so let's just create an unreachable block after
1313 it and hope that the vulkan driver will DCE any instructinos in it. */
1314 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1315 }
1316
1317 static void
1318 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1319 {
1320 SpvId ptr = get_src(ctx, intr->src);
1321
1322 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1323 SpvId result = spirv_builder_emit_load(&ctx->builder,
1324 get_glsl_type(ctx, var->type),
1325 ptr);
1326 unsigned num_components = nir_dest_num_components(intr->dest);
1327 unsigned bit_size = nir_dest_bit_size(intr->dest);
1328 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1329 store_dest(ctx, &intr->dest, result, nir_type_uint);
1330 }
1331
1332 static void
1333 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1334 {
1335 SpvId ptr = get_src(ctx, &intr->src[0]);
1336 SpvId src = get_src(ctx, &intr->src[1]);
1337
1338 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1339 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1340 SpvId result = emit_bitcast(ctx, type, src);
1341 spirv_builder_emit_store(&ctx->builder, ptr, result);
1342 }
1343
1344 static SpvId
1345 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1346 SpvStorageClass storage_class,
1347 const char *name, SpvBuiltIn builtin)
1348 {
1349 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1350 storage_class,
1351 var_type);
1352 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1353 storage_class);
1354 spirv_builder_emit_name(&ctx->builder, var, name);
1355 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1356
1357 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1358 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1359 return var;
1360 }
1361
1362 static void
1363 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1364 {
1365 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1366 if (!ctx->front_face_var)
1367 ctx->front_face_var = create_builtin_var(ctx, var_type,
1368 SpvStorageClassInput,
1369 "gl_FrontFacing",
1370 SpvBuiltInFrontFacing);
1371
1372 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1373 ctx->front_face_var);
1374 assert(1 == nir_dest_num_components(intr->dest));
1375 store_dest(ctx, &intr->dest, result, nir_type_bool);
1376 }
1377
1378 static void
1379 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1380 {
1381 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1382 if (!ctx->instance_id_var)
1383 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1384 SpvStorageClassInput,
1385 "gl_InstanceId",
1386 SpvBuiltInInstanceIndex);
1387
1388 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1389 ctx->instance_id_var);
1390 assert(1 == nir_dest_num_components(intr->dest));
1391 store_dest(ctx, &intr->dest, result, nir_type_uint);
1392 }
1393
1394 static void
1395 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1396 {
1397 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1398 if (!ctx->vertex_id_var)
1399 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1400 SpvStorageClassInput,
1401 "gl_VertexID",
1402 SpvBuiltInVertexIndex);
1403
1404 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1405 ctx->vertex_id_var);
1406 assert(1 == nir_dest_num_components(intr->dest));
1407 store_dest(ctx, &intr->dest, result, nir_type_uint);
1408 }
1409
1410 static void
1411 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1412 {
1413 switch (intr->intrinsic) {
1414 case nir_intrinsic_load_ubo:
1415 emit_load_ubo(ctx, intr);
1416 break;
1417
1418 case nir_intrinsic_discard:
1419 emit_discard(ctx, intr);
1420 break;
1421
1422 case nir_intrinsic_load_deref:
1423 emit_load_deref(ctx, intr);
1424 break;
1425
1426 case nir_intrinsic_store_deref:
1427 emit_store_deref(ctx, intr);
1428 break;
1429
1430 case nir_intrinsic_load_front_face:
1431 emit_load_front_face(ctx, intr);
1432 break;
1433
1434 case nir_intrinsic_load_instance_id:
1435 emit_load_instance_id(ctx, intr);
1436 break;
1437
1438 case nir_intrinsic_load_vertex_id:
1439 emit_load_vertex_id(ctx, intr);
1440 break;
1441
1442 default:
1443 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1444 nir_intrinsic_infos[intr->intrinsic].name);
1445 unreachable("unsupported intrinsic");
1446 }
1447 }
1448
1449 static void
1450 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1451 {
1452 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1453 undef->def.num_components);
1454
1455 store_ssa_def(ctx, &undef->def,
1456 spirv_builder_emit_undef(&ctx->builder, type));
1457 }
1458
1459 static SpvId
1460 get_src_float(struct ntv_context *ctx, nir_src *src)
1461 {
1462 SpvId def = get_src(ctx, src);
1463 unsigned num_components = nir_src_num_components(*src);
1464 unsigned bit_size = nir_src_bit_size(*src);
1465 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1466 }
1467
1468 static SpvId
1469 get_src_int(struct ntv_context *ctx, nir_src *src)
1470 {
1471 SpvId def = get_src(ctx, src);
1472 unsigned num_components = nir_src_num_components(*src);
1473 unsigned bit_size = nir_src_bit_size(*src);
1474 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1475 }
1476
1477 static void
1478 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1479 {
1480 assert(tex->op == nir_texop_tex ||
1481 tex->op == nir_texop_txb ||
1482 tex->op == nir_texop_txl ||
1483 tex->op == nir_texop_txd ||
1484 tex->op == nir_texop_txf ||
1485 tex->op == nir_texop_txs);
1486 assert(tex->texture_index == tex->sampler_index);
1487
1488 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1489 offset = 0;
1490 unsigned coord_components = 0;
1491 for (unsigned i = 0; i < tex->num_srcs; i++) {
1492 switch (tex->src[i].src_type) {
1493 case nir_tex_src_coord:
1494 if (tex->op == nir_texop_txf)
1495 coord = get_src_int(ctx, &tex->src[i].src);
1496 else
1497 coord = get_src_float(ctx, &tex->src[i].src);
1498 coord_components = nir_src_num_components(tex->src[i].src);
1499 break;
1500
1501 case nir_tex_src_projector:
1502 assert(nir_src_num_components(tex->src[i].src) == 1);
1503 proj = get_src_float(ctx, &tex->src[i].src);
1504 assert(proj != 0);
1505 break;
1506
1507 case nir_tex_src_offset:
1508 offset = get_src_int(ctx, &tex->src[i].src);
1509 break;
1510
1511 case nir_tex_src_bias:
1512 assert(tex->op == nir_texop_txb);
1513 bias = get_src_float(ctx, &tex->src[i].src);
1514 assert(bias != 0);
1515 break;
1516
1517 case nir_tex_src_lod:
1518 assert(nir_src_num_components(tex->src[i].src) == 1);
1519 if (tex->op == nir_texop_txf ||
1520 tex->op == nir_texop_txs)
1521 lod = get_src_int(ctx, &tex->src[i].src);
1522 else
1523 lod = get_src_float(ctx, &tex->src[i].src);
1524 assert(lod != 0);
1525 break;
1526
1527 case nir_tex_src_comparator:
1528 assert(nir_src_num_components(tex->src[i].src) == 1);
1529 dref = get_src_float(ctx, &tex->src[i].src);
1530 assert(dref != 0);
1531 break;
1532
1533 case nir_tex_src_ddx:
1534 dx = get_src_float(ctx, &tex->src[i].src);
1535 assert(dx != 0);
1536 break;
1537
1538 case nir_tex_src_ddy:
1539 dy = get_src_float(ctx, &tex->src[i].src);
1540 assert(dy != 0);
1541 break;
1542
1543 default:
1544 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1545 unreachable("unknown texture source");
1546 }
1547 }
1548
1549 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1550 lod = emit_float_const(ctx, 32, 0.0f);
1551 assert(lod != 0);
1552 }
1553
1554 SpvId image_type = ctx->image_types[tex->texture_index];
1555 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1556 image_type);
1557
1558 assert(ctx->samplers_used & (1u << tex->texture_index));
1559 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1560 ctx->samplers[tex->texture_index]);
1561
1562 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1563
1564 if (tex->op == nir_texop_txs) {
1565 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1566 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1567 dest_type, image,
1568 lod);
1569 store_dest(ctx, &tex->dest, result, tex->dest_type);
1570 return;
1571 }
1572
1573 if (proj && coord_components > 0) {
1574 SpvId constituents[coord_components + 1];
1575 if (coord_components == 1)
1576 constituents[0] = coord;
1577 else {
1578 assert(coord_components > 1);
1579 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1580 for (uint32_t i = 0; i < coord_components; ++i)
1581 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1582 float_type,
1583 coord,
1584 &i, 1);
1585 }
1586
1587 constituents[coord_components++] = proj;
1588
1589 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1590 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1591 vec_type,
1592 constituents,
1593 coord_components);
1594 }
1595
1596 SpvId actual_dest_type = dest_type;
1597 if (dref)
1598 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1599
1600 SpvId result;
1601 if (tex->op == nir_texop_txf) {
1602 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1603 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1604 image, coord, lod);
1605 } else {
1606 result = spirv_builder_emit_image_sample(&ctx->builder,
1607 actual_dest_type, load,
1608 coord,
1609 proj != 0,
1610 lod, bias, dref, dx, dy,
1611 offset);
1612 }
1613
1614 spirv_builder_emit_decoration(&ctx->builder, result,
1615 SpvDecorationRelaxedPrecision);
1616
1617 if (dref && nir_dest_num_components(tex->dest) > 1) {
1618 SpvId components[4] = { result, result, result, result };
1619 result = spirv_builder_emit_composite_construct(&ctx->builder,
1620 dest_type,
1621 components,
1622 4);
1623 }
1624
1625 store_dest(ctx, &tex->dest, result, tex->dest_type);
1626 }
1627
1628 static void
1629 start_block(struct ntv_context *ctx, SpvId label)
1630 {
1631 /* terminate previous block if needed */
1632 if (ctx->block_started)
1633 spirv_builder_emit_branch(&ctx->builder, label);
1634
1635 /* start new block */
1636 spirv_builder_label(&ctx->builder, label);
1637 ctx->block_started = true;
1638 }
1639
1640 static void
1641 branch(struct ntv_context *ctx, SpvId label)
1642 {
1643 assert(ctx->block_started);
1644 spirv_builder_emit_branch(&ctx->builder, label);
1645 ctx->block_started = false;
1646 }
1647
1648 static void
1649 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1650 SpvId else_id)
1651 {
1652 assert(ctx->block_started);
1653 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1654 then_id, else_id);
1655 ctx->block_started = false;
1656 }
1657
1658 static void
1659 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1660 {
1661 switch (jump->type) {
1662 case nir_jump_break:
1663 assert(ctx->loop_break);
1664 branch(ctx, ctx->loop_break);
1665 break;
1666
1667 case nir_jump_continue:
1668 assert(ctx->loop_cont);
1669 branch(ctx, ctx->loop_cont);
1670 break;
1671
1672 default:
1673 unreachable("Unsupported jump type\n");
1674 }
1675 }
1676
1677 static void
1678 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1679 {
1680 assert(deref->deref_type == nir_deref_type_var);
1681
1682 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1683 assert(he);
1684 SpvId result = (SpvId)(intptr_t)he->data;
1685 /* uint is a bit of a lie here, it's really just an opaque type */
1686 store_dest(ctx, &deref->dest, result, nir_type_uint);
1687 }
1688
1689 static void
1690 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1691 {
1692 assert(deref->deref_type == nir_deref_type_array);
1693 nir_variable *var = nir_deref_instr_get_variable(deref);
1694
1695 SpvStorageClass storage_class;
1696 switch (var->data.mode) {
1697 case nir_var_shader_in:
1698 storage_class = SpvStorageClassInput;
1699 break;
1700
1701 case nir_var_shader_out:
1702 storage_class = SpvStorageClassOutput;
1703 break;
1704
1705 default:
1706 unreachable("Unsupported nir_variable_mode\n");
1707 }
1708
1709 SpvId index = get_src(ctx, &deref->arr.index);
1710
1711 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1712 storage_class,
1713 get_glsl_type(ctx, deref->type));
1714
1715 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1716 ptr_type,
1717 get_src(ctx, &deref->parent),
1718 &index, 1);
1719 /* uint is a bit of a lie here, it's really just an opaque type */
1720 store_dest(ctx, &deref->dest, result, nir_type_uint);
1721 }
1722
1723 static void
1724 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1725 {
1726 switch (deref->deref_type) {
1727 case nir_deref_type_var:
1728 emit_deref_var(ctx, deref);
1729 break;
1730
1731 case nir_deref_type_array:
1732 emit_deref_array(ctx, deref);
1733 break;
1734
1735 default:
1736 unreachable("unexpected deref_type");
1737 }
1738 }
1739
1740 static void
1741 emit_block(struct ntv_context *ctx, struct nir_block *block)
1742 {
1743 start_block(ctx, block_label(ctx, block));
1744 nir_foreach_instr(instr, block) {
1745 switch (instr->type) {
1746 case nir_instr_type_alu:
1747 emit_alu(ctx, nir_instr_as_alu(instr));
1748 break;
1749 case nir_instr_type_intrinsic:
1750 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1751 break;
1752 case nir_instr_type_load_const:
1753 emit_load_const(ctx, nir_instr_as_load_const(instr));
1754 break;
1755 case nir_instr_type_ssa_undef:
1756 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1757 break;
1758 case nir_instr_type_tex:
1759 emit_tex(ctx, nir_instr_as_tex(instr));
1760 break;
1761 case nir_instr_type_phi:
1762 unreachable("nir_instr_type_phi not supported");
1763 break;
1764 case nir_instr_type_jump:
1765 emit_jump(ctx, nir_instr_as_jump(instr));
1766 break;
1767 case nir_instr_type_call:
1768 unreachable("nir_instr_type_call not supported");
1769 break;
1770 case nir_instr_type_parallel_copy:
1771 unreachable("nir_instr_type_parallel_copy not supported");
1772 break;
1773 case nir_instr_type_deref:
1774 emit_deref(ctx, nir_instr_as_deref(instr));
1775 break;
1776 }
1777 }
1778 }
1779
1780 static void
1781 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1782
1783 static SpvId
1784 get_src_bool(struct ntv_context *ctx, nir_src *src)
1785 {
1786 assert(nir_src_bit_size(*src) == 1);
1787 return get_src(ctx, src);
1788 }
1789
1790 static void
1791 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1792 {
1793 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1794
1795 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1796 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1797 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1798 SpvId else_id = endif_id;
1799
1800 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1801 if (has_else) {
1802 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1803 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1804 }
1805
1806 /* create a header-block */
1807 start_block(ctx, header_id);
1808 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1809 SpvSelectionControlMaskNone);
1810 branch_conditional(ctx, condition, then_id, else_id);
1811
1812 emit_cf_list(ctx, &if_stmt->then_list);
1813
1814 if (has_else) {
1815 if (ctx->block_started)
1816 branch(ctx, endif_id);
1817
1818 emit_cf_list(ctx, &if_stmt->else_list);
1819 }
1820
1821 start_block(ctx, endif_id);
1822 }
1823
1824 static void
1825 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1826 {
1827 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1828 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1829 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1830 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1831
1832 /* create a header-block */
1833 start_block(ctx, header_id);
1834 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1835 branch(ctx, begin_id);
1836
1837 SpvId save_break = ctx->loop_break;
1838 SpvId save_cont = ctx->loop_cont;
1839 ctx->loop_break = break_id;
1840 ctx->loop_cont = cont_id;
1841
1842 emit_cf_list(ctx, &loop->body);
1843
1844 ctx->loop_break = save_break;
1845 ctx->loop_cont = save_cont;
1846
1847 branch(ctx, cont_id);
1848 start_block(ctx, cont_id);
1849 branch(ctx, header_id);
1850
1851 start_block(ctx, break_id);
1852 }
1853
1854 static void
1855 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1856 {
1857 foreach_list_typed(nir_cf_node, node, node, list) {
1858 switch (node->type) {
1859 case nir_cf_node_block:
1860 emit_block(ctx, nir_cf_node_as_block(node));
1861 break;
1862
1863 case nir_cf_node_if:
1864 emit_if(ctx, nir_cf_node_as_if(node));
1865 break;
1866
1867 case nir_cf_node_loop:
1868 emit_loop(ctx, nir_cf_node_as_loop(node));
1869 break;
1870
1871 case nir_cf_node_function:
1872 unreachable("nir_cf_node_function not supported");
1873 break;
1874 }
1875 }
1876 }
1877
1878 struct spirv_shader *
1879 nir_to_spirv(struct nir_shader *s)
1880 {
1881 struct spirv_shader *ret = NULL;
1882
1883 struct ntv_context ctx = {};
1884
1885 switch (s->info.stage) {
1886 case MESA_SHADER_VERTEX:
1887 case MESA_SHADER_FRAGMENT:
1888 case MESA_SHADER_COMPUTE:
1889 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1890 break;
1891
1892 case MESA_SHADER_TESS_CTRL:
1893 case MESA_SHADER_TESS_EVAL:
1894 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1895 break;
1896
1897 case MESA_SHADER_GEOMETRY:
1898 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1899 break;
1900
1901 default:
1902 unreachable("invalid stage");
1903 }
1904
1905 // TODO: only enable when needed
1906 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1907 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1908 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1909 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
1910 }
1911
1912 ctx.stage = s->info.stage;
1913 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1914 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1915
1916 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1917 SpvMemoryModelGLSL450);
1918
1919 SpvExecutionModel exec_model;
1920 switch (s->info.stage) {
1921 case MESA_SHADER_VERTEX:
1922 exec_model = SpvExecutionModelVertex;
1923 break;
1924 case MESA_SHADER_TESS_CTRL:
1925 exec_model = SpvExecutionModelTessellationControl;
1926 break;
1927 case MESA_SHADER_TESS_EVAL:
1928 exec_model = SpvExecutionModelTessellationEvaluation;
1929 break;
1930 case MESA_SHADER_GEOMETRY:
1931 exec_model = SpvExecutionModelGeometry;
1932 break;
1933 case MESA_SHADER_FRAGMENT:
1934 exec_model = SpvExecutionModelFragment;
1935 break;
1936 case MESA_SHADER_COMPUTE:
1937 exec_model = SpvExecutionModelGLCompute;
1938 break;
1939 default:
1940 unreachable("invalid stage");
1941 }
1942
1943 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1944 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1945 NULL, 0);
1946 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1947 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1948
1949 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1950 _mesa_key_pointer_equal);
1951
1952 nir_foreach_variable(var, &s->inputs)
1953 emit_input(&ctx, var);
1954
1955 nir_foreach_variable(var, &s->outputs)
1956 emit_output(&ctx, var);
1957
1958 nir_foreach_variable(var, &s->uniforms)
1959 emit_uniform(&ctx, var);
1960
1961 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1962 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1963 SpvExecutionModeOriginUpperLeft);
1964 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1965 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1966 SpvExecutionModeDepthReplacing);
1967 }
1968
1969
1970 spirv_builder_function(&ctx.builder, entry_point, type_void,
1971 SpvFunctionControlMaskNone,
1972 type_main);
1973
1974 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1975 nir_metadata_require(entry, nir_metadata_block_index);
1976
1977 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1978 if (!ctx.defs)
1979 goto fail;
1980 ctx.num_defs = entry->ssa_alloc;
1981
1982 nir_index_local_regs(entry);
1983 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1984 if (!ctx.regs)
1985 goto fail;
1986 ctx.num_regs = entry->reg_alloc;
1987
1988 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1989 if (!block_ids)
1990 goto fail;
1991
1992 for (int i = 0; i < entry->num_blocks; ++i)
1993 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1994
1995 ctx.block_ids = block_ids;
1996 ctx.num_blocks = entry->num_blocks;
1997
1998 /* emit a block only for the variable declarations */
1999 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2000 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2001 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
2002 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2003 SpvStorageClassFunction,
2004 type);
2005 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2006 SpvStorageClassFunction);
2007
2008 ctx.regs[reg->index] = var;
2009 }
2010
2011 emit_cf_list(&ctx, &entry->body);
2012
2013 free(ctx.defs);
2014
2015 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2016 spirv_builder_function_end(&ctx.builder);
2017
2018 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2019 "main", ctx.entry_ifaces,
2020 ctx.num_entry_ifaces);
2021
2022 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2023
2024 ret = CALLOC_STRUCT(spirv_shader);
2025 if (!ret)
2026 goto fail;
2027
2028 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2029 if (!ret->words)
2030 goto fail;
2031
2032 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2033 assert(ret->num_words == num_words);
2034
2035 return ret;
2036
2037 fail:
2038
2039 if (ret)
2040 spirv_shader_delete(ret);
2041
2042 if (ctx.vars)
2043 _mesa_hash_table_destroy(ctx.vars, NULL);
2044
2045 return NULL;
2046 }
2047
2048 void
2049 spirv_shader_delete(struct spirv_shader *s)
2050 {
2051 FREE(s->words);
2052 FREE(s);
2053 }