zink: implement streamout and xfb handling in ntv
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 /* this consistently maps slots to a zero-indexed value to avoid wasting slots */
33 static unsigned slot_pack_map[] = {
34 /* Position is builtin */
35 [VARYING_SLOT_POS] = UINT_MAX,
36 [VARYING_SLOT_COL0] = 0, /* input/output */
37 [VARYING_SLOT_COL1] = 1, /* input/output */
38 [VARYING_SLOT_FOGC] = 2, /* input/output */
39 /* TEX0-7 are translated to VAR0-7 by nir, so we don't need to reserve */
40 [VARYING_SLOT_TEX0] = UINT_MAX, /* input/output */
41 [VARYING_SLOT_TEX1] = UINT_MAX,
42 [VARYING_SLOT_TEX2] = UINT_MAX,
43 [VARYING_SLOT_TEX3] = UINT_MAX,
44 [VARYING_SLOT_TEX4] = UINT_MAX,
45 [VARYING_SLOT_TEX5] = UINT_MAX,
46 [VARYING_SLOT_TEX6] = UINT_MAX,
47 [VARYING_SLOT_TEX7] = UINT_MAX,
48
49 /* PointSize is builtin */
50 [VARYING_SLOT_PSIZ] = UINT_MAX,
51
52 [VARYING_SLOT_BFC0] = 3, /* output only */
53 [VARYING_SLOT_BFC1] = 4, /* output only */
54 [VARYING_SLOT_EDGE] = 5, /* output only */
55 [VARYING_SLOT_CLIP_VERTEX] = 6, /* output only */
56
57 /* ClipDistance is builtin */
58 [VARYING_SLOT_CLIP_DIST0] = UINT_MAX,
59 [VARYING_SLOT_CLIP_DIST1] = UINT_MAX,
60
61 /* CullDistance is builtin */
62 [VARYING_SLOT_CULL_DIST0] = UINT_MAX, /* input/output */
63 [VARYING_SLOT_CULL_DIST1] = UINT_MAX, /* never actually used */
64
65 /* PrimitiveId is builtin */
66 [VARYING_SLOT_PRIMITIVE_ID] = UINT_MAX,
67
68 /* Layer is builtin */
69 [VARYING_SLOT_LAYER] = UINT_MAX, /* input/output */
70
71 /* ViewportIndex is builtin */
72 [VARYING_SLOT_VIEWPORT] = UINT_MAX, /* input/output */
73
74 /* FrontFacing is builtin */
75 [VARYING_SLOT_FACE] = UINT_MAX,
76
77 /* PointCoord is builtin */
78 [VARYING_SLOT_PNTC] = UINT_MAX, /* input only */
79
80 /* TessLevelOuter is builtin */
81 [VARYING_SLOT_TESS_LEVEL_OUTER] = UINT_MAX,
82 /* TessLevelInner is builtin */
83 [VARYING_SLOT_TESS_LEVEL_INNER] = UINT_MAX,
84
85 [VARYING_SLOT_BOUNDING_BOX0] = 7, /* Only appears as TCS output. */
86 [VARYING_SLOT_BOUNDING_BOX1] = 8, /* Only appears as TCS output. */
87 [VARYING_SLOT_VIEW_INDEX] = 9, /* input/output */
88 [VARYING_SLOT_VIEWPORT_MASK] = 10, /* output only */
89 };
90 #define NTV_MIN_RESERVED_SLOTS 10
91
92 struct ntv_context {
93 struct spirv_builder builder;
94
95 SpvId GLSL_std_450;
96
97 gl_shader_stage stage;
98
99 SpvId ubos[128];
100 size_t num_ubos;
101 SpvId image_types[PIPE_MAX_SAMPLERS];
102 SpvId samplers[PIPE_MAX_SAMPLERS];
103 unsigned samplers_used : PIPE_MAX_SAMPLERS;
104 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
105 size_t num_entry_ifaces;
106
107 SpvId *defs;
108 size_t num_defs;
109
110 SpvId *regs;
111 size_t num_regs;
112
113 struct hash_table *vars; /* nir_variable -> SpvId */
114 struct hash_table *so_outputs; /* pipe_stream_output -> SpvId */
115 unsigned outputs[VARYING_SLOT_MAX];
116 const struct glsl_type *so_output_gl_types[VARYING_SLOT_MAX];
117 SpvId so_output_types[VARYING_SLOT_MAX];
118
119 const SpvId *block_ids;
120 size_t num_blocks;
121 bool block_started;
122 SpvId loop_break, loop_cont;
123
124 SpvId front_face_var, instance_id_var, vertex_id_var;
125 };
126
127 static SpvId
128 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
129 unsigned num_components, float value);
130
131 static SpvId
132 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
133 unsigned num_components, uint32_t value);
134
135 static SpvId
136 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
137 unsigned num_components, int32_t value);
138
139 static SpvId
140 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
141
142 static SpvId
143 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
144 SpvId src0, SpvId src1);
145
146 static SpvId
147 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
148 SpvId src0, SpvId src1, SpvId src2);
149
150 static SpvId
151 get_bvec_type(struct ntv_context *ctx, int num_components)
152 {
153 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
154 if (num_components > 1)
155 return spirv_builder_type_vector(&ctx->builder, bool_type,
156 num_components);
157
158 assert(num_components == 1);
159 return bool_type;
160 }
161
162 static SpvId
163 block_label(struct ntv_context *ctx, nir_block *block)
164 {
165 assert(block->index < ctx->num_blocks);
166 return ctx->block_ids[block->index];
167 }
168
169 static SpvId
170 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
171 {
172 assert(bit_size == 32);
173 return spirv_builder_const_float(&ctx->builder, bit_size, value);
174 }
175
176 static SpvId
177 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
178 {
179 assert(bit_size == 32);
180 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
181 }
182
183 static SpvId
184 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
185 {
186 assert(bit_size == 32);
187 return spirv_builder_const_int(&ctx->builder, bit_size, value);
188 }
189
190 static SpvId
191 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
192 {
193 assert(bit_size == 32); // only 32-bit floats supported so far
194
195 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
196 if (num_components > 1)
197 return spirv_builder_type_vector(&ctx->builder, float_type,
198 num_components);
199
200 assert(num_components == 1);
201 return float_type;
202 }
203
204 static SpvId
205 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
206 {
207 assert(bit_size == 32); // only 32-bit ints supported so far
208
209 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
210 if (num_components > 1)
211 return spirv_builder_type_vector(&ctx->builder, int_type,
212 num_components);
213
214 assert(num_components == 1);
215 return int_type;
216 }
217
218 static SpvId
219 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
220 {
221 assert(bit_size == 32); // only 32-bit uints supported so far
222
223 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
224 if (num_components > 1)
225 return spirv_builder_type_vector(&ctx->builder, uint_type,
226 num_components);
227
228 assert(num_components == 1);
229 return uint_type;
230 }
231
232 static SpvId
233 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
234 {
235 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
236 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
237 }
238
239 static SpvId
240 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
241 {
242 switch (type) {
243 case GLSL_TYPE_BOOL:
244 return spirv_builder_type_bool(&ctx->builder);
245
246 case GLSL_TYPE_FLOAT:
247 return spirv_builder_type_float(&ctx->builder, 32);
248
249 case GLSL_TYPE_INT:
250 return spirv_builder_type_int(&ctx->builder, 32);
251
252 case GLSL_TYPE_UINT:
253 return spirv_builder_type_uint(&ctx->builder, 32);
254 /* TODO: handle more types */
255
256 default:
257 unreachable("unknown GLSL type");
258 }
259 }
260
261 static SpvId
262 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
263 {
264 assert(type);
265 if (glsl_type_is_scalar(type))
266 return get_glsl_basetype(ctx, glsl_get_base_type(type));
267
268 if (glsl_type_is_vector(type))
269 return spirv_builder_type_vector(&ctx->builder,
270 get_glsl_basetype(ctx, glsl_get_base_type(type)),
271 glsl_get_vector_elements(type));
272
273 if (glsl_type_is_array(type)) {
274 SpvId ret = spirv_builder_type_array(&ctx->builder,
275 get_glsl_type(ctx, glsl_get_array_element(type)),
276 emit_uint_const(ctx, 32, glsl_get_length(type)));
277 uint32_t stride = glsl_get_explicit_stride(type);
278 if (stride)
279 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
280 return ret;
281 }
282
283
284 unreachable("we shouldn't get here, I think...");
285 }
286
287 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
288 case VARYING_SLOT_##SLOT: \
289 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
290 break
291
292
293 static void
294 emit_input(struct ntv_context *ctx, struct nir_variable *var)
295 {
296 SpvId var_type = get_glsl_type(ctx, var->type);
297 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
298 SpvStorageClassInput,
299 var_type);
300 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
301 SpvStorageClassInput);
302
303 if (var->name)
304 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
305
306 if (ctx->stage == MESA_SHADER_FRAGMENT) {
307 unsigned slot = var->data.location;
308 switch (slot) {
309 HANDLE_EMIT_BUILTIN(POS, FragCoord);
310 HANDLE_EMIT_BUILTIN(PNTC, PointCoord);
311 HANDLE_EMIT_BUILTIN(LAYER, Layer);
312 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
313 HANDLE_EMIT_BUILTIN(CLIP_DIST0, ClipDistance);
314 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
315 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
316 HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
317
318 default:
319 if (slot < VARYING_SLOT_VAR0) {
320 slot = slot_pack_map[slot];
321 if (slot == UINT_MAX)
322 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 } else
324 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
325 assert(slot < VARYING_SLOT_VAR0);
326 spirv_builder_emit_location(&ctx->builder, var_id, slot);
327 }
328 } else {
329 spirv_builder_emit_location(&ctx->builder, var_id,
330 var->data.driver_location);
331 }
332
333 if (var->data.location_frac)
334 spirv_builder_emit_component(&ctx->builder, var_id,
335 var->data.location_frac);
336
337 if (var->data.interpolation == INTERP_MODE_FLAT)
338 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
339
340 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
341
342 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
343 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
344 }
345
346 static void
347 emit_output(struct ntv_context *ctx, struct nir_variable *var)
348 {
349 SpvId var_type = get_glsl_type(ctx, var->type);
350 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
351 SpvStorageClassOutput,
352 var_type);
353 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
354 SpvStorageClassOutput);
355 if (var->name)
356 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
357
358
359 if (ctx->stage == MESA_SHADER_VERTEX) {
360 unsigned slot = var->data.location;
361 switch (slot) {
362 HANDLE_EMIT_BUILTIN(POS, Position);
363 HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
364 HANDLE_EMIT_BUILTIN(LAYER, Layer);
365 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
366 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
367 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
368 HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
369 HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
370
371 case VARYING_SLOT_CLIP_DIST0:
372 assert(glsl_type_is_array(var->type));
373 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
374 /* this can be as large as 2x vec4, which requires 2 slots */
375 ctx->outputs[VARYING_SLOT_CLIP_DIST1] = var_id;
376 ctx->so_output_gl_types[VARYING_SLOT_CLIP_DIST1] = var->type;
377 ctx->so_output_types[VARYING_SLOT_CLIP_DIST1] = var_type;
378 break;
379
380 default:
381 if (slot < VARYING_SLOT_VAR0) {
382 slot = slot_pack_map[slot];
383 if (slot == UINT_MAX)
384 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
385 } else
386 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
387 assert(slot < VARYING_SLOT_VAR0);
388 spirv_builder_emit_location(&ctx->builder, var_id, slot);
389 /* non-builtins get location incremented by VARYING_SLOT_VAR0 in vtn, so
390 * use driver_location for non-builtins with defined slots to avoid overlap
391 */
392 }
393 ctx->outputs[var->data.location] = var_id;
394 ctx->so_output_gl_types[var->data.location] = var->type;
395 ctx->so_output_types[var->data.location] = var_type;
396 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
397 if (var->data.location >= FRAG_RESULT_DATA0)
398 spirv_builder_emit_location(&ctx->builder, var_id,
399 var->data.location - FRAG_RESULT_DATA0);
400 else {
401 switch (var->data.location) {
402 case FRAG_RESULT_COLOR:
403 spirv_builder_emit_location(&ctx->builder, var_id, 0);
404 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
405 break;
406
407 case FRAG_RESULT_DEPTH:
408 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
409 break;
410
411 default:
412 spirv_builder_emit_location(&ctx->builder, var_id,
413 var->data.driver_location);
414 }
415 }
416 }
417
418 if (var->data.location_frac)
419 spirv_builder_emit_component(&ctx->builder, var_id,
420 var->data.location_frac);
421
422 switch (var->data.interpolation) {
423 case INTERP_MODE_NONE:
424 case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
425 break;
426 case INTERP_MODE_FLAT:
427 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
428 break;
429 case INTERP_MODE_EXPLICIT:
430 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
431 break;
432 case INTERP_MODE_NOPERSPECTIVE:
433 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
434 break;
435 default:
436 unreachable("unknown interpolation value");
437 }
438
439 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
440
441 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
442 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
443 }
444
445 static SpvDim
446 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
447 {
448 *is_ms = false;
449 switch (gdim) {
450 case GLSL_SAMPLER_DIM_1D:
451 return SpvDim1D;
452 case GLSL_SAMPLER_DIM_2D:
453 return SpvDim2D;
454 case GLSL_SAMPLER_DIM_3D:
455 return SpvDim3D;
456 case GLSL_SAMPLER_DIM_CUBE:
457 return SpvDimCube;
458 case GLSL_SAMPLER_DIM_RECT:
459 return SpvDim2D;
460 case GLSL_SAMPLER_DIM_BUF:
461 return SpvDimBuffer;
462 case GLSL_SAMPLER_DIM_EXTERNAL:
463 return SpvDim2D; /* seems dodgy... */
464 case GLSL_SAMPLER_DIM_MS:
465 *is_ms = true;
466 return SpvDim2D;
467 default:
468 fprintf(stderr, "unknown sampler type %d\n", gdim);
469 break;
470 }
471 return SpvDim2D;
472 }
473
474 uint32_t
475 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
476 {
477 if (stage == MESA_SHADER_NONE ||
478 stage >= MESA_SHADER_COMPUTE) {
479 unreachable("not supported");
480 } else {
481 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
482 PIPE_MAX_SHADER_SAMPLER_VIEWS);
483
484 switch (type) {
485 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
486 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
487 return stage_offset + index;
488
489 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
490 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
491 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
492
493 default:
494 unreachable("unexpected type");
495 }
496 }
497 }
498
499 static void
500 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
501 {
502 const struct glsl_type *type = glsl_without_array(var->type);
503
504 bool is_ms;
505 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
506
507 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
508 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
509 dimension, false,
510 glsl_sampler_type_is_array(type),
511 is_ms, 1,
512 SpvImageFormatUnknown);
513
514 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
515 image_type);
516 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
517 SpvStorageClassUniformConstant,
518 sampled_type);
519
520 if (glsl_type_is_array(var->type)) {
521 for (int i = 0; i < glsl_get_length(var->type); ++i) {
522 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
523 SpvStorageClassUniformConstant);
524
525 if (var->name) {
526 char element_name[100];
527 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
528 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
529 }
530
531 int index = var->data.binding + i;
532 assert(!(ctx->samplers_used & (1 << index)));
533 assert(!ctx->image_types[index]);
534 ctx->image_types[index] = image_type;
535 ctx->samplers[index] = var_id;
536 ctx->samplers_used |= 1 << index;
537
538 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
539 var->data.descriptor_set);
540 int binding = zink_binding(ctx->stage,
541 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
542 var->data.binding + i);
543 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
544 }
545 } else {
546 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
547 SpvStorageClassUniformConstant);
548
549 if (var->name)
550 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
551
552 int index = var->data.binding;
553 assert(!(ctx->samplers_used & (1 << index)));
554 assert(!ctx->image_types[index]);
555 ctx->image_types[index] = image_type;
556 ctx->samplers[index] = var_id;
557 ctx->samplers_used |= 1 << index;
558
559 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
560 var->data.descriptor_set);
561 int binding = zink_binding(ctx->stage,
562 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
563 var->data.binding);
564 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
565 }
566 }
567
568 static void
569 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
570 {
571 uint32_t size = glsl_count_attribute_slots(var->type, false);
572 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
573 SpvId array_length = emit_uint_const(ctx, 32, size);
574 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
575 array_length);
576 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
577
578 // wrap UBO-array in a struct
579 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
580 if (var->name) {
581 char struct_name[100];
582 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
583 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
584 }
585
586 spirv_builder_emit_decoration(&ctx->builder, struct_type,
587 SpvDecorationBlock);
588 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
589
590
591 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
592 SpvStorageClassUniform,
593 struct_type);
594
595 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
596 SpvStorageClassUniform);
597 if (var->name)
598 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
599
600 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
601 ctx->ubos[ctx->num_ubos++] = var_id;
602
603 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
604 var->data.descriptor_set);
605 int binding = zink_binding(ctx->stage,
606 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
607 var->data.binding);
608 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
609 }
610
611 static void
612 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
613 {
614 if (var->data.mode == nir_var_mem_ubo)
615 emit_ubo(ctx, var);
616 else {
617 assert(var->data.mode == nir_var_uniform);
618 if (glsl_type_is_sampler(glsl_without_array(var->type)))
619 emit_sampler(ctx, var);
620 }
621 }
622
623 static SpvId
624 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
625 {
626 assert(ssa->index < ctx->num_defs);
627 assert(ctx->defs[ssa->index] != 0);
628 return ctx->defs[ssa->index];
629 }
630
631 static SpvId
632 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
633 {
634 assert(reg->index < ctx->num_regs);
635 assert(ctx->regs[reg->index] != 0);
636 return ctx->regs[reg->index];
637 }
638
639 static SpvId
640 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
641 {
642 assert(reg->reg);
643 assert(!reg->indirect);
644 assert(!reg->base_offset);
645
646 SpvId var = get_var_from_reg(ctx, reg->reg);
647 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
648 return spirv_builder_emit_load(&ctx->builder, type, var);
649 }
650
651 static SpvId
652 get_src(struct ntv_context *ctx, nir_src *src)
653 {
654 if (src->is_ssa)
655 return get_src_ssa(ctx, src->ssa);
656 else
657 return get_src_reg(ctx, &src->reg);
658 }
659
660 static SpvId
661 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
662 {
663 assert(!alu->src[src].negate);
664 assert(!alu->src[src].abs);
665
666 SpvId def = get_src(ctx, &alu->src[src].src);
667
668 unsigned used_channels = 0;
669 bool need_swizzle = false;
670 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
671 if (!nir_alu_instr_channel_used(alu, src, i))
672 continue;
673
674 used_channels++;
675
676 if (alu->src[src].swizzle[i] != i)
677 need_swizzle = true;
678 }
679 assert(used_channels != 0);
680
681 unsigned live_channels = nir_src_num_components(alu->src[src].src);
682 if (used_channels != live_channels)
683 need_swizzle = true;
684
685 if (!need_swizzle)
686 return def;
687
688 int bit_size = nir_src_bit_size(alu->src[src].src);
689 assert(bit_size == 1 || bit_size == 32);
690
691 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
692 spirv_builder_type_uint(&ctx->builder, bit_size);
693
694 if (used_channels == 1) {
695 uint32_t indices[] = { alu->src[src].swizzle[0] };
696 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
697 def, indices,
698 ARRAY_SIZE(indices));
699 } else if (live_channels == 1) {
700 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
701 raw_type,
702 used_channels);
703
704 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
705 for (unsigned i = 0; i < used_channels; ++i)
706 constituents[i] = def;
707
708 return spirv_builder_emit_composite_construct(&ctx->builder,
709 raw_vec_type,
710 constituents,
711 used_channels);
712 } else {
713 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
714 raw_type,
715 used_channels);
716
717 uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
718 size_t num_components = 0;
719 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
720 if (!nir_alu_instr_channel_used(alu, src, i))
721 continue;
722
723 components[num_components++] = alu->src[src].swizzle[i];
724 }
725
726 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
727 def, def, components,
728 num_components);
729 }
730 }
731
732 static void
733 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
734 {
735 assert(result != 0);
736 assert(ssa->index < ctx->num_defs);
737 ctx->defs[ssa->index] = result;
738 }
739
740 static SpvId
741 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
742 SpvId if_true, SpvId if_false)
743 {
744 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
745 }
746
747 static SpvId
748 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
749 {
750 SpvId type = get_bvec_type(ctx, num_components);
751 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
752 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
753 }
754
755 static SpvId
756 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
757 {
758 return emit_unop(ctx, SpvOpBitcast, type, value);
759 }
760
761 static SpvId
762 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
763 unsigned num_components)
764 {
765 SpvId type = get_uvec_type(ctx, bit_size, num_components);
766 return emit_bitcast(ctx, type, value);
767 }
768
769 static SpvId
770 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
771 unsigned num_components)
772 {
773 SpvId type = get_ivec_type(ctx, bit_size, num_components);
774 return emit_bitcast(ctx, type, value);
775 }
776
777 static SpvId
778 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
779 unsigned num_components)
780 {
781 SpvId type = get_fvec_type(ctx, bit_size, num_components);
782 return emit_bitcast(ctx, type, value);
783 }
784
785 static void
786 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
787 {
788 SpvId var = get_var_from_reg(ctx, reg->reg);
789 assert(var);
790 spirv_builder_emit_store(&ctx->builder, var, result);
791 }
792
793 static void
794 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
795 {
796 if (dest->is_ssa)
797 store_ssa_def(ctx, &dest->ssa, result);
798 else
799 store_reg_def(ctx, &dest->reg, result);
800 }
801
802 static SpvId
803 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
804 {
805 unsigned num_components = nir_dest_num_components(*dest);
806 unsigned bit_size = nir_dest_bit_size(*dest);
807
808 if (bit_size != 1) {
809 switch (nir_alu_type_get_base_type(type)) {
810 case nir_type_bool:
811 assert("bool should have bit-size 1");
812
813 case nir_type_uint:
814 break; /* nothing to do! */
815
816 case nir_type_int:
817 case nir_type_float:
818 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
819 break;
820
821 default:
822 unreachable("unsupported nir_alu_type");
823 }
824 }
825
826 store_dest_raw(ctx, dest, result);
827 return result;
828 }
829
830 static SpvId
831 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
832 {
833 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
834 }
835
836 /* return the intended xfb output vec type based on base type and vector size */
837 static SpvId
838 get_output_type(struct ntv_context *ctx, unsigned register_index, unsigned num_components)
839 {
840 const struct glsl_type *out_type = ctx->so_output_gl_types[register_index];
841 enum glsl_base_type base_type = glsl_get_base_type(out_type);
842 if (base_type == GLSL_TYPE_ARRAY)
843 base_type = glsl_get_base_type(glsl_without_array(out_type));
844
845 switch (base_type) {
846 case GLSL_TYPE_BOOL:
847 return get_bvec_type(ctx, num_components);
848
849 case GLSL_TYPE_FLOAT:
850 return get_fvec_type(ctx, 32, num_components);
851
852 case GLSL_TYPE_INT:
853 return get_ivec_type(ctx, 32, num_components);
854
855 case GLSL_TYPE_UINT:
856 return get_uvec_type(ctx, 32, num_components);
857
858 default:
859 break;
860 }
861 unreachable("unknown type");
862 return 0;
863 }
864
865 /* for streamout create new outputs, as streamout can be done on individual components,
866 from complete outputs, so we just can't use the created packed outputs */
867 static void
868 emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
869 const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
870 {
871 for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
872 struct pipe_stream_output so_output = local_so_info->output[i];
873 SpvId out_type = get_output_type(ctx, so_output.register_index, so_output.num_components);
874 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
875 SpvStorageClassOutput,
876 out_type);
877 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
878 SpvStorageClassOutput);
879 char name[10];
880
881 snprintf(name, 10, "xfb%d", i);
882 spirv_builder_emit_name(&ctx->builder, var_id, name);
883 spirv_builder_emit_offset(&ctx->builder, var_id, (so_output.dst_offset * 4));
884 spirv_builder_emit_xfb_buffer(&ctx->builder, var_id, so_output.output_buffer);
885 spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->stride[so_output.output_buffer] * 4);
886
887 /* output location is incremented by VARYING_SLOT_VAR0 for non-builtins in vtn,
888 * so we need to ensure that the new xfb location slot doesn't conflict with any previously-emitted
889 * outputs.
890 *
891 * if there's no previous outputs that take up user slots (VAR0+) then we can start right after the
892 * glsl builtin reserved slots, otherwise we start just after the adjusted user output slot
893 */
894 uint32_t location = NTV_MIN_RESERVED_SLOTS + i;
895 if (max_output_location >= VARYING_SLOT_VAR0)
896 location = max_output_location - VARYING_SLOT_VAR0 + 1 + i;
897 assert(location < VARYING_SLOT_VAR0);
898 spirv_builder_emit_location(&ctx->builder, var_id, location);
899
900 /* note: gl_ClipDistance[4] can the 0-indexed member of VARYING_SLOT_CLIP_DIST1 here,
901 * so this is still the 0 component
902 */
903 if (so_output.start_component)
904 spirv_builder_emit_component(&ctx->builder, var_id, so_output.start_component);
905
906 uint32_t *key = ralloc_size(NULL, sizeof(uint32_t));
907 *key = (uint32_t)so_output.register_index << 2 | so_output.start_component;
908 _mesa_hash_table_insert(ctx->so_outputs, key, (void *)(intptr_t)var_id);
909
910 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
911 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
912 }
913 }
914
915 static void
916 emit_so_outputs(struct ntv_context *ctx,
917 const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
918 {
919 SpvId loaded_outputs[VARYING_SLOT_MAX] = {};
920 for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
921 uint32_t components[NIR_MAX_VEC_COMPONENTS];
922 struct pipe_stream_output so_output = local_so_info->output[i];
923 uint32_t so_key = (uint32_t) so_output.register_index << 2 | so_output.start_component;
924 struct hash_entry *he = _mesa_hash_table_search(ctx->so_outputs, &so_key);
925 assert(he);
926 SpvId so_output_var_id = (SpvId)(intptr_t)he->data;
927
928 SpvId type = get_output_type(ctx, so_output.register_index, so_output.num_components);
929 SpvId output = ctx->outputs[so_output.register_index];
930 SpvId output_type = ctx->so_output_types[so_output.register_index];
931 const struct glsl_type *out_type = ctx->so_output_gl_types[so_output.register_index];
932
933 if (!loaded_outputs[so_output.register_index])
934 loaded_outputs[so_output.register_index] = spirv_builder_emit_load(&ctx->builder, output_type, output);
935 SpvId src = loaded_outputs[so_output.register_index];
936
937 SpvId result;
938
939 for (unsigned c = 0; c < so_output.num_components; c++) {
940 components[c] = so_output.start_component + c;
941 /* this is the second half of a 2 * vec4 array */
942 if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
943 components[c] += 4;
944 }
945
946 /* if we're emitting a scalar or the type we're emitting matches the output's original type and we're
947 * emitting the same number of components, then we can skip any sort of conversion here
948 */
949 if (glsl_type_is_scalar(out_type) || (type == output_type && glsl_get_length(out_type) == so_output.num_components))
950 result = src;
951 else {
952 if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_POS) {
953 /* gl_Position was modified by nir_lower_clip_halfz, so we need to reverse that for streamout here:
954 *
955 * opengl gl_Position.z = (vulkan gl_Position.z * 2.0) - vulkan gl_Position.w
956 *
957 * to do this, we extract the z and w components, perform the multiply and subtract ops, then reinsert
958 */
959 uint32_t z_component[] = {2};
960 uint32_t w_component[] = {3};
961 SpvId ftype = spirv_builder_type_float(&ctx->builder, 32);
962 SpvId z = spirv_builder_emit_composite_extract(&ctx->builder, ftype, src, z_component, 1);
963 SpvId w = spirv_builder_emit_composite_extract(&ctx->builder, ftype, src, w_component, 1);
964 SpvId new_z = emit_binop(ctx, SpvOpFMul, ftype, z, spirv_builder_const_float(&ctx->builder, 32, 2.0));
965 new_z = emit_binop(ctx, SpvOpFSub, ftype, new_z, w);
966 src = spirv_builder_emit_vector_insert(&ctx->builder, type, src, new_z, 2);
967 }
968 /* OpCompositeExtract can only extract scalars for our use here */
969 if (so_output.num_components == 1) {
970 result = spirv_builder_emit_composite_extract(&ctx->builder, type, src, components, so_output.num_components);
971 } else if (glsl_type_is_vector(out_type)) {
972 /* OpVectorShuffle can select vector members into a differently-sized vector */
973 result = spirv_builder_emit_vector_shuffle(&ctx->builder, type,
974 src, src,
975 components, so_output.num_components);
976 result = emit_unop(ctx, SpvOpBitcast, type, result);
977 } else {
978 /* for arrays, we need to manually extract each desired member
979 * and re-pack them into the desired output type
980 */
981 for (unsigned c = 0; c < so_output.num_components; c++) {
982 uint32_t member[] = { so_output.start_component + c };
983 SpvId base_type = get_glsl_type(ctx, glsl_without_array(out_type));
984
985 if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
986 member[0] += 4;
987 components[c] = spirv_builder_emit_composite_extract(&ctx->builder, base_type, src, member, 1);
988 }
989 result = spirv_builder_emit_composite_construct(&ctx->builder, type, components, so_output.num_components);
990 }
991 }
992
993 spirv_builder_emit_store(&ctx->builder, so_output_var_id, result);
994 }
995 }
996
997 static SpvId
998 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
999 SpvId src0, SpvId src1)
1000 {
1001 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
1002 }
1003
1004 static SpvId
1005 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
1006 SpvId src0, SpvId src1, SpvId src2)
1007 {
1008 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
1009 }
1010
1011 static SpvId
1012 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1013 SpvId src)
1014 {
1015 SpvId args[] = { src };
1016 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1017 op, args, ARRAY_SIZE(args));
1018 }
1019
1020 static SpvId
1021 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1022 SpvId src0, SpvId src1)
1023 {
1024 SpvId args[] = { src0, src1 };
1025 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1026 op, args, ARRAY_SIZE(args));
1027 }
1028
1029 static SpvId
1030 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1031 SpvId src0, SpvId src1, SpvId src2)
1032 {
1033 SpvId args[] = { src0, src1, src2 };
1034 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1035 op, args, ARRAY_SIZE(args));
1036 }
1037
1038 static SpvId
1039 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
1040 unsigned num_components, float value)
1041 {
1042 assert(bit_size == 32);
1043
1044 SpvId result = emit_float_const(ctx, bit_size, value);
1045 if (num_components == 1)
1046 return result;
1047
1048 assert(num_components > 1);
1049 SpvId components[num_components];
1050 for (int i = 0; i < num_components; i++)
1051 components[i] = result;
1052
1053 SpvId type = get_fvec_type(ctx, bit_size, num_components);
1054 return spirv_builder_const_composite(&ctx->builder, type, components,
1055 num_components);
1056 }
1057
1058 static SpvId
1059 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
1060 unsigned num_components, uint32_t value)
1061 {
1062 assert(bit_size == 32);
1063
1064 SpvId result = emit_uint_const(ctx, bit_size, value);
1065 if (num_components == 1)
1066 return result;
1067
1068 assert(num_components > 1);
1069 SpvId components[num_components];
1070 for (int i = 0; i < num_components; i++)
1071 components[i] = result;
1072
1073 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1074 return spirv_builder_const_composite(&ctx->builder, type, components,
1075 num_components);
1076 }
1077
1078 static SpvId
1079 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
1080 unsigned num_components, int32_t value)
1081 {
1082 assert(bit_size == 32);
1083
1084 SpvId result = emit_int_const(ctx, bit_size, value);
1085 if (num_components == 1)
1086 return result;
1087
1088 assert(num_components > 1);
1089 SpvId components[num_components];
1090 for (int i = 0; i < num_components; i++)
1091 components[i] = result;
1092
1093 SpvId type = get_ivec_type(ctx, bit_size, num_components);
1094 return spirv_builder_const_composite(&ctx->builder, type, components,
1095 num_components);
1096 }
1097
1098 static inline unsigned
1099 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
1100 {
1101 if (nir_op_infos[instr->op].input_sizes[src] > 0)
1102 return nir_op_infos[instr->op].input_sizes[src];
1103
1104 if (instr->dest.dest.is_ssa)
1105 return instr->dest.dest.ssa.num_components;
1106 else
1107 return instr->dest.dest.reg.reg->num_components;
1108 }
1109
1110 static SpvId
1111 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
1112 {
1113 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
1114
1115 unsigned num_components = alu_instr_src_components(alu, src);
1116 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
1117 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
1118
1119 if (bit_size == 1)
1120 return raw_value;
1121 else {
1122 switch (nir_alu_type_get_base_type(type)) {
1123 case nir_type_bool:
1124 unreachable("bool should have bit-size 1");
1125
1126 case nir_type_int:
1127 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
1128
1129 case nir_type_uint:
1130 return raw_value;
1131
1132 case nir_type_float:
1133 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
1134
1135 default:
1136 unreachable("unknown nir_alu_type");
1137 }
1138 }
1139 }
1140
1141 static SpvId
1142 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
1143 {
1144 assert(!alu->dest.saturate);
1145 return store_dest(ctx, &alu->dest.dest, result,
1146 nir_op_infos[alu->op].output_type);
1147 }
1148
1149 static SpvId
1150 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
1151 {
1152 unsigned num_components = nir_dest_num_components(*dest);
1153 unsigned bit_size = nir_dest_bit_size(*dest);
1154
1155 if (bit_size == 1)
1156 return get_bvec_type(ctx, num_components);
1157
1158 switch (nir_alu_type_get_base_type(type)) {
1159 case nir_type_bool:
1160 unreachable("bool should have bit-size 1");
1161
1162 case nir_type_int:
1163 return get_ivec_type(ctx, bit_size, num_components);
1164
1165 case nir_type_uint:
1166 return get_uvec_type(ctx, bit_size, num_components);
1167
1168 case nir_type_float:
1169 return get_fvec_type(ctx, bit_size, num_components);
1170
1171 default:
1172 unreachable("unsupported nir_alu_type");
1173 }
1174 }
1175
1176 static void
1177 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
1178 {
1179 SpvId src[nir_op_infos[alu->op].num_inputs];
1180 unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
1181 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
1182 src[i] = get_alu_src(ctx, alu, i);
1183 in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
1184 }
1185
1186 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
1187 nir_op_infos[alu->op].output_type);
1188 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
1189 unsigned num_components = nir_dest_num_components(alu->dest.dest);
1190
1191 SpvId result = 0;
1192 switch (alu->op) {
1193 case nir_op_mov:
1194 assert(nir_op_infos[alu->op].num_inputs == 1);
1195 result = src[0];
1196 break;
1197
1198 #define UNOP(nir_op, spirv_op) \
1199 case nir_op: \
1200 assert(nir_op_infos[alu->op].num_inputs == 1); \
1201 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
1202 break;
1203
1204 UNOP(nir_op_ineg, SpvOpSNegate)
1205 UNOP(nir_op_fneg, SpvOpFNegate)
1206 UNOP(nir_op_fddx, SpvOpDPdx)
1207 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
1208 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
1209 UNOP(nir_op_fddy, SpvOpDPdy)
1210 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
1211 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
1212 UNOP(nir_op_f2i32, SpvOpConvertFToS)
1213 UNOP(nir_op_f2u32, SpvOpConvertFToU)
1214 UNOP(nir_op_i2f32, SpvOpConvertSToF)
1215 UNOP(nir_op_u2f32, SpvOpConvertUToF)
1216 #undef UNOP
1217
1218 case nir_op_inot:
1219 if (bit_size == 1)
1220 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
1221 else
1222 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
1223 break;
1224
1225 case nir_op_b2i32:
1226 assert(nir_op_infos[alu->op].num_inputs == 1);
1227 result = emit_select(ctx, dest_type, src[0],
1228 get_ivec_constant(ctx, 32, num_components, 1),
1229 get_ivec_constant(ctx, 32, num_components, 0));
1230 break;
1231
1232 case nir_op_b2f32:
1233 assert(nir_op_infos[alu->op].num_inputs == 1);
1234 result = emit_select(ctx, dest_type, src[0],
1235 get_fvec_constant(ctx, 32, num_components, 1),
1236 get_fvec_constant(ctx, 32, num_components, 0));
1237 break;
1238
1239 #define BUILTIN_UNOP(nir_op, spirv_op) \
1240 case nir_op: \
1241 assert(nir_op_infos[alu->op].num_inputs == 1); \
1242 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1243 break;
1244
1245 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1246 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1247 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1248 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1249 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1250 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1251 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1252 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1253 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1254 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1255 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1256 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1257 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1258 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1259 #undef BUILTIN_UNOP
1260
1261 case nir_op_frcp:
1262 assert(nir_op_infos[alu->op].num_inputs == 1);
1263 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1264 get_fvec_constant(ctx, bit_size, num_components, 1),
1265 src[0]);
1266 break;
1267
1268 case nir_op_f2b1:
1269 assert(nir_op_infos[alu->op].num_inputs == 1);
1270 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1271 get_fvec_constant(ctx,
1272 nir_src_bit_size(alu->src[0].src),
1273 num_components, 0));
1274 break;
1275 case nir_op_i2b1:
1276 assert(nir_op_infos[alu->op].num_inputs == 1);
1277 result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1278 get_ivec_constant(ctx,
1279 nir_src_bit_size(alu->src[0].src),
1280 num_components, 0));
1281 break;
1282
1283
1284 #define BINOP(nir_op, spirv_op) \
1285 case nir_op: \
1286 assert(nir_op_infos[alu->op].num_inputs == 2); \
1287 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1288 break;
1289
1290 BINOP(nir_op_iadd, SpvOpIAdd)
1291 BINOP(nir_op_isub, SpvOpISub)
1292 BINOP(nir_op_imul, SpvOpIMul)
1293 BINOP(nir_op_idiv, SpvOpSDiv)
1294 BINOP(nir_op_udiv, SpvOpUDiv)
1295 BINOP(nir_op_umod, SpvOpUMod)
1296 BINOP(nir_op_fadd, SpvOpFAdd)
1297 BINOP(nir_op_fsub, SpvOpFSub)
1298 BINOP(nir_op_fmul, SpvOpFMul)
1299 BINOP(nir_op_fdiv, SpvOpFDiv)
1300 BINOP(nir_op_fmod, SpvOpFMod)
1301 BINOP(nir_op_ilt, SpvOpSLessThan)
1302 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1303 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1304 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1305 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1306 BINOP(nir_op_feq, SpvOpFOrdEqual)
1307 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1308 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1309 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1310 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1311 #undef BINOP
1312
1313 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1314 case nir_op: \
1315 assert(nir_op_infos[alu->op].num_inputs == 2); \
1316 if (nir_src_bit_size(alu->src[0].src) == 1) \
1317 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1318 else \
1319 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1320 break;
1321
1322 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1323 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1324 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1325 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1326 #undef BINOP_LOG
1327
1328 #define BUILTIN_BINOP(nir_op, spirv_op) \
1329 case nir_op: \
1330 assert(nir_op_infos[alu->op].num_inputs == 2); \
1331 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1332 break;
1333
1334 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1335 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1336 #undef BUILTIN_BINOP
1337
1338 case nir_op_fdot2:
1339 case nir_op_fdot3:
1340 case nir_op_fdot4:
1341 assert(nir_op_infos[alu->op].num_inputs == 2);
1342 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1343 break;
1344
1345 case nir_op_fdph:
1346 unreachable("should already be lowered away");
1347
1348 case nir_op_seq:
1349 case nir_op_sne:
1350 case nir_op_slt:
1351 case nir_op_sge: {
1352 assert(nir_op_infos[alu->op].num_inputs == 2);
1353 int num_components = nir_dest_num_components(alu->dest.dest);
1354 SpvId bool_type = get_bvec_type(ctx, num_components);
1355
1356 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1357 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1358 if (num_components > 1) {
1359 SpvId zero_comps[num_components], one_comps[num_components];
1360 for (int i = 0; i < num_components; i++) {
1361 zero_comps[i] = zero;
1362 one_comps[i] = one;
1363 }
1364
1365 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1366 zero_comps, num_components);
1367 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1368 one_comps, num_components);
1369 }
1370
1371 SpvOp op;
1372 switch (alu->op) {
1373 case nir_op_seq: op = SpvOpFOrdEqual; break;
1374 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1375 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1376 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1377 default: unreachable("unexpected op");
1378 }
1379
1380 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1381 result = emit_select(ctx, dest_type, result, one, zero);
1382 }
1383 break;
1384
1385 case nir_op_flrp:
1386 assert(nir_op_infos[alu->op].num_inputs == 3);
1387 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1388 src[0], src[1], src[2]);
1389 break;
1390
1391 case nir_op_fcsel:
1392 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1393 get_bvec_type(ctx, num_components),
1394 src[0],
1395 get_fvec_constant(ctx,
1396 nir_src_bit_size(alu->src[0].src),
1397 num_components, 0));
1398 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1399 break;
1400
1401 case nir_op_bcsel:
1402 assert(nir_op_infos[alu->op].num_inputs == 3);
1403 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1404 break;
1405
1406 case nir_op_bany_fnequal2:
1407 case nir_op_bany_fnequal3:
1408 case nir_op_bany_fnequal4: {
1409 assert(nir_op_infos[alu->op].num_inputs == 2);
1410 assert(alu_instr_src_components(alu, 0) ==
1411 alu_instr_src_components(alu, 1));
1412 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1413 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1414 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1415 result = emit_binop(ctx, op,
1416 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1417 src[0], src[1]);
1418 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1419 break;
1420 }
1421
1422 case nir_op_ball_fequal2:
1423 case nir_op_ball_fequal3:
1424 case nir_op_ball_fequal4: {
1425 assert(nir_op_infos[alu->op].num_inputs == 2);
1426 assert(alu_instr_src_components(alu, 0) ==
1427 alu_instr_src_components(alu, 1));
1428 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1429 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1430 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1431 result = emit_binop(ctx, op,
1432 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1433 src[0], src[1]);
1434 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1435 break;
1436 }
1437
1438 case nir_op_bany_inequal2:
1439 case nir_op_bany_inequal3:
1440 case nir_op_bany_inequal4: {
1441 assert(nir_op_infos[alu->op].num_inputs == 2);
1442 assert(alu_instr_src_components(alu, 0) ==
1443 alu_instr_src_components(alu, 1));
1444 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1445 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1446 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1447 result = emit_binop(ctx, op,
1448 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1449 src[0], src[1]);
1450 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1451 break;
1452 }
1453
1454 case nir_op_ball_iequal2:
1455 case nir_op_ball_iequal3:
1456 case nir_op_ball_iequal4: {
1457 assert(nir_op_infos[alu->op].num_inputs == 2);
1458 assert(alu_instr_src_components(alu, 0) ==
1459 alu_instr_src_components(alu, 1));
1460 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1461 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1462 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1463 result = emit_binop(ctx, op,
1464 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1465 src[0], src[1]);
1466 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1467 break;
1468 }
1469
1470 case nir_op_vec2:
1471 case nir_op_vec3:
1472 case nir_op_vec4: {
1473 int num_inputs = nir_op_infos[alu->op].num_inputs;
1474 assert(2 <= num_inputs && num_inputs <= 4);
1475 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1476 src, num_inputs);
1477 }
1478 break;
1479
1480 default:
1481 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1482 nir_op_infos[alu->op].name);
1483
1484 unreachable("unsupported opcode");
1485 return;
1486 }
1487
1488 store_alu_result(ctx, alu, result);
1489 }
1490
1491 static void
1492 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1493 {
1494 unsigned bit_size = load_const->def.bit_size;
1495 unsigned num_components = load_const->def.num_components;
1496
1497 SpvId constant;
1498 if (num_components > 1) {
1499 SpvId components[num_components];
1500 SpvId type;
1501 if (bit_size == 1) {
1502 for (int i = 0; i < num_components; i++)
1503 components[i] = spirv_builder_const_bool(&ctx->builder,
1504 load_const->value[i].b);
1505
1506 type = get_bvec_type(ctx, num_components);
1507 } else {
1508 for (int i = 0; i < num_components; i++)
1509 components[i] = emit_uint_const(ctx, bit_size,
1510 load_const->value[i].u32);
1511
1512 type = get_uvec_type(ctx, bit_size, num_components);
1513 }
1514 constant = spirv_builder_const_composite(&ctx->builder, type,
1515 components, num_components);
1516 } else {
1517 assert(num_components == 1);
1518 if (bit_size == 1)
1519 constant = spirv_builder_const_bool(&ctx->builder,
1520 load_const->value[0].b);
1521 else
1522 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1523 }
1524
1525 store_ssa_def(ctx, &load_const->def, constant);
1526 }
1527
1528 static void
1529 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1530 {
1531 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1532 assert(const_block_index); // no dynamic indexing for now
1533 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1534
1535 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1536 if (const_offset) {
1537 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1538 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1539 SpvStorageClassUniform,
1540 uvec4_type);
1541
1542 unsigned idx = const_offset->u32;
1543 SpvId member = emit_uint_const(ctx, 32, 0);
1544 SpvId offset = emit_uint_const(ctx, 32, idx);
1545 SpvId offsets[] = { member, offset };
1546 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1547 ctx->ubos[0], offsets,
1548 ARRAY_SIZE(offsets));
1549 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1550
1551 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1552 unsigned num_components = nir_dest_num_components(intr->dest);
1553 if (num_components == 1) {
1554 uint32_t components[] = { 0 };
1555 result = spirv_builder_emit_composite_extract(&ctx->builder,
1556 type,
1557 result, components,
1558 1);
1559 } else if (num_components < 4) {
1560 SpvId constituents[num_components];
1561 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1562 for (uint32_t i = 0; i < num_components; ++i)
1563 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1564 uint_type,
1565 result, &i,
1566 1);
1567
1568 result = spirv_builder_emit_composite_construct(&ctx->builder,
1569 type,
1570 constituents,
1571 num_components);
1572 }
1573
1574 if (nir_dest_bit_size(intr->dest) == 1)
1575 result = uvec_to_bvec(ctx, result, num_components);
1576
1577 store_dest(ctx, &intr->dest, result, nir_type_uint);
1578 } else
1579 unreachable("uniform-addressing not yet supported");
1580 }
1581
1582 static void
1583 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1584 {
1585 assert(ctx->block_started);
1586 spirv_builder_emit_kill(&ctx->builder);
1587 /* discard is weird in NIR, so let's just create an unreachable block after
1588 it and hope that the vulkan driver will DCE any instructinos in it. */
1589 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1590 }
1591
1592 static void
1593 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1594 {
1595 SpvId ptr = get_src(ctx, intr->src);
1596
1597 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1598 SpvId result = spirv_builder_emit_load(&ctx->builder,
1599 get_glsl_type(ctx, var->type),
1600 ptr);
1601 unsigned num_components = nir_dest_num_components(intr->dest);
1602 unsigned bit_size = nir_dest_bit_size(intr->dest);
1603 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1604 store_dest(ctx, &intr->dest, result, nir_type_uint);
1605 }
1606
1607 static void
1608 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1609 {
1610 SpvId ptr = get_src(ctx, &intr->src[0]);
1611 SpvId src = get_src(ctx, &intr->src[1]);
1612
1613 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1614 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1615 SpvId result = emit_bitcast(ctx, type, src);
1616 spirv_builder_emit_store(&ctx->builder, ptr, result);
1617 }
1618
1619 static SpvId
1620 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1621 SpvStorageClass storage_class,
1622 const char *name, SpvBuiltIn builtin)
1623 {
1624 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1625 storage_class,
1626 var_type);
1627 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1628 storage_class);
1629 spirv_builder_emit_name(&ctx->builder, var, name);
1630 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1631
1632 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1633 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1634 return var;
1635 }
1636
1637 static void
1638 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1639 {
1640 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1641 if (!ctx->front_face_var)
1642 ctx->front_face_var = create_builtin_var(ctx, var_type,
1643 SpvStorageClassInput,
1644 "gl_FrontFacing",
1645 SpvBuiltInFrontFacing);
1646
1647 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1648 ctx->front_face_var);
1649 assert(1 == nir_dest_num_components(intr->dest));
1650 store_dest(ctx, &intr->dest, result, nir_type_bool);
1651 }
1652
1653 static void
1654 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1655 {
1656 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1657 if (!ctx->instance_id_var)
1658 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1659 SpvStorageClassInput,
1660 "gl_InstanceId",
1661 SpvBuiltInInstanceIndex);
1662
1663 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1664 ctx->instance_id_var);
1665 assert(1 == nir_dest_num_components(intr->dest));
1666 store_dest(ctx, &intr->dest, result, nir_type_uint);
1667 }
1668
1669 static void
1670 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1671 {
1672 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1673 if (!ctx->vertex_id_var)
1674 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1675 SpvStorageClassInput,
1676 "gl_VertexID",
1677 SpvBuiltInVertexIndex);
1678
1679 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1680 ctx->vertex_id_var);
1681 assert(1 == nir_dest_num_components(intr->dest));
1682 store_dest(ctx, &intr->dest, result, nir_type_uint);
1683 }
1684
1685 static void
1686 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1687 {
1688 switch (intr->intrinsic) {
1689 case nir_intrinsic_load_ubo:
1690 emit_load_ubo(ctx, intr);
1691 break;
1692
1693 case nir_intrinsic_discard:
1694 emit_discard(ctx, intr);
1695 break;
1696
1697 case nir_intrinsic_load_deref:
1698 emit_load_deref(ctx, intr);
1699 break;
1700
1701 case nir_intrinsic_store_deref:
1702 emit_store_deref(ctx, intr);
1703 break;
1704
1705 case nir_intrinsic_load_front_face:
1706 emit_load_front_face(ctx, intr);
1707 break;
1708
1709 case nir_intrinsic_load_instance_id:
1710 emit_load_instance_id(ctx, intr);
1711 break;
1712
1713 case nir_intrinsic_load_vertex_id:
1714 emit_load_vertex_id(ctx, intr);
1715 break;
1716
1717 default:
1718 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1719 nir_intrinsic_infos[intr->intrinsic].name);
1720 unreachable("unsupported intrinsic");
1721 }
1722 }
1723
1724 static void
1725 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1726 {
1727 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1728 undef->def.num_components);
1729
1730 store_ssa_def(ctx, &undef->def,
1731 spirv_builder_emit_undef(&ctx->builder, type));
1732 }
1733
1734 static SpvId
1735 get_src_float(struct ntv_context *ctx, nir_src *src)
1736 {
1737 SpvId def = get_src(ctx, src);
1738 unsigned num_components = nir_src_num_components(*src);
1739 unsigned bit_size = nir_src_bit_size(*src);
1740 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1741 }
1742
1743 static SpvId
1744 get_src_int(struct ntv_context *ctx, nir_src *src)
1745 {
1746 SpvId def = get_src(ctx, src);
1747 unsigned num_components = nir_src_num_components(*src);
1748 unsigned bit_size = nir_src_bit_size(*src);
1749 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1750 }
1751
1752 static void
1753 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1754 {
1755 assert(tex->op == nir_texop_tex ||
1756 tex->op == nir_texop_txb ||
1757 tex->op == nir_texop_txl ||
1758 tex->op == nir_texop_txd ||
1759 tex->op == nir_texop_txf ||
1760 tex->op == nir_texop_txf_ms ||
1761 tex->op == nir_texop_txs);
1762 assert(tex->texture_index == tex->sampler_index);
1763
1764 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1765 offset = 0, sample = 0;
1766 unsigned coord_components = 0;
1767 for (unsigned i = 0; i < tex->num_srcs; i++) {
1768 switch (tex->src[i].src_type) {
1769 case nir_tex_src_coord:
1770 if (tex->op == nir_texop_txf ||
1771 tex->op == nir_texop_txf_ms)
1772 coord = get_src_int(ctx, &tex->src[i].src);
1773 else
1774 coord = get_src_float(ctx, &tex->src[i].src);
1775 coord_components = nir_src_num_components(tex->src[i].src);
1776 break;
1777
1778 case nir_tex_src_projector:
1779 assert(nir_src_num_components(tex->src[i].src) == 1);
1780 proj = get_src_float(ctx, &tex->src[i].src);
1781 assert(proj != 0);
1782 break;
1783
1784 case nir_tex_src_offset:
1785 offset = get_src_int(ctx, &tex->src[i].src);
1786 break;
1787
1788 case nir_tex_src_bias:
1789 assert(tex->op == nir_texop_txb);
1790 bias = get_src_float(ctx, &tex->src[i].src);
1791 assert(bias != 0);
1792 break;
1793
1794 case nir_tex_src_lod:
1795 assert(nir_src_num_components(tex->src[i].src) == 1);
1796 if (tex->op == nir_texop_txf ||
1797 tex->op == nir_texop_txf_ms ||
1798 tex->op == nir_texop_txs)
1799 lod = get_src_int(ctx, &tex->src[i].src);
1800 else
1801 lod = get_src_float(ctx, &tex->src[i].src);
1802 assert(lod != 0);
1803 break;
1804
1805 case nir_tex_src_ms_index:
1806 assert(nir_src_num_components(tex->src[i].src) == 1);
1807 sample = get_src_int(ctx, &tex->src[i].src);
1808 break;
1809
1810 case nir_tex_src_comparator:
1811 assert(nir_src_num_components(tex->src[i].src) == 1);
1812 dref = get_src_float(ctx, &tex->src[i].src);
1813 assert(dref != 0);
1814 break;
1815
1816 case nir_tex_src_ddx:
1817 dx = get_src_float(ctx, &tex->src[i].src);
1818 assert(dx != 0);
1819 break;
1820
1821 case nir_tex_src_ddy:
1822 dy = get_src_float(ctx, &tex->src[i].src);
1823 assert(dy != 0);
1824 break;
1825
1826 default:
1827 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1828 unreachable("unknown texture source");
1829 }
1830 }
1831
1832 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1833 lod = emit_float_const(ctx, 32, 0.0f);
1834 assert(lod != 0);
1835 }
1836
1837 SpvId image_type = ctx->image_types[tex->texture_index];
1838 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1839 image_type);
1840
1841 assert(ctx->samplers_used & (1u << tex->texture_index));
1842 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1843 ctx->samplers[tex->texture_index]);
1844
1845 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1846
1847 if (tex->op == nir_texop_txs) {
1848 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1849 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1850 dest_type, image,
1851 lod);
1852 store_dest(ctx, &tex->dest, result, tex->dest_type);
1853 return;
1854 }
1855
1856 if (proj && coord_components > 0) {
1857 SpvId constituents[coord_components + 1];
1858 if (coord_components == 1)
1859 constituents[0] = coord;
1860 else {
1861 assert(coord_components > 1);
1862 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1863 for (uint32_t i = 0; i < coord_components; ++i)
1864 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1865 float_type,
1866 coord,
1867 &i, 1);
1868 }
1869
1870 constituents[coord_components++] = proj;
1871
1872 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1873 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1874 vec_type,
1875 constituents,
1876 coord_components);
1877 }
1878
1879 SpvId actual_dest_type = dest_type;
1880 if (dref)
1881 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1882
1883 SpvId result;
1884 if (tex->op == nir_texop_txf ||
1885 tex->op == nir_texop_txf_ms) {
1886 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1887 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1888 image, coord, lod, sample);
1889 } else {
1890 result = spirv_builder_emit_image_sample(&ctx->builder,
1891 actual_dest_type, load,
1892 coord,
1893 proj != 0,
1894 lod, bias, dref, dx, dy,
1895 offset);
1896 }
1897
1898 spirv_builder_emit_decoration(&ctx->builder, result,
1899 SpvDecorationRelaxedPrecision);
1900
1901 if (dref && nir_dest_num_components(tex->dest) > 1) {
1902 SpvId components[4] = { result, result, result, result };
1903 result = spirv_builder_emit_composite_construct(&ctx->builder,
1904 dest_type,
1905 components,
1906 4);
1907 }
1908
1909 store_dest(ctx, &tex->dest, result, tex->dest_type);
1910 }
1911
1912 static void
1913 start_block(struct ntv_context *ctx, SpvId label)
1914 {
1915 /* terminate previous block if needed */
1916 if (ctx->block_started)
1917 spirv_builder_emit_branch(&ctx->builder, label);
1918
1919 /* start new block */
1920 spirv_builder_label(&ctx->builder, label);
1921 ctx->block_started = true;
1922 }
1923
1924 static void
1925 branch(struct ntv_context *ctx, SpvId label)
1926 {
1927 assert(ctx->block_started);
1928 spirv_builder_emit_branch(&ctx->builder, label);
1929 ctx->block_started = false;
1930 }
1931
1932 static void
1933 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1934 SpvId else_id)
1935 {
1936 assert(ctx->block_started);
1937 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1938 then_id, else_id);
1939 ctx->block_started = false;
1940 }
1941
1942 static void
1943 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1944 {
1945 switch (jump->type) {
1946 case nir_jump_break:
1947 assert(ctx->loop_break);
1948 branch(ctx, ctx->loop_break);
1949 break;
1950
1951 case nir_jump_continue:
1952 assert(ctx->loop_cont);
1953 branch(ctx, ctx->loop_cont);
1954 break;
1955
1956 default:
1957 unreachable("Unsupported jump type\n");
1958 }
1959 }
1960
1961 static void
1962 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1963 {
1964 assert(deref->deref_type == nir_deref_type_var);
1965
1966 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1967 assert(he);
1968 SpvId result = (SpvId)(intptr_t)he->data;
1969 store_dest_raw(ctx, &deref->dest, result);
1970 }
1971
1972 static void
1973 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1974 {
1975 assert(deref->deref_type == nir_deref_type_array);
1976 nir_variable *var = nir_deref_instr_get_variable(deref);
1977
1978 SpvStorageClass storage_class;
1979 switch (var->data.mode) {
1980 case nir_var_shader_in:
1981 storage_class = SpvStorageClassInput;
1982 break;
1983
1984 case nir_var_shader_out:
1985 storage_class = SpvStorageClassOutput;
1986 break;
1987
1988 default:
1989 unreachable("Unsupported nir_variable_mode\n");
1990 }
1991
1992 SpvId index = get_src(ctx, &deref->arr.index);
1993
1994 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1995 storage_class,
1996 get_glsl_type(ctx, deref->type));
1997
1998 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1999 ptr_type,
2000 get_src(ctx, &deref->parent),
2001 &index, 1);
2002 /* uint is a bit of a lie here, it's really just an opaque type */
2003 store_dest(ctx, &deref->dest, result, nir_type_uint);
2004 }
2005
2006 static void
2007 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
2008 {
2009 switch (deref->deref_type) {
2010 case nir_deref_type_var:
2011 emit_deref_var(ctx, deref);
2012 break;
2013
2014 case nir_deref_type_array:
2015 emit_deref_array(ctx, deref);
2016 break;
2017
2018 default:
2019 unreachable("unexpected deref_type");
2020 }
2021 }
2022
2023 static void
2024 emit_block(struct ntv_context *ctx, struct nir_block *block)
2025 {
2026 start_block(ctx, block_label(ctx, block));
2027 nir_foreach_instr(instr, block) {
2028 switch (instr->type) {
2029 case nir_instr_type_alu:
2030 emit_alu(ctx, nir_instr_as_alu(instr));
2031 break;
2032 case nir_instr_type_intrinsic:
2033 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
2034 break;
2035 case nir_instr_type_load_const:
2036 emit_load_const(ctx, nir_instr_as_load_const(instr));
2037 break;
2038 case nir_instr_type_ssa_undef:
2039 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
2040 break;
2041 case nir_instr_type_tex:
2042 emit_tex(ctx, nir_instr_as_tex(instr));
2043 break;
2044 case nir_instr_type_phi:
2045 unreachable("nir_instr_type_phi not supported");
2046 break;
2047 case nir_instr_type_jump:
2048 emit_jump(ctx, nir_instr_as_jump(instr));
2049 break;
2050 case nir_instr_type_call:
2051 unreachable("nir_instr_type_call not supported");
2052 break;
2053 case nir_instr_type_parallel_copy:
2054 unreachable("nir_instr_type_parallel_copy not supported");
2055 break;
2056 case nir_instr_type_deref:
2057 emit_deref(ctx, nir_instr_as_deref(instr));
2058 break;
2059 }
2060 }
2061 }
2062
2063 static void
2064 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
2065
2066 static SpvId
2067 get_src_bool(struct ntv_context *ctx, nir_src *src)
2068 {
2069 assert(nir_src_bit_size(*src) == 1);
2070 return get_src(ctx, src);
2071 }
2072
2073 static void
2074 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
2075 {
2076 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
2077
2078 SpvId header_id = spirv_builder_new_id(&ctx->builder);
2079 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
2080 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
2081 SpvId else_id = endif_id;
2082
2083 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
2084 if (has_else) {
2085 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
2086 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
2087 }
2088
2089 /* create a header-block */
2090 start_block(ctx, header_id);
2091 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
2092 SpvSelectionControlMaskNone);
2093 branch_conditional(ctx, condition, then_id, else_id);
2094
2095 emit_cf_list(ctx, &if_stmt->then_list);
2096
2097 if (has_else) {
2098 if (ctx->block_started)
2099 branch(ctx, endif_id);
2100
2101 emit_cf_list(ctx, &if_stmt->else_list);
2102 }
2103
2104 start_block(ctx, endif_id);
2105 }
2106
2107 static void
2108 emit_loop(struct ntv_context *ctx, nir_loop *loop)
2109 {
2110 SpvId header_id = spirv_builder_new_id(&ctx->builder);
2111 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
2112 SpvId break_id = spirv_builder_new_id(&ctx->builder);
2113 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
2114
2115 /* create a header-block */
2116 start_block(ctx, header_id);
2117 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
2118 branch(ctx, begin_id);
2119
2120 SpvId save_break = ctx->loop_break;
2121 SpvId save_cont = ctx->loop_cont;
2122 ctx->loop_break = break_id;
2123 ctx->loop_cont = cont_id;
2124
2125 emit_cf_list(ctx, &loop->body);
2126
2127 ctx->loop_break = save_break;
2128 ctx->loop_cont = save_cont;
2129
2130 branch(ctx, cont_id);
2131 start_block(ctx, cont_id);
2132 branch(ctx, header_id);
2133
2134 start_block(ctx, break_id);
2135 }
2136
2137 static void
2138 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
2139 {
2140 foreach_list_typed(nir_cf_node, node, node, list) {
2141 switch (node->type) {
2142 case nir_cf_node_block:
2143 emit_block(ctx, nir_cf_node_as_block(node));
2144 break;
2145
2146 case nir_cf_node_if:
2147 emit_if(ctx, nir_cf_node_as_if(node));
2148 break;
2149
2150 case nir_cf_node_loop:
2151 emit_loop(ctx, nir_cf_node_as_loop(node));
2152 break;
2153
2154 case nir_cf_node_function:
2155 unreachable("nir_cf_node_function not supported");
2156 break;
2157 }
2158 }
2159 }
2160
2161 struct spirv_shader *
2162 nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
2163 {
2164 struct spirv_shader *ret = NULL;
2165
2166 struct ntv_context ctx = {};
2167
2168 switch (s->info.stage) {
2169 case MESA_SHADER_VERTEX:
2170 case MESA_SHADER_FRAGMENT:
2171 case MESA_SHADER_COMPUTE:
2172 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
2173 break;
2174
2175 case MESA_SHADER_TESS_CTRL:
2176 case MESA_SHADER_TESS_EVAL:
2177 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
2178 break;
2179
2180 case MESA_SHADER_GEOMETRY:
2181 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
2182 break;
2183
2184 default:
2185 unreachable("invalid stage");
2186 }
2187
2188 // TODO: only enable when needed
2189 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2190 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
2191 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
2192 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
2193 }
2194
2195 ctx.stage = s->info.stage;
2196 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
2197 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
2198
2199 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
2200 SpvMemoryModelGLSL450);
2201
2202 SpvExecutionModel exec_model;
2203 switch (s->info.stage) {
2204 case MESA_SHADER_VERTEX:
2205 exec_model = SpvExecutionModelVertex;
2206 break;
2207 case MESA_SHADER_TESS_CTRL:
2208 exec_model = SpvExecutionModelTessellationControl;
2209 break;
2210 case MESA_SHADER_TESS_EVAL:
2211 exec_model = SpvExecutionModelTessellationEvaluation;
2212 break;
2213 case MESA_SHADER_GEOMETRY:
2214 exec_model = SpvExecutionModelGeometry;
2215 break;
2216 case MESA_SHADER_FRAGMENT:
2217 exec_model = SpvExecutionModelFragment;
2218 break;
2219 case MESA_SHADER_COMPUTE:
2220 exec_model = SpvExecutionModelGLCompute;
2221 break;
2222 default:
2223 unreachable("invalid stage");
2224 }
2225
2226 SpvId type_void = spirv_builder_type_void(&ctx.builder);
2227 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
2228 NULL, 0);
2229 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
2230 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2231
2232 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
2233 _mesa_key_pointer_equal);
2234
2235 ctx.so_outputs = _mesa_hash_table_create(NULL, _mesa_hash_u32,
2236 _mesa_key_u32_equal);
2237
2238 nir_foreach_variable(var, &s->inputs)
2239 emit_input(&ctx, var);
2240
2241 nir_foreach_variable(var, &s->outputs)
2242 emit_output(&ctx, var);
2243
2244 if (so_info)
2245 emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info, local_so_info);
2246 nir_foreach_variable(var, &s->uniforms)
2247 emit_uniform(&ctx, var);
2248
2249 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2250 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2251 SpvExecutionModeOriginUpperLeft);
2252 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2253 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2254 SpvExecutionModeDepthReplacing);
2255 }
2256
2257 if (so_info && so_info->num_outputs) {
2258 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTransformFeedback);
2259 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2260 SpvExecutionModeXfb);
2261 }
2262
2263 spirv_builder_function(&ctx.builder, entry_point, type_void,
2264 SpvFunctionControlMaskNone,
2265 type_main);
2266
2267 nir_function_impl *entry = nir_shader_get_entrypoint(s);
2268 nir_metadata_require(entry, nir_metadata_block_index);
2269
2270 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
2271 if (!ctx.defs)
2272 goto fail;
2273 ctx.num_defs = entry->ssa_alloc;
2274
2275 nir_index_local_regs(entry);
2276 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
2277 if (!ctx.regs)
2278 goto fail;
2279 ctx.num_regs = entry->reg_alloc;
2280
2281 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
2282 if (!block_ids)
2283 goto fail;
2284
2285 for (int i = 0; i < entry->num_blocks; ++i)
2286 block_ids[i] = spirv_builder_new_id(&ctx.builder);
2287
2288 ctx.block_ids = block_ids;
2289 ctx.num_blocks = entry->num_blocks;
2290
2291 /* emit a block only for the variable declarations */
2292 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2293 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2294 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
2295 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2296 SpvStorageClassFunction,
2297 type);
2298 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2299 SpvStorageClassFunction);
2300
2301 ctx.regs[reg->index] = var;
2302 }
2303
2304 emit_cf_list(&ctx, &entry->body);
2305
2306 free(ctx.defs);
2307
2308 if (so_info)
2309 emit_so_outputs(&ctx, so_info, local_so_info);
2310
2311 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2312 spirv_builder_function_end(&ctx.builder);
2313
2314 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2315 "main", ctx.entry_ifaces,
2316 ctx.num_entry_ifaces);
2317
2318 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2319
2320 ret = CALLOC_STRUCT(spirv_shader);
2321 if (!ret)
2322 goto fail;
2323
2324 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2325 if (!ret->words)
2326 goto fail;
2327
2328 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2329 assert(ret->num_words == num_words);
2330
2331 return ret;
2332
2333 fail:
2334
2335 if (ret)
2336 spirv_shader_delete(ret);
2337
2338 if (ctx.vars)
2339 _mesa_hash_table_destroy(ctx.vars, NULL);
2340
2341 if (ctx.so_outputs)
2342 _mesa_hash_table_destroy(ctx.so_outputs, NULL);
2343
2344 return NULL;
2345 }
2346
2347 void
2348 spirv_shader_delete(struct spirv_shader *s)
2349 {
2350 FREE(s->words);
2351 FREE(s);
2352 }