zink/spirv: do not reinvent store_dest
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 unsigned samplers_used : PIPE_MAX_SAMPLERS;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
172 nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
335 break;
336
337 case FRAG_RESULT_DEPTH:
338 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
339 break;
340
341 default:
342 spirv_builder_emit_location(&ctx->builder, var_id,
343 var->data.driver_location);
344 }
345 }
346 }
347
348 if (var->data.location_frac)
349 spirv_builder_emit_component(&ctx->builder, var_id,
350 var->data.location_frac);
351
352 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
353
354 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
355 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
356 }
357
358 static SpvDim
359 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
360 {
361 *is_ms = false;
362 switch (gdim) {
363 case GLSL_SAMPLER_DIM_1D:
364 return SpvDim1D;
365 case GLSL_SAMPLER_DIM_2D:
366 return SpvDim2D;
367 case GLSL_SAMPLER_DIM_3D:
368 return SpvDim3D;
369 case GLSL_SAMPLER_DIM_CUBE:
370 return SpvDimCube;
371 case GLSL_SAMPLER_DIM_RECT:
372 return SpvDim2D;
373 case GLSL_SAMPLER_DIM_BUF:
374 return SpvDimBuffer;
375 case GLSL_SAMPLER_DIM_EXTERNAL:
376 return SpvDim2D; /* seems dodgy... */
377 case GLSL_SAMPLER_DIM_MS:
378 *is_ms = true;
379 return SpvDim2D;
380 default:
381 fprintf(stderr, "unknown sampler type %d\n", gdim);
382 break;
383 }
384 return SpvDim2D;
385 }
386
387 uint32_t
388 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
389 {
390 if (stage == MESA_SHADER_NONE ||
391 stage >= MESA_SHADER_COMPUTE) {
392 unreachable("not supported");
393 } else {
394 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
395 PIPE_MAX_SHADER_SAMPLER_VIEWS);
396
397 switch (type) {
398 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
399 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
400 return stage_offset + index;
401
402 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
403 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
404 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
405
406 default:
407 unreachable("unexpected type");
408 }
409 }
410 }
411
412 static void
413 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
414 {
415 const struct glsl_type *type = glsl_without_array(var->type);
416
417 bool is_ms;
418 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
419
420 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
421 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
422 dimension, false,
423 glsl_sampler_type_is_array(type),
424 is_ms, 1,
425 SpvImageFormatUnknown);
426
427 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
428 image_type);
429 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
430 SpvStorageClassUniformConstant,
431 sampled_type);
432
433 if (glsl_type_is_array(var->type)) {
434 for (int i = 0; i < glsl_get_length(var->type); ++i) {
435 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
436 SpvStorageClassUniformConstant);
437
438 if (var->name) {
439 char element_name[100];
440 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
441 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
442 }
443
444 int index = var->data.binding + i;
445 assert(!(ctx->samplers_used & (1 << index)));
446 assert(!ctx->image_types[index]);
447 ctx->image_types[index] = image_type;
448 ctx->samplers[index] = var_id;
449 ctx->samplers_used |= 1 << index;
450
451 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
452 var->data.descriptor_set);
453 int binding = zink_binding(ctx->stage,
454 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
455 var->data.binding + i);
456 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
457 }
458 } else {
459 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
460 SpvStorageClassUniformConstant);
461
462 if (var->name)
463 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
464
465 int index = var->data.binding;
466 assert(!(ctx->samplers_used & (1 << index)));
467 assert(!ctx->image_types[index]);
468 ctx->image_types[index] = image_type;
469 ctx->samplers[index] = var_id;
470 ctx->samplers_used |= 1 << index;
471
472 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
473 var->data.descriptor_set);
474 int binding = zink_binding(ctx->stage,
475 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
476 var->data.binding);
477 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
478 }
479 }
480
481 static void
482 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
483 {
484 uint32_t size = glsl_count_attribute_slots(var->type, false);
485 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
486 SpvId array_length = emit_uint_const(ctx, 32, size);
487 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
488 array_length);
489 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
490
491 // wrap UBO-array in a struct
492 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
493 if (var->name) {
494 char struct_name[100];
495 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
496 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
497 }
498
499 spirv_builder_emit_decoration(&ctx->builder, struct_type,
500 SpvDecorationBlock);
501 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
502
503
504 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
505 SpvStorageClassUniform,
506 struct_type);
507
508 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
509 SpvStorageClassUniform);
510 if (var->name)
511 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
512
513 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
514 ctx->ubos[ctx->num_ubos++] = var_id;
515
516 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
517 var->data.descriptor_set);
518 int binding = zink_binding(ctx->stage,
519 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
520 var->data.binding);
521 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
522 }
523
524 static void
525 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
526 {
527 if (var->data.mode == nir_var_mem_ubo)
528 emit_ubo(ctx, var);
529 else {
530 assert(var->data.mode == nir_var_uniform);
531 if (glsl_type_is_sampler(glsl_without_array(var->type)))
532 emit_sampler(ctx, var);
533 }
534 }
535
536 static SpvId
537 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
538 {
539 assert(ssa->index < ctx->num_defs);
540 assert(ctx->defs[ssa->index] != 0);
541 return ctx->defs[ssa->index];
542 }
543
544 static SpvId
545 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
546 {
547 assert(reg->index < ctx->num_regs);
548 assert(ctx->regs[reg->index] != 0);
549 return ctx->regs[reg->index];
550 }
551
552 static SpvId
553 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
554 {
555 assert(reg->reg);
556 assert(!reg->indirect);
557 assert(!reg->base_offset);
558
559 SpvId var = get_var_from_reg(ctx, reg->reg);
560 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
561 return spirv_builder_emit_load(&ctx->builder, type, var);
562 }
563
564 static SpvId
565 get_src_uint(struct ntv_context *ctx, nir_src *src)
566 {
567 if (src->is_ssa)
568 return get_src_uint_ssa(ctx, src->ssa);
569 else
570 return get_src_uint_reg(ctx, &src->reg);
571 }
572
573 static SpvId
574 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
575 {
576 assert(!alu->src[src].negate);
577 assert(!alu->src[src].abs);
578
579 SpvId def = get_src_uint(ctx, &alu->src[src].src);
580
581 unsigned used_channels = 0;
582 bool need_swizzle = false;
583 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
584 if (!nir_alu_instr_channel_used(alu, src, i))
585 continue;
586
587 used_channels++;
588
589 if (alu->src[src].swizzle[i] != i)
590 need_swizzle = true;
591 }
592 assert(used_channels != 0);
593
594 unsigned live_channels = nir_src_num_components(alu->src[src].src);
595 if (used_channels != live_channels)
596 need_swizzle = true;
597
598 if (!need_swizzle)
599 return def;
600
601 int bit_size = nir_src_bit_size(alu->src[src].src);
602 assert(bit_size == 1 || bit_size == 32);
603
604 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
605 if (used_channels == 1) {
606 uint32_t indices[] = { alu->src[src].swizzle[0] };
607 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
608 def, indices,
609 ARRAY_SIZE(indices));
610 } else if (live_channels == 1) {
611 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
612 used_channels);
613
614 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
615 for (unsigned i = 0; i < used_channels; ++i)
616 constituents[i] = def;
617
618 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
619 constituents,
620 used_channels);
621 } else {
622 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
623 used_channels);
624
625 uint32_t components[NIR_MAX_VEC_COMPONENTS];
626 size_t num_components = 0;
627 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
628 if (!nir_alu_instr_channel_used(alu, src, i))
629 continue;
630
631 components[num_components++] = alu->src[src].swizzle[i];
632 }
633
634 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
635 def, def, components, num_components);
636 }
637 }
638
639 static void
640 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
641 {
642 assert(result != 0);
643 assert(ssa->index < ctx->num_defs);
644 ctx->defs[ssa->index] = result;
645 }
646
647 static SpvId
648 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
649 SpvId if_true, SpvId if_false)
650 {
651 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
652 }
653
654 static SpvId
655 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
656 {
657 SpvId otype = get_uvec_type(ctx, 32, num_components);
658 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
659 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
660 return emit_select(ctx, otype, value, one, zero);
661 }
662
663 static SpvId
664 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
665 {
666 SpvId type = get_bvec_type(ctx, num_components);
667 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
668 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
669 }
670
671 static SpvId
672 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
673 {
674 return emit_unop(ctx, SpvOpBitcast, type, value);
675 }
676
677 static SpvId
678 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
679 unsigned num_components)
680 {
681 SpvId type = get_uvec_type(ctx, bit_size, num_components);
682 return emit_bitcast(ctx, type, value);
683 }
684
685 static SpvId
686 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
687 unsigned num_components)
688 {
689 SpvId type = get_ivec_type(ctx, bit_size, num_components);
690 return emit_bitcast(ctx, type, value);
691 }
692
693 static SpvId
694 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
695 unsigned num_components)
696 {
697 SpvId type = get_fvec_type(ctx, bit_size, num_components);
698 return emit_bitcast(ctx, type, value);
699 }
700
701 static void
702 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
703 {
704 SpvId var = get_var_from_reg(ctx, reg->reg);
705 assert(var);
706 spirv_builder_emit_store(&ctx->builder, var, result);
707 }
708
709 static void
710 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
711 {
712 if (dest->is_ssa)
713 store_ssa_def_uint(ctx, &dest->ssa, result);
714 else
715 store_reg_def(ctx, &dest->reg, result);
716 }
717
718 static void
719 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
720 {
721 unsigned num_components = nir_dest_num_components(*dest);
722 unsigned bit_size = nir_dest_bit_size(*dest);
723
724 switch (nir_alu_type_get_base_type(type)) {
725 case nir_type_bool:
726 assert(bit_size == 1);
727 result = bvec_to_uvec(ctx, result, num_components);
728 break;
729
730 case nir_type_uint:
731 break; /* nothing to do! */
732
733 case nir_type_int:
734 case nir_type_float:
735 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
736 break;
737
738 default:
739 unreachable("unsupported nir_alu_type");
740 }
741
742 store_dest_uint(ctx, dest, result);
743 }
744
745 static SpvId
746 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
747 {
748 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
749 }
750
751 static SpvId
752 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
753 SpvId src0, SpvId src1)
754 {
755 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
756 }
757
758 static SpvId
759 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
760 SpvId src0, SpvId src1, SpvId src2)
761 {
762 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
763 }
764
765 static SpvId
766 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
767 SpvId src)
768 {
769 SpvId args[] = { src };
770 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
771 op, args, ARRAY_SIZE(args));
772 }
773
774 static SpvId
775 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
776 SpvId src0, SpvId src1)
777 {
778 SpvId args[] = { src0, src1 };
779 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
780 op, args, ARRAY_SIZE(args));
781 }
782
783 static SpvId
784 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
785 SpvId src0, SpvId src1, SpvId src2)
786 {
787 SpvId args[] = { src0, src1, src2 };
788 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
789 op, args, ARRAY_SIZE(args));
790 }
791
792 static SpvId
793 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
794 unsigned num_components, float value)
795 {
796 assert(bit_size == 32);
797
798 SpvId result = emit_float_const(ctx, bit_size, value);
799 if (num_components == 1)
800 return result;
801
802 assert(num_components > 1);
803 SpvId components[num_components];
804 for (int i = 0; i < num_components; i++)
805 components[i] = result;
806
807 SpvId type = get_fvec_type(ctx, bit_size, num_components);
808 return spirv_builder_const_composite(&ctx->builder, type, components,
809 num_components);
810 }
811
812 static SpvId
813 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
814 unsigned num_components, uint32_t value)
815 {
816 assert(bit_size == 32);
817
818 SpvId result = emit_uint_const(ctx, bit_size, value);
819 if (num_components == 1)
820 return result;
821
822 assert(num_components > 1);
823 SpvId components[num_components];
824 for (int i = 0; i < num_components; i++)
825 components[i] = result;
826
827 SpvId type = get_uvec_type(ctx, bit_size, num_components);
828 return spirv_builder_const_composite(&ctx->builder, type, components,
829 num_components);
830 }
831
832 static SpvId
833 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
834 unsigned num_components, int32_t value)
835 {
836 assert(bit_size == 32);
837
838 SpvId result = emit_int_const(ctx, bit_size, value);
839 if (num_components == 1)
840 return result;
841
842 assert(num_components > 1);
843 SpvId components[num_components];
844 for (int i = 0; i < num_components; i++)
845 components[i] = result;
846
847 SpvId type = get_ivec_type(ctx, bit_size, num_components);
848 return spirv_builder_const_composite(&ctx->builder, type, components,
849 num_components);
850 }
851
852 static inline unsigned
853 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
854 {
855 if (nir_op_infos[instr->op].input_sizes[src] > 0)
856 return nir_op_infos[instr->op].input_sizes[src];
857
858 if (instr->dest.dest.is_ssa)
859 return instr->dest.dest.ssa.num_components;
860 else
861 return instr->dest.dest.reg.reg->num_components;
862 }
863
864 static SpvId
865 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
866 {
867 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
868
869 unsigned num_components = alu_instr_src_components(alu, src);
870 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
871 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
872
873 switch (nir_alu_type_get_base_type(type)) {
874 case nir_type_bool:
875 assert(bit_size == 1);
876 return uvec_to_bvec(ctx, uint_value, num_components);
877
878 case nir_type_int:
879 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
880
881 case nir_type_uint:
882 return uint_value;
883
884 case nir_type_float:
885 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
886
887 default:
888 unreachable("unknown nir_alu_type");
889 }
890 }
891
892 static void
893 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
894 {
895 assert(!alu->dest.saturate);
896 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
897 }
898
899 static SpvId
900 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
901 {
902 unsigned num_components = nir_dest_num_components(*dest);
903 unsigned bit_size = nir_dest_bit_size(*dest);
904
905 switch (nir_alu_type_get_base_type(type)) {
906 case nir_type_bool:
907 return get_bvec_type(ctx, num_components);
908
909 case nir_type_int:
910 return get_ivec_type(ctx, bit_size, num_components);
911
912 case nir_type_uint:
913 return get_uvec_type(ctx, bit_size, num_components);
914
915 case nir_type_float:
916 return get_fvec_type(ctx, bit_size, num_components);
917
918 default:
919 unreachable("unsupported nir_alu_type");
920 }
921 }
922
923 static void
924 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
925 {
926 SpvId src[nir_op_infos[alu->op].num_inputs];
927 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
928 src[i] = get_alu_src(ctx, alu, i);
929
930 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
931 nir_op_infos[alu->op].output_type);
932 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
933 unsigned num_components = nir_dest_num_components(alu->dest.dest);
934
935 SpvId result = 0;
936 switch (alu->op) {
937 case nir_op_mov:
938 assert(nir_op_infos[alu->op].num_inputs == 1);
939 result = src[0];
940 break;
941
942 #define UNOP(nir_op, spirv_op) \
943 case nir_op: \
944 assert(nir_op_infos[alu->op].num_inputs == 1); \
945 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
946 break;
947
948 UNOP(nir_op_ineg, SpvOpSNegate)
949 UNOP(nir_op_fneg, SpvOpFNegate)
950 UNOP(nir_op_fddx, SpvOpDPdx)
951 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
952 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
953 UNOP(nir_op_fddy, SpvOpDPdy)
954 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
955 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
956 UNOP(nir_op_f2i32, SpvOpConvertFToS)
957 UNOP(nir_op_f2u32, SpvOpConvertFToU)
958 UNOP(nir_op_i2f32, SpvOpConvertSToF)
959 UNOP(nir_op_u2f32, SpvOpConvertUToF)
960 UNOP(nir_op_inot, SpvOpNot)
961 #undef UNOP
962
963 case nir_op_b2i32:
964 assert(nir_op_infos[alu->op].num_inputs == 1);
965 result = emit_select(ctx, dest_type, src[0],
966 get_ivec_constant(ctx, 32, num_components, 1),
967 get_ivec_constant(ctx, 32, num_components, 0));
968 break;
969
970 case nir_op_b2f32:
971 assert(nir_op_infos[alu->op].num_inputs == 1);
972 result = emit_select(ctx, dest_type, src[0],
973 get_fvec_constant(ctx, 32, num_components, 1),
974 get_fvec_constant(ctx, 32, num_components, 0));
975 break;
976
977 #define BUILTIN_UNOP(nir_op, spirv_op) \
978 case nir_op: \
979 assert(nir_op_infos[alu->op].num_inputs == 1); \
980 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
981 break;
982
983 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
984 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
985 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
986 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
987 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
988 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
989 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
990 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
991 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
992 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
993 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
994 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
995 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
996 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
997 #undef BUILTIN_UNOP
998
999 case nir_op_frcp:
1000 assert(nir_op_infos[alu->op].num_inputs == 1);
1001 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1002 get_fvec_constant(ctx, bit_size, num_components, 1),
1003 src[0]);
1004 break;
1005
1006 case nir_op_f2b1:
1007 assert(nir_op_infos[alu->op].num_inputs == 1);
1008 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1009 get_fvec_constant(ctx,
1010 nir_src_bit_size(alu->src[0].src),
1011 num_components, 0));
1012 break;
1013
1014
1015 #define BINOP(nir_op, spirv_op) \
1016 case nir_op: \
1017 assert(nir_op_infos[alu->op].num_inputs == 2); \
1018 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1019 break;
1020
1021 BINOP(nir_op_iadd, SpvOpIAdd)
1022 BINOP(nir_op_isub, SpvOpISub)
1023 BINOP(nir_op_imul, SpvOpIMul)
1024 BINOP(nir_op_idiv, SpvOpSDiv)
1025 BINOP(nir_op_udiv, SpvOpUDiv)
1026 BINOP(nir_op_umod, SpvOpUMod)
1027 BINOP(nir_op_fadd, SpvOpFAdd)
1028 BINOP(nir_op_fsub, SpvOpFSub)
1029 BINOP(nir_op_fmul, SpvOpFMul)
1030 BINOP(nir_op_fdiv, SpvOpFDiv)
1031 BINOP(nir_op_fmod, SpvOpFMod)
1032 BINOP(nir_op_ilt, SpvOpSLessThan)
1033 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1034 BINOP(nir_op_ieq, SpvOpIEqual)
1035 BINOP(nir_op_ine, SpvOpINotEqual)
1036 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1037 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1038 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1039 BINOP(nir_op_feq, SpvOpFOrdEqual)
1040 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1041 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1042 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1043 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1044 BINOP(nir_op_iand, SpvOpBitwiseAnd)
1045 BINOP(nir_op_ior, SpvOpBitwiseOr)
1046 #undef BINOP
1047
1048 #define BUILTIN_BINOP(nir_op, spirv_op) \
1049 case nir_op: \
1050 assert(nir_op_infos[alu->op].num_inputs == 2); \
1051 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1052 break;
1053
1054 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1055 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1056 #undef BUILTIN_BINOP
1057
1058 case nir_op_fdot2:
1059 case nir_op_fdot3:
1060 case nir_op_fdot4:
1061 assert(nir_op_infos[alu->op].num_inputs == 2);
1062 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1063 break;
1064
1065 case nir_op_fdph:
1066 unreachable("should already be lowered away");
1067
1068 case nir_op_seq:
1069 case nir_op_sne:
1070 case nir_op_slt:
1071 case nir_op_sge: {
1072 assert(nir_op_infos[alu->op].num_inputs == 2);
1073 int num_components = nir_dest_num_components(alu->dest.dest);
1074 SpvId bool_type = get_bvec_type(ctx, num_components);
1075
1076 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1077 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1078 if (num_components > 1) {
1079 SpvId zero_comps[num_components], one_comps[num_components];
1080 for (int i = 0; i < num_components; i++) {
1081 zero_comps[i] = zero;
1082 one_comps[i] = one;
1083 }
1084
1085 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1086 zero_comps, num_components);
1087 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1088 one_comps, num_components);
1089 }
1090
1091 SpvOp op;
1092 switch (alu->op) {
1093 case nir_op_seq: op = SpvOpFOrdEqual; break;
1094 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1095 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1096 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1097 default: unreachable("unexpected op");
1098 }
1099
1100 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1101 result = emit_select(ctx, dest_type, result, one, zero);
1102 }
1103 break;
1104
1105 case nir_op_flrp:
1106 assert(nir_op_infos[alu->op].num_inputs == 3);
1107 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1108 src[0], src[1], src[2]);
1109 break;
1110
1111 case nir_op_fcsel:
1112 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1113 get_bvec_type(ctx, num_components),
1114 src[0],
1115 get_fvec_constant(ctx,
1116 nir_src_bit_size(alu->src[0].src),
1117 num_components, 0));
1118 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1119 break;
1120
1121 case nir_op_bcsel:
1122 assert(nir_op_infos[alu->op].num_inputs == 3);
1123 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1124 break;
1125
1126 case nir_op_bany_fnequal2:
1127 case nir_op_bany_fnequal3:
1128 case nir_op_bany_fnequal4:
1129 assert(nir_op_infos[alu->op].num_inputs == 2);
1130 assert(alu_instr_src_components(alu, 0) ==
1131 alu_instr_src_components(alu, 1));
1132 result = emit_binop(ctx, SpvOpFOrdNotEqual,
1133 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1134 src[0], src[1]);
1135 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1136 break;
1137
1138 case nir_op_ball_fequal2:
1139 case nir_op_ball_fequal3:
1140 case nir_op_ball_fequal4:
1141 assert(nir_op_infos[alu->op].num_inputs == 2);
1142 assert(alu_instr_src_components(alu, 0) ==
1143 alu_instr_src_components(alu, 1));
1144 result = emit_binop(ctx, SpvOpFOrdEqual,
1145 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1146 src[0], src[1]);
1147 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1148 break;
1149
1150 case nir_op_bany_inequal2:
1151 case nir_op_bany_inequal3:
1152 case nir_op_bany_inequal4:
1153 assert(nir_op_infos[alu->op].num_inputs == 2);
1154 assert(alu_instr_src_components(alu, 0) ==
1155 alu_instr_src_components(alu, 1));
1156 result = emit_binop(ctx, SpvOpINotEqual,
1157 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1158 src[0], src[1]);
1159 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1160 break;
1161
1162 case nir_op_ball_iequal2:
1163 case nir_op_ball_iequal3:
1164 case nir_op_ball_iequal4:
1165 assert(nir_op_infos[alu->op].num_inputs == 2);
1166 assert(alu_instr_src_components(alu, 0) ==
1167 alu_instr_src_components(alu, 1));
1168 result = emit_binop(ctx, SpvOpIEqual,
1169 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1170 src[0], src[1]);
1171 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1172 break;
1173
1174 case nir_op_vec2:
1175 case nir_op_vec3:
1176 case nir_op_vec4: {
1177 int num_inputs = nir_op_infos[alu->op].num_inputs;
1178 assert(2 <= num_inputs && num_inputs <= 4);
1179 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1180 src, num_inputs);
1181 }
1182 break;
1183
1184 default:
1185 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1186 nir_op_infos[alu->op].name);
1187
1188 unreachable("unsupported opcode");
1189 return;
1190 }
1191
1192 store_alu_result(ctx, alu, result);
1193 }
1194
1195 static void
1196 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1197 {
1198 unsigned bit_size = load_const->def.bit_size;
1199 unsigned num_components = load_const->def.num_components;
1200
1201 SpvId constant;
1202 if (num_components > 1) {
1203 SpvId components[num_components];
1204 SpvId type;
1205 if (bit_size == 1) {
1206 for (int i = 0; i < num_components; i++)
1207 components[i] = spirv_builder_const_bool(&ctx->builder,
1208 load_const->value[i].b);
1209
1210 type = get_bvec_type(ctx, num_components);
1211 } else {
1212 for (int i = 0; i < num_components; i++)
1213 components[i] = emit_uint_const(ctx, bit_size,
1214 load_const->value[i].u32);
1215
1216 type = get_uvec_type(ctx, bit_size, num_components);
1217 }
1218 constant = spirv_builder_const_composite(&ctx->builder, type,
1219 components, num_components);
1220 } else {
1221 assert(num_components == 1);
1222 if (bit_size == 1)
1223 constant = spirv_builder_const_bool(&ctx->builder,
1224 load_const->value[0].b);
1225 else
1226 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1227 }
1228
1229 if (bit_size == 1)
1230 constant = bvec_to_uvec(ctx, constant, num_components);
1231
1232 store_ssa_def_uint(ctx, &load_const->def, constant);
1233 }
1234
1235 static void
1236 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1237 {
1238 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1239 assert(const_block_index); // no dynamic indexing for now
1240 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1241
1242 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1243 if (const_offset) {
1244 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1245 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1246 SpvStorageClassUniform,
1247 uvec4_type);
1248
1249 unsigned idx = const_offset->u32;
1250 SpvId member = emit_uint_const(ctx, 32, 0);
1251 SpvId offset = emit_uint_const(ctx, 32, idx);
1252 SpvId offsets[] = { member, offset };
1253 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1254 ctx->ubos[0], offsets,
1255 ARRAY_SIZE(offsets));
1256 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1257
1258 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1259 unsigned num_components = nir_dest_num_components(intr->dest);
1260 if (num_components == 1) {
1261 uint32_t components[] = { 0 };
1262 result = spirv_builder_emit_composite_extract(&ctx->builder,
1263 type,
1264 result, components,
1265 1);
1266 } else if (num_components < 4) {
1267 SpvId constituents[num_components];
1268 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1269 for (uint32_t i = 0; i < num_components; ++i)
1270 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1271 uint_type,
1272 result, &i,
1273 1);
1274
1275 result = spirv_builder_emit_composite_construct(&ctx->builder,
1276 type,
1277 constituents,
1278 num_components);
1279 }
1280
1281 store_dest_uint(ctx, &intr->dest, result);
1282 } else
1283 unreachable("uniform-addressing not yet supported");
1284 }
1285
1286 static void
1287 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1288 {
1289 assert(ctx->block_started);
1290 spirv_builder_emit_kill(&ctx->builder);
1291 /* discard is weird in NIR, so let's just create an unreachable block after
1292 it and hope that the vulkan driver will DCE any instructinos in it. */
1293 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1294 }
1295
1296 static void
1297 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1298 {
1299 /* uint is a bit of a lie here; it's really just a pointer */
1300 SpvId ptr = get_src_uint(ctx, intr->src);
1301
1302 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1303 SpvId result = spirv_builder_emit_load(&ctx->builder,
1304 get_glsl_type(ctx, var->type),
1305 ptr);
1306 unsigned num_components = nir_dest_num_components(intr->dest);
1307 unsigned bit_size = nir_dest_bit_size(intr->dest);
1308 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1309 store_dest_uint(ctx, &intr->dest, result);
1310 }
1311
1312 static void
1313 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1314 {
1315 /* uint is a bit of a lie here; it's really just a pointer */
1316 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1317 SpvId src = get_src_uint(ctx, &intr->src[1]);
1318
1319 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1320 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1321 SpvId result = emit_bitcast(ctx, type, src);
1322 spirv_builder_emit_store(&ctx->builder, ptr, result);
1323 }
1324
1325 static SpvId
1326 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1327 SpvStorageClass storage_class,
1328 const char *name, SpvBuiltIn builtin)
1329 {
1330 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1331 storage_class,
1332 var_type);
1333 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1334 storage_class);
1335 spirv_builder_emit_name(&ctx->builder, var, name);
1336 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1337
1338 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1339 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1340 return var;
1341 }
1342
1343 static void
1344 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1345 {
1346 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1347 if (!ctx->front_face_var)
1348 ctx->front_face_var = create_builtin_var(ctx, var_type,
1349 SpvStorageClassInput,
1350 "gl_FrontFacing",
1351 SpvBuiltInFrontFacing);
1352
1353 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1354 ctx->front_face_var);
1355 assert(1 == nir_dest_num_components(intr->dest));
1356 store_dest(ctx, &intr->dest, result, nir_type_bool);
1357 }
1358
1359 static void
1360 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1361 {
1362 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1363 if (!ctx->instance_id_var)
1364 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1365 SpvStorageClassInput,
1366 "gl_InstanceId",
1367 SpvBuiltInInstanceIndex);
1368
1369 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1370 ctx->instance_id_var);
1371 assert(1 == nir_dest_num_components(intr->dest));
1372 store_dest_uint(ctx, &intr->dest, result);
1373 }
1374
1375 static void
1376 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1377 {
1378 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1379 if (!ctx->vertex_id_var)
1380 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1381 SpvStorageClassInput,
1382 "gl_VertexID",
1383 SpvBuiltInVertexIndex);
1384
1385 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1386 ctx->vertex_id_var);
1387 assert(1 == nir_dest_num_components(intr->dest));
1388 store_dest_uint(ctx, &intr->dest, result);
1389 }
1390
1391 static void
1392 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1393 {
1394 switch (intr->intrinsic) {
1395 case nir_intrinsic_load_ubo:
1396 emit_load_ubo(ctx, intr);
1397 break;
1398
1399 case nir_intrinsic_discard:
1400 emit_discard(ctx, intr);
1401 break;
1402
1403 case nir_intrinsic_load_deref:
1404 emit_load_deref(ctx, intr);
1405 break;
1406
1407 case nir_intrinsic_store_deref:
1408 emit_store_deref(ctx, intr);
1409 break;
1410
1411 case nir_intrinsic_load_front_face:
1412 emit_load_front_face(ctx, intr);
1413 break;
1414
1415 case nir_intrinsic_load_instance_id:
1416 emit_load_instance_id(ctx, intr);
1417 break;
1418
1419 case nir_intrinsic_load_vertex_id:
1420 emit_load_vertex_id(ctx, intr);
1421 break;
1422
1423 default:
1424 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1425 nir_intrinsic_infos[intr->intrinsic].name);
1426 unreachable("unsupported intrinsic");
1427 }
1428 }
1429
1430 static void
1431 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1432 {
1433 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1434 undef->def.num_components);
1435
1436 store_ssa_def_uint(ctx, &undef->def,
1437 spirv_builder_emit_undef(&ctx->builder, type));
1438 }
1439
1440 static SpvId
1441 get_src_float(struct ntv_context *ctx, nir_src *src)
1442 {
1443 SpvId def = get_src_uint(ctx, src);
1444 unsigned num_components = nir_src_num_components(*src);
1445 unsigned bit_size = nir_src_bit_size(*src);
1446 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1447 }
1448
1449 static SpvId
1450 get_src_int(struct ntv_context *ctx, nir_src *src)
1451 {
1452 SpvId def = get_src_uint(ctx, src);
1453 unsigned num_components = nir_src_num_components(*src);
1454 unsigned bit_size = nir_src_bit_size(*src);
1455 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1456 }
1457
1458 static void
1459 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1460 {
1461 assert(tex->op == nir_texop_tex ||
1462 tex->op == nir_texop_txb ||
1463 tex->op == nir_texop_txl ||
1464 tex->op == nir_texop_txd ||
1465 tex->op == nir_texop_txf ||
1466 tex->op == nir_texop_txs);
1467 assert(tex->texture_index == tex->sampler_index);
1468
1469 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1470 offset = 0;
1471 unsigned coord_components = 0;
1472 for (unsigned i = 0; i < tex->num_srcs; i++) {
1473 switch (tex->src[i].src_type) {
1474 case nir_tex_src_coord:
1475 if (tex->op == nir_texop_txf)
1476 coord = get_src_int(ctx, &tex->src[i].src);
1477 else
1478 coord = get_src_float(ctx, &tex->src[i].src);
1479 coord_components = nir_src_num_components(tex->src[i].src);
1480 break;
1481
1482 case nir_tex_src_projector:
1483 assert(nir_src_num_components(tex->src[i].src) == 1);
1484 proj = get_src_float(ctx, &tex->src[i].src);
1485 assert(proj != 0);
1486 break;
1487
1488 case nir_tex_src_offset:
1489 offset = get_src_int(ctx, &tex->src[i].src);
1490 break;
1491
1492 case nir_tex_src_bias:
1493 assert(tex->op == nir_texop_txb);
1494 bias = get_src_float(ctx, &tex->src[i].src);
1495 assert(bias != 0);
1496 break;
1497
1498 case nir_tex_src_lod:
1499 assert(nir_src_num_components(tex->src[i].src) == 1);
1500 if (tex->op == nir_texop_txf ||
1501 tex->op == nir_texop_txs)
1502 lod = get_src_int(ctx, &tex->src[i].src);
1503 else
1504 lod = get_src_float(ctx, &tex->src[i].src);
1505 assert(lod != 0);
1506 break;
1507
1508 case nir_tex_src_comparator:
1509 assert(nir_src_num_components(tex->src[i].src) == 1);
1510 dref = get_src_float(ctx, &tex->src[i].src);
1511 assert(dref != 0);
1512 break;
1513
1514 case nir_tex_src_ddx:
1515 dx = get_src_float(ctx, &tex->src[i].src);
1516 assert(dx != 0);
1517 break;
1518
1519 case nir_tex_src_ddy:
1520 dy = get_src_float(ctx, &tex->src[i].src);
1521 assert(dy != 0);
1522 break;
1523
1524 default:
1525 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1526 unreachable("unknown texture source");
1527 }
1528 }
1529
1530 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1531 lod = emit_float_const(ctx, 32, 0.0f);
1532 assert(lod != 0);
1533 }
1534
1535 SpvId image_type = ctx->image_types[tex->texture_index];
1536 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1537 image_type);
1538
1539 assert(ctx->samplers_used & (1u << tex->texture_index));
1540 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1541 ctx->samplers[tex->texture_index]);
1542
1543 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1544
1545 if (tex->op == nir_texop_txs) {
1546 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1547 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1548 dest_type, image,
1549 lod);
1550 store_dest(ctx, &tex->dest, result, tex->dest_type);
1551 return;
1552 }
1553
1554 if (proj && coord_components > 0) {
1555 SpvId constituents[coord_components + 1];
1556 if (coord_components == 1)
1557 constituents[0] = coord;
1558 else {
1559 assert(coord_components > 1);
1560 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1561 for (uint32_t i = 0; i < coord_components; ++i)
1562 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1563 float_type,
1564 coord,
1565 &i, 1);
1566 }
1567
1568 constituents[coord_components++] = proj;
1569
1570 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1571 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1572 vec_type,
1573 constituents,
1574 coord_components);
1575 }
1576
1577 SpvId actual_dest_type = dest_type;
1578 if (dref)
1579 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1580
1581 SpvId result;
1582 if (tex->op == nir_texop_txf) {
1583 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1584 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1585 image, coord, lod);
1586 } else {
1587 result = spirv_builder_emit_image_sample(&ctx->builder,
1588 actual_dest_type, load,
1589 coord,
1590 proj != 0,
1591 lod, bias, dref, dx, dy,
1592 offset);
1593 }
1594
1595 spirv_builder_emit_decoration(&ctx->builder, result,
1596 SpvDecorationRelaxedPrecision);
1597
1598 if (dref && nir_dest_num_components(tex->dest) > 1) {
1599 SpvId components[4] = { result, result, result, result };
1600 result = spirv_builder_emit_composite_construct(&ctx->builder,
1601 dest_type,
1602 components,
1603 4);
1604 }
1605
1606 store_dest(ctx, &tex->dest, result, tex->dest_type);
1607 }
1608
1609 static void
1610 start_block(struct ntv_context *ctx, SpvId label)
1611 {
1612 /* terminate previous block if needed */
1613 if (ctx->block_started)
1614 spirv_builder_emit_branch(&ctx->builder, label);
1615
1616 /* start new block */
1617 spirv_builder_label(&ctx->builder, label);
1618 ctx->block_started = true;
1619 }
1620
1621 static void
1622 branch(struct ntv_context *ctx, SpvId label)
1623 {
1624 assert(ctx->block_started);
1625 spirv_builder_emit_branch(&ctx->builder, label);
1626 ctx->block_started = false;
1627 }
1628
1629 static void
1630 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1631 SpvId else_id)
1632 {
1633 assert(ctx->block_started);
1634 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1635 then_id, else_id);
1636 ctx->block_started = false;
1637 }
1638
1639 static void
1640 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1641 {
1642 switch (jump->type) {
1643 case nir_jump_break:
1644 assert(ctx->loop_break);
1645 branch(ctx, ctx->loop_break);
1646 break;
1647
1648 case nir_jump_continue:
1649 assert(ctx->loop_cont);
1650 branch(ctx, ctx->loop_cont);
1651 break;
1652
1653 default:
1654 unreachable("Unsupported jump type\n");
1655 }
1656 }
1657
1658 static void
1659 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1660 {
1661 assert(deref->deref_type == nir_deref_type_var);
1662
1663 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1664 assert(he);
1665 SpvId result = (SpvId)(intptr_t)he->data;
1666 /* uint is a bit of a lie here, it's really just an opaque type */
1667 store_dest_uint(ctx, &deref->dest, result);
1668 }
1669
1670 static void
1671 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1672 {
1673 assert(deref->deref_type == nir_deref_type_array);
1674 nir_variable *var = nir_deref_instr_get_variable(deref);
1675
1676 SpvStorageClass storage_class;
1677 switch (var->data.mode) {
1678 case nir_var_shader_in:
1679 storage_class = SpvStorageClassInput;
1680 break;
1681
1682 case nir_var_shader_out:
1683 storage_class = SpvStorageClassOutput;
1684 break;
1685
1686 default:
1687 unreachable("Unsupported nir_variable_mode\n");
1688 }
1689
1690 SpvId index = get_src_uint(ctx, &deref->arr.index);
1691
1692 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1693 storage_class,
1694 get_glsl_type(ctx, deref->type));
1695
1696 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1697 ptr_type,
1698 get_src_uint(ctx, &deref->parent),
1699 &index, 1);
1700 /* uint is a bit of a lie here, it's really just an opaque type */
1701 store_dest_uint(ctx, &deref->dest, result);
1702 }
1703
1704 static void
1705 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1706 {
1707 switch (deref->deref_type) {
1708 case nir_deref_type_var:
1709 emit_deref_var(ctx, deref);
1710 break;
1711
1712 case nir_deref_type_array:
1713 emit_deref_array(ctx, deref);
1714 break;
1715
1716 default:
1717 unreachable("unexpected deref_type");
1718 }
1719 }
1720
1721 static void
1722 emit_block(struct ntv_context *ctx, struct nir_block *block)
1723 {
1724 start_block(ctx, block_label(ctx, block));
1725 nir_foreach_instr(instr, block) {
1726 switch (instr->type) {
1727 case nir_instr_type_alu:
1728 emit_alu(ctx, nir_instr_as_alu(instr));
1729 break;
1730 case nir_instr_type_intrinsic:
1731 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1732 break;
1733 case nir_instr_type_load_const:
1734 emit_load_const(ctx, nir_instr_as_load_const(instr));
1735 break;
1736 case nir_instr_type_ssa_undef:
1737 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1738 break;
1739 case nir_instr_type_tex:
1740 emit_tex(ctx, nir_instr_as_tex(instr));
1741 break;
1742 case nir_instr_type_phi:
1743 unreachable("nir_instr_type_phi not supported");
1744 break;
1745 case nir_instr_type_jump:
1746 emit_jump(ctx, nir_instr_as_jump(instr));
1747 break;
1748 case nir_instr_type_call:
1749 unreachable("nir_instr_type_call not supported");
1750 break;
1751 case nir_instr_type_parallel_copy:
1752 unreachable("nir_instr_type_parallel_copy not supported");
1753 break;
1754 case nir_instr_type_deref:
1755 emit_deref(ctx, nir_instr_as_deref(instr));
1756 break;
1757 }
1758 }
1759 }
1760
1761 static void
1762 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1763
1764 static SpvId
1765 get_src_bool(struct ntv_context *ctx, nir_src *src)
1766 {
1767 SpvId def = get_src_uint(ctx, src);
1768 assert(nir_src_bit_size(*src) == 1);
1769 unsigned num_components = nir_src_num_components(*src);
1770 return uvec_to_bvec(ctx, def, num_components);
1771 }
1772
1773 static void
1774 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1775 {
1776 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1777
1778 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1779 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1780 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1781 SpvId else_id = endif_id;
1782
1783 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1784 if (has_else) {
1785 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1786 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1787 }
1788
1789 /* create a header-block */
1790 start_block(ctx, header_id);
1791 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1792 SpvSelectionControlMaskNone);
1793 branch_conditional(ctx, condition, then_id, else_id);
1794
1795 emit_cf_list(ctx, &if_stmt->then_list);
1796
1797 if (has_else) {
1798 if (ctx->block_started)
1799 branch(ctx, endif_id);
1800
1801 emit_cf_list(ctx, &if_stmt->else_list);
1802 }
1803
1804 start_block(ctx, endif_id);
1805 }
1806
1807 static void
1808 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1809 {
1810 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1811 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1812 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1813 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1814
1815 /* create a header-block */
1816 start_block(ctx, header_id);
1817 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1818 branch(ctx, begin_id);
1819
1820 SpvId save_break = ctx->loop_break;
1821 SpvId save_cont = ctx->loop_cont;
1822 ctx->loop_break = break_id;
1823 ctx->loop_cont = cont_id;
1824
1825 emit_cf_list(ctx, &loop->body);
1826
1827 ctx->loop_break = save_break;
1828 ctx->loop_cont = save_cont;
1829
1830 branch(ctx, cont_id);
1831 start_block(ctx, cont_id);
1832 branch(ctx, header_id);
1833
1834 start_block(ctx, break_id);
1835 }
1836
1837 static void
1838 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1839 {
1840 foreach_list_typed(nir_cf_node, node, node, list) {
1841 switch (node->type) {
1842 case nir_cf_node_block:
1843 emit_block(ctx, nir_cf_node_as_block(node));
1844 break;
1845
1846 case nir_cf_node_if:
1847 emit_if(ctx, nir_cf_node_as_if(node));
1848 break;
1849
1850 case nir_cf_node_loop:
1851 emit_loop(ctx, nir_cf_node_as_loop(node));
1852 break;
1853
1854 case nir_cf_node_function:
1855 unreachable("nir_cf_node_function not supported");
1856 break;
1857 }
1858 }
1859 }
1860
1861 struct spirv_shader *
1862 nir_to_spirv(struct nir_shader *s)
1863 {
1864 struct spirv_shader *ret = NULL;
1865
1866 struct ntv_context ctx = {};
1867
1868 switch (s->info.stage) {
1869 case MESA_SHADER_VERTEX:
1870 case MESA_SHADER_FRAGMENT:
1871 case MESA_SHADER_COMPUTE:
1872 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1873 break;
1874
1875 case MESA_SHADER_TESS_CTRL:
1876 case MESA_SHADER_TESS_EVAL:
1877 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1878 break;
1879
1880 case MESA_SHADER_GEOMETRY:
1881 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1882 break;
1883
1884 default:
1885 unreachable("invalid stage");
1886 }
1887
1888 // TODO: only enable when needed
1889 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1890 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1891 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1892 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
1893 }
1894
1895 ctx.stage = s->info.stage;
1896 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1897 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1898
1899 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1900 SpvMemoryModelGLSL450);
1901
1902 SpvExecutionModel exec_model;
1903 switch (s->info.stage) {
1904 case MESA_SHADER_VERTEX:
1905 exec_model = SpvExecutionModelVertex;
1906 break;
1907 case MESA_SHADER_TESS_CTRL:
1908 exec_model = SpvExecutionModelTessellationControl;
1909 break;
1910 case MESA_SHADER_TESS_EVAL:
1911 exec_model = SpvExecutionModelTessellationEvaluation;
1912 break;
1913 case MESA_SHADER_GEOMETRY:
1914 exec_model = SpvExecutionModelGeometry;
1915 break;
1916 case MESA_SHADER_FRAGMENT:
1917 exec_model = SpvExecutionModelFragment;
1918 break;
1919 case MESA_SHADER_COMPUTE:
1920 exec_model = SpvExecutionModelGLCompute;
1921 break;
1922 default:
1923 unreachable("invalid stage");
1924 }
1925
1926 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1927 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1928 NULL, 0);
1929 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1930 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1931
1932 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1933 _mesa_key_pointer_equal);
1934
1935 nir_foreach_variable(var, &s->inputs)
1936 emit_input(&ctx, var);
1937
1938 nir_foreach_variable(var, &s->outputs)
1939 emit_output(&ctx, var);
1940
1941 nir_foreach_variable(var, &s->uniforms)
1942 emit_uniform(&ctx, var);
1943
1944 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1945 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1946 SpvExecutionModeOriginUpperLeft);
1947 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1948 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1949 SpvExecutionModeDepthReplacing);
1950 }
1951
1952
1953 spirv_builder_function(&ctx.builder, entry_point, type_void,
1954 SpvFunctionControlMaskNone,
1955 type_main);
1956
1957 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1958 nir_metadata_require(entry, nir_metadata_block_index);
1959
1960 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1961 if (!ctx.defs)
1962 goto fail;
1963 ctx.num_defs = entry->ssa_alloc;
1964
1965 nir_index_local_regs(entry);
1966 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1967 if (!ctx.regs)
1968 goto fail;
1969 ctx.num_regs = entry->reg_alloc;
1970
1971 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1972 if (!block_ids)
1973 goto fail;
1974
1975 for (int i = 0; i < entry->num_blocks; ++i)
1976 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1977
1978 ctx.block_ids = block_ids;
1979 ctx.num_blocks = entry->num_blocks;
1980
1981 /* emit a block only for the variable declarations */
1982 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1983 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1984 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1985 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1986 SpvStorageClassFunction,
1987 type);
1988 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1989 SpvStorageClassFunction);
1990
1991 ctx.regs[reg->index] = var;
1992 }
1993
1994 emit_cf_list(&ctx, &entry->body);
1995
1996 free(ctx.defs);
1997
1998 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1999 spirv_builder_function_end(&ctx.builder);
2000
2001 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2002 "main", ctx.entry_ifaces,
2003 ctx.num_entry_ifaces);
2004
2005 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2006
2007 ret = CALLOC_STRUCT(spirv_shader);
2008 if (!ret)
2009 goto fail;
2010
2011 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2012 if (!ret->words)
2013 goto fail;
2014
2015 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2016 assert(ret->num_words == num_words);
2017
2018 return ret;
2019
2020 fail:
2021
2022 if (ret)
2023 spirv_shader_delete(ret);
2024
2025 if (ctx.vars)
2026 _mesa_hash_table_destroy(ctx.vars, NULL);
2027
2028 return NULL;
2029 }
2030
2031 void
2032 spirv_shader_delete(struct spirv_shader *s)
2033 {
2034 FREE(s->words);
2035 FREE(s);
2036 }