zink: use OpFUnordNotEqual for nir_op_fne
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 /* this consistently maps slots to a zero-indexed value to avoid wasting slots */
33 static unsigned slot_pack_map[] = {
34 /* Position is builtin */
35 [VARYING_SLOT_POS] = UINT_MAX,
36 [VARYING_SLOT_COL0] = 0, /* input/output */
37 [VARYING_SLOT_COL1] = 1, /* input/output */
38 [VARYING_SLOT_FOGC] = 2, /* input/output */
39 /* TEX0-7 are translated to VAR0-7 by nir, so we don't need to reserve */
40 [VARYING_SLOT_TEX0] = UINT_MAX, /* input/output */
41 [VARYING_SLOT_TEX1] = UINT_MAX,
42 [VARYING_SLOT_TEX2] = UINT_MAX,
43 [VARYING_SLOT_TEX3] = UINT_MAX,
44 [VARYING_SLOT_TEX4] = UINT_MAX,
45 [VARYING_SLOT_TEX5] = UINT_MAX,
46 [VARYING_SLOT_TEX6] = UINT_MAX,
47 [VARYING_SLOT_TEX7] = UINT_MAX,
48
49 /* PointSize is builtin */
50 [VARYING_SLOT_PSIZ] = UINT_MAX,
51
52 [VARYING_SLOT_BFC0] = 3, /* output only */
53 [VARYING_SLOT_BFC1] = 4, /* output only */
54 [VARYING_SLOT_EDGE] = 5, /* output only */
55 [VARYING_SLOT_CLIP_VERTEX] = 6, /* output only */
56
57 /* ClipDistance is builtin */
58 [VARYING_SLOT_CLIP_DIST0] = UINT_MAX,
59 [VARYING_SLOT_CLIP_DIST1] = UINT_MAX,
60
61 /* CullDistance is builtin */
62 [VARYING_SLOT_CULL_DIST0] = UINT_MAX, /* input/output */
63 [VARYING_SLOT_CULL_DIST1] = UINT_MAX, /* never actually used */
64
65 /* PrimitiveId is builtin */
66 [VARYING_SLOT_PRIMITIVE_ID] = UINT_MAX,
67
68 /* Layer is builtin */
69 [VARYING_SLOT_LAYER] = UINT_MAX, /* input/output */
70
71 /* ViewportIndex is builtin */
72 [VARYING_SLOT_VIEWPORT] = UINT_MAX, /* input/output */
73
74 /* FrontFacing is builtin */
75 [VARYING_SLOT_FACE] = UINT_MAX,
76
77 /* PointCoord is builtin */
78 [VARYING_SLOT_PNTC] = UINT_MAX, /* input only */
79
80 /* TessLevelOuter is builtin */
81 [VARYING_SLOT_TESS_LEVEL_OUTER] = UINT_MAX,
82 /* TessLevelInner is builtin */
83 [VARYING_SLOT_TESS_LEVEL_INNER] = UINT_MAX,
84
85 [VARYING_SLOT_BOUNDING_BOX0] = 7, /* Only appears as TCS output. */
86 [VARYING_SLOT_BOUNDING_BOX1] = 8, /* Only appears as TCS output. */
87 [VARYING_SLOT_VIEW_INDEX] = 9, /* input/output */
88 [VARYING_SLOT_VIEWPORT_MASK] = 10, /* output only */
89 };
90 #define NTV_MIN_RESERVED_SLOTS 11
91
92 struct ntv_context {
93 struct spirv_builder builder;
94
95 SpvId GLSL_std_450;
96
97 gl_shader_stage stage;
98
99 SpvId ubos[128];
100 size_t num_ubos;
101 SpvId image_types[PIPE_MAX_SAMPLERS];
102 SpvId samplers[PIPE_MAX_SAMPLERS];
103 unsigned samplers_used : PIPE_MAX_SAMPLERS;
104 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
105 size_t num_entry_ifaces;
106
107 SpvId *defs;
108 size_t num_defs;
109
110 SpvId *regs;
111 size_t num_regs;
112
113 struct hash_table *vars; /* nir_variable -> SpvId */
114 struct hash_table *so_outputs; /* pipe_stream_output -> SpvId */
115 unsigned outputs[VARYING_SLOT_MAX];
116 const struct glsl_type *so_output_gl_types[VARYING_SLOT_MAX];
117 SpvId so_output_types[VARYING_SLOT_MAX];
118
119 const SpvId *block_ids;
120 size_t num_blocks;
121 bool block_started;
122 SpvId loop_break, loop_cont;
123
124 SpvId front_face_var, instance_id_var, vertex_id_var;
125 };
126
127 static SpvId
128 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
129 unsigned num_components, float value);
130
131 static SpvId
132 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
133 unsigned num_components, uint32_t value);
134
135 static SpvId
136 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
137 unsigned num_components, int32_t value);
138
139 static SpvId
140 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
141
142 static SpvId
143 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
144 SpvId src0, SpvId src1);
145
146 static SpvId
147 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
148 SpvId src0, SpvId src1, SpvId src2);
149
150 static SpvId
151 get_bvec_type(struct ntv_context *ctx, int num_components)
152 {
153 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
154 if (num_components > 1)
155 return spirv_builder_type_vector(&ctx->builder, bool_type,
156 num_components);
157
158 assert(num_components == 1);
159 return bool_type;
160 }
161
162 static SpvId
163 block_label(struct ntv_context *ctx, nir_block *block)
164 {
165 assert(block->index < ctx->num_blocks);
166 return ctx->block_ids[block->index];
167 }
168
169 static SpvId
170 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
171 {
172 assert(bit_size == 32);
173 return spirv_builder_const_float(&ctx->builder, bit_size, value);
174 }
175
176 static SpvId
177 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
178 {
179 assert(bit_size == 32);
180 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
181 }
182
183 static SpvId
184 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
185 {
186 assert(bit_size == 32);
187 return spirv_builder_const_int(&ctx->builder, bit_size, value);
188 }
189
190 static SpvId
191 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
192 {
193 assert(bit_size == 32); // only 32-bit floats supported so far
194
195 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
196 if (num_components > 1)
197 return spirv_builder_type_vector(&ctx->builder, float_type,
198 num_components);
199
200 assert(num_components == 1);
201 return float_type;
202 }
203
204 static SpvId
205 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
206 {
207 assert(bit_size == 32); // only 32-bit ints supported so far
208
209 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
210 if (num_components > 1)
211 return spirv_builder_type_vector(&ctx->builder, int_type,
212 num_components);
213
214 assert(num_components == 1);
215 return int_type;
216 }
217
218 static SpvId
219 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
220 {
221 assert(bit_size == 32); // only 32-bit uints supported so far
222
223 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
224 if (num_components > 1)
225 return spirv_builder_type_vector(&ctx->builder, uint_type,
226 num_components);
227
228 assert(num_components == 1);
229 return uint_type;
230 }
231
232 static SpvId
233 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
234 {
235 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
236 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
237 }
238
239 static SpvId
240 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
241 {
242 switch (type) {
243 case GLSL_TYPE_BOOL:
244 return spirv_builder_type_bool(&ctx->builder);
245
246 case GLSL_TYPE_FLOAT:
247 return spirv_builder_type_float(&ctx->builder, 32);
248
249 case GLSL_TYPE_INT:
250 return spirv_builder_type_int(&ctx->builder, 32);
251
252 case GLSL_TYPE_UINT:
253 return spirv_builder_type_uint(&ctx->builder, 32);
254 /* TODO: handle more types */
255
256 default:
257 unreachable("unknown GLSL type");
258 }
259 }
260
261 static SpvId
262 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
263 {
264 assert(type);
265 if (glsl_type_is_scalar(type))
266 return get_glsl_basetype(ctx, glsl_get_base_type(type));
267
268 if (glsl_type_is_vector(type))
269 return spirv_builder_type_vector(&ctx->builder,
270 get_glsl_basetype(ctx, glsl_get_base_type(type)),
271 glsl_get_vector_elements(type));
272
273 if (glsl_type_is_array(type)) {
274 SpvId ret = spirv_builder_type_array(&ctx->builder,
275 get_glsl_type(ctx, glsl_get_array_element(type)),
276 emit_uint_const(ctx, 32, glsl_get_length(type)));
277 uint32_t stride = glsl_get_explicit_stride(type);
278 if (stride)
279 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
280 return ret;
281 }
282
283
284 unreachable("we shouldn't get here, I think...");
285 }
286
287 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
288 case VARYING_SLOT_##SLOT: \
289 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
290 break
291
292
293 static void
294 emit_input(struct ntv_context *ctx, struct nir_variable *var)
295 {
296 SpvId var_type = get_glsl_type(ctx, var->type);
297 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
298 SpvStorageClassInput,
299 var_type);
300 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
301 SpvStorageClassInput);
302
303 if (var->name)
304 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
305
306 if (ctx->stage == MESA_SHADER_FRAGMENT) {
307 unsigned slot = var->data.location;
308 switch (slot) {
309 HANDLE_EMIT_BUILTIN(POS, FragCoord);
310 HANDLE_EMIT_BUILTIN(PNTC, PointCoord);
311 HANDLE_EMIT_BUILTIN(LAYER, Layer);
312 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
313 HANDLE_EMIT_BUILTIN(CLIP_DIST0, ClipDistance);
314 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
315 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
316 HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
317
318 default:
319 if (slot < VARYING_SLOT_VAR0) {
320 slot = slot_pack_map[slot];
321 if (slot == UINT_MAX)
322 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 } else
324 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
325 assert(slot < VARYING_SLOT_VAR0);
326 spirv_builder_emit_location(&ctx->builder, var_id, slot);
327 }
328 } else {
329 spirv_builder_emit_location(&ctx->builder, var_id,
330 var->data.driver_location);
331 }
332
333 if (var->data.location_frac)
334 spirv_builder_emit_component(&ctx->builder, var_id,
335 var->data.location_frac);
336
337 if (var->data.interpolation == INTERP_MODE_FLAT)
338 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
339
340 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
341
342 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
343 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
344 }
345
346 static void
347 emit_output(struct ntv_context *ctx, struct nir_variable *var)
348 {
349 SpvId var_type = get_glsl_type(ctx, var->type);
350 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
351 SpvStorageClassOutput,
352 var_type);
353 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
354 SpvStorageClassOutput);
355 if (var->name)
356 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
357
358
359 if (ctx->stage == MESA_SHADER_VERTEX) {
360 unsigned slot = var->data.location;
361 switch (slot) {
362 HANDLE_EMIT_BUILTIN(POS, Position);
363 HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
364 HANDLE_EMIT_BUILTIN(LAYER, Layer);
365 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
366 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
367 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
368 HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
369 HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
370
371 case VARYING_SLOT_CLIP_DIST0:
372 assert(glsl_type_is_array(var->type));
373 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
374 /* this can be as large as 2x vec4, which requires 2 slots */
375 ctx->outputs[VARYING_SLOT_CLIP_DIST1] = var_id;
376 ctx->so_output_gl_types[VARYING_SLOT_CLIP_DIST1] = var->type;
377 ctx->so_output_types[VARYING_SLOT_CLIP_DIST1] = var_type;
378 break;
379
380 default:
381 if (slot < VARYING_SLOT_VAR0) {
382 slot = slot_pack_map[slot];
383 if (slot == UINT_MAX)
384 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
385 } else
386 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
387 assert(slot < VARYING_SLOT_VAR0);
388 spirv_builder_emit_location(&ctx->builder, var_id, slot);
389 /* non-builtins get location incremented by VARYING_SLOT_VAR0 in vtn, so
390 * use driver_location for non-builtins with defined slots to avoid overlap
391 */
392 }
393 ctx->outputs[var->data.location] = var_id;
394 ctx->so_output_gl_types[var->data.location] = var->type;
395 ctx->so_output_types[var->data.location] = var_type;
396 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
397 if (var->data.location >= FRAG_RESULT_DATA0)
398 spirv_builder_emit_location(&ctx->builder, var_id,
399 var->data.location - FRAG_RESULT_DATA0);
400 else {
401 switch (var->data.location) {
402 case FRAG_RESULT_COLOR:
403 spirv_builder_emit_location(&ctx->builder, var_id, 0);
404 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
405 break;
406
407 case FRAG_RESULT_DEPTH:
408 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
409 break;
410
411 default:
412 spirv_builder_emit_location(&ctx->builder, var_id,
413 var->data.driver_location);
414 }
415 }
416 }
417
418 if (var->data.location_frac)
419 spirv_builder_emit_component(&ctx->builder, var_id,
420 var->data.location_frac);
421
422 switch (var->data.interpolation) {
423 case INTERP_MODE_NONE:
424 case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
425 break;
426 case INTERP_MODE_FLAT:
427 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
428 break;
429 case INTERP_MODE_EXPLICIT:
430 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
431 break;
432 case INTERP_MODE_NOPERSPECTIVE:
433 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
434 break;
435 default:
436 unreachable("unknown interpolation value");
437 }
438
439 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
440
441 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
442 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
443 }
444
445 static SpvDim
446 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
447 {
448 *is_ms = false;
449 switch (gdim) {
450 case GLSL_SAMPLER_DIM_1D:
451 return SpvDim1D;
452 case GLSL_SAMPLER_DIM_2D:
453 return SpvDim2D;
454 case GLSL_SAMPLER_DIM_3D:
455 return SpvDim3D;
456 case GLSL_SAMPLER_DIM_CUBE:
457 return SpvDimCube;
458 case GLSL_SAMPLER_DIM_RECT:
459 return SpvDim2D;
460 case GLSL_SAMPLER_DIM_BUF:
461 return SpvDimBuffer;
462 case GLSL_SAMPLER_DIM_EXTERNAL:
463 return SpvDim2D; /* seems dodgy... */
464 case GLSL_SAMPLER_DIM_MS:
465 *is_ms = true;
466 return SpvDim2D;
467 default:
468 fprintf(stderr, "unknown sampler type %d\n", gdim);
469 break;
470 }
471 return SpvDim2D;
472 }
473
474 uint32_t
475 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
476 {
477 if (stage == MESA_SHADER_NONE ||
478 stage >= MESA_SHADER_COMPUTE) {
479 unreachable("not supported");
480 } else {
481 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
482 PIPE_MAX_SHADER_SAMPLER_VIEWS);
483
484 switch (type) {
485 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
486 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
487 return stage_offset + index;
488
489 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
490 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
491 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
492
493 default:
494 unreachable("unexpected type");
495 }
496 }
497 }
498
499 static void
500 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
501 {
502 const struct glsl_type *type = glsl_without_array(var->type);
503
504 bool is_ms;
505 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
506
507 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
508 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
509 dimension, false,
510 glsl_sampler_type_is_array(type),
511 is_ms, 1,
512 SpvImageFormatUnknown);
513
514 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
515 image_type);
516 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
517 SpvStorageClassUniformConstant,
518 sampled_type);
519
520 if (glsl_type_is_array(var->type)) {
521 for (int i = 0; i < glsl_get_length(var->type); ++i) {
522 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
523 SpvStorageClassUniformConstant);
524
525 if (var->name) {
526 char element_name[100];
527 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
528 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
529 }
530
531 int index = var->data.binding + i;
532 assert(!(ctx->samplers_used & (1 << index)));
533 assert(!ctx->image_types[index]);
534 ctx->image_types[index] = image_type;
535 ctx->samplers[index] = var_id;
536 ctx->samplers_used |= 1 << index;
537
538 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
539 var->data.descriptor_set);
540 int binding = zink_binding(ctx->stage,
541 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
542 var->data.binding + i);
543 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
544 }
545 } else {
546 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
547 SpvStorageClassUniformConstant);
548
549 if (var->name)
550 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
551
552 int index = var->data.binding;
553 assert(!(ctx->samplers_used & (1 << index)));
554 assert(!ctx->image_types[index]);
555 ctx->image_types[index] = image_type;
556 ctx->samplers[index] = var_id;
557 ctx->samplers_used |= 1 << index;
558
559 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
560 var->data.descriptor_set);
561 int binding = zink_binding(ctx->stage,
562 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
563 var->data.binding);
564 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
565 }
566 }
567
568 static void
569 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
570 {
571 uint32_t size = glsl_count_attribute_slots(var->type, false);
572 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
573 SpvId array_length = emit_uint_const(ctx, 32, size);
574 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
575 array_length);
576 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
577
578 // wrap UBO-array in a struct
579 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
580 if (var->name) {
581 char struct_name[100];
582 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
583 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
584 }
585
586 spirv_builder_emit_decoration(&ctx->builder, struct_type,
587 SpvDecorationBlock);
588 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
589
590
591 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
592 SpvStorageClassUniform,
593 struct_type);
594
595 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
596 SpvStorageClassUniform);
597 if (var->name)
598 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
599
600 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
601 ctx->ubos[ctx->num_ubos++] = var_id;
602
603 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
604 var->data.descriptor_set);
605 int binding = zink_binding(ctx->stage,
606 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
607 var->data.binding);
608 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
609 }
610
611 static void
612 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
613 {
614 if (var->data.mode == nir_var_mem_ubo)
615 emit_ubo(ctx, var);
616 else {
617 assert(var->data.mode == nir_var_uniform);
618 if (glsl_type_is_sampler(glsl_without_array(var->type)))
619 emit_sampler(ctx, var);
620 }
621 }
622
623 static SpvId
624 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
625 {
626 assert(ssa->index < ctx->num_defs);
627 assert(ctx->defs[ssa->index] != 0);
628 return ctx->defs[ssa->index];
629 }
630
631 static SpvId
632 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
633 {
634 assert(reg->index < ctx->num_regs);
635 assert(ctx->regs[reg->index] != 0);
636 return ctx->regs[reg->index];
637 }
638
639 static SpvId
640 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
641 {
642 assert(reg->reg);
643 assert(!reg->indirect);
644 assert(!reg->base_offset);
645
646 SpvId var = get_var_from_reg(ctx, reg->reg);
647 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
648 return spirv_builder_emit_load(&ctx->builder, type, var);
649 }
650
651 static SpvId
652 get_src(struct ntv_context *ctx, nir_src *src)
653 {
654 if (src->is_ssa)
655 return get_src_ssa(ctx, src->ssa);
656 else
657 return get_src_reg(ctx, &src->reg);
658 }
659
660 static SpvId
661 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
662 {
663 assert(!alu->src[src].negate);
664 assert(!alu->src[src].abs);
665
666 SpvId def = get_src(ctx, &alu->src[src].src);
667
668 unsigned used_channels = 0;
669 bool need_swizzle = false;
670 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
671 if (!nir_alu_instr_channel_used(alu, src, i))
672 continue;
673
674 used_channels++;
675
676 if (alu->src[src].swizzle[i] != i)
677 need_swizzle = true;
678 }
679 assert(used_channels != 0);
680
681 unsigned live_channels = nir_src_num_components(alu->src[src].src);
682 if (used_channels != live_channels)
683 need_swizzle = true;
684
685 if (!need_swizzle)
686 return def;
687
688 int bit_size = nir_src_bit_size(alu->src[src].src);
689 assert(bit_size == 1 || bit_size == 32);
690
691 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
692 spirv_builder_type_uint(&ctx->builder, bit_size);
693
694 if (used_channels == 1) {
695 uint32_t indices[] = { alu->src[src].swizzle[0] };
696 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
697 def, indices,
698 ARRAY_SIZE(indices));
699 } else if (live_channels == 1) {
700 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
701 raw_type,
702 used_channels);
703
704 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
705 for (unsigned i = 0; i < used_channels; ++i)
706 constituents[i] = def;
707
708 return spirv_builder_emit_composite_construct(&ctx->builder,
709 raw_vec_type,
710 constituents,
711 used_channels);
712 } else {
713 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
714 raw_type,
715 used_channels);
716
717 uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
718 size_t num_components = 0;
719 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
720 if (!nir_alu_instr_channel_used(alu, src, i))
721 continue;
722
723 components[num_components++] = alu->src[src].swizzle[i];
724 }
725
726 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
727 def, def, components,
728 num_components);
729 }
730 }
731
732 static void
733 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
734 {
735 assert(result != 0);
736 assert(ssa->index < ctx->num_defs);
737 ctx->defs[ssa->index] = result;
738 }
739
740 static SpvId
741 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
742 SpvId if_true, SpvId if_false)
743 {
744 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
745 }
746
747 static SpvId
748 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
749 {
750 SpvId type = get_bvec_type(ctx, num_components);
751 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
752 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
753 }
754
755 static SpvId
756 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
757 {
758 return emit_unop(ctx, SpvOpBitcast, type, value);
759 }
760
761 static SpvId
762 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
763 unsigned num_components)
764 {
765 SpvId type = get_uvec_type(ctx, bit_size, num_components);
766 return emit_bitcast(ctx, type, value);
767 }
768
769 static SpvId
770 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
771 unsigned num_components)
772 {
773 SpvId type = get_ivec_type(ctx, bit_size, num_components);
774 return emit_bitcast(ctx, type, value);
775 }
776
777 static SpvId
778 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
779 unsigned num_components)
780 {
781 SpvId type = get_fvec_type(ctx, bit_size, num_components);
782 return emit_bitcast(ctx, type, value);
783 }
784
785 static void
786 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
787 {
788 SpvId var = get_var_from_reg(ctx, reg->reg);
789 assert(var);
790 spirv_builder_emit_store(&ctx->builder, var, result);
791 }
792
793 static void
794 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
795 {
796 if (dest->is_ssa)
797 store_ssa_def(ctx, &dest->ssa, result);
798 else
799 store_reg_def(ctx, &dest->reg, result);
800 }
801
802 static SpvId
803 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
804 {
805 unsigned num_components = nir_dest_num_components(*dest);
806 unsigned bit_size = nir_dest_bit_size(*dest);
807
808 if (bit_size != 1) {
809 switch (nir_alu_type_get_base_type(type)) {
810 case nir_type_bool:
811 assert("bool should have bit-size 1");
812
813 case nir_type_uint:
814 break; /* nothing to do! */
815
816 case nir_type_int:
817 case nir_type_float:
818 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
819 break;
820
821 default:
822 unreachable("unsupported nir_alu_type");
823 }
824 }
825
826 store_dest_raw(ctx, dest, result);
827 return result;
828 }
829
830 static SpvId
831 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
832 {
833 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
834 }
835
836 /* return the intended xfb output vec type based on base type and vector size */
837 static SpvId
838 get_output_type(struct ntv_context *ctx, unsigned register_index, unsigned num_components)
839 {
840 const struct glsl_type *out_type = ctx->so_output_gl_types[register_index];
841 enum glsl_base_type base_type = glsl_get_base_type(out_type);
842 if (base_type == GLSL_TYPE_ARRAY)
843 base_type = glsl_get_base_type(glsl_without_array(out_type));
844
845 switch (base_type) {
846 case GLSL_TYPE_BOOL:
847 return get_bvec_type(ctx, num_components);
848
849 case GLSL_TYPE_FLOAT:
850 return get_fvec_type(ctx, 32, num_components);
851
852 case GLSL_TYPE_INT:
853 return get_ivec_type(ctx, 32, num_components);
854
855 case GLSL_TYPE_UINT:
856 return get_uvec_type(ctx, 32, num_components);
857
858 default:
859 break;
860 }
861 unreachable("unknown type");
862 return 0;
863 }
864
865 /* for streamout create new outputs, as streamout can be done on individual components,
866 from complete outputs, so we just can't use the created packed outputs */
867 static void
868 emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
869 const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
870 {
871 for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
872 struct pipe_stream_output so_output = local_so_info->output[i];
873 SpvId out_type = get_output_type(ctx, so_output.register_index, so_output.num_components);
874 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
875 SpvStorageClassOutput,
876 out_type);
877 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
878 SpvStorageClassOutput);
879 char name[10];
880
881 snprintf(name, 10, "xfb%d", i);
882 spirv_builder_emit_name(&ctx->builder, var_id, name);
883 spirv_builder_emit_offset(&ctx->builder, var_id, (so_output.dst_offset * 4));
884 spirv_builder_emit_xfb_buffer(&ctx->builder, var_id, so_output.output_buffer);
885 spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->stride[so_output.output_buffer] * 4);
886
887 /* output location is incremented by VARYING_SLOT_VAR0 for non-builtins in vtn,
888 * so we need to ensure that the new xfb location slot doesn't conflict with any previously-emitted
889 * outputs.
890 *
891 * if there's no previous outputs that take up user slots (VAR0+) then we can start right after the
892 * glsl builtin reserved slots, otherwise we start just after the adjusted user output slot
893 */
894 uint32_t location = NTV_MIN_RESERVED_SLOTS + i;
895 if (max_output_location >= VARYING_SLOT_VAR0)
896 location = max_output_location - VARYING_SLOT_VAR0 + 1 + i;
897 assert(location < VARYING_SLOT_VAR0);
898 spirv_builder_emit_location(&ctx->builder, var_id, location);
899
900 /* note: gl_ClipDistance[4] can the 0-indexed member of VARYING_SLOT_CLIP_DIST1 here,
901 * so this is still the 0 component
902 */
903 if (so_output.start_component)
904 spirv_builder_emit_component(&ctx->builder, var_id, so_output.start_component);
905
906 uint32_t *key = ralloc_size(NULL, sizeof(uint32_t));
907 *key = (uint32_t)so_output.register_index << 2 | so_output.start_component;
908 _mesa_hash_table_insert(ctx->so_outputs, key, (void *)(intptr_t)var_id);
909
910 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
911 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
912 }
913 }
914
915 static void
916 emit_so_outputs(struct ntv_context *ctx,
917 const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
918 {
919 SpvId loaded_outputs[VARYING_SLOT_MAX] = {};
920 for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
921 uint32_t components[NIR_MAX_VEC_COMPONENTS];
922 struct pipe_stream_output so_output = local_so_info->output[i];
923 uint32_t so_key = (uint32_t) so_output.register_index << 2 | so_output.start_component;
924 struct hash_entry *he = _mesa_hash_table_search(ctx->so_outputs, &so_key);
925 assert(he);
926 SpvId so_output_var_id = (SpvId)(intptr_t)he->data;
927
928 SpvId type = get_output_type(ctx, so_output.register_index, so_output.num_components);
929 SpvId output = ctx->outputs[so_output.register_index];
930 SpvId output_type = ctx->so_output_types[so_output.register_index];
931 const struct glsl_type *out_type = ctx->so_output_gl_types[so_output.register_index];
932
933 if (!loaded_outputs[so_output.register_index])
934 loaded_outputs[so_output.register_index] = spirv_builder_emit_load(&ctx->builder, output_type, output);
935 SpvId src = loaded_outputs[so_output.register_index];
936
937 SpvId result;
938
939 for (unsigned c = 0; c < so_output.num_components; c++) {
940 components[c] = so_output.start_component + c;
941 /* this is the second half of a 2 * vec4 array */
942 if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
943 components[c] += 4;
944 }
945
946 /* if we're emitting a scalar or the type we're emitting matches the output's original type and we're
947 * emitting the same number of components, then we can skip any sort of conversion here
948 */
949 if (glsl_type_is_scalar(out_type) || (type == output_type && glsl_get_length(out_type) == so_output.num_components))
950 result = src;
951 else {
952 /* OpCompositeExtract can only extract scalars for our use here */
953 if (so_output.num_components == 1) {
954 result = spirv_builder_emit_composite_extract(&ctx->builder, type, src, components, so_output.num_components);
955 } else if (glsl_type_is_vector(out_type)) {
956 /* OpVectorShuffle can select vector members into a differently-sized vector */
957 result = spirv_builder_emit_vector_shuffle(&ctx->builder, type,
958 src, src,
959 components, so_output.num_components);
960 result = emit_unop(ctx, SpvOpBitcast, type, result);
961 } else {
962 /* for arrays, we need to manually extract each desired member
963 * and re-pack them into the desired output type
964 */
965 for (unsigned c = 0; c < so_output.num_components; c++) {
966 uint32_t member[] = { so_output.start_component + c };
967 SpvId base_type = get_glsl_type(ctx, glsl_without_array(out_type));
968
969 if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
970 member[0] += 4;
971 components[c] = spirv_builder_emit_composite_extract(&ctx->builder, base_type, src, member, 1);
972 }
973 result = spirv_builder_emit_composite_construct(&ctx->builder, type, components, so_output.num_components);
974 }
975 }
976
977 spirv_builder_emit_store(&ctx->builder, so_output_var_id, result);
978 }
979 }
980
981 static SpvId
982 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
983 SpvId src0, SpvId src1)
984 {
985 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
986 }
987
988 static SpvId
989 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
990 SpvId src0, SpvId src1, SpvId src2)
991 {
992 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
993 }
994
995 static SpvId
996 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
997 SpvId src)
998 {
999 SpvId args[] = { src };
1000 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1001 op, args, ARRAY_SIZE(args));
1002 }
1003
1004 static SpvId
1005 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1006 SpvId src0, SpvId src1)
1007 {
1008 SpvId args[] = { src0, src1 };
1009 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1010 op, args, ARRAY_SIZE(args));
1011 }
1012
1013 static SpvId
1014 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1015 SpvId src0, SpvId src1, SpvId src2)
1016 {
1017 SpvId args[] = { src0, src1, src2 };
1018 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1019 op, args, ARRAY_SIZE(args));
1020 }
1021
1022 static SpvId
1023 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
1024 unsigned num_components, float value)
1025 {
1026 assert(bit_size == 32);
1027
1028 SpvId result = emit_float_const(ctx, bit_size, value);
1029 if (num_components == 1)
1030 return result;
1031
1032 assert(num_components > 1);
1033 SpvId components[num_components];
1034 for (int i = 0; i < num_components; i++)
1035 components[i] = result;
1036
1037 SpvId type = get_fvec_type(ctx, bit_size, num_components);
1038 return spirv_builder_const_composite(&ctx->builder, type, components,
1039 num_components);
1040 }
1041
1042 static SpvId
1043 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
1044 unsigned num_components, uint32_t value)
1045 {
1046 assert(bit_size == 32);
1047
1048 SpvId result = emit_uint_const(ctx, bit_size, value);
1049 if (num_components == 1)
1050 return result;
1051
1052 assert(num_components > 1);
1053 SpvId components[num_components];
1054 for (int i = 0; i < num_components; i++)
1055 components[i] = result;
1056
1057 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1058 return spirv_builder_const_composite(&ctx->builder, type, components,
1059 num_components);
1060 }
1061
1062 static SpvId
1063 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
1064 unsigned num_components, int32_t value)
1065 {
1066 assert(bit_size == 32);
1067
1068 SpvId result = emit_int_const(ctx, bit_size, value);
1069 if (num_components == 1)
1070 return result;
1071
1072 assert(num_components > 1);
1073 SpvId components[num_components];
1074 for (int i = 0; i < num_components; i++)
1075 components[i] = result;
1076
1077 SpvId type = get_ivec_type(ctx, bit_size, num_components);
1078 return spirv_builder_const_composite(&ctx->builder, type, components,
1079 num_components);
1080 }
1081
1082 static inline unsigned
1083 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
1084 {
1085 if (nir_op_infos[instr->op].input_sizes[src] > 0)
1086 return nir_op_infos[instr->op].input_sizes[src];
1087
1088 if (instr->dest.dest.is_ssa)
1089 return instr->dest.dest.ssa.num_components;
1090 else
1091 return instr->dest.dest.reg.reg->num_components;
1092 }
1093
1094 static SpvId
1095 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
1096 {
1097 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
1098
1099 unsigned num_components = alu_instr_src_components(alu, src);
1100 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
1101 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
1102
1103 if (bit_size == 1)
1104 return raw_value;
1105 else {
1106 switch (nir_alu_type_get_base_type(type)) {
1107 case nir_type_bool:
1108 unreachable("bool should have bit-size 1");
1109
1110 case nir_type_int:
1111 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
1112
1113 case nir_type_uint:
1114 return raw_value;
1115
1116 case nir_type_float:
1117 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
1118
1119 default:
1120 unreachable("unknown nir_alu_type");
1121 }
1122 }
1123 }
1124
1125 static SpvId
1126 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
1127 {
1128 assert(!alu->dest.saturate);
1129 return store_dest(ctx, &alu->dest.dest, result,
1130 nir_op_infos[alu->op].output_type);
1131 }
1132
1133 static SpvId
1134 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
1135 {
1136 unsigned num_components = nir_dest_num_components(*dest);
1137 unsigned bit_size = nir_dest_bit_size(*dest);
1138
1139 if (bit_size == 1)
1140 return get_bvec_type(ctx, num_components);
1141
1142 switch (nir_alu_type_get_base_type(type)) {
1143 case nir_type_bool:
1144 unreachable("bool should have bit-size 1");
1145
1146 case nir_type_int:
1147 return get_ivec_type(ctx, bit_size, num_components);
1148
1149 case nir_type_uint:
1150 return get_uvec_type(ctx, bit_size, num_components);
1151
1152 case nir_type_float:
1153 return get_fvec_type(ctx, bit_size, num_components);
1154
1155 default:
1156 unreachable("unsupported nir_alu_type");
1157 }
1158 }
1159
1160 static void
1161 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
1162 {
1163 SpvId src[nir_op_infos[alu->op].num_inputs];
1164 unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
1165 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
1166 src[i] = get_alu_src(ctx, alu, i);
1167 in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
1168 }
1169
1170 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
1171 nir_op_infos[alu->op].output_type);
1172 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
1173 unsigned num_components = nir_dest_num_components(alu->dest.dest);
1174
1175 SpvId result = 0;
1176 switch (alu->op) {
1177 case nir_op_mov:
1178 assert(nir_op_infos[alu->op].num_inputs == 1);
1179 result = src[0];
1180 break;
1181
1182 #define UNOP(nir_op, spirv_op) \
1183 case nir_op: \
1184 assert(nir_op_infos[alu->op].num_inputs == 1); \
1185 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
1186 break;
1187
1188 UNOP(nir_op_ineg, SpvOpSNegate)
1189 UNOP(nir_op_fneg, SpvOpFNegate)
1190 UNOP(nir_op_fddx, SpvOpDPdx)
1191 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
1192 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
1193 UNOP(nir_op_fddy, SpvOpDPdy)
1194 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
1195 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
1196 UNOP(nir_op_f2i32, SpvOpConvertFToS)
1197 UNOP(nir_op_f2u32, SpvOpConvertFToU)
1198 UNOP(nir_op_i2f32, SpvOpConvertSToF)
1199 UNOP(nir_op_u2f32, SpvOpConvertUToF)
1200 UNOP(nir_op_bitfield_reverse, SpvOpBitReverse)
1201 #undef UNOP
1202
1203 case nir_op_inot:
1204 if (bit_size == 1)
1205 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
1206 else
1207 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
1208 break;
1209
1210 case nir_op_b2i32:
1211 assert(nir_op_infos[alu->op].num_inputs == 1);
1212 result = emit_select(ctx, dest_type, src[0],
1213 get_ivec_constant(ctx, 32, num_components, 1),
1214 get_ivec_constant(ctx, 32, num_components, 0));
1215 break;
1216
1217 case nir_op_b2f32:
1218 assert(nir_op_infos[alu->op].num_inputs == 1);
1219 result = emit_select(ctx, dest_type, src[0],
1220 get_fvec_constant(ctx, 32, num_components, 1),
1221 get_fvec_constant(ctx, 32, num_components, 0));
1222 break;
1223
1224 #define BUILTIN_UNOP(nir_op, spirv_op) \
1225 case nir_op: \
1226 assert(nir_op_infos[alu->op].num_inputs == 1); \
1227 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1228 break;
1229
1230 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1231 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1232 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1233 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1234 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1235 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1236 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1237 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1238 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1239 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1240 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1241 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1242 BUILTIN_UNOP(nir_op_isign, GLSLstd450SSign)
1243 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1244 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1245 #undef BUILTIN_UNOP
1246
1247 case nir_op_frcp:
1248 assert(nir_op_infos[alu->op].num_inputs == 1);
1249 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1250 get_fvec_constant(ctx, bit_size, num_components, 1),
1251 src[0]);
1252 break;
1253
1254 case nir_op_f2b1:
1255 assert(nir_op_infos[alu->op].num_inputs == 1);
1256 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1257 get_fvec_constant(ctx,
1258 nir_src_bit_size(alu->src[0].src),
1259 num_components, 0));
1260 break;
1261 case nir_op_i2b1:
1262 assert(nir_op_infos[alu->op].num_inputs == 1);
1263 result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1264 get_ivec_constant(ctx,
1265 nir_src_bit_size(alu->src[0].src),
1266 num_components, 0));
1267 break;
1268
1269
1270 #define BINOP(nir_op, spirv_op) \
1271 case nir_op: \
1272 assert(nir_op_infos[alu->op].num_inputs == 2); \
1273 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1274 break;
1275
1276 BINOP(nir_op_iadd, SpvOpIAdd)
1277 BINOP(nir_op_isub, SpvOpISub)
1278 BINOP(nir_op_imul, SpvOpIMul)
1279 BINOP(nir_op_idiv, SpvOpSDiv)
1280 BINOP(nir_op_udiv, SpvOpUDiv)
1281 BINOP(nir_op_umod, SpvOpUMod)
1282 BINOP(nir_op_fadd, SpvOpFAdd)
1283 BINOP(nir_op_fsub, SpvOpFSub)
1284 BINOP(nir_op_fmul, SpvOpFMul)
1285 BINOP(nir_op_fdiv, SpvOpFDiv)
1286 BINOP(nir_op_fmod, SpvOpFMod)
1287 BINOP(nir_op_ilt, SpvOpSLessThan)
1288 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1289 BINOP(nir_op_ult, SpvOpULessThan)
1290 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1291 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1292 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1293 BINOP(nir_op_feq, SpvOpFOrdEqual)
1294 BINOP(nir_op_fne, SpvOpFUnordNotEqual)
1295 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1296 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1297 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1298 BINOP(nir_op_ixor, SpvOpBitwiseXor)
1299 #undef BINOP
1300
1301 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1302 case nir_op: \
1303 assert(nir_op_infos[alu->op].num_inputs == 2); \
1304 if (nir_src_bit_size(alu->src[0].src) == 1) \
1305 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1306 else \
1307 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1308 break;
1309
1310 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1311 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1312 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1313 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1314 #undef BINOP_LOG
1315
1316 #define BUILTIN_BINOP(nir_op, spirv_op) \
1317 case nir_op: \
1318 assert(nir_op_infos[alu->op].num_inputs == 2); \
1319 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1320 break;
1321
1322 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1323 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1324 BUILTIN_BINOP(nir_op_imin, GLSLstd450SMin)
1325 BUILTIN_BINOP(nir_op_imax, GLSLstd450SMax)
1326 BUILTIN_BINOP(nir_op_umin, GLSLstd450UMin)
1327 BUILTIN_BINOP(nir_op_umax, GLSLstd450UMax)
1328 #undef BUILTIN_BINOP
1329
1330 case nir_op_fdot2:
1331 case nir_op_fdot3:
1332 case nir_op_fdot4:
1333 assert(nir_op_infos[alu->op].num_inputs == 2);
1334 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1335 break;
1336
1337 case nir_op_fdph:
1338 unreachable("should already be lowered away");
1339
1340 case nir_op_seq:
1341 case nir_op_sne:
1342 case nir_op_slt:
1343 case nir_op_sge: {
1344 assert(nir_op_infos[alu->op].num_inputs == 2);
1345 int num_components = nir_dest_num_components(alu->dest.dest);
1346 SpvId bool_type = get_bvec_type(ctx, num_components);
1347
1348 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1349 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1350 if (num_components > 1) {
1351 SpvId zero_comps[num_components], one_comps[num_components];
1352 for (int i = 0; i < num_components; i++) {
1353 zero_comps[i] = zero;
1354 one_comps[i] = one;
1355 }
1356
1357 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1358 zero_comps, num_components);
1359 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1360 one_comps, num_components);
1361 }
1362
1363 SpvOp op;
1364 switch (alu->op) {
1365 case nir_op_seq: op = SpvOpFOrdEqual; break;
1366 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1367 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1368 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1369 default: unreachable("unexpected op");
1370 }
1371
1372 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1373 result = emit_select(ctx, dest_type, result, one, zero);
1374 }
1375 break;
1376
1377 case nir_op_flrp:
1378 assert(nir_op_infos[alu->op].num_inputs == 3);
1379 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1380 src[0], src[1], src[2]);
1381 break;
1382
1383 case nir_op_fcsel:
1384 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1385 get_bvec_type(ctx, num_components),
1386 src[0],
1387 get_fvec_constant(ctx,
1388 nir_src_bit_size(alu->src[0].src),
1389 num_components, 0));
1390 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1391 break;
1392
1393 case nir_op_bcsel:
1394 assert(nir_op_infos[alu->op].num_inputs == 3);
1395 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1396 break;
1397
1398 case nir_op_bany_fnequal2:
1399 case nir_op_bany_fnequal3:
1400 case nir_op_bany_fnequal4: {
1401 assert(nir_op_infos[alu->op].num_inputs == 2);
1402 assert(alu_instr_src_components(alu, 0) ==
1403 alu_instr_src_components(alu, 1));
1404 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1405 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1406 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1407 result = emit_binop(ctx, op,
1408 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1409 src[0], src[1]);
1410 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1411 break;
1412 }
1413
1414 case nir_op_ball_fequal2:
1415 case nir_op_ball_fequal3:
1416 case nir_op_ball_fequal4: {
1417 assert(nir_op_infos[alu->op].num_inputs == 2);
1418 assert(alu_instr_src_components(alu, 0) ==
1419 alu_instr_src_components(alu, 1));
1420 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1421 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1422 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1423 result = emit_binop(ctx, op,
1424 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1425 src[0], src[1]);
1426 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1427 break;
1428 }
1429
1430 case nir_op_bany_inequal2:
1431 case nir_op_bany_inequal3:
1432 case nir_op_bany_inequal4: {
1433 assert(nir_op_infos[alu->op].num_inputs == 2);
1434 assert(alu_instr_src_components(alu, 0) ==
1435 alu_instr_src_components(alu, 1));
1436 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1437 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1438 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1439 result = emit_binop(ctx, op,
1440 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1441 src[0], src[1]);
1442 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1443 break;
1444 }
1445
1446 case nir_op_ball_iequal2:
1447 case nir_op_ball_iequal3:
1448 case nir_op_ball_iequal4: {
1449 assert(nir_op_infos[alu->op].num_inputs == 2);
1450 assert(alu_instr_src_components(alu, 0) ==
1451 alu_instr_src_components(alu, 1));
1452 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1453 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1454 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1455 result = emit_binop(ctx, op,
1456 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1457 src[0], src[1]);
1458 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1459 break;
1460 }
1461
1462 case nir_op_vec2:
1463 case nir_op_vec3:
1464 case nir_op_vec4: {
1465 int num_inputs = nir_op_infos[alu->op].num_inputs;
1466 assert(2 <= num_inputs && num_inputs <= 4);
1467 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1468 src, num_inputs);
1469 }
1470 break;
1471
1472 default:
1473 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1474 nir_op_infos[alu->op].name);
1475
1476 unreachable("unsupported opcode");
1477 return;
1478 }
1479
1480 store_alu_result(ctx, alu, result);
1481 }
1482
1483 static void
1484 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1485 {
1486 unsigned bit_size = load_const->def.bit_size;
1487 unsigned num_components = load_const->def.num_components;
1488
1489 SpvId constant;
1490 if (num_components > 1) {
1491 SpvId components[num_components];
1492 SpvId type;
1493 if (bit_size == 1) {
1494 for (int i = 0; i < num_components; i++)
1495 components[i] = spirv_builder_const_bool(&ctx->builder,
1496 load_const->value[i].b);
1497
1498 type = get_bvec_type(ctx, num_components);
1499 } else {
1500 for (int i = 0; i < num_components; i++)
1501 components[i] = emit_uint_const(ctx, bit_size,
1502 load_const->value[i].u32);
1503
1504 type = get_uvec_type(ctx, bit_size, num_components);
1505 }
1506 constant = spirv_builder_const_composite(&ctx->builder, type,
1507 components, num_components);
1508 } else {
1509 assert(num_components == 1);
1510 if (bit_size == 1)
1511 constant = spirv_builder_const_bool(&ctx->builder,
1512 load_const->value[0].b);
1513 else
1514 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1515 }
1516
1517 store_ssa_def(ctx, &load_const->def, constant);
1518 }
1519
1520 static void
1521 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1522 {
1523 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1524 assert(const_block_index); // no dynamic indexing for now
1525 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1526
1527 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1528 if (const_offset) {
1529 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1530 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1531 SpvStorageClassUniform,
1532 uvec4_type);
1533
1534 unsigned idx = const_offset->u32;
1535 SpvId member = emit_uint_const(ctx, 32, 0);
1536 SpvId offset = emit_uint_const(ctx, 32, idx);
1537 SpvId offsets[] = { member, offset };
1538 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1539 ctx->ubos[0], offsets,
1540 ARRAY_SIZE(offsets));
1541 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1542
1543 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1544 unsigned num_components = nir_dest_num_components(intr->dest);
1545 if (num_components == 1) {
1546 uint32_t components[] = { 0 };
1547 result = spirv_builder_emit_composite_extract(&ctx->builder,
1548 type,
1549 result, components,
1550 1);
1551 } else if (num_components < 4) {
1552 SpvId constituents[num_components];
1553 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1554 for (uint32_t i = 0; i < num_components; ++i)
1555 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1556 uint_type,
1557 result, &i,
1558 1);
1559
1560 result = spirv_builder_emit_composite_construct(&ctx->builder,
1561 type,
1562 constituents,
1563 num_components);
1564 }
1565
1566 if (nir_dest_bit_size(intr->dest) == 1)
1567 result = uvec_to_bvec(ctx, result, num_components);
1568
1569 store_dest(ctx, &intr->dest, result, nir_type_uint);
1570 } else
1571 unreachable("uniform-addressing not yet supported");
1572 }
1573
1574 static void
1575 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1576 {
1577 assert(ctx->block_started);
1578 spirv_builder_emit_kill(&ctx->builder);
1579 /* discard is weird in NIR, so let's just create an unreachable block after
1580 it and hope that the vulkan driver will DCE any instructinos in it. */
1581 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1582 }
1583
1584 static void
1585 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1586 {
1587 SpvId ptr = get_src(ctx, intr->src);
1588
1589 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1590 SpvId result = spirv_builder_emit_load(&ctx->builder,
1591 get_glsl_type(ctx, var->type),
1592 ptr);
1593 unsigned num_components = nir_dest_num_components(intr->dest);
1594 unsigned bit_size = nir_dest_bit_size(intr->dest);
1595 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1596 store_dest(ctx, &intr->dest, result, nir_type_uint);
1597 }
1598
1599 static void
1600 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1601 {
1602 SpvId ptr = get_src(ctx, &intr->src[0]);
1603 SpvId src = get_src(ctx, &intr->src[1]);
1604
1605 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1606 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1607 SpvId result = emit_bitcast(ctx, type, src);
1608 spirv_builder_emit_store(&ctx->builder, ptr, result);
1609 }
1610
1611 static SpvId
1612 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1613 SpvStorageClass storage_class,
1614 const char *name, SpvBuiltIn builtin)
1615 {
1616 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1617 storage_class,
1618 var_type);
1619 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1620 storage_class);
1621 spirv_builder_emit_name(&ctx->builder, var, name);
1622 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1623
1624 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1625 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1626 return var;
1627 }
1628
1629 static void
1630 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1631 {
1632 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1633 if (!ctx->front_face_var)
1634 ctx->front_face_var = create_builtin_var(ctx, var_type,
1635 SpvStorageClassInput,
1636 "gl_FrontFacing",
1637 SpvBuiltInFrontFacing);
1638
1639 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1640 ctx->front_face_var);
1641 assert(1 == nir_dest_num_components(intr->dest));
1642 store_dest(ctx, &intr->dest, result, nir_type_bool);
1643 }
1644
1645 static void
1646 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1647 {
1648 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1649 if (!ctx->instance_id_var)
1650 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1651 SpvStorageClassInput,
1652 "gl_InstanceId",
1653 SpvBuiltInInstanceIndex);
1654
1655 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1656 ctx->instance_id_var);
1657 assert(1 == nir_dest_num_components(intr->dest));
1658 store_dest(ctx, &intr->dest, result, nir_type_uint);
1659 }
1660
1661 static void
1662 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1663 {
1664 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1665 if (!ctx->vertex_id_var)
1666 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1667 SpvStorageClassInput,
1668 "gl_VertexID",
1669 SpvBuiltInVertexIndex);
1670
1671 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1672 ctx->vertex_id_var);
1673 assert(1 == nir_dest_num_components(intr->dest));
1674 store_dest(ctx, &intr->dest, result, nir_type_uint);
1675 }
1676
1677 static void
1678 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1679 {
1680 switch (intr->intrinsic) {
1681 case nir_intrinsic_load_ubo:
1682 emit_load_ubo(ctx, intr);
1683 break;
1684
1685 case nir_intrinsic_discard:
1686 emit_discard(ctx, intr);
1687 break;
1688
1689 case nir_intrinsic_load_deref:
1690 emit_load_deref(ctx, intr);
1691 break;
1692
1693 case nir_intrinsic_store_deref:
1694 emit_store_deref(ctx, intr);
1695 break;
1696
1697 case nir_intrinsic_load_front_face:
1698 emit_load_front_face(ctx, intr);
1699 break;
1700
1701 case nir_intrinsic_load_instance_id:
1702 emit_load_instance_id(ctx, intr);
1703 break;
1704
1705 case nir_intrinsic_load_vertex_id:
1706 emit_load_vertex_id(ctx, intr);
1707 break;
1708
1709 default:
1710 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1711 nir_intrinsic_infos[intr->intrinsic].name);
1712 unreachable("unsupported intrinsic");
1713 }
1714 }
1715
1716 static void
1717 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1718 {
1719 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1720 undef->def.num_components);
1721
1722 store_ssa_def(ctx, &undef->def,
1723 spirv_builder_emit_undef(&ctx->builder, type));
1724 }
1725
1726 static SpvId
1727 get_src_float(struct ntv_context *ctx, nir_src *src)
1728 {
1729 SpvId def = get_src(ctx, src);
1730 unsigned num_components = nir_src_num_components(*src);
1731 unsigned bit_size = nir_src_bit_size(*src);
1732 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1733 }
1734
1735 static SpvId
1736 get_src_int(struct ntv_context *ctx, nir_src *src)
1737 {
1738 SpvId def = get_src(ctx, src);
1739 unsigned num_components = nir_src_num_components(*src);
1740 unsigned bit_size = nir_src_bit_size(*src);
1741 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1742 }
1743
1744 static void
1745 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1746 {
1747 assert(tex->op == nir_texop_tex ||
1748 tex->op == nir_texop_txb ||
1749 tex->op == nir_texop_txl ||
1750 tex->op == nir_texop_txd ||
1751 tex->op == nir_texop_txf ||
1752 tex->op == nir_texop_txf_ms ||
1753 tex->op == nir_texop_txs);
1754 assert(tex->texture_index == tex->sampler_index);
1755
1756 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1757 offset = 0, sample = 0;
1758 unsigned coord_components = 0;
1759 for (unsigned i = 0; i < tex->num_srcs; i++) {
1760 switch (tex->src[i].src_type) {
1761 case nir_tex_src_coord:
1762 if (tex->op == nir_texop_txf ||
1763 tex->op == nir_texop_txf_ms)
1764 coord = get_src_int(ctx, &tex->src[i].src);
1765 else
1766 coord = get_src_float(ctx, &tex->src[i].src);
1767 coord_components = nir_src_num_components(tex->src[i].src);
1768 break;
1769
1770 case nir_tex_src_projector:
1771 assert(nir_src_num_components(tex->src[i].src) == 1);
1772 proj = get_src_float(ctx, &tex->src[i].src);
1773 assert(proj != 0);
1774 break;
1775
1776 case nir_tex_src_offset:
1777 offset = get_src_int(ctx, &tex->src[i].src);
1778 break;
1779
1780 case nir_tex_src_bias:
1781 assert(tex->op == nir_texop_txb);
1782 bias = get_src_float(ctx, &tex->src[i].src);
1783 assert(bias != 0);
1784 break;
1785
1786 case nir_tex_src_lod:
1787 assert(nir_src_num_components(tex->src[i].src) == 1);
1788 if (tex->op == nir_texop_txf ||
1789 tex->op == nir_texop_txf_ms ||
1790 tex->op == nir_texop_txs)
1791 lod = get_src_int(ctx, &tex->src[i].src);
1792 else
1793 lod = get_src_float(ctx, &tex->src[i].src);
1794 assert(lod != 0);
1795 break;
1796
1797 case nir_tex_src_ms_index:
1798 assert(nir_src_num_components(tex->src[i].src) == 1);
1799 sample = get_src_int(ctx, &tex->src[i].src);
1800 break;
1801
1802 case nir_tex_src_comparator:
1803 assert(nir_src_num_components(tex->src[i].src) == 1);
1804 dref = get_src_float(ctx, &tex->src[i].src);
1805 assert(dref != 0);
1806 break;
1807
1808 case nir_tex_src_ddx:
1809 dx = get_src_float(ctx, &tex->src[i].src);
1810 assert(dx != 0);
1811 break;
1812
1813 case nir_tex_src_ddy:
1814 dy = get_src_float(ctx, &tex->src[i].src);
1815 assert(dy != 0);
1816 break;
1817
1818 default:
1819 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1820 unreachable("unknown texture source");
1821 }
1822 }
1823
1824 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1825 lod = emit_float_const(ctx, 32, 0.0f);
1826 assert(lod != 0);
1827 }
1828
1829 SpvId image_type = ctx->image_types[tex->texture_index];
1830 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1831 image_type);
1832
1833 assert(ctx->samplers_used & (1u << tex->texture_index));
1834 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1835 ctx->samplers[tex->texture_index]);
1836
1837 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1838
1839 if (tex->op == nir_texop_txs) {
1840 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1841 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1842 dest_type, image,
1843 lod);
1844 store_dest(ctx, &tex->dest, result, tex->dest_type);
1845 return;
1846 }
1847
1848 if (proj && coord_components > 0) {
1849 SpvId constituents[coord_components + 1];
1850 if (coord_components == 1)
1851 constituents[0] = coord;
1852 else {
1853 assert(coord_components > 1);
1854 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1855 for (uint32_t i = 0; i < coord_components; ++i)
1856 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1857 float_type,
1858 coord,
1859 &i, 1);
1860 }
1861
1862 constituents[coord_components++] = proj;
1863
1864 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1865 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1866 vec_type,
1867 constituents,
1868 coord_components);
1869 }
1870
1871 SpvId actual_dest_type = dest_type;
1872 if (dref)
1873 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1874
1875 SpvId result;
1876 if (tex->op == nir_texop_txf ||
1877 tex->op == nir_texop_txf_ms) {
1878 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1879 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1880 image, coord, lod, sample);
1881 } else {
1882 result = spirv_builder_emit_image_sample(&ctx->builder,
1883 actual_dest_type, load,
1884 coord,
1885 proj != 0,
1886 lod, bias, dref, dx, dy,
1887 offset);
1888 }
1889
1890 spirv_builder_emit_decoration(&ctx->builder, result,
1891 SpvDecorationRelaxedPrecision);
1892
1893 if (dref && nir_dest_num_components(tex->dest) > 1) {
1894 SpvId components[4] = { result, result, result, result };
1895 result = spirv_builder_emit_composite_construct(&ctx->builder,
1896 dest_type,
1897 components,
1898 4);
1899 }
1900
1901 store_dest(ctx, &tex->dest, result, tex->dest_type);
1902 }
1903
1904 static void
1905 start_block(struct ntv_context *ctx, SpvId label)
1906 {
1907 /* terminate previous block if needed */
1908 if (ctx->block_started)
1909 spirv_builder_emit_branch(&ctx->builder, label);
1910
1911 /* start new block */
1912 spirv_builder_label(&ctx->builder, label);
1913 ctx->block_started = true;
1914 }
1915
1916 static void
1917 branch(struct ntv_context *ctx, SpvId label)
1918 {
1919 assert(ctx->block_started);
1920 spirv_builder_emit_branch(&ctx->builder, label);
1921 ctx->block_started = false;
1922 }
1923
1924 static void
1925 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1926 SpvId else_id)
1927 {
1928 assert(ctx->block_started);
1929 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1930 then_id, else_id);
1931 ctx->block_started = false;
1932 }
1933
1934 static void
1935 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1936 {
1937 switch (jump->type) {
1938 case nir_jump_break:
1939 assert(ctx->loop_break);
1940 branch(ctx, ctx->loop_break);
1941 break;
1942
1943 case nir_jump_continue:
1944 assert(ctx->loop_cont);
1945 branch(ctx, ctx->loop_cont);
1946 break;
1947
1948 default:
1949 unreachable("Unsupported jump type\n");
1950 }
1951 }
1952
1953 static void
1954 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1955 {
1956 assert(deref->deref_type == nir_deref_type_var);
1957
1958 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1959 assert(he);
1960 SpvId result = (SpvId)(intptr_t)he->data;
1961 store_dest_raw(ctx, &deref->dest, result);
1962 }
1963
1964 static void
1965 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1966 {
1967 assert(deref->deref_type == nir_deref_type_array);
1968 nir_variable *var = nir_deref_instr_get_variable(deref);
1969
1970 SpvStorageClass storage_class;
1971 switch (var->data.mode) {
1972 case nir_var_shader_in:
1973 storage_class = SpvStorageClassInput;
1974 break;
1975
1976 case nir_var_shader_out:
1977 storage_class = SpvStorageClassOutput;
1978 break;
1979
1980 default:
1981 unreachable("Unsupported nir_variable_mode\n");
1982 }
1983
1984 SpvId index = get_src(ctx, &deref->arr.index);
1985
1986 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1987 storage_class,
1988 get_glsl_type(ctx, deref->type));
1989
1990 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1991 ptr_type,
1992 get_src(ctx, &deref->parent),
1993 &index, 1);
1994 /* uint is a bit of a lie here, it's really just an opaque type */
1995 store_dest(ctx, &deref->dest, result, nir_type_uint);
1996 }
1997
1998 static void
1999 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
2000 {
2001 switch (deref->deref_type) {
2002 case nir_deref_type_var:
2003 emit_deref_var(ctx, deref);
2004 break;
2005
2006 case nir_deref_type_array:
2007 emit_deref_array(ctx, deref);
2008 break;
2009
2010 default:
2011 unreachable("unexpected deref_type");
2012 }
2013 }
2014
2015 static void
2016 emit_block(struct ntv_context *ctx, struct nir_block *block)
2017 {
2018 start_block(ctx, block_label(ctx, block));
2019 nir_foreach_instr(instr, block) {
2020 switch (instr->type) {
2021 case nir_instr_type_alu:
2022 emit_alu(ctx, nir_instr_as_alu(instr));
2023 break;
2024 case nir_instr_type_intrinsic:
2025 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
2026 break;
2027 case nir_instr_type_load_const:
2028 emit_load_const(ctx, nir_instr_as_load_const(instr));
2029 break;
2030 case nir_instr_type_ssa_undef:
2031 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
2032 break;
2033 case nir_instr_type_tex:
2034 emit_tex(ctx, nir_instr_as_tex(instr));
2035 break;
2036 case nir_instr_type_phi:
2037 unreachable("nir_instr_type_phi not supported");
2038 break;
2039 case nir_instr_type_jump:
2040 emit_jump(ctx, nir_instr_as_jump(instr));
2041 break;
2042 case nir_instr_type_call:
2043 unreachable("nir_instr_type_call not supported");
2044 break;
2045 case nir_instr_type_parallel_copy:
2046 unreachable("nir_instr_type_parallel_copy not supported");
2047 break;
2048 case nir_instr_type_deref:
2049 emit_deref(ctx, nir_instr_as_deref(instr));
2050 break;
2051 }
2052 }
2053 }
2054
2055 static void
2056 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
2057
2058 static SpvId
2059 get_src_bool(struct ntv_context *ctx, nir_src *src)
2060 {
2061 assert(nir_src_bit_size(*src) == 1);
2062 return get_src(ctx, src);
2063 }
2064
2065 static void
2066 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
2067 {
2068 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
2069
2070 SpvId header_id = spirv_builder_new_id(&ctx->builder);
2071 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
2072 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
2073 SpvId else_id = endif_id;
2074
2075 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
2076 if (has_else) {
2077 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
2078 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
2079 }
2080
2081 /* create a header-block */
2082 start_block(ctx, header_id);
2083 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
2084 SpvSelectionControlMaskNone);
2085 branch_conditional(ctx, condition, then_id, else_id);
2086
2087 emit_cf_list(ctx, &if_stmt->then_list);
2088
2089 if (has_else) {
2090 if (ctx->block_started)
2091 branch(ctx, endif_id);
2092
2093 emit_cf_list(ctx, &if_stmt->else_list);
2094 }
2095
2096 start_block(ctx, endif_id);
2097 }
2098
2099 static void
2100 emit_loop(struct ntv_context *ctx, nir_loop *loop)
2101 {
2102 SpvId header_id = spirv_builder_new_id(&ctx->builder);
2103 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
2104 SpvId break_id = spirv_builder_new_id(&ctx->builder);
2105 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
2106
2107 /* create a header-block */
2108 start_block(ctx, header_id);
2109 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
2110 branch(ctx, begin_id);
2111
2112 SpvId save_break = ctx->loop_break;
2113 SpvId save_cont = ctx->loop_cont;
2114 ctx->loop_break = break_id;
2115 ctx->loop_cont = cont_id;
2116
2117 emit_cf_list(ctx, &loop->body);
2118
2119 ctx->loop_break = save_break;
2120 ctx->loop_cont = save_cont;
2121
2122 branch(ctx, cont_id);
2123 start_block(ctx, cont_id);
2124 branch(ctx, header_id);
2125
2126 start_block(ctx, break_id);
2127 }
2128
2129 static void
2130 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
2131 {
2132 foreach_list_typed(nir_cf_node, node, node, list) {
2133 switch (node->type) {
2134 case nir_cf_node_block:
2135 emit_block(ctx, nir_cf_node_as_block(node));
2136 break;
2137
2138 case nir_cf_node_if:
2139 emit_if(ctx, nir_cf_node_as_if(node));
2140 break;
2141
2142 case nir_cf_node_loop:
2143 emit_loop(ctx, nir_cf_node_as_loop(node));
2144 break;
2145
2146 case nir_cf_node_function:
2147 unreachable("nir_cf_node_function not supported");
2148 break;
2149 }
2150 }
2151 }
2152
2153 struct spirv_shader *
2154 nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
2155 {
2156 struct spirv_shader *ret = NULL;
2157
2158 struct ntv_context ctx = {};
2159
2160 switch (s->info.stage) {
2161 case MESA_SHADER_VERTEX:
2162 case MESA_SHADER_FRAGMENT:
2163 case MESA_SHADER_COMPUTE:
2164 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
2165 break;
2166
2167 case MESA_SHADER_TESS_CTRL:
2168 case MESA_SHADER_TESS_EVAL:
2169 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
2170 break;
2171
2172 case MESA_SHADER_GEOMETRY:
2173 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
2174 break;
2175
2176 default:
2177 unreachable("invalid stage");
2178 }
2179
2180 // TODO: only enable when needed
2181 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2182 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
2183 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
2184 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
2185 }
2186
2187 ctx.stage = s->info.stage;
2188 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
2189 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
2190
2191 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
2192 SpvMemoryModelGLSL450);
2193
2194 SpvExecutionModel exec_model;
2195 switch (s->info.stage) {
2196 case MESA_SHADER_VERTEX:
2197 exec_model = SpvExecutionModelVertex;
2198 break;
2199 case MESA_SHADER_TESS_CTRL:
2200 exec_model = SpvExecutionModelTessellationControl;
2201 break;
2202 case MESA_SHADER_TESS_EVAL:
2203 exec_model = SpvExecutionModelTessellationEvaluation;
2204 break;
2205 case MESA_SHADER_GEOMETRY:
2206 exec_model = SpvExecutionModelGeometry;
2207 break;
2208 case MESA_SHADER_FRAGMENT:
2209 exec_model = SpvExecutionModelFragment;
2210 break;
2211 case MESA_SHADER_COMPUTE:
2212 exec_model = SpvExecutionModelGLCompute;
2213 break;
2214 default:
2215 unreachable("invalid stage");
2216 }
2217
2218 SpvId type_void = spirv_builder_type_void(&ctx.builder);
2219 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
2220 NULL, 0);
2221 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
2222 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2223
2224 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
2225 _mesa_key_pointer_equal);
2226
2227 ctx.so_outputs = _mesa_hash_table_create(NULL, _mesa_hash_u32,
2228 _mesa_key_u32_equal);
2229
2230 nir_foreach_variable(var, &s->inputs)
2231 emit_input(&ctx, var);
2232
2233 nir_foreach_variable(var, &s->outputs)
2234 emit_output(&ctx, var);
2235
2236 if (so_info)
2237 emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info, local_so_info);
2238 nir_foreach_variable(var, &s->uniforms)
2239 emit_uniform(&ctx, var);
2240
2241 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2242 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2243 SpvExecutionModeOriginUpperLeft);
2244 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2245 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2246 SpvExecutionModeDepthReplacing);
2247 }
2248
2249 if (so_info && so_info->num_outputs) {
2250 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTransformFeedback);
2251 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2252 SpvExecutionModeXfb);
2253 }
2254
2255 spirv_builder_function(&ctx.builder, entry_point, type_void,
2256 SpvFunctionControlMaskNone,
2257 type_main);
2258
2259 nir_function_impl *entry = nir_shader_get_entrypoint(s);
2260 nir_metadata_require(entry, nir_metadata_block_index);
2261
2262 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
2263 if (!ctx.defs)
2264 goto fail;
2265 ctx.num_defs = entry->ssa_alloc;
2266
2267 nir_index_local_regs(entry);
2268 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
2269 if (!ctx.regs)
2270 goto fail;
2271 ctx.num_regs = entry->reg_alloc;
2272
2273 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
2274 if (!block_ids)
2275 goto fail;
2276
2277 for (int i = 0; i < entry->num_blocks; ++i)
2278 block_ids[i] = spirv_builder_new_id(&ctx.builder);
2279
2280 ctx.block_ids = block_ids;
2281 ctx.num_blocks = entry->num_blocks;
2282
2283 /* emit a block only for the variable declarations */
2284 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2285 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2286 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
2287 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2288 SpvStorageClassFunction,
2289 type);
2290 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2291 SpvStorageClassFunction);
2292
2293 ctx.regs[reg->index] = var;
2294 }
2295
2296 emit_cf_list(&ctx, &entry->body);
2297
2298 free(ctx.defs);
2299
2300 if (so_info)
2301 emit_so_outputs(&ctx, so_info, local_so_info);
2302
2303 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2304 spirv_builder_function_end(&ctx.builder);
2305
2306 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2307 "main", ctx.entry_ifaces,
2308 ctx.num_entry_ifaces);
2309
2310 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2311
2312 ret = CALLOC_STRUCT(spirv_shader);
2313 if (!ret)
2314 goto fail;
2315
2316 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2317 if (!ret->words)
2318 goto fail;
2319
2320 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2321 assert(ret->num_words == num_words);
2322
2323 return ret;
2324
2325 fail:
2326
2327 if (ret)
2328 spirv_shader_delete(ret);
2329
2330 if (ctx.vars)
2331 _mesa_hash_table_destroy(ctx.vars, NULL);
2332
2333 if (ctx.so_outputs)
2334 _mesa_hash_table_destroy(ctx.so_outputs, NULL);
2335
2336 return NULL;
2337 }
2338
2339 void
2340 spirv_shader_delete(struct spirv_shader *s)
2341 {
2342 FREE(s->words);
2343 FREE(s);
2344 }