zink: do not convert bools to/from uint
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 unsigned samplers_used : PIPE_MAX_SAMPLERS;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
172 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
335 break;
336
337 case FRAG_RESULT_DEPTH:
338 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
339 break;
340
341 default:
342 spirv_builder_emit_location(&ctx->builder, var_id,
343 var->data.driver_location);
344 }
345 }
346 }
347
348 if (var->data.location_frac)
349 spirv_builder_emit_component(&ctx->builder, var_id,
350 var->data.location_frac);
351
352 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
353
354 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
355 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
356 }
357
358 static SpvDim
359 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
360 {
361 *is_ms = false;
362 switch (gdim) {
363 case GLSL_SAMPLER_DIM_1D:
364 return SpvDim1D;
365 case GLSL_SAMPLER_DIM_2D:
366 return SpvDim2D;
367 case GLSL_SAMPLER_DIM_3D:
368 return SpvDim3D;
369 case GLSL_SAMPLER_DIM_CUBE:
370 return SpvDimCube;
371 case GLSL_SAMPLER_DIM_RECT:
372 return SpvDim2D;
373 case GLSL_SAMPLER_DIM_BUF:
374 return SpvDimBuffer;
375 case GLSL_SAMPLER_DIM_EXTERNAL:
376 return SpvDim2D; /* seems dodgy... */
377 case GLSL_SAMPLER_DIM_MS:
378 *is_ms = true;
379 return SpvDim2D;
380 default:
381 fprintf(stderr, "unknown sampler type %d\n", gdim);
382 break;
383 }
384 return SpvDim2D;
385 }
386
387 uint32_t
388 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
389 {
390 if (stage == MESA_SHADER_NONE ||
391 stage >= MESA_SHADER_COMPUTE) {
392 unreachable("not supported");
393 } else {
394 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
395 PIPE_MAX_SHADER_SAMPLER_VIEWS);
396
397 switch (type) {
398 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
399 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
400 return stage_offset + index;
401
402 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
403 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
404 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
405
406 default:
407 unreachable("unexpected type");
408 }
409 }
410 }
411
412 static void
413 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
414 {
415 const struct glsl_type *type = glsl_without_array(var->type);
416
417 bool is_ms;
418 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
419
420 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
421 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
422 dimension, false,
423 glsl_sampler_type_is_array(type),
424 is_ms, 1,
425 SpvImageFormatUnknown);
426
427 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
428 image_type);
429 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
430 SpvStorageClassUniformConstant,
431 sampled_type);
432
433 if (glsl_type_is_array(var->type)) {
434 for (int i = 0; i < glsl_get_length(var->type); ++i) {
435 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
436 SpvStorageClassUniformConstant);
437
438 if (var->name) {
439 char element_name[100];
440 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
441 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
442 }
443
444 int index = var->data.binding + i;
445 assert(!(ctx->samplers_used & (1 << index)));
446 assert(!ctx->image_types[index]);
447 ctx->image_types[index] = image_type;
448 ctx->samplers[index] = var_id;
449 ctx->samplers_used |= 1 << index;
450
451 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
452 var->data.descriptor_set);
453 int binding = zink_binding(ctx->stage,
454 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
455 var->data.binding + i);
456 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
457 }
458 } else {
459 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
460 SpvStorageClassUniformConstant);
461
462 if (var->name)
463 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
464
465 int index = var->data.binding;
466 assert(!(ctx->samplers_used & (1 << index)));
467 assert(!ctx->image_types[index]);
468 ctx->image_types[index] = image_type;
469 ctx->samplers[index] = var_id;
470 ctx->samplers_used |= 1 << index;
471
472 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
473 var->data.descriptor_set);
474 int binding = zink_binding(ctx->stage,
475 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
476 var->data.binding);
477 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
478 }
479 }
480
481 static void
482 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
483 {
484 uint32_t size = glsl_count_attribute_slots(var->type, false);
485 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
486 SpvId array_length = emit_uint_const(ctx, 32, size);
487 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
488 array_length);
489 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
490
491 // wrap UBO-array in a struct
492 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
493 if (var->name) {
494 char struct_name[100];
495 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
496 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
497 }
498
499 spirv_builder_emit_decoration(&ctx->builder, struct_type,
500 SpvDecorationBlock);
501 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
502
503
504 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
505 SpvStorageClassUniform,
506 struct_type);
507
508 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
509 SpvStorageClassUniform);
510 if (var->name)
511 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
512
513 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
514 ctx->ubos[ctx->num_ubos++] = var_id;
515
516 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
517 var->data.descriptor_set);
518 int binding = zink_binding(ctx->stage,
519 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
520 var->data.binding);
521 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
522 }
523
524 static void
525 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
526 {
527 if (var->data.mode == nir_var_mem_ubo)
528 emit_ubo(ctx, var);
529 else {
530 assert(var->data.mode == nir_var_uniform);
531 if (glsl_type_is_sampler(glsl_without_array(var->type)))
532 emit_sampler(ctx, var);
533 }
534 }
535
536 static SpvId
537 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
538 {
539 assert(ssa->index < ctx->num_defs);
540 assert(ctx->defs[ssa->index] != 0);
541 return ctx->defs[ssa->index];
542 }
543
544 static SpvId
545 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
546 {
547 assert(reg->index < ctx->num_regs);
548 assert(ctx->regs[reg->index] != 0);
549 return ctx->regs[reg->index];
550 }
551
552 static SpvId
553 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
554 {
555 assert(reg->reg);
556 assert(!reg->indirect);
557 assert(!reg->base_offset);
558
559 SpvId var = get_var_from_reg(ctx, reg->reg);
560 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
561 return spirv_builder_emit_load(&ctx->builder, type, var);
562 }
563
564 static SpvId
565 get_src(struct ntv_context *ctx, nir_src *src)
566 {
567 if (src->is_ssa)
568 return get_src_ssa(ctx, src->ssa);
569 else
570 return get_src_reg(ctx, &src->reg);
571 }
572
573 static SpvId
574 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
575 {
576 assert(!alu->src[src].negate);
577 assert(!alu->src[src].abs);
578
579 SpvId def = get_src(ctx, &alu->src[src].src);
580
581 unsigned used_channels = 0;
582 bool need_swizzle = false;
583 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
584 if (!nir_alu_instr_channel_used(alu, src, i))
585 continue;
586
587 used_channels++;
588
589 if (alu->src[src].swizzle[i] != i)
590 need_swizzle = true;
591 }
592 assert(used_channels != 0);
593
594 unsigned live_channels = nir_src_num_components(alu->src[src].src);
595 if (used_channels != live_channels)
596 need_swizzle = true;
597
598 if (!need_swizzle)
599 return def;
600
601 int bit_size = nir_src_bit_size(alu->src[src].src);
602 assert(bit_size == 1 || bit_size == 32);
603
604 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
605 spirv_builder_type_uint(&ctx->builder, bit_size);
606
607 if (used_channels == 1) {
608 uint32_t indices[] = { alu->src[src].swizzle[0] };
609 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
610 def, indices,
611 ARRAY_SIZE(indices));
612 } else if (live_channels == 1) {
613 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
614 raw_type,
615 used_channels);
616
617 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
618 for (unsigned i = 0; i < used_channels; ++i)
619 constituents[i] = def;
620
621 return spirv_builder_emit_composite_construct(&ctx->builder,
622 raw_vec_type,
623 constituents,
624 used_channels);
625 } else {
626 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
627 raw_type,
628 used_channels);
629
630 uint32_t components[NIR_MAX_VEC_COMPONENTS];
631 size_t num_components = 0;
632 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
633 if (!nir_alu_instr_channel_used(alu, src, i))
634 continue;
635
636 components[num_components++] = alu->src[src].swizzle[i];
637 }
638
639 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
640 def, def, components,
641 num_components);
642 }
643 }
644
645 static void
646 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
647 {
648 assert(result != 0);
649 assert(ssa->index < ctx->num_defs);
650 ctx->defs[ssa->index] = result;
651 }
652
653 static SpvId
654 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
655 SpvId if_true, SpvId if_false)
656 {
657 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
658 }
659
660 static SpvId
661 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
662 {
663 SpvId type = get_bvec_type(ctx, num_components);
664 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
665 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
666 }
667
668 static SpvId
669 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
670 {
671 return emit_unop(ctx, SpvOpBitcast, type, value);
672 }
673
674 static SpvId
675 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
676 unsigned num_components)
677 {
678 SpvId type = get_uvec_type(ctx, bit_size, num_components);
679 return emit_bitcast(ctx, type, value);
680 }
681
682 static SpvId
683 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
684 unsigned num_components)
685 {
686 SpvId type = get_ivec_type(ctx, bit_size, num_components);
687 return emit_bitcast(ctx, type, value);
688 }
689
690 static SpvId
691 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
692 unsigned num_components)
693 {
694 SpvId type = get_fvec_type(ctx, bit_size, num_components);
695 return emit_bitcast(ctx, type, value);
696 }
697
698 static void
699 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
700 {
701 SpvId var = get_var_from_reg(ctx, reg->reg);
702 assert(var);
703 spirv_builder_emit_store(&ctx->builder, var, result);
704 }
705
706 static void
707 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
708 {
709 if (dest->is_ssa)
710 store_ssa_def(ctx, &dest->ssa, result);
711 else
712 store_reg_def(ctx, &dest->reg, result);
713 }
714
715 static void
716 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
717 {
718 unsigned num_components = nir_dest_num_components(*dest);
719 unsigned bit_size = nir_dest_bit_size(*dest);
720
721 if (bit_size != 1) {
722 switch (nir_alu_type_get_base_type(type)) {
723 case nir_type_bool:
724 assert("bool should have bit-size 1");
725
726 case nir_type_uint:
727 break; /* nothing to do! */
728
729 case nir_type_int:
730 case nir_type_float:
731 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
732 break;
733
734 default:
735 unreachable("unsupported nir_alu_type");
736 }
737 }
738
739 store_dest_raw(ctx, dest, result);
740 }
741
742 static SpvId
743 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
744 {
745 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
746 }
747
748 static SpvId
749 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
750 SpvId src0, SpvId src1)
751 {
752 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
753 }
754
755 static SpvId
756 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
757 SpvId src0, SpvId src1, SpvId src2)
758 {
759 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
760 }
761
762 static SpvId
763 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
764 SpvId src)
765 {
766 SpvId args[] = { src };
767 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
768 op, args, ARRAY_SIZE(args));
769 }
770
771 static SpvId
772 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
773 SpvId src0, SpvId src1)
774 {
775 SpvId args[] = { src0, src1 };
776 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
777 op, args, ARRAY_SIZE(args));
778 }
779
780 static SpvId
781 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
782 SpvId src0, SpvId src1, SpvId src2)
783 {
784 SpvId args[] = { src0, src1, src2 };
785 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
786 op, args, ARRAY_SIZE(args));
787 }
788
789 static SpvId
790 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
791 unsigned num_components, float value)
792 {
793 assert(bit_size == 32);
794
795 SpvId result = emit_float_const(ctx, bit_size, value);
796 if (num_components == 1)
797 return result;
798
799 assert(num_components > 1);
800 SpvId components[num_components];
801 for (int i = 0; i < num_components; i++)
802 components[i] = result;
803
804 SpvId type = get_fvec_type(ctx, bit_size, num_components);
805 return spirv_builder_const_composite(&ctx->builder, type, components,
806 num_components);
807 }
808
809 static SpvId
810 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
811 unsigned num_components, uint32_t value)
812 {
813 assert(bit_size == 32);
814
815 SpvId result = emit_uint_const(ctx, bit_size, value);
816 if (num_components == 1)
817 return result;
818
819 assert(num_components > 1);
820 SpvId components[num_components];
821 for (int i = 0; i < num_components; i++)
822 components[i] = result;
823
824 SpvId type = get_uvec_type(ctx, bit_size, num_components);
825 return spirv_builder_const_composite(&ctx->builder, type, components,
826 num_components);
827 }
828
829 static SpvId
830 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
831 unsigned num_components, int32_t value)
832 {
833 assert(bit_size == 32);
834
835 SpvId result = emit_int_const(ctx, bit_size, value);
836 if (num_components == 1)
837 return result;
838
839 assert(num_components > 1);
840 SpvId components[num_components];
841 for (int i = 0; i < num_components; i++)
842 components[i] = result;
843
844 SpvId type = get_ivec_type(ctx, bit_size, num_components);
845 return spirv_builder_const_composite(&ctx->builder, type, components,
846 num_components);
847 }
848
849 static inline unsigned
850 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
851 {
852 if (nir_op_infos[instr->op].input_sizes[src] > 0)
853 return nir_op_infos[instr->op].input_sizes[src];
854
855 if (instr->dest.dest.is_ssa)
856 return instr->dest.dest.ssa.num_components;
857 else
858 return instr->dest.dest.reg.reg->num_components;
859 }
860
861 static SpvId
862 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
863 {
864 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
865
866 unsigned num_components = alu_instr_src_components(alu, src);
867 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
868 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
869
870 if (bit_size == 1)
871 return raw_value;
872 else {
873 switch (nir_alu_type_get_base_type(type)) {
874 case nir_type_bool:
875 unreachable("bool should have bit-size 1");
876
877 case nir_type_int:
878 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
879
880 case nir_type_uint:
881 return raw_value;
882
883 case nir_type_float:
884 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
885
886 default:
887 unreachable("unknown nir_alu_type");
888 }
889 }
890 }
891
892 static void
893 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
894 {
895 assert(!alu->dest.saturate);
896 return store_dest(ctx, &alu->dest.dest, result,
897 nir_op_infos[alu->op].output_type);
898 }
899
900 static SpvId
901 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
902 {
903 unsigned num_components = nir_dest_num_components(*dest);
904 unsigned bit_size = nir_dest_bit_size(*dest);
905
906 if (bit_size == 1)
907 return get_bvec_type(ctx, num_components);
908
909 switch (nir_alu_type_get_base_type(type)) {
910 case nir_type_bool:
911 unreachable("bool should have bit-size 1");
912
913 case nir_type_int:
914 return get_ivec_type(ctx, bit_size, num_components);
915
916 case nir_type_uint:
917 return get_uvec_type(ctx, bit_size, num_components);
918
919 case nir_type_float:
920 return get_fvec_type(ctx, bit_size, num_components);
921
922 default:
923 unreachable("unsupported nir_alu_type");
924 }
925 }
926
927 static void
928 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
929 {
930 SpvId src[nir_op_infos[alu->op].num_inputs];
931 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
932 src[i] = get_alu_src(ctx, alu, i);
933
934 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
935 nir_op_infos[alu->op].output_type);
936 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
937 unsigned num_components = nir_dest_num_components(alu->dest.dest);
938
939 SpvId result = 0;
940 switch (alu->op) {
941 case nir_op_mov:
942 assert(nir_op_infos[alu->op].num_inputs == 1);
943 result = src[0];
944 break;
945
946 #define UNOP(nir_op, spirv_op) \
947 case nir_op: \
948 assert(nir_op_infos[alu->op].num_inputs == 1); \
949 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
950 break;
951
952 UNOP(nir_op_ineg, SpvOpSNegate)
953 UNOP(nir_op_fneg, SpvOpFNegate)
954 UNOP(nir_op_fddx, SpvOpDPdx)
955 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
956 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
957 UNOP(nir_op_fddy, SpvOpDPdy)
958 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
959 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
960 UNOP(nir_op_f2i32, SpvOpConvertFToS)
961 UNOP(nir_op_f2u32, SpvOpConvertFToU)
962 UNOP(nir_op_i2f32, SpvOpConvertSToF)
963 UNOP(nir_op_u2f32, SpvOpConvertUToF)
964 UNOP(nir_op_inot, SpvOpNot)
965 #undef UNOP
966
967 case nir_op_b2i32:
968 assert(nir_op_infos[alu->op].num_inputs == 1);
969 result = emit_select(ctx, dest_type, src[0],
970 get_ivec_constant(ctx, 32, num_components, 1),
971 get_ivec_constant(ctx, 32, num_components, 0));
972 break;
973
974 case nir_op_b2f32:
975 assert(nir_op_infos[alu->op].num_inputs == 1);
976 result = emit_select(ctx, dest_type, src[0],
977 get_fvec_constant(ctx, 32, num_components, 1),
978 get_fvec_constant(ctx, 32, num_components, 0));
979 break;
980
981 #define BUILTIN_UNOP(nir_op, spirv_op) \
982 case nir_op: \
983 assert(nir_op_infos[alu->op].num_inputs == 1); \
984 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
985 break;
986
987 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
988 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
989 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
990 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
991 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
992 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
993 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
994 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
995 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
996 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
997 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
998 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
999 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1000 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1001 #undef BUILTIN_UNOP
1002
1003 case nir_op_frcp:
1004 assert(nir_op_infos[alu->op].num_inputs == 1);
1005 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1006 get_fvec_constant(ctx, bit_size, num_components, 1),
1007 src[0]);
1008 break;
1009
1010 case nir_op_f2b1:
1011 assert(nir_op_infos[alu->op].num_inputs == 1);
1012 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1013 get_fvec_constant(ctx,
1014 nir_src_bit_size(alu->src[0].src),
1015 num_components, 0));
1016 break;
1017
1018
1019 #define BINOP(nir_op, spirv_op) \
1020 case nir_op: \
1021 assert(nir_op_infos[alu->op].num_inputs == 2); \
1022 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1023 break;
1024
1025 BINOP(nir_op_iadd, SpvOpIAdd)
1026 BINOP(nir_op_isub, SpvOpISub)
1027 BINOP(nir_op_imul, SpvOpIMul)
1028 BINOP(nir_op_idiv, SpvOpSDiv)
1029 BINOP(nir_op_udiv, SpvOpUDiv)
1030 BINOP(nir_op_umod, SpvOpUMod)
1031 BINOP(nir_op_fadd, SpvOpFAdd)
1032 BINOP(nir_op_fsub, SpvOpFSub)
1033 BINOP(nir_op_fmul, SpvOpFMul)
1034 BINOP(nir_op_fdiv, SpvOpFDiv)
1035 BINOP(nir_op_fmod, SpvOpFMod)
1036 BINOP(nir_op_ilt, SpvOpSLessThan)
1037 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1038 BINOP(nir_op_ieq, SpvOpIEqual)
1039 BINOP(nir_op_ine, SpvOpINotEqual)
1040 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1041 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1042 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1043 BINOP(nir_op_feq, SpvOpFOrdEqual)
1044 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1045 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1046 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1047 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1048 BINOP(nir_op_iand, SpvOpBitwiseAnd)
1049 BINOP(nir_op_ior, SpvOpBitwiseOr)
1050 #undef BINOP
1051
1052 #define BUILTIN_BINOP(nir_op, spirv_op) \
1053 case nir_op: \
1054 assert(nir_op_infos[alu->op].num_inputs == 2); \
1055 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1056 break;
1057
1058 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1059 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1060 #undef BUILTIN_BINOP
1061
1062 case nir_op_fdot2:
1063 case nir_op_fdot3:
1064 case nir_op_fdot4:
1065 assert(nir_op_infos[alu->op].num_inputs == 2);
1066 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1067 break;
1068
1069 case nir_op_fdph:
1070 unreachable("should already be lowered away");
1071
1072 case nir_op_seq:
1073 case nir_op_sne:
1074 case nir_op_slt:
1075 case nir_op_sge: {
1076 assert(nir_op_infos[alu->op].num_inputs == 2);
1077 int num_components = nir_dest_num_components(alu->dest.dest);
1078 SpvId bool_type = get_bvec_type(ctx, num_components);
1079
1080 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1081 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1082 if (num_components > 1) {
1083 SpvId zero_comps[num_components], one_comps[num_components];
1084 for (int i = 0; i < num_components; i++) {
1085 zero_comps[i] = zero;
1086 one_comps[i] = one;
1087 }
1088
1089 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1090 zero_comps, num_components);
1091 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1092 one_comps, num_components);
1093 }
1094
1095 SpvOp op;
1096 switch (alu->op) {
1097 case nir_op_seq: op = SpvOpFOrdEqual; break;
1098 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1099 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1100 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1101 default: unreachable("unexpected op");
1102 }
1103
1104 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1105 result = emit_select(ctx, dest_type, result, one, zero);
1106 }
1107 break;
1108
1109 case nir_op_flrp:
1110 assert(nir_op_infos[alu->op].num_inputs == 3);
1111 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1112 src[0], src[1], src[2]);
1113 break;
1114
1115 case nir_op_fcsel:
1116 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1117 get_bvec_type(ctx, num_components),
1118 src[0],
1119 get_fvec_constant(ctx,
1120 nir_src_bit_size(alu->src[0].src),
1121 num_components, 0));
1122 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1123 break;
1124
1125 case nir_op_bcsel:
1126 assert(nir_op_infos[alu->op].num_inputs == 3);
1127 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1128 break;
1129
1130 case nir_op_bany_fnequal2:
1131 case nir_op_bany_fnequal3:
1132 case nir_op_bany_fnequal4:
1133 assert(nir_op_infos[alu->op].num_inputs == 2);
1134 assert(alu_instr_src_components(alu, 0) ==
1135 alu_instr_src_components(alu, 1));
1136 result = emit_binop(ctx, SpvOpFOrdNotEqual,
1137 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1138 src[0], src[1]);
1139 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1140 break;
1141
1142 case nir_op_ball_fequal2:
1143 case nir_op_ball_fequal3:
1144 case nir_op_ball_fequal4:
1145 assert(nir_op_infos[alu->op].num_inputs == 2);
1146 assert(alu_instr_src_components(alu, 0) ==
1147 alu_instr_src_components(alu, 1));
1148 result = emit_binop(ctx, SpvOpFOrdEqual,
1149 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1150 src[0], src[1]);
1151 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1152 break;
1153
1154 case nir_op_bany_inequal2:
1155 case nir_op_bany_inequal3:
1156 case nir_op_bany_inequal4:
1157 assert(nir_op_infos[alu->op].num_inputs == 2);
1158 assert(alu_instr_src_components(alu, 0) ==
1159 alu_instr_src_components(alu, 1));
1160 result = emit_binop(ctx, SpvOpINotEqual,
1161 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1162 src[0], src[1]);
1163 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1164 break;
1165
1166 case nir_op_ball_iequal2:
1167 case nir_op_ball_iequal3:
1168 case nir_op_ball_iequal4:
1169 assert(nir_op_infos[alu->op].num_inputs == 2);
1170 assert(alu_instr_src_components(alu, 0) ==
1171 alu_instr_src_components(alu, 1));
1172 result = emit_binop(ctx, SpvOpIEqual,
1173 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1174 src[0], src[1]);
1175 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1176 break;
1177
1178 case nir_op_vec2:
1179 case nir_op_vec3:
1180 case nir_op_vec4: {
1181 int num_inputs = nir_op_infos[alu->op].num_inputs;
1182 assert(2 <= num_inputs && num_inputs <= 4);
1183 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1184 src, num_inputs);
1185 }
1186 break;
1187
1188 default:
1189 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1190 nir_op_infos[alu->op].name);
1191
1192 unreachable("unsupported opcode");
1193 return;
1194 }
1195
1196 store_alu_result(ctx, alu, result);
1197 }
1198
1199 static void
1200 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1201 {
1202 unsigned bit_size = load_const->def.bit_size;
1203 unsigned num_components = load_const->def.num_components;
1204
1205 SpvId constant;
1206 if (num_components > 1) {
1207 SpvId components[num_components];
1208 SpvId type;
1209 if (bit_size == 1) {
1210 for (int i = 0; i < num_components; i++)
1211 components[i] = spirv_builder_const_bool(&ctx->builder,
1212 load_const->value[i].b);
1213
1214 type = get_bvec_type(ctx, num_components);
1215 } else {
1216 for (int i = 0; i < num_components; i++)
1217 components[i] = emit_uint_const(ctx, bit_size,
1218 load_const->value[i].u32);
1219
1220 type = get_uvec_type(ctx, bit_size, num_components);
1221 }
1222 constant = spirv_builder_const_composite(&ctx->builder, type,
1223 components, num_components);
1224 } else {
1225 assert(num_components == 1);
1226 if (bit_size == 1)
1227 constant = spirv_builder_const_bool(&ctx->builder,
1228 load_const->value[0].b);
1229 else
1230 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1231 }
1232
1233 store_ssa_def(ctx, &load_const->def, constant);
1234 }
1235
1236 static void
1237 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1238 {
1239 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1240 assert(const_block_index); // no dynamic indexing for now
1241 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1242
1243 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1244 if (const_offset) {
1245 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1246 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1247 SpvStorageClassUniform,
1248 uvec4_type);
1249
1250 unsigned idx = const_offset->u32;
1251 SpvId member = emit_uint_const(ctx, 32, 0);
1252 SpvId offset = emit_uint_const(ctx, 32, idx);
1253 SpvId offsets[] = { member, offset };
1254 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1255 ctx->ubos[0], offsets,
1256 ARRAY_SIZE(offsets));
1257 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1258
1259 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1260 unsigned num_components = nir_dest_num_components(intr->dest);
1261 if (num_components == 1) {
1262 uint32_t components[] = { 0 };
1263 result = spirv_builder_emit_composite_extract(&ctx->builder,
1264 type,
1265 result, components,
1266 1);
1267 } else if (num_components < 4) {
1268 SpvId constituents[num_components];
1269 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1270 for (uint32_t i = 0; i < num_components; ++i)
1271 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1272 uint_type,
1273 result, &i,
1274 1);
1275
1276 result = spirv_builder_emit_composite_construct(&ctx->builder,
1277 type,
1278 constituents,
1279 num_components);
1280 }
1281
1282 if (nir_dest_bit_size(intr->dest) == 1)
1283 result = uvec_to_bvec(ctx, result, num_components);
1284
1285 store_dest(ctx, &intr->dest, result, nir_type_uint);
1286 } else
1287 unreachable("uniform-addressing not yet supported");
1288 }
1289
1290 static void
1291 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1292 {
1293 assert(ctx->block_started);
1294 spirv_builder_emit_kill(&ctx->builder);
1295 /* discard is weird in NIR, so let's just create an unreachable block after
1296 it and hope that the vulkan driver will DCE any instructinos in it. */
1297 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1298 }
1299
1300 static void
1301 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1302 {
1303 SpvId ptr = get_src(ctx, intr->src);
1304
1305 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1306 SpvId result = spirv_builder_emit_load(&ctx->builder,
1307 get_glsl_type(ctx, var->type),
1308 ptr);
1309 unsigned num_components = nir_dest_num_components(intr->dest);
1310 unsigned bit_size = nir_dest_bit_size(intr->dest);
1311 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1312 store_dest(ctx, &intr->dest, result, nir_type_uint);
1313 }
1314
1315 static void
1316 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1317 {
1318 SpvId ptr = get_src(ctx, &intr->src[0]);
1319 SpvId src = get_src(ctx, &intr->src[1]);
1320
1321 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1322 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1323 SpvId result = emit_bitcast(ctx, type, src);
1324 spirv_builder_emit_store(&ctx->builder, ptr, result);
1325 }
1326
1327 static SpvId
1328 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1329 SpvStorageClass storage_class,
1330 const char *name, SpvBuiltIn builtin)
1331 {
1332 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1333 storage_class,
1334 var_type);
1335 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1336 storage_class);
1337 spirv_builder_emit_name(&ctx->builder, var, name);
1338 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1339
1340 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1341 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1342 return var;
1343 }
1344
1345 static void
1346 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1347 {
1348 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1349 if (!ctx->front_face_var)
1350 ctx->front_face_var = create_builtin_var(ctx, var_type,
1351 SpvStorageClassInput,
1352 "gl_FrontFacing",
1353 SpvBuiltInFrontFacing);
1354
1355 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1356 ctx->front_face_var);
1357 assert(1 == nir_dest_num_components(intr->dest));
1358 store_dest(ctx, &intr->dest, result, nir_type_bool);
1359 }
1360
1361 static void
1362 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1363 {
1364 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1365 if (!ctx->instance_id_var)
1366 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1367 SpvStorageClassInput,
1368 "gl_InstanceId",
1369 SpvBuiltInInstanceIndex);
1370
1371 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1372 ctx->instance_id_var);
1373 assert(1 == nir_dest_num_components(intr->dest));
1374 store_dest(ctx, &intr->dest, result, nir_type_uint);
1375 }
1376
1377 static void
1378 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1379 {
1380 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1381 if (!ctx->vertex_id_var)
1382 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1383 SpvStorageClassInput,
1384 "gl_VertexID",
1385 SpvBuiltInVertexIndex);
1386
1387 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1388 ctx->vertex_id_var);
1389 assert(1 == nir_dest_num_components(intr->dest));
1390 store_dest(ctx, &intr->dest, result, nir_type_uint);
1391 }
1392
1393 static void
1394 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1395 {
1396 switch (intr->intrinsic) {
1397 case nir_intrinsic_load_ubo:
1398 emit_load_ubo(ctx, intr);
1399 break;
1400
1401 case nir_intrinsic_discard:
1402 emit_discard(ctx, intr);
1403 break;
1404
1405 case nir_intrinsic_load_deref:
1406 emit_load_deref(ctx, intr);
1407 break;
1408
1409 case nir_intrinsic_store_deref:
1410 emit_store_deref(ctx, intr);
1411 break;
1412
1413 case nir_intrinsic_load_front_face:
1414 emit_load_front_face(ctx, intr);
1415 break;
1416
1417 case nir_intrinsic_load_instance_id:
1418 emit_load_instance_id(ctx, intr);
1419 break;
1420
1421 case nir_intrinsic_load_vertex_id:
1422 emit_load_vertex_id(ctx, intr);
1423 break;
1424
1425 default:
1426 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1427 nir_intrinsic_infos[intr->intrinsic].name);
1428 unreachable("unsupported intrinsic");
1429 }
1430 }
1431
1432 static void
1433 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1434 {
1435 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1436 undef->def.num_components);
1437
1438 store_ssa_def(ctx, &undef->def,
1439 spirv_builder_emit_undef(&ctx->builder, type));
1440 }
1441
1442 static SpvId
1443 get_src_float(struct ntv_context *ctx, nir_src *src)
1444 {
1445 SpvId def = get_src(ctx, src);
1446 unsigned num_components = nir_src_num_components(*src);
1447 unsigned bit_size = nir_src_bit_size(*src);
1448 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1449 }
1450
1451 static SpvId
1452 get_src_int(struct ntv_context *ctx, nir_src *src)
1453 {
1454 SpvId def = get_src(ctx, src);
1455 unsigned num_components = nir_src_num_components(*src);
1456 unsigned bit_size = nir_src_bit_size(*src);
1457 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1458 }
1459
1460 static void
1461 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1462 {
1463 assert(tex->op == nir_texop_tex ||
1464 tex->op == nir_texop_txb ||
1465 tex->op == nir_texop_txl ||
1466 tex->op == nir_texop_txd ||
1467 tex->op == nir_texop_txf ||
1468 tex->op == nir_texop_txs);
1469 assert(tex->texture_index == tex->sampler_index);
1470
1471 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1472 offset = 0;
1473 unsigned coord_components = 0;
1474 for (unsigned i = 0; i < tex->num_srcs; i++) {
1475 switch (tex->src[i].src_type) {
1476 case nir_tex_src_coord:
1477 if (tex->op == nir_texop_txf)
1478 coord = get_src_int(ctx, &tex->src[i].src);
1479 else
1480 coord = get_src_float(ctx, &tex->src[i].src);
1481 coord_components = nir_src_num_components(tex->src[i].src);
1482 break;
1483
1484 case nir_tex_src_projector:
1485 assert(nir_src_num_components(tex->src[i].src) == 1);
1486 proj = get_src_float(ctx, &tex->src[i].src);
1487 assert(proj != 0);
1488 break;
1489
1490 case nir_tex_src_offset:
1491 offset = get_src_int(ctx, &tex->src[i].src);
1492 break;
1493
1494 case nir_tex_src_bias:
1495 assert(tex->op == nir_texop_txb);
1496 bias = get_src_float(ctx, &tex->src[i].src);
1497 assert(bias != 0);
1498 break;
1499
1500 case nir_tex_src_lod:
1501 assert(nir_src_num_components(tex->src[i].src) == 1);
1502 if (tex->op == nir_texop_txf ||
1503 tex->op == nir_texop_txs)
1504 lod = get_src_int(ctx, &tex->src[i].src);
1505 else
1506 lod = get_src_float(ctx, &tex->src[i].src);
1507 assert(lod != 0);
1508 break;
1509
1510 case nir_tex_src_comparator:
1511 assert(nir_src_num_components(tex->src[i].src) == 1);
1512 dref = get_src_float(ctx, &tex->src[i].src);
1513 assert(dref != 0);
1514 break;
1515
1516 case nir_tex_src_ddx:
1517 dx = get_src_float(ctx, &tex->src[i].src);
1518 assert(dx != 0);
1519 break;
1520
1521 case nir_tex_src_ddy:
1522 dy = get_src_float(ctx, &tex->src[i].src);
1523 assert(dy != 0);
1524 break;
1525
1526 default:
1527 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1528 unreachable("unknown texture source");
1529 }
1530 }
1531
1532 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1533 lod = emit_float_const(ctx, 32, 0.0f);
1534 assert(lod != 0);
1535 }
1536
1537 SpvId image_type = ctx->image_types[tex->texture_index];
1538 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1539 image_type);
1540
1541 assert(ctx->samplers_used & (1u << tex->texture_index));
1542 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1543 ctx->samplers[tex->texture_index]);
1544
1545 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1546
1547 if (tex->op == nir_texop_txs) {
1548 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1549 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1550 dest_type, image,
1551 lod);
1552 store_dest(ctx, &tex->dest, result, tex->dest_type);
1553 return;
1554 }
1555
1556 if (proj && coord_components > 0) {
1557 SpvId constituents[coord_components + 1];
1558 if (coord_components == 1)
1559 constituents[0] = coord;
1560 else {
1561 assert(coord_components > 1);
1562 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1563 for (uint32_t i = 0; i < coord_components; ++i)
1564 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1565 float_type,
1566 coord,
1567 &i, 1);
1568 }
1569
1570 constituents[coord_components++] = proj;
1571
1572 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1573 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1574 vec_type,
1575 constituents,
1576 coord_components);
1577 }
1578
1579 SpvId actual_dest_type = dest_type;
1580 if (dref)
1581 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1582
1583 SpvId result;
1584 if (tex->op == nir_texop_txf) {
1585 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1586 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1587 image, coord, lod);
1588 } else {
1589 result = spirv_builder_emit_image_sample(&ctx->builder,
1590 actual_dest_type, load,
1591 coord,
1592 proj != 0,
1593 lod, bias, dref, dx, dy,
1594 offset);
1595 }
1596
1597 spirv_builder_emit_decoration(&ctx->builder, result,
1598 SpvDecorationRelaxedPrecision);
1599
1600 if (dref && nir_dest_num_components(tex->dest) > 1) {
1601 SpvId components[4] = { result, result, result, result };
1602 result = spirv_builder_emit_composite_construct(&ctx->builder,
1603 dest_type,
1604 components,
1605 4);
1606 }
1607
1608 store_dest(ctx, &tex->dest, result, tex->dest_type);
1609 }
1610
1611 static void
1612 start_block(struct ntv_context *ctx, SpvId label)
1613 {
1614 /* terminate previous block if needed */
1615 if (ctx->block_started)
1616 spirv_builder_emit_branch(&ctx->builder, label);
1617
1618 /* start new block */
1619 spirv_builder_label(&ctx->builder, label);
1620 ctx->block_started = true;
1621 }
1622
1623 static void
1624 branch(struct ntv_context *ctx, SpvId label)
1625 {
1626 assert(ctx->block_started);
1627 spirv_builder_emit_branch(&ctx->builder, label);
1628 ctx->block_started = false;
1629 }
1630
1631 static void
1632 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1633 SpvId else_id)
1634 {
1635 assert(ctx->block_started);
1636 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1637 then_id, else_id);
1638 ctx->block_started = false;
1639 }
1640
1641 static void
1642 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1643 {
1644 switch (jump->type) {
1645 case nir_jump_break:
1646 assert(ctx->loop_break);
1647 branch(ctx, ctx->loop_break);
1648 break;
1649
1650 case nir_jump_continue:
1651 assert(ctx->loop_cont);
1652 branch(ctx, ctx->loop_cont);
1653 break;
1654
1655 default:
1656 unreachable("Unsupported jump type\n");
1657 }
1658 }
1659
1660 static void
1661 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1662 {
1663 assert(deref->deref_type == nir_deref_type_var);
1664
1665 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1666 assert(he);
1667 SpvId result = (SpvId)(intptr_t)he->data;
1668 /* uint is a bit of a lie here, it's really just an opaque type */
1669 store_dest(ctx, &deref->dest, result, nir_type_uint);
1670 }
1671
1672 static void
1673 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1674 {
1675 assert(deref->deref_type == nir_deref_type_array);
1676 nir_variable *var = nir_deref_instr_get_variable(deref);
1677
1678 SpvStorageClass storage_class;
1679 switch (var->data.mode) {
1680 case nir_var_shader_in:
1681 storage_class = SpvStorageClassInput;
1682 break;
1683
1684 case nir_var_shader_out:
1685 storage_class = SpvStorageClassOutput;
1686 break;
1687
1688 default:
1689 unreachable("Unsupported nir_variable_mode\n");
1690 }
1691
1692 SpvId index = get_src(ctx, &deref->arr.index);
1693
1694 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1695 storage_class,
1696 get_glsl_type(ctx, deref->type));
1697
1698 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1699 ptr_type,
1700 get_src(ctx, &deref->parent),
1701 &index, 1);
1702 /* uint is a bit of a lie here, it's really just an opaque type */
1703 store_dest(ctx, &deref->dest, result, nir_type_uint);
1704 }
1705
1706 static void
1707 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1708 {
1709 switch (deref->deref_type) {
1710 case nir_deref_type_var:
1711 emit_deref_var(ctx, deref);
1712 break;
1713
1714 case nir_deref_type_array:
1715 emit_deref_array(ctx, deref);
1716 break;
1717
1718 default:
1719 unreachable("unexpected deref_type");
1720 }
1721 }
1722
1723 static void
1724 emit_block(struct ntv_context *ctx, struct nir_block *block)
1725 {
1726 start_block(ctx, block_label(ctx, block));
1727 nir_foreach_instr(instr, block) {
1728 switch (instr->type) {
1729 case nir_instr_type_alu:
1730 emit_alu(ctx, nir_instr_as_alu(instr));
1731 break;
1732 case nir_instr_type_intrinsic:
1733 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1734 break;
1735 case nir_instr_type_load_const:
1736 emit_load_const(ctx, nir_instr_as_load_const(instr));
1737 break;
1738 case nir_instr_type_ssa_undef:
1739 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1740 break;
1741 case nir_instr_type_tex:
1742 emit_tex(ctx, nir_instr_as_tex(instr));
1743 break;
1744 case nir_instr_type_phi:
1745 unreachable("nir_instr_type_phi not supported");
1746 break;
1747 case nir_instr_type_jump:
1748 emit_jump(ctx, nir_instr_as_jump(instr));
1749 break;
1750 case nir_instr_type_call:
1751 unreachable("nir_instr_type_call not supported");
1752 break;
1753 case nir_instr_type_parallel_copy:
1754 unreachable("nir_instr_type_parallel_copy not supported");
1755 break;
1756 case nir_instr_type_deref:
1757 emit_deref(ctx, nir_instr_as_deref(instr));
1758 break;
1759 }
1760 }
1761 }
1762
1763 static void
1764 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1765
1766 static SpvId
1767 get_src_bool(struct ntv_context *ctx, nir_src *src)
1768 {
1769 assert(nir_src_bit_size(*src) == 1);
1770 return get_src(ctx, src);
1771 }
1772
1773 static void
1774 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1775 {
1776 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1777
1778 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1779 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1780 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1781 SpvId else_id = endif_id;
1782
1783 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1784 if (has_else) {
1785 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1786 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1787 }
1788
1789 /* create a header-block */
1790 start_block(ctx, header_id);
1791 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1792 SpvSelectionControlMaskNone);
1793 branch_conditional(ctx, condition, then_id, else_id);
1794
1795 emit_cf_list(ctx, &if_stmt->then_list);
1796
1797 if (has_else) {
1798 if (ctx->block_started)
1799 branch(ctx, endif_id);
1800
1801 emit_cf_list(ctx, &if_stmt->else_list);
1802 }
1803
1804 start_block(ctx, endif_id);
1805 }
1806
1807 static void
1808 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1809 {
1810 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1811 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1812 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1813 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1814
1815 /* create a header-block */
1816 start_block(ctx, header_id);
1817 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1818 branch(ctx, begin_id);
1819
1820 SpvId save_break = ctx->loop_break;
1821 SpvId save_cont = ctx->loop_cont;
1822 ctx->loop_break = break_id;
1823 ctx->loop_cont = cont_id;
1824
1825 emit_cf_list(ctx, &loop->body);
1826
1827 ctx->loop_break = save_break;
1828 ctx->loop_cont = save_cont;
1829
1830 branch(ctx, cont_id);
1831 start_block(ctx, cont_id);
1832 branch(ctx, header_id);
1833
1834 start_block(ctx, break_id);
1835 }
1836
1837 static void
1838 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1839 {
1840 foreach_list_typed(nir_cf_node, node, node, list) {
1841 switch (node->type) {
1842 case nir_cf_node_block:
1843 emit_block(ctx, nir_cf_node_as_block(node));
1844 break;
1845
1846 case nir_cf_node_if:
1847 emit_if(ctx, nir_cf_node_as_if(node));
1848 break;
1849
1850 case nir_cf_node_loop:
1851 emit_loop(ctx, nir_cf_node_as_loop(node));
1852 break;
1853
1854 case nir_cf_node_function:
1855 unreachable("nir_cf_node_function not supported");
1856 break;
1857 }
1858 }
1859 }
1860
1861 struct spirv_shader *
1862 nir_to_spirv(struct nir_shader *s)
1863 {
1864 struct spirv_shader *ret = NULL;
1865
1866 struct ntv_context ctx = {};
1867
1868 switch (s->info.stage) {
1869 case MESA_SHADER_VERTEX:
1870 case MESA_SHADER_FRAGMENT:
1871 case MESA_SHADER_COMPUTE:
1872 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1873 break;
1874
1875 case MESA_SHADER_TESS_CTRL:
1876 case MESA_SHADER_TESS_EVAL:
1877 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1878 break;
1879
1880 case MESA_SHADER_GEOMETRY:
1881 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1882 break;
1883
1884 default:
1885 unreachable("invalid stage");
1886 }
1887
1888 // TODO: only enable when needed
1889 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1890 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1891 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1892 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
1893 }
1894
1895 ctx.stage = s->info.stage;
1896 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1897 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1898
1899 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1900 SpvMemoryModelGLSL450);
1901
1902 SpvExecutionModel exec_model;
1903 switch (s->info.stage) {
1904 case MESA_SHADER_VERTEX:
1905 exec_model = SpvExecutionModelVertex;
1906 break;
1907 case MESA_SHADER_TESS_CTRL:
1908 exec_model = SpvExecutionModelTessellationControl;
1909 break;
1910 case MESA_SHADER_TESS_EVAL:
1911 exec_model = SpvExecutionModelTessellationEvaluation;
1912 break;
1913 case MESA_SHADER_GEOMETRY:
1914 exec_model = SpvExecutionModelGeometry;
1915 break;
1916 case MESA_SHADER_FRAGMENT:
1917 exec_model = SpvExecutionModelFragment;
1918 break;
1919 case MESA_SHADER_COMPUTE:
1920 exec_model = SpvExecutionModelGLCompute;
1921 break;
1922 default:
1923 unreachable("invalid stage");
1924 }
1925
1926 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1927 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1928 NULL, 0);
1929 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1930 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1931
1932 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1933 _mesa_key_pointer_equal);
1934
1935 nir_foreach_variable(var, &s->inputs)
1936 emit_input(&ctx, var);
1937
1938 nir_foreach_variable(var, &s->outputs)
1939 emit_output(&ctx, var);
1940
1941 nir_foreach_variable(var, &s->uniforms)
1942 emit_uniform(&ctx, var);
1943
1944 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1945 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1946 SpvExecutionModeOriginUpperLeft);
1947 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1948 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1949 SpvExecutionModeDepthReplacing);
1950 }
1951
1952
1953 spirv_builder_function(&ctx.builder, entry_point, type_void,
1954 SpvFunctionControlMaskNone,
1955 type_main);
1956
1957 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1958 nir_metadata_require(entry, nir_metadata_block_index);
1959
1960 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1961 if (!ctx.defs)
1962 goto fail;
1963 ctx.num_defs = entry->ssa_alloc;
1964
1965 nir_index_local_regs(entry);
1966 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1967 if (!ctx.regs)
1968 goto fail;
1969 ctx.num_regs = entry->reg_alloc;
1970
1971 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1972 if (!block_ids)
1973 goto fail;
1974
1975 for (int i = 0; i < entry->num_blocks; ++i)
1976 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1977
1978 ctx.block_ids = block_ids;
1979 ctx.num_blocks = entry->num_blocks;
1980
1981 /* emit a block only for the variable declarations */
1982 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1983 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1984 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1985 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1986 SpvStorageClassFunction,
1987 type);
1988 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1989 SpvStorageClassFunction);
1990
1991 ctx.regs[reg->index] = var;
1992 }
1993
1994 emit_cf_list(&ctx, &entry->body);
1995
1996 free(ctx.defs);
1997
1998 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1999 spirv_builder_function_end(&ctx.builder);
2000
2001 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2002 "main", ctx.entry_ifaces,
2003 ctx.num_entry_ifaces);
2004
2005 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2006
2007 ret = CALLOC_STRUCT(spirv_shader);
2008 if (!ret)
2009 goto fail;
2010
2011 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2012 if (!ret->words)
2013 goto fail;
2014
2015 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2016 assert(ret->num_words == num_words);
2017
2018 return ret;
2019
2020 fail:
2021
2022 if (ret)
2023 spirv_shader_delete(ret);
2024
2025 if (ctx.vars)
2026 _mesa_hash_table_destroy(ctx.vars, NULL);
2027
2028 return NULL;
2029 }
2030
2031 void
2032 spirv_shader_delete(struct spirv_shader *s)
2033 {
2034 FREE(s->words);
2035 FREE(s);
2036 }