zink/spirv: rename functions a bit
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 unsigned samplers_used : PIPE_MAX_SAMPLERS;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
172 nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
335 break;
336
337 case FRAG_RESULT_DEPTH:
338 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
339 break;
340
341 default:
342 spirv_builder_emit_location(&ctx->builder, var_id,
343 var->data.driver_location);
344 }
345 }
346 }
347
348 if (var->data.location_frac)
349 spirv_builder_emit_component(&ctx->builder, var_id,
350 var->data.location_frac);
351
352 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
353
354 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
355 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
356 }
357
358 static SpvDim
359 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
360 {
361 *is_ms = false;
362 switch (gdim) {
363 case GLSL_SAMPLER_DIM_1D:
364 return SpvDim1D;
365 case GLSL_SAMPLER_DIM_2D:
366 return SpvDim2D;
367 case GLSL_SAMPLER_DIM_3D:
368 return SpvDim3D;
369 case GLSL_SAMPLER_DIM_CUBE:
370 return SpvDimCube;
371 case GLSL_SAMPLER_DIM_RECT:
372 return SpvDim2D;
373 case GLSL_SAMPLER_DIM_BUF:
374 return SpvDimBuffer;
375 case GLSL_SAMPLER_DIM_EXTERNAL:
376 return SpvDim2D; /* seems dodgy... */
377 case GLSL_SAMPLER_DIM_MS:
378 *is_ms = true;
379 return SpvDim2D;
380 default:
381 fprintf(stderr, "unknown sampler type %d\n", gdim);
382 break;
383 }
384 return SpvDim2D;
385 }
386
387 uint32_t
388 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
389 {
390 if (stage == MESA_SHADER_NONE ||
391 stage >= MESA_SHADER_COMPUTE) {
392 unreachable("not supported");
393 } else {
394 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
395 PIPE_MAX_SHADER_SAMPLER_VIEWS);
396
397 switch (type) {
398 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
399 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
400 return stage_offset + index;
401
402 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
403 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
404 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
405
406 default:
407 unreachable("unexpected type");
408 }
409 }
410 }
411
412 static void
413 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
414 {
415 const struct glsl_type *type = glsl_without_array(var->type);
416
417 bool is_ms;
418 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
419
420 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
421 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
422 dimension, false,
423 glsl_sampler_type_is_array(type),
424 is_ms, 1,
425 SpvImageFormatUnknown);
426
427 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
428 image_type);
429 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
430 SpvStorageClassUniformConstant,
431 sampled_type);
432
433 if (glsl_type_is_array(var->type)) {
434 for (int i = 0; i < glsl_get_length(var->type); ++i) {
435 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
436 SpvStorageClassUniformConstant);
437
438 if (var->name) {
439 char element_name[100];
440 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
441 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
442 }
443
444 int index = var->data.binding + i;
445 assert(!(ctx->samplers_used & (1 << index)));
446 assert(!ctx->image_types[index]);
447 ctx->image_types[index] = image_type;
448 ctx->samplers[index] = var_id;
449 ctx->samplers_used |= 1 << index;
450
451 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
452 var->data.descriptor_set);
453 int binding = zink_binding(ctx->stage,
454 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
455 var->data.binding + i);
456 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
457 }
458 } else {
459 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
460 SpvStorageClassUniformConstant);
461
462 if (var->name)
463 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
464
465 int index = var->data.binding;
466 assert(!(ctx->samplers_used & (1 << index)));
467 assert(!ctx->image_types[index]);
468 ctx->image_types[index] = image_type;
469 ctx->samplers[index] = var_id;
470 ctx->samplers_used |= 1 << index;
471
472 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
473 var->data.descriptor_set);
474 int binding = zink_binding(ctx->stage,
475 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
476 var->data.binding);
477 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
478 }
479 }
480
481 static void
482 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
483 {
484 uint32_t size = glsl_count_attribute_slots(var->type, false);
485 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
486 SpvId array_length = emit_uint_const(ctx, 32, size);
487 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
488 array_length);
489 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
490
491 // wrap UBO-array in a struct
492 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
493 if (var->name) {
494 char struct_name[100];
495 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
496 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
497 }
498
499 spirv_builder_emit_decoration(&ctx->builder, struct_type,
500 SpvDecorationBlock);
501 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
502
503
504 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
505 SpvStorageClassUniform,
506 struct_type);
507
508 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
509 SpvStorageClassUniform);
510 if (var->name)
511 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
512
513 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
514 ctx->ubos[ctx->num_ubos++] = var_id;
515
516 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
517 var->data.descriptor_set);
518 int binding = zink_binding(ctx->stage,
519 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
520 var->data.binding);
521 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
522 }
523
524 static void
525 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
526 {
527 if (var->data.mode == nir_var_mem_ubo)
528 emit_ubo(ctx, var);
529 else {
530 assert(var->data.mode == nir_var_uniform);
531 if (glsl_type_is_sampler(glsl_without_array(var->type)))
532 emit_sampler(ctx, var);
533 }
534 }
535
536 static SpvId
537 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
538 {
539 assert(ssa->index < ctx->num_defs);
540 assert(ctx->defs[ssa->index] != 0);
541 return ctx->defs[ssa->index];
542 }
543
544 static SpvId
545 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
546 {
547 assert(reg->index < ctx->num_regs);
548 assert(ctx->regs[reg->index] != 0);
549 return ctx->regs[reg->index];
550 }
551
552 static SpvId
553 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
554 {
555 assert(reg->reg);
556 assert(!reg->indirect);
557 assert(!reg->base_offset);
558
559 SpvId var = get_var_from_reg(ctx, reg->reg);
560 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
561 return spirv_builder_emit_load(&ctx->builder, type, var);
562 }
563
564 static SpvId
565 get_src(struct ntv_context *ctx, nir_src *src)
566 {
567 if (src->is_ssa)
568 return get_src_ssa(ctx, src->ssa);
569 else
570 return get_src_reg(ctx, &src->reg);
571 }
572
573 static SpvId
574 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
575 {
576 assert(!alu->src[src].negate);
577 assert(!alu->src[src].abs);
578
579 SpvId def = get_src(ctx, &alu->src[src].src);
580
581 unsigned used_channels = 0;
582 bool need_swizzle = false;
583 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
584 if (!nir_alu_instr_channel_used(alu, src, i))
585 continue;
586
587 used_channels++;
588
589 if (alu->src[src].swizzle[i] != i)
590 need_swizzle = true;
591 }
592 assert(used_channels != 0);
593
594 unsigned live_channels = nir_src_num_components(alu->src[src].src);
595 if (used_channels != live_channels)
596 need_swizzle = true;
597
598 if (!need_swizzle)
599 return def;
600
601 int bit_size = nir_src_bit_size(alu->src[src].src);
602 assert(bit_size == 1 || bit_size == 32);
603
604 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
605 if (used_channels == 1) {
606 uint32_t indices[] = { alu->src[src].swizzle[0] };
607 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
608 def, indices,
609 ARRAY_SIZE(indices));
610 } else if (live_channels == 1) {
611 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
612 used_channels);
613
614 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
615 for (unsigned i = 0; i < used_channels; ++i)
616 constituents[i] = def;
617
618 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
619 constituents,
620 used_channels);
621 } else {
622 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
623 used_channels);
624
625 uint32_t components[NIR_MAX_VEC_COMPONENTS];
626 size_t num_components = 0;
627 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
628 if (!nir_alu_instr_channel_used(alu, src, i))
629 continue;
630
631 components[num_components++] = alu->src[src].swizzle[i];
632 }
633
634 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
635 def, def, components, num_components);
636 }
637 }
638
639 static void
640 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
641 {
642 assert(result != 0);
643 assert(ssa->index < ctx->num_defs);
644 ctx->defs[ssa->index] = result;
645 }
646
647 static SpvId
648 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
649 SpvId if_true, SpvId if_false)
650 {
651 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
652 }
653
654 static SpvId
655 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
656 {
657 SpvId otype = get_uvec_type(ctx, 32, num_components);
658 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
659 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
660 return emit_select(ctx, otype, value, one, zero);
661 }
662
663 static SpvId
664 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
665 {
666 SpvId type = get_bvec_type(ctx, num_components);
667 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
668 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
669 }
670
671 static SpvId
672 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
673 {
674 return emit_unop(ctx, SpvOpBitcast, type, value);
675 }
676
677 static SpvId
678 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
679 unsigned num_components)
680 {
681 SpvId type = get_uvec_type(ctx, bit_size, num_components);
682 return emit_bitcast(ctx, type, value);
683 }
684
685 static SpvId
686 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
687 unsigned num_components)
688 {
689 SpvId type = get_ivec_type(ctx, bit_size, num_components);
690 return emit_bitcast(ctx, type, value);
691 }
692
693 static SpvId
694 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
695 unsigned num_components)
696 {
697 SpvId type = get_fvec_type(ctx, bit_size, num_components);
698 return emit_bitcast(ctx, type, value);
699 }
700
701 static void
702 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
703 {
704 SpvId var = get_var_from_reg(ctx, reg->reg);
705 assert(var);
706 spirv_builder_emit_store(&ctx->builder, var, result);
707 }
708
709 static void
710 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
711 {
712 if (dest->is_ssa)
713 store_ssa_def(ctx, &dest->ssa, result);
714 else
715 store_reg_def(ctx, &dest->reg, result);
716 }
717
718 static void
719 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
720 {
721 unsigned num_components = nir_dest_num_components(*dest);
722 unsigned bit_size = nir_dest_bit_size(*dest);
723
724 switch (nir_alu_type_get_base_type(type)) {
725 case nir_type_bool:
726 assert(bit_size == 1);
727 result = bvec_to_uvec(ctx, result, num_components);
728 break;
729
730 case nir_type_uint:
731 break; /* nothing to do! */
732
733 case nir_type_int:
734 case nir_type_float:
735 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
736 break;
737
738 default:
739 unreachable("unsupported nir_alu_type");
740 }
741
742 store_dest_raw(ctx, dest, result);
743 }
744
745 static SpvId
746 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
747 {
748 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
749 }
750
751 static SpvId
752 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
753 SpvId src0, SpvId src1)
754 {
755 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
756 }
757
758 static SpvId
759 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
760 SpvId src0, SpvId src1, SpvId src2)
761 {
762 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
763 }
764
765 static SpvId
766 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
767 SpvId src)
768 {
769 SpvId args[] = { src };
770 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
771 op, args, ARRAY_SIZE(args));
772 }
773
774 static SpvId
775 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
776 SpvId src0, SpvId src1)
777 {
778 SpvId args[] = { src0, src1 };
779 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
780 op, args, ARRAY_SIZE(args));
781 }
782
783 static SpvId
784 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
785 SpvId src0, SpvId src1, SpvId src2)
786 {
787 SpvId args[] = { src0, src1, src2 };
788 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
789 op, args, ARRAY_SIZE(args));
790 }
791
792 static SpvId
793 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
794 unsigned num_components, float value)
795 {
796 assert(bit_size == 32);
797
798 SpvId result = emit_float_const(ctx, bit_size, value);
799 if (num_components == 1)
800 return result;
801
802 assert(num_components > 1);
803 SpvId components[num_components];
804 for (int i = 0; i < num_components; i++)
805 components[i] = result;
806
807 SpvId type = get_fvec_type(ctx, bit_size, num_components);
808 return spirv_builder_const_composite(&ctx->builder, type, components,
809 num_components);
810 }
811
812 static SpvId
813 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
814 unsigned num_components, uint32_t value)
815 {
816 assert(bit_size == 32);
817
818 SpvId result = emit_uint_const(ctx, bit_size, value);
819 if (num_components == 1)
820 return result;
821
822 assert(num_components > 1);
823 SpvId components[num_components];
824 for (int i = 0; i < num_components; i++)
825 components[i] = result;
826
827 SpvId type = get_uvec_type(ctx, bit_size, num_components);
828 return spirv_builder_const_composite(&ctx->builder, type, components,
829 num_components);
830 }
831
832 static SpvId
833 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
834 unsigned num_components, int32_t value)
835 {
836 assert(bit_size == 32);
837
838 SpvId result = emit_int_const(ctx, bit_size, value);
839 if (num_components == 1)
840 return result;
841
842 assert(num_components > 1);
843 SpvId components[num_components];
844 for (int i = 0; i < num_components; i++)
845 components[i] = result;
846
847 SpvId type = get_ivec_type(ctx, bit_size, num_components);
848 return spirv_builder_const_composite(&ctx->builder, type, components,
849 num_components);
850 }
851
852 static inline unsigned
853 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
854 {
855 if (nir_op_infos[instr->op].input_sizes[src] > 0)
856 return nir_op_infos[instr->op].input_sizes[src];
857
858 if (instr->dest.dest.is_ssa)
859 return instr->dest.dest.ssa.num_components;
860 else
861 return instr->dest.dest.reg.reg->num_components;
862 }
863
864 static SpvId
865 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
866 {
867 SpvId uint_value = get_alu_src_raw(ctx, alu, src);
868
869 unsigned num_components = alu_instr_src_components(alu, src);
870 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
871 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
872
873 switch (nir_alu_type_get_base_type(type)) {
874 case nir_type_bool:
875 assert(bit_size == 1);
876 return uvec_to_bvec(ctx, uint_value, num_components);
877
878 case nir_type_int:
879 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
880
881 case nir_type_uint:
882 return uint_value;
883
884 case nir_type_float:
885 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
886
887 default:
888 unreachable("unknown nir_alu_type");
889 }
890 }
891
892 static void
893 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
894 {
895 assert(!alu->dest.saturate);
896 return store_dest(ctx, &alu->dest.dest, result,
897 nir_op_infos[alu->op].output_type);
898 }
899
900 static SpvId
901 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
902 {
903 unsigned num_components = nir_dest_num_components(*dest);
904 unsigned bit_size = nir_dest_bit_size(*dest);
905
906 switch (nir_alu_type_get_base_type(type)) {
907 case nir_type_bool:
908 return get_bvec_type(ctx, num_components);
909
910 case nir_type_int:
911 return get_ivec_type(ctx, bit_size, num_components);
912
913 case nir_type_uint:
914 return get_uvec_type(ctx, bit_size, num_components);
915
916 case nir_type_float:
917 return get_fvec_type(ctx, bit_size, num_components);
918
919 default:
920 unreachable("unsupported nir_alu_type");
921 }
922 }
923
924 static void
925 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
926 {
927 SpvId src[nir_op_infos[alu->op].num_inputs];
928 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
929 src[i] = get_alu_src(ctx, alu, i);
930
931 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
932 nir_op_infos[alu->op].output_type);
933 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
934 unsigned num_components = nir_dest_num_components(alu->dest.dest);
935
936 SpvId result = 0;
937 switch (alu->op) {
938 case nir_op_mov:
939 assert(nir_op_infos[alu->op].num_inputs == 1);
940 result = src[0];
941 break;
942
943 #define UNOP(nir_op, spirv_op) \
944 case nir_op: \
945 assert(nir_op_infos[alu->op].num_inputs == 1); \
946 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
947 break;
948
949 UNOP(nir_op_ineg, SpvOpSNegate)
950 UNOP(nir_op_fneg, SpvOpFNegate)
951 UNOP(nir_op_fddx, SpvOpDPdx)
952 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
953 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
954 UNOP(nir_op_fddy, SpvOpDPdy)
955 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
956 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
957 UNOP(nir_op_f2i32, SpvOpConvertFToS)
958 UNOP(nir_op_f2u32, SpvOpConvertFToU)
959 UNOP(nir_op_i2f32, SpvOpConvertSToF)
960 UNOP(nir_op_u2f32, SpvOpConvertUToF)
961 UNOP(nir_op_inot, SpvOpNot)
962 #undef UNOP
963
964 case nir_op_b2i32:
965 assert(nir_op_infos[alu->op].num_inputs == 1);
966 result = emit_select(ctx, dest_type, src[0],
967 get_ivec_constant(ctx, 32, num_components, 1),
968 get_ivec_constant(ctx, 32, num_components, 0));
969 break;
970
971 case nir_op_b2f32:
972 assert(nir_op_infos[alu->op].num_inputs == 1);
973 result = emit_select(ctx, dest_type, src[0],
974 get_fvec_constant(ctx, 32, num_components, 1),
975 get_fvec_constant(ctx, 32, num_components, 0));
976 break;
977
978 #define BUILTIN_UNOP(nir_op, spirv_op) \
979 case nir_op: \
980 assert(nir_op_infos[alu->op].num_inputs == 1); \
981 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
982 break;
983
984 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
985 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
986 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
987 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
988 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
989 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
990 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
991 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
992 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
993 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
994 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
995 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
996 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
997 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
998 #undef BUILTIN_UNOP
999
1000 case nir_op_frcp:
1001 assert(nir_op_infos[alu->op].num_inputs == 1);
1002 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1003 get_fvec_constant(ctx, bit_size, num_components, 1),
1004 src[0]);
1005 break;
1006
1007 case nir_op_f2b1:
1008 assert(nir_op_infos[alu->op].num_inputs == 1);
1009 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1010 get_fvec_constant(ctx,
1011 nir_src_bit_size(alu->src[0].src),
1012 num_components, 0));
1013 break;
1014
1015
1016 #define BINOP(nir_op, spirv_op) \
1017 case nir_op: \
1018 assert(nir_op_infos[alu->op].num_inputs == 2); \
1019 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1020 break;
1021
1022 BINOP(nir_op_iadd, SpvOpIAdd)
1023 BINOP(nir_op_isub, SpvOpISub)
1024 BINOP(nir_op_imul, SpvOpIMul)
1025 BINOP(nir_op_idiv, SpvOpSDiv)
1026 BINOP(nir_op_udiv, SpvOpUDiv)
1027 BINOP(nir_op_umod, SpvOpUMod)
1028 BINOP(nir_op_fadd, SpvOpFAdd)
1029 BINOP(nir_op_fsub, SpvOpFSub)
1030 BINOP(nir_op_fmul, SpvOpFMul)
1031 BINOP(nir_op_fdiv, SpvOpFDiv)
1032 BINOP(nir_op_fmod, SpvOpFMod)
1033 BINOP(nir_op_ilt, SpvOpSLessThan)
1034 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1035 BINOP(nir_op_ieq, SpvOpIEqual)
1036 BINOP(nir_op_ine, SpvOpINotEqual)
1037 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1038 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1039 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1040 BINOP(nir_op_feq, SpvOpFOrdEqual)
1041 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1042 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1043 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1044 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1045 BINOP(nir_op_iand, SpvOpBitwiseAnd)
1046 BINOP(nir_op_ior, SpvOpBitwiseOr)
1047 #undef BINOP
1048
1049 #define BUILTIN_BINOP(nir_op, spirv_op) \
1050 case nir_op: \
1051 assert(nir_op_infos[alu->op].num_inputs == 2); \
1052 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1053 break;
1054
1055 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1056 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1057 #undef BUILTIN_BINOP
1058
1059 case nir_op_fdot2:
1060 case nir_op_fdot3:
1061 case nir_op_fdot4:
1062 assert(nir_op_infos[alu->op].num_inputs == 2);
1063 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1064 break;
1065
1066 case nir_op_fdph:
1067 unreachable("should already be lowered away");
1068
1069 case nir_op_seq:
1070 case nir_op_sne:
1071 case nir_op_slt:
1072 case nir_op_sge: {
1073 assert(nir_op_infos[alu->op].num_inputs == 2);
1074 int num_components = nir_dest_num_components(alu->dest.dest);
1075 SpvId bool_type = get_bvec_type(ctx, num_components);
1076
1077 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1078 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1079 if (num_components > 1) {
1080 SpvId zero_comps[num_components], one_comps[num_components];
1081 for (int i = 0; i < num_components; i++) {
1082 zero_comps[i] = zero;
1083 one_comps[i] = one;
1084 }
1085
1086 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1087 zero_comps, num_components);
1088 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1089 one_comps, num_components);
1090 }
1091
1092 SpvOp op;
1093 switch (alu->op) {
1094 case nir_op_seq: op = SpvOpFOrdEqual; break;
1095 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1096 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1097 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1098 default: unreachable("unexpected op");
1099 }
1100
1101 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1102 result = emit_select(ctx, dest_type, result, one, zero);
1103 }
1104 break;
1105
1106 case nir_op_flrp:
1107 assert(nir_op_infos[alu->op].num_inputs == 3);
1108 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1109 src[0], src[1], src[2]);
1110 break;
1111
1112 case nir_op_fcsel:
1113 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1114 get_bvec_type(ctx, num_components),
1115 src[0],
1116 get_fvec_constant(ctx,
1117 nir_src_bit_size(alu->src[0].src),
1118 num_components, 0));
1119 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1120 break;
1121
1122 case nir_op_bcsel:
1123 assert(nir_op_infos[alu->op].num_inputs == 3);
1124 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1125 break;
1126
1127 case nir_op_bany_fnequal2:
1128 case nir_op_bany_fnequal3:
1129 case nir_op_bany_fnequal4:
1130 assert(nir_op_infos[alu->op].num_inputs == 2);
1131 assert(alu_instr_src_components(alu, 0) ==
1132 alu_instr_src_components(alu, 1));
1133 result = emit_binop(ctx, SpvOpFOrdNotEqual,
1134 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1135 src[0], src[1]);
1136 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1137 break;
1138
1139 case nir_op_ball_fequal2:
1140 case nir_op_ball_fequal3:
1141 case nir_op_ball_fequal4:
1142 assert(nir_op_infos[alu->op].num_inputs == 2);
1143 assert(alu_instr_src_components(alu, 0) ==
1144 alu_instr_src_components(alu, 1));
1145 result = emit_binop(ctx, SpvOpFOrdEqual,
1146 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1147 src[0], src[1]);
1148 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1149 break;
1150
1151 case nir_op_bany_inequal2:
1152 case nir_op_bany_inequal3:
1153 case nir_op_bany_inequal4:
1154 assert(nir_op_infos[alu->op].num_inputs == 2);
1155 assert(alu_instr_src_components(alu, 0) ==
1156 alu_instr_src_components(alu, 1));
1157 result = emit_binop(ctx, SpvOpINotEqual,
1158 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1159 src[0], src[1]);
1160 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1161 break;
1162
1163 case nir_op_ball_iequal2:
1164 case nir_op_ball_iequal3:
1165 case nir_op_ball_iequal4:
1166 assert(nir_op_infos[alu->op].num_inputs == 2);
1167 assert(alu_instr_src_components(alu, 0) ==
1168 alu_instr_src_components(alu, 1));
1169 result = emit_binop(ctx, SpvOpIEqual,
1170 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1171 src[0], src[1]);
1172 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1173 break;
1174
1175 case nir_op_vec2:
1176 case nir_op_vec3:
1177 case nir_op_vec4: {
1178 int num_inputs = nir_op_infos[alu->op].num_inputs;
1179 assert(2 <= num_inputs && num_inputs <= 4);
1180 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1181 src, num_inputs);
1182 }
1183 break;
1184
1185 default:
1186 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1187 nir_op_infos[alu->op].name);
1188
1189 unreachable("unsupported opcode");
1190 return;
1191 }
1192
1193 store_alu_result(ctx, alu, result);
1194 }
1195
1196 static void
1197 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1198 {
1199 unsigned bit_size = load_const->def.bit_size;
1200 unsigned num_components = load_const->def.num_components;
1201
1202 SpvId constant;
1203 if (num_components > 1) {
1204 SpvId components[num_components];
1205 SpvId type;
1206 if (bit_size == 1) {
1207 for (int i = 0; i < num_components; i++)
1208 components[i] = spirv_builder_const_bool(&ctx->builder,
1209 load_const->value[i].b);
1210
1211 type = get_bvec_type(ctx, num_components);
1212 } else {
1213 for (int i = 0; i < num_components; i++)
1214 components[i] = emit_uint_const(ctx, bit_size,
1215 load_const->value[i].u32);
1216
1217 type = get_uvec_type(ctx, bit_size, num_components);
1218 }
1219 constant = spirv_builder_const_composite(&ctx->builder, type,
1220 components, num_components);
1221 } else {
1222 assert(num_components == 1);
1223 if (bit_size == 1)
1224 constant = spirv_builder_const_bool(&ctx->builder,
1225 load_const->value[0].b);
1226 else
1227 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1228 }
1229
1230 if (bit_size == 1)
1231 constant = bvec_to_uvec(ctx, constant, num_components);
1232
1233 store_ssa_def(ctx, &load_const->def, constant);
1234 }
1235
1236 static void
1237 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1238 {
1239 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1240 assert(const_block_index); // no dynamic indexing for now
1241 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1242
1243 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1244 if (const_offset) {
1245 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1246 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1247 SpvStorageClassUniform,
1248 uvec4_type);
1249
1250 unsigned idx = const_offset->u32;
1251 SpvId member = emit_uint_const(ctx, 32, 0);
1252 SpvId offset = emit_uint_const(ctx, 32, idx);
1253 SpvId offsets[] = { member, offset };
1254 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1255 ctx->ubos[0], offsets,
1256 ARRAY_SIZE(offsets));
1257 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1258
1259 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1260 unsigned num_components = nir_dest_num_components(intr->dest);
1261 if (num_components == 1) {
1262 uint32_t components[] = { 0 };
1263 result = spirv_builder_emit_composite_extract(&ctx->builder,
1264 type,
1265 result, components,
1266 1);
1267 } else if (num_components < 4) {
1268 SpvId constituents[num_components];
1269 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1270 for (uint32_t i = 0; i < num_components; ++i)
1271 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1272 uint_type,
1273 result, &i,
1274 1);
1275
1276 result = spirv_builder_emit_composite_construct(&ctx->builder,
1277 type,
1278 constituents,
1279 num_components);
1280 }
1281
1282 store_dest(ctx, &intr->dest, result, nir_type_uint);
1283 } else
1284 unreachable("uniform-addressing not yet supported");
1285 }
1286
1287 static void
1288 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1289 {
1290 assert(ctx->block_started);
1291 spirv_builder_emit_kill(&ctx->builder);
1292 /* discard is weird in NIR, so let's just create an unreachable block after
1293 it and hope that the vulkan driver will DCE any instructinos in it. */
1294 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1295 }
1296
1297 static void
1298 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1299 {
1300 SpvId ptr = get_src(ctx, intr->src);
1301
1302 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1303 SpvId result = spirv_builder_emit_load(&ctx->builder,
1304 get_glsl_type(ctx, var->type),
1305 ptr);
1306 unsigned num_components = nir_dest_num_components(intr->dest);
1307 unsigned bit_size = nir_dest_bit_size(intr->dest);
1308 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1309 store_dest(ctx, &intr->dest, result, nir_type_uint);
1310 }
1311
1312 static void
1313 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1314 {
1315 SpvId ptr = get_src(ctx, &intr->src[0]);
1316 SpvId src = get_src(ctx, &intr->src[1]);
1317
1318 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1319 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1320 SpvId result = emit_bitcast(ctx, type, src);
1321 spirv_builder_emit_store(&ctx->builder, ptr, result);
1322 }
1323
1324 static SpvId
1325 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1326 SpvStorageClass storage_class,
1327 const char *name, SpvBuiltIn builtin)
1328 {
1329 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1330 storage_class,
1331 var_type);
1332 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1333 storage_class);
1334 spirv_builder_emit_name(&ctx->builder, var, name);
1335 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1336
1337 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1338 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1339 return var;
1340 }
1341
1342 static void
1343 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1344 {
1345 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1346 if (!ctx->front_face_var)
1347 ctx->front_face_var = create_builtin_var(ctx, var_type,
1348 SpvStorageClassInput,
1349 "gl_FrontFacing",
1350 SpvBuiltInFrontFacing);
1351
1352 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1353 ctx->front_face_var);
1354 assert(1 == nir_dest_num_components(intr->dest));
1355 store_dest(ctx, &intr->dest, result, nir_type_bool);
1356 }
1357
1358 static void
1359 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1360 {
1361 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1362 if (!ctx->instance_id_var)
1363 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1364 SpvStorageClassInput,
1365 "gl_InstanceId",
1366 SpvBuiltInInstanceIndex);
1367
1368 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1369 ctx->instance_id_var);
1370 assert(1 == nir_dest_num_components(intr->dest));
1371 store_dest(ctx, &intr->dest, result, nir_type_uint);
1372 }
1373
1374 static void
1375 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1376 {
1377 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1378 if (!ctx->vertex_id_var)
1379 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1380 SpvStorageClassInput,
1381 "gl_VertexID",
1382 SpvBuiltInVertexIndex);
1383
1384 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1385 ctx->vertex_id_var);
1386 assert(1 == nir_dest_num_components(intr->dest));
1387 store_dest(ctx, &intr->dest, result, nir_type_uint);
1388 }
1389
1390 static void
1391 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1392 {
1393 switch (intr->intrinsic) {
1394 case nir_intrinsic_load_ubo:
1395 emit_load_ubo(ctx, intr);
1396 break;
1397
1398 case nir_intrinsic_discard:
1399 emit_discard(ctx, intr);
1400 break;
1401
1402 case nir_intrinsic_load_deref:
1403 emit_load_deref(ctx, intr);
1404 break;
1405
1406 case nir_intrinsic_store_deref:
1407 emit_store_deref(ctx, intr);
1408 break;
1409
1410 case nir_intrinsic_load_front_face:
1411 emit_load_front_face(ctx, intr);
1412 break;
1413
1414 case nir_intrinsic_load_instance_id:
1415 emit_load_instance_id(ctx, intr);
1416 break;
1417
1418 case nir_intrinsic_load_vertex_id:
1419 emit_load_vertex_id(ctx, intr);
1420 break;
1421
1422 default:
1423 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1424 nir_intrinsic_infos[intr->intrinsic].name);
1425 unreachable("unsupported intrinsic");
1426 }
1427 }
1428
1429 static void
1430 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1431 {
1432 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1433 undef->def.num_components);
1434
1435 store_ssa_def(ctx, &undef->def,
1436 spirv_builder_emit_undef(&ctx->builder, type));
1437 }
1438
1439 static SpvId
1440 get_src_float(struct ntv_context *ctx, nir_src *src)
1441 {
1442 SpvId def = get_src(ctx, src);
1443 unsigned num_components = nir_src_num_components(*src);
1444 unsigned bit_size = nir_src_bit_size(*src);
1445 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1446 }
1447
1448 static SpvId
1449 get_src_int(struct ntv_context *ctx, nir_src *src)
1450 {
1451 SpvId def = get_src(ctx, src);
1452 unsigned num_components = nir_src_num_components(*src);
1453 unsigned bit_size = nir_src_bit_size(*src);
1454 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1455 }
1456
1457 static void
1458 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1459 {
1460 assert(tex->op == nir_texop_tex ||
1461 tex->op == nir_texop_txb ||
1462 tex->op == nir_texop_txl ||
1463 tex->op == nir_texop_txd ||
1464 tex->op == nir_texop_txf ||
1465 tex->op == nir_texop_txs);
1466 assert(tex->texture_index == tex->sampler_index);
1467
1468 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1469 offset = 0;
1470 unsigned coord_components = 0;
1471 for (unsigned i = 0; i < tex->num_srcs; i++) {
1472 switch (tex->src[i].src_type) {
1473 case nir_tex_src_coord:
1474 if (tex->op == nir_texop_txf)
1475 coord = get_src_int(ctx, &tex->src[i].src);
1476 else
1477 coord = get_src_float(ctx, &tex->src[i].src);
1478 coord_components = nir_src_num_components(tex->src[i].src);
1479 break;
1480
1481 case nir_tex_src_projector:
1482 assert(nir_src_num_components(tex->src[i].src) == 1);
1483 proj = get_src_float(ctx, &tex->src[i].src);
1484 assert(proj != 0);
1485 break;
1486
1487 case nir_tex_src_offset:
1488 offset = get_src_int(ctx, &tex->src[i].src);
1489 break;
1490
1491 case nir_tex_src_bias:
1492 assert(tex->op == nir_texop_txb);
1493 bias = get_src_float(ctx, &tex->src[i].src);
1494 assert(bias != 0);
1495 break;
1496
1497 case nir_tex_src_lod:
1498 assert(nir_src_num_components(tex->src[i].src) == 1);
1499 if (tex->op == nir_texop_txf ||
1500 tex->op == nir_texop_txs)
1501 lod = get_src_int(ctx, &tex->src[i].src);
1502 else
1503 lod = get_src_float(ctx, &tex->src[i].src);
1504 assert(lod != 0);
1505 break;
1506
1507 case nir_tex_src_comparator:
1508 assert(nir_src_num_components(tex->src[i].src) == 1);
1509 dref = get_src_float(ctx, &tex->src[i].src);
1510 assert(dref != 0);
1511 break;
1512
1513 case nir_tex_src_ddx:
1514 dx = get_src_float(ctx, &tex->src[i].src);
1515 assert(dx != 0);
1516 break;
1517
1518 case nir_tex_src_ddy:
1519 dy = get_src_float(ctx, &tex->src[i].src);
1520 assert(dy != 0);
1521 break;
1522
1523 default:
1524 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1525 unreachable("unknown texture source");
1526 }
1527 }
1528
1529 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1530 lod = emit_float_const(ctx, 32, 0.0f);
1531 assert(lod != 0);
1532 }
1533
1534 SpvId image_type = ctx->image_types[tex->texture_index];
1535 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1536 image_type);
1537
1538 assert(ctx->samplers_used & (1u << tex->texture_index));
1539 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1540 ctx->samplers[tex->texture_index]);
1541
1542 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1543
1544 if (tex->op == nir_texop_txs) {
1545 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1546 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1547 dest_type, image,
1548 lod);
1549 store_dest(ctx, &tex->dest, result, tex->dest_type);
1550 return;
1551 }
1552
1553 if (proj && coord_components > 0) {
1554 SpvId constituents[coord_components + 1];
1555 if (coord_components == 1)
1556 constituents[0] = coord;
1557 else {
1558 assert(coord_components > 1);
1559 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1560 for (uint32_t i = 0; i < coord_components; ++i)
1561 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1562 float_type,
1563 coord,
1564 &i, 1);
1565 }
1566
1567 constituents[coord_components++] = proj;
1568
1569 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1570 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1571 vec_type,
1572 constituents,
1573 coord_components);
1574 }
1575
1576 SpvId actual_dest_type = dest_type;
1577 if (dref)
1578 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1579
1580 SpvId result;
1581 if (tex->op == nir_texop_txf) {
1582 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1583 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1584 image, coord, lod);
1585 } else {
1586 result = spirv_builder_emit_image_sample(&ctx->builder,
1587 actual_dest_type, load,
1588 coord,
1589 proj != 0,
1590 lod, bias, dref, dx, dy,
1591 offset);
1592 }
1593
1594 spirv_builder_emit_decoration(&ctx->builder, result,
1595 SpvDecorationRelaxedPrecision);
1596
1597 if (dref && nir_dest_num_components(tex->dest) > 1) {
1598 SpvId components[4] = { result, result, result, result };
1599 result = spirv_builder_emit_composite_construct(&ctx->builder,
1600 dest_type,
1601 components,
1602 4);
1603 }
1604
1605 store_dest(ctx, &tex->dest, result, tex->dest_type);
1606 }
1607
1608 static void
1609 start_block(struct ntv_context *ctx, SpvId label)
1610 {
1611 /* terminate previous block if needed */
1612 if (ctx->block_started)
1613 spirv_builder_emit_branch(&ctx->builder, label);
1614
1615 /* start new block */
1616 spirv_builder_label(&ctx->builder, label);
1617 ctx->block_started = true;
1618 }
1619
1620 static void
1621 branch(struct ntv_context *ctx, SpvId label)
1622 {
1623 assert(ctx->block_started);
1624 spirv_builder_emit_branch(&ctx->builder, label);
1625 ctx->block_started = false;
1626 }
1627
1628 static void
1629 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1630 SpvId else_id)
1631 {
1632 assert(ctx->block_started);
1633 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1634 then_id, else_id);
1635 ctx->block_started = false;
1636 }
1637
1638 static void
1639 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1640 {
1641 switch (jump->type) {
1642 case nir_jump_break:
1643 assert(ctx->loop_break);
1644 branch(ctx, ctx->loop_break);
1645 break;
1646
1647 case nir_jump_continue:
1648 assert(ctx->loop_cont);
1649 branch(ctx, ctx->loop_cont);
1650 break;
1651
1652 default:
1653 unreachable("Unsupported jump type\n");
1654 }
1655 }
1656
1657 static void
1658 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1659 {
1660 assert(deref->deref_type == nir_deref_type_var);
1661
1662 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1663 assert(he);
1664 SpvId result = (SpvId)(intptr_t)he->data;
1665 /* uint is a bit of a lie here, it's really just an opaque type */
1666 store_dest(ctx, &deref->dest, result, nir_type_uint);
1667 }
1668
1669 static void
1670 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1671 {
1672 assert(deref->deref_type == nir_deref_type_array);
1673 nir_variable *var = nir_deref_instr_get_variable(deref);
1674
1675 SpvStorageClass storage_class;
1676 switch (var->data.mode) {
1677 case nir_var_shader_in:
1678 storage_class = SpvStorageClassInput;
1679 break;
1680
1681 case nir_var_shader_out:
1682 storage_class = SpvStorageClassOutput;
1683 break;
1684
1685 default:
1686 unreachable("Unsupported nir_variable_mode\n");
1687 }
1688
1689 SpvId index = get_src(ctx, &deref->arr.index);
1690
1691 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1692 storage_class,
1693 get_glsl_type(ctx, deref->type));
1694
1695 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1696 ptr_type,
1697 get_src(ctx, &deref->parent),
1698 &index, 1);
1699 /* uint is a bit of a lie here, it's really just an opaque type */
1700 store_dest(ctx, &deref->dest, result, nir_type_uint);
1701 }
1702
1703 static void
1704 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1705 {
1706 switch (deref->deref_type) {
1707 case nir_deref_type_var:
1708 emit_deref_var(ctx, deref);
1709 break;
1710
1711 case nir_deref_type_array:
1712 emit_deref_array(ctx, deref);
1713 break;
1714
1715 default:
1716 unreachable("unexpected deref_type");
1717 }
1718 }
1719
1720 static void
1721 emit_block(struct ntv_context *ctx, struct nir_block *block)
1722 {
1723 start_block(ctx, block_label(ctx, block));
1724 nir_foreach_instr(instr, block) {
1725 switch (instr->type) {
1726 case nir_instr_type_alu:
1727 emit_alu(ctx, nir_instr_as_alu(instr));
1728 break;
1729 case nir_instr_type_intrinsic:
1730 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1731 break;
1732 case nir_instr_type_load_const:
1733 emit_load_const(ctx, nir_instr_as_load_const(instr));
1734 break;
1735 case nir_instr_type_ssa_undef:
1736 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1737 break;
1738 case nir_instr_type_tex:
1739 emit_tex(ctx, nir_instr_as_tex(instr));
1740 break;
1741 case nir_instr_type_phi:
1742 unreachable("nir_instr_type_phi not supported");
1743 break;
1744 case nir_instr_type_jump:
1745 emit_jump(ctx, nir_instr_as_jump(instr));
1746 break;
1747 case nir_instr_type_call:
1748 unreachable("nir_instr_type_call not supported");
1749 break;
1750 case nir_instr_type_parallel_copy:
1751 unreachable("nir_instr_type_parallel_copy not supported");
1752 break;
1753 case nir_instr_type_deref:
1754 emit_deref(ctx, nir_instr_as_deref(instr));
1755 break;
1756 }
1757 }
1758 }
1759
1760 static void
1761 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1762
1763 static SpvId
1764 get_src_bool(struct ntv_context *ctx, nir_src *src)
1765 {
1766 SpvId def = get_src(ctx, src);
1767 assert(nir_src_bit_size(*src) == 1);
1768 unsigned num_components = nir_src_num_components(*src);
1769 return uvec_to_bvec(ctx, def, num_components);
1770 }
1771
1772 static void
1773 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1774 {
1775 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1776
1777 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1778 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1779 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1780 SpvId else_id = endif_id;
1781
1782 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1783 if (has_else) {
1784 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1785 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1786 }
1787
1788 /* create a header-block */
1789 start_block(ctx, header_id);
1790 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1791 SpvSelectionControlMaskNone);
1792 branch_conditional(ctx, condition, then_id, else_id);
1793
1794 emit_cf_list(ctx, &if_stmt->then_list);
1795
1796 if (has_else) {
1797 if (ctx->block_started)
1798 branch(ctx, endif_id);
1799
1800 emit_cf_list(ctx, &if_stmt->else_list);
1801 }
1802
1803 start_block(ctx, endif_id);
1804 }
1805
1806 static void
1807 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1808 {
1809 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1810 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1811 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1812 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1813
1814 /* create a header-block */
1815 start_block(ctx, header_id);
1816 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1817 branch(ctx, begin_id);
1818
1819 SpvId save_break = ctx->loop_break;
1820 SpvId save_cont = ctx->loop_cont;
1821 ctx->loop_break = break_id;
1822 ctx->loop_cont = cont_id;
1823
1824 emit_cf_list(ctx, &loop->body);
1825
1826 ctx->loop_break = save_break;
1827 ctx->loop_cont = save_cont;
1828
1829 branch(ctx, cont_id);
1830 start_block(ctx, cont_id);
1831 branch(ctx, header_id);
1832
1833 start_block(ctx, break_id);
1834 }
1835
1836 static void
1837 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1838 {
1839 foreach_list_typed(nir_cf_node, node, node, list) {
1840 switch (node->type) {
1841 case nir_cf_node_block:
1842 emit_block(ctx, nir_cf_node_as_block(node));
1843 break;
1844
1845 case nir_cf_node_if:
1846 emit_if(ctx, nir_cf_node_as_if(node));
1847 break;
1848
1849 case nir_cf_node_loop:
1850 emit_loop(ctx, nir_cf_node_as_loop(node));
1851 break;
1852
1853 case nir_cf_node_function:
1854 unreachable("nir_cf_node_function not supported");
1855 break;
1856 }
1857 }
1858 }
1859
1860 struct spirv_shader *
1861 nir_to_spirv(struct nir_shader *s)
1862 {
1863 struct spirv_shader *ret = NULL;
1864
1865 struct ntv_context ctx = {};
1866
1867 switch (s->info.stage) {
1868 case MESA_SHADER_VERTEX:
1869 case MESA_SHADER_FRAGMENT:
1870 case MESA_SHADER_COMPUTE:
1871 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1872 break;
1873
1874 case MESA_SHADER_TESS_CTRL:
1875 case MESA_SHADER_TESS_EVAL:
1876 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1877 break;
1878
1879 case MESA_SHADER_GEOMETRY:
1880 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1881 break;
1882
1883 default:
1884 unreachable("invalid stage");
1885 }
1886
1887 // TODO: only enable when needed
1888 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1889 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1890 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1891 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
1892 }
1893
1894 ctx.stage = s->info.stage;
1895 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1896 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1897
1898 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1899 SpvMemoryModelGLSL450);
1900
1901 SpvExecutionModel exec_model;
1902 switch (s->info.stage) {
1903 case MESA_SHADER_VERTEX:
1904 exec_model = SpvExecutionModelVertex;
1905 break;
1906 case MESA_SHADER_TESS_CTRL:
1907 exec_model = SpvExecutionModelTessellationControl;
1908 break;
1909 case MESA_SHADER_TESS_EVAL:
1910 exec_model = SpvExecutionModelTessellationEvaluation;
1911 break;
1912 case MESA_SHADER_GEOMETRY:
1913 exec_model = SpvExecutionModelGeometry;
1914 break;
1915 case MESA_SHADER_FRAGMENT:
1916 exec_model = SpvExecutionModelFragment;
1917 break;
1918 case MESA_SHADER_COMPUTE:
1919 exec_model = SpvExecutionModelGLCompute;
1920 break;
1921 default:
1922 unreachable("invalid stage");
1923 }
1924
1925 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1926 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1927 NULL, 0);
1928 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1929 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1930
1931 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1932 _mesa_key_pointer_equal);
1933
1934 nir_foreach_variable(var, &s->inputs)
1935 emit_input(&ctx, var);
1936
1937 nir_foreach_variable(var, &s->outputs)
1938 emit_output(&ctx, var);
1939
1940 nir_foreach_variable(var, &s->uniforms)
1941 emit_uniform(&ctx, var);
1942
1943 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1944 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1945 SpvExecutionModeOriginUpperLeft);
1946 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1947 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1948 SpvExecutionModeDepthReplacing);
1949 }
1950
1951
1952 spirv_builder_function(&ctx.builder, entry_point, type_void,
1953 SpvFunctionControlMaskNone,
1954 type_main);
1955
1956 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1957 nir_metadata_require(entry, nir_metadata_block_index);
1958
1959 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1960 if (!ctx.defs)
1961 goto fail;
1962 ctx.num_defs = entry->ssa_alloc;
1963
1964 nir_index_local_regs(entry);
1965 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1966 if (!ctx.regs)
1967 goto fail;
1968 ctx.num_regs = entry->reg_alloc;
1969
1970 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1971 if (!block_ids)
1972 goto fail;
1973
1974 for (int i = 0; i < entry->num_blocks; ++i)
1975 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1976
1977 ctx.block_ids = block_ids;
1978 ctx.num_blocks = entry->num_blocks;
1979
1980 /* emit a block only for the variable declarations */
1981 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1982 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1983 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1984 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1985 SpvStorageClassFunction,
1986 type);
1987 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1988 SpvStorageClassFunction);
1989
1990 ctx.regs[reg->index] = var;
1991 }
1992
1993 emit_cf_list(&ctx, &entry->body);
1994
1995 free(ctx.defs);
1996
1997 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1998 spirv_builder_function_end(&ctx.builder);
1999
2000 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2001 "main", ctx.entry_ifaces,
2002 ctx.num_entry_ifaces);
2003
2004 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2005
2006 ret = CALLOC_STRUCT(spirv_shader);
2007 if (!ret)
2008 goto fail;
2009
2010 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2011 if (!ret->words)
2012 goto fail;
2013
2014 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2015 assert(ret->num_words == num_words);
2016
2017 return ret;
2018
2019 fail:
2020
2021 if (ret)
2022 spirv_shader_delete(ret);
2023
2024 if (ctx.vars)
2025 _mesa_hash_table_destroy(ctx.vars, NULL);
2026
2027 return NULL;
2028 }
2029
2030 void
2031 spirv_shader_delete(struct spirv_shader *s)
2032 {
2033 FREE(s->words);
2034 FREE(s);
2035 }