zink: add spirv_builder methods for OpVectorExtractDynamic and OpVectorInsertDynamic
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 /* this consistently maps slots to a zero-indexed value to avoid wasting slots */
33 static unsigned slot_pack_map[] = {
34 /* Position is builtin */
35 [VARYING_SLOT_POS] = UINT_MAX,
36 [VARYING_SLOT_COL0] = 0, /* input/output */
37 [VARYING_SLOT_COL1] = 1, /* input/output */
38 [VARYING_SLOT_FOGC] = 2, /* input/output */
39 /* TEX0-7 are translated to VAR0-7 by nir, so we don't need to reserve */
40 [VARYING_SLOT_TEX0] = UINT_MAX, /* input/output */
41 [VARYING_SLOT_TEX1] = UINT_MAX,
42 [VARYING_SLOT_TEX2] = UINT_MAX,
43 [VARYING_SLOT_TEX3] = UINT_MAX,
44 [VARYING_SLOT_TEX4] = UINT_MAX,
45 [VARYING_SLOT_TEX5] = UINT_MAX,
46 [VARYING_SLOT_TEX6] = UINT_MAX,
47 [VARYING_SLOT_TEX7] = UINT_MAX,
48
49 /* PointSize is builtin */
50 [VARYING_SLOT_PSIZ] = UINT_MAX,
51
52 [VARYING_SLOT_BFC0] = 3, /* output only */
53 [VARYING_SLOT_BFC1] = 4, /* output only */
54 [VARYING_SLOT_EDGE] = 5, /* output only */
55 [VARYING_SLOT_CLIP_VERTEX] = 6, /* output only */
56
57 /* ClipDistance is builtin */
58 [VARYING_SLOT_CLIP_DIST0] = UINT_MAX,
59 [VARYING_SLOT_CLIP_DIST1] = UINT_MAX,
60
61 /* CullDistance is builtin */
62 [VARYING_SLOT_CULL_DIST0] = UINT_MAX, /* input/output */
63 [VARYING_SLOT_CULL_DIST1] = UINT_MAX, /* never actually used */
64
65 /* PrimitiveId is builtin */
66 [VARYING_SLOT_PRIMITIVE_ID] = UINT_MAX,
67
68 /* Layer is builtin */
69 [VARYING_SLOT_LAYER] = UINT_MAX, /* input/output */
70
71 /* ViewportIndex is builtin */
72 [VARYING_SLOT_VIEWPORT] = UINT_MAX, /* input/output */
73
74 /* FrontFacing is builtin */
75 [VARYING_SLOT_FACE] = UINT_MAX,
76
77 /* PointCoord is builtin */
78 [VARYING_SLOT_PNTC] = UINT_MAX, /* input only */
79
80 /* TessLevelOuter is builtin */
81 [VARYING_SLOT_TESS_LEVEL_OUTER] = UINT_MAX,
82 /* TessLevelInner is builtin */
83 [VARYING_SLOT_TESS_LEVEL_INNER] = UINT_MAX,
84
85 [VARYING_SLOT_BOUNDING_BOX0] = 7, /* Only appears as TCS output. */
86 [VARYING_SLOT_BOUNDING_BOX1] = 8, /* Only appears as TCS output. */
87 [VARYING_SLOT_VIEW_INDEX] = 9, /* input/output */
88 [VARYING_SLOT_VIEWPORT_MASK] = 10, /* output only */
89 };
90 #define NTV_MIN_RESERVED_SLOTS 10
91
92 struct ntv_context {
93 struct spirv_builder builder;
94
95 SpvId GLSL_std_450;
96
97 gl_shader_stage stage;
98
99 SpvId ubos[128];
100 size_t num_ubos;
101 SpvId image_types[PIPE_MAX_SAMPLERS];
102 SpvId samplers[PIPE_MAX_SAMPLERS];
103 unsigned samplers_used : PIPE_MAX_SAMPLERS;
104 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
105 size_t num_entry_ifaces;
106
107 SpvId *defs;
108 size_t num_defs;
109
110 SpvId *regs;
111 size_t num_regs;
112
113 struct hash_table *vars; /* nir_variable -> SpvId */
114
115 const SpvId *block_ids;
116 size_t num_blocks;
117 bool block_started;
118 SpvId loop_break, loop_cont;
119
120 SpvId front_face_var, instance_id_var, vertex_id_var;
121 };
122
123 static SpvId
124 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
125 unsigned num_components, float value);
126
127 static SpvId
128 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
129 unsigned num_components, uint32_t value);
130
131 static SpvId
132 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
133 unsigned num_components, int32_t value);
134
135 static SpvId
136 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
137
138 static SpvId
139 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
140 SpvId src0, SpvId src1);
141
142 static SpvId
143 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
144 SpvId src0, SpvId src1, SpvId src2);
145
146 static SpvId
147 get_bvec_type(struct ntv_context *ctx, int num_components)
148 {
149 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
150 if (num_components > 1)
151 return spirv_builder_type_vector(&ctx->builder, bool_type,
152 num_components);
153
154 assert(num_components == 1);
155 return bool_type;
156 }
157
158 static SpvId
159 block_label(struct ntv_context *ctx, nir_block *block)
160 {
161 assert(block->index < ctx->num_blocks);
162 return ctx->block_ids[block->index];
163 }
164
165 static SpvId
166 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
167 {
168 assert(bit_size == 32);
169 return spirv_builder_const_float(&ctx->builder, bit_size, value);
170 }
171
172 static SpvId
173 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
174 {
175 assert(bit_size == 32);
176 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
177 }
178
179 static SpvId
180 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
181 {
182 assert(bit_size == 32);
183 return spirv_builder_const_int(&ctx->builder, bit_size, value);
184 }
185
186 static SpvId
187 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
188 {
189 assert(bit_size == 32); // only 32-bit floats supported so far
190
191 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
192 if (num_components > 1)
193 return spirv_builder_type_vector(&ctx->builder, float_type,
194 num_components);
195
196 assert(num_components == 1);
197 return float_type;
198 }
199
200 static SpvId
201 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
202 {
203 assert(bit_size == 32); // only 32-bit ints supported so far
204
205 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
206 if (num_components > 1)
207 return spirv_builder_type_vector(&ctx->builder, int_type,
208 num_components);
209
210 assert(num_components == 1);
211 return int_type;
212 }
213
214 static SpvId
215 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
216 {
217 assert(bit_size == 32); // only 32-bit uints supported so far
218
219 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
220 if (num_components > 1)
221 return spirv_builder_type_vector(&ctx->builder, uint_type,
222 num_components);
223
224 assert(num_components == 1);
225 return uint_type;
226 }
227
228 static SpvId
229 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
230 {
231 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
232 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
233 }
234
235 static SpvId
236 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
237 {
238 switch (type) {
239 case GLSL_TYPE_BOOL:
240 return spirv_builder_type_bool(&ctx->builder);
241
242 case GLSL_TYPE_FLOAT:
243 return spirv_builder_type_float(&ctx->builder, 32);
244
245 case GLSL_TYPE_INT:
246 return spirv_builder_type_int(&ctx->builder, 32);
247
248 case GLSL_TYPE_UINT:
249 return spirv_builder_type_uint(&ctx->builder, 32);
250 /* TODO: handle more types */
251
252 default:
253 unreachable("unknown GLSL type");
254 }
255 }
256
257 static SpvId
258 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
259 {
260 assert(type);
261 if (glsl_type_is_scalar(type))
262 return get_glsl_basetype(ctx, glsl_get_base_type(type));
263
264 if (glsl_type_is_vector(type))
265 return spirv_builder_type_vector(&ctx->builder,
266 get_glsl_basetype(ctx, glsl_get_base_type(type)),
267 glsl_get_vector_elements(type));
268
269 if (glsl_type_is_array(type)) {
270 SpvId ret = spirv_builder_type_array(&ctx->builder,
271 get_glsl_type(ctx, glsl_get_array_element(type)),
272 emit_uint_const(ctx, 32, glsl_get_length(type)));
273 uint32_t stride = glsl_get_explicit_stride(type);
274 if (stride)
275 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
276 return ret;
277 }
278
279
280 unreachable("we shouldn't get here, I think...");
281 }
282
283 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
284 case VARYING_SLOT_##SLOT: \
285 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
286 break
287
288
289 static void
290 emit_input(struct ntv_context *ctx, struct nir_variable *var)
291 {
292 SpvId var_type = get_glsl_type(ctx, var->type);
293 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
294 SpvStorageClassInput,
295 var_type);
296 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
297 SpvStorageClassInput);
298
299 if (var->name)
300 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
301
302 if (ctx->stage == MESA_SHADER_FRAGMENT) {
303 unsigned slot = var->data.location;
304 switch (slot) {
305 HANDLE_EMIT_BUILTIN(POS, FragCoord);
306 HANDLE_EMIT_BUILTIN(PNTC, PointCoord);
307 HANDLE_EMIT_BUILTIN(LAYER, Layer);
308 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
309 HANDLE_EMIT_BUILTIN(CLIP_DIST0, ClipDistance);
310 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
311 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
312 HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
313
314 default:
315 if (slot < VARYING_SLOT_VAR0) {
316 slot = slot_pack_map[slot];
317 if (slot == UINT_MAX)
318 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
319 } else
320 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
321 assert(slot < VARYING_SLOT_VAR0);
322 spirv_builder_emit_location(&ctx->builder, var_id, slot);
323 }
324 } else {
325 spirv_builder_emit_location(&ctx->builder, var_id,
326 var->data.driver_location);
327 }
328
329 if (var->data.location_frac)
330 spirv_builder_emit_component(&ctx->builder, var_id,
331 var->data.location_frac);
332
333 if (var->data.interpolation == INTERP_MODE_FLAT)
334 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
335
336 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
337
338 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
339 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
340 }
341
342 static void
343 emit_output(struct ntv_context *ctx, struct nir_variable *var)
344 {
345 SpvId var_type = get_glsl_type(ctx, var->type);
346 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
347 SpvStorageClassOutput,
348 var_type);
349 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
350 SpvStorageClassOutput);
351 if (var->name)
352 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
353
354
355 if (ctx->stage == MESA_SHADER_VERTEX) {
356
357 unsigned slot = var->data.location;
358 switch (slot) {
359 HANDLE_EMIT_BUILTIN(POS, Position);
360 HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
361 HANDLE_EMIT_BUILTIN(LAYER, Layer);
362 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
363 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
364 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
365 HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
366 HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
367
368 case VARYING_SLOT_CLIP_DIST0:
369 assert(glsl_type_is_array(var->type));
370 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
371 break;
372
373 default:
374 if (slot < VARYING_SLOT_VAR0) {
375 slot = slot_pack_map[slot];
376 if (slot == UINT_MAX)
377 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
378 } else
379 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
380 assert(slot < VARYING_SLOT_VAR0);
381 spirv_builder_emit_location(&ctx->builder, var_id, slot);
382 /* non-builtins get location incremented by VARYING_SLOT_VAR0 in vtn, so
383 * use driver_location for non-builtins with defined slots to avoid overlap
384 */
385 }
386 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
387 if (var->data.location >= FRAG_RESULT_DATA0)
388 spirv_builder_emit_location(&ctx->builder, var_id,
389 var->data.location - FRAG_RESULT_DATA0);
390 else {
391 switch (var->data.location) {
392 case FRAG_RESULT_COLOR:
393 spirv_builder_emit_location(&ctx->builder, var_id, 0);
394 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
395 break;
396
397 case FRAG_RESULT_DEPTH:
398 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
399 break;
400
401 default:
402 spirv_builder_emit_location(&ctx->builder, var_id,
403 var->data.driver_location);
404 }
405 }
406 }
407
408 if (var->data.location_frac)
409 spirv_builder_emit_component(&ctx->builder, var_id,
410 var->data.location_frac);
411
412 switch (var->data.interpolation) {
413 case INTERP_MODE_NONE:
414 case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
415 break;
416 case INTERP_MODE_FLAT:
417 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
418 break;
419 case INTERP_MODE_EXPLICIT:
420 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
421 break;
422 case INTERP_MODE_NOPERSPECTIVE:
423 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
424 break;
425 default:
426 unreachable("unknown interpolation value");
427 }
428
429 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
430
431 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
432 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
433 }
434
435 static SpvDim
436 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
437 {
438 *is_ms = false;
439 switch (gdim) {
440 case GLSL_SAMPLER_DIM_1D:
441 return SpvDim1D;
442 case GLSL_SAMPLER_DIM_2D:
443 return SpvDim2D;
444 case GLSL_SAMPLER_DIM_3D:
445 return SpvDim3D;
446 case GLSL_SAMPLER_DIM_CUBE:
447 return SpvDimCube;
448 case GLSL_SAMPLER_DIM_RECT:
449 return SpvDim2D;
450 case GLSL_SAMPLER_DIM_BUF:
451 return SpvDimBuffer;
452 case GLSL_SAMPLER_DIM_EXTERNAL:
453 return SpvDim2D; /* seems dodgy... */
454 case GLSL_SAMPLER_DIM_MS:
455 *is_ms = true;
456 return SpvDim2D;
457 default:
458 fprintf(stderr, "unknown sampler type %d\n", gdim);
459 break;
460 }
461 return SpvDim2D;
462 }
463
464 uint32_t
465 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
466 {
467 if (stage == MESA_SHADER_NONE ||
468 stage >= MESA_SHADER_COMPUTE) {
469 unreachable("not supported");
470 } else {
471 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
472 PIPE_MAX_SHADER_SAMPLER_VIEWS);
473
474 switch (type) {
475 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
476 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
477 return stage_offset + index;
478
479 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
480 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
481 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
482
483 default:
484 unreachable("unexpected type");
485 }
486 }
487 }
488
489 static void
490 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
491 {
492 const struct glsl_type *type = glsl_without_array(var->type);
493
494 bool is_ms;
495 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
496
497 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
498 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
499 dimension, false,
500 glsl_sampler_type_is_array(type),
501 is_ms, 1,
502 SpvImageFormatUnknown);
503
504 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
505 image_type);
506 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
507 SpvStorageClassUniformConstant,
508 sampled_type);
509
510 if (glsl_type_is_array(var->type)) {
511 for (int i = 0; i < glsl_get_length(var->type); ++i) {
512 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
513 SpvStorageClassUniformConstant);
514
515 if (var->name) {
516 char element_name[100];
517 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
518 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
519 }
520
521 int index = var->data.binding + i;
522 assert(!(ctx->samplers_used & (1 << index)));
523 assert(!ctx->image_types[index]);
524 ctx->image_types[index] = image_type;
525 ctx->samplers[index] = var_id;
526 ctx->samplers_used |= 1 << index;
527
528 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
529 var->data.descriptor_set);
530 int binding = zink_binding(ctx->stage,
531 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
532 var->data.binding + i);
533 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
534 }
535 } else {
536 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
537 SpvStorageClassUniformConstant);
538
539 if (var->name)
540 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
541
542 int index = var->data.binding;
543 assert(!(ctx->samplers_used & (1 << index)));
544 assert(!ctx->image_types[index]);
545 ctx->image_types[index] = image_type;
546 ctx->samplers[index] = var_id;
547 ctx->samplers_used |= 1 << index;
548
549 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
550 var->data.descriptor_set);
551 int binding = zink_binding(ctx->stage,
552 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
553 var->data.binding);
554 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
555 }
556 }
557
558 static void
559 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
560 {
561 uint32_t size = glsl_count_attribute_slots(var->type, false);
562 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
563 SpvId array_length = emit_uint_const(ctx, 32, size);
564 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
565 array_length);
566 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
567
568 // wrap UBO-array in a struct
569 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
570 if (var->name) {
571 char struct_name[100];
572 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
573 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
574 }
575
576 spirv_builder_emit_decoration(&ctx->builder, struct_type,
577 SpvDecorationBlock);
578 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
579
580
581 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
582 SpvStorageClassUniform,
583 struct_type);
584
585 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
586 SpvStorageClassUniform);
587 if (var->name)
588 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
589
590 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
591 ctx->ubos[ctx->num_ubos++] = var_id;
592
593 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
594 var->data.descriptor_set);
595 int binding = zink_binding(ctx->stage,
596 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
597 var->data.binding);
598 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
599 }
600
601 static void
602 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
603 {
604 if (var->data.mode == nir_var_mem_ubo)
605 emit_ubo(ctx, var);
606 else {
607 assert(var->data.mode == nir_var_uniform);
608 if (glsl_type_is_sampler(glsl_without_array(var->type)))
609 emit_sampler(ctx, var);
610 }
611 }
612
613 static SpvId
614 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
615 {
616 assert(ssa->index < ctx->num_defs);
617 assert(ctx->defs[ssa->index] != 0);
618 return ctx->defs[ssa->index];
619 }
620
621 static SpvId
622 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
623 {
624 assert(reg->index < ctx->num_regs);
625 assert(ctx->regs[reg->index] != 0);
626 return ctx->regs[reg->index];
627 }
628
629 static SpvId
630 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
631 {
632 assert(reg->reg);
633 assert(!reg->indirect);
634 assert(!reg->base_offset);
635
636 SpvId var = get_var_from_reg(ctx, reg->reg);
637 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
638 return spirv_builder_emit_load(&ctx->builder, type, var);
639 }
640
641 static SpvId
642 get_src(struct ntv_context *ctx, nir_src *src)
643 {
644 if (src->is_ssa)
645 return get_src_ssa(ctx, src->ssa);
646 else
647 return get_src_reg(ctx, &src->reg);
648 }
649
650 static SpvId
651 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
652 {
653 assert(!alu->src[src].negate);
654 assert(!alu->src[src].abs);
655
656 SpvId def = get_src(ctx, &alu->src[src].src);
657
658 unsigned used_channels = 0;
659 bool need_swizzle = false;
660 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
661 if (!nir_alu_instr_channel_used(alu, src, i))
662 continue;
663
664 used_channels++;
665
666 if (alu->src[src].swizzle[i] != i)
667 need_swizzle = true;
668 }
669 assert(used_channels != 0);
670
671 unsigned live_channels = nir_src_num_components(alu->src[src].src);
672 if (used_channels != live_channels)
673 need_swizzle = true;
674
675 if (!need_swizzle)
676 return def;
677
678 int bit_size = nir_src_bit_size(alu->src[src].src);
679 assert(bit_size == 1 || bit_size == 32);
680
681 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
682 spirv_builder_type_uint(&ctx->builder, bit_size);
683
684 if (used_channels == 1) {
685 uint32_t indices[] = { alu->src[src].swizzle[0] };
686 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
687 def, indices,
688 ARRAY_SIZE(indices));
689 } else if (live_channels == 1) {
690 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
691 raw_type,
692 used_channels);
693
694 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
695 for (unsigned i = 0; i < used_channels; ++i)
696 constituents[i] = def;
697
698 return spirv_builder_emit_composite_construct(&ctx->builder,
699 raw_vec_type,
700 constituents,
701 used_channels);
702 } else {
703 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
704 raw_type,
705 used_channels);
706
707 uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
708 size_t num_components = 0;
709 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
710 if (!nir_alu_instr_channel_used(alu, src, i))
711 continue;
712
713 components[num_components++] = alu->src[src].swizzle[i];
714 }
715
716 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
717 def, def, components,
718 num_components);
719 }
720 }
721
722 static void
723 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
724 {
725 assert(result != 0);
726 assert(ssa->index < ctx->num_defs);
727 ctx->defs[ssa->index] = result;
728 }
729
730 static SpvId
731 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
732 SpvId if_true, SpvId if_false)
733 {
734 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
735 }
736
737 static SpvId
738 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
739 {
740 SpvId type = get_bvec_type(ctx, num_components);
741 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
742 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
743 }
744
745 static SpvId
746 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
747 {
748 return emit_unop(ctx, SpvOpBitcast, type, value);
749 }
750
751 static SpvId
752 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
753 unsigned num_components)
754 {
755 SpvId type = get_uvec_type(ctx, bit_size, num_components);
756 return emit_bitcast(ctx, type, value);
757 }
758
759 static SpvId
760 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
761 unsigned num_components)
762 {
763 SpvId type = get_ivec_type(ctx, bit_size, num_components);
764 return emit_bitcast(ctx, type, value);
765 }
766
767 static SpvId
768 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
769 unsigned num_components)
770 {
771 SpvId type = get_fvec_type(ctx, bit_size, num_components);
772 return emit_bitcast(ctx, type, value);
773 }
774
775 static void
776 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
777 {
778 SpvId var = get_var_from_reg(ctx, reg->reg);
779 assert(var);
780 spirv_builder_emit_store(&ctx->builder, var, result);
781 }
782
783 static void
784 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
785 {
786 if (dest->is_ssa)
787 store_ssa_def(ctx, &dest->ssa, result);
788 else
789 store_reg_def(ctx, &dest->reg, result);
790 }
791
792 static SpvId
793 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
794 {
795 unsigned num_components = nir_dest_num_components(*dest);
796 unsigned bit_size = nir_dest_bit_size(*dest);
797
798 if (bit_size != 1) {
799 switch (nir_alu_type_get_base_type(type)) {
800 case nir_type_bool:
801 assert("bool should have bit-size 1");
802
803 case nir_type_uint:
804 break; /* nothing to do! */
805
806 case nir_type_int:
807 case nir_type_float:
808 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
809 break;
810
811 default:
812 unreachable("unsupported nir_alu_type");
813 }
814 }
815
816 store_dest_raw(ctx, dest, result);
817 return result;
818 }
819
820 static SpvId
821 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
822 {
823 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
824 }
825
826 static SpvId
827 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
828 SpvId src0, SpvId src1)
829 {
830 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
831 }
832
833 static SpvId
834 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
835 SpvId src0, SpvId src1, SpvId src2)
836 {
837 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
838 }
839
840 static SpvId
841 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
842 SpvId src)
843 {
844 SpvId args[] = { src };
845 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
846 op, args, ARRAY_SIZE(args));
847 }
848
849 static SpvId
850 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
851 SpvId src0, SpvId src1)
852 {
853 SpvId args[] = { src0, src1 };
854 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
855 op, args, ARRAY_SIZE(args));
856 }
857
858 static SpvId
859 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
860 SpvId src0, SpvId src1, SpvId src2)
861 {
862 SpvId args[] = { src0, src1, src2 };
863 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
864 op, args, ARRAY_SIZE(args));
865 }
866
867 static SpvId
868 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
869 unsigned num_components, float value)
870 {
871 assert(bit_size == 32);
872
873 SpvId result = emit_float_const(ctx, bit_size, value);
874 if (num_components == 1)
875 return result;
876
877 assert(num_components > 1);
878 SpvId components[num_components];
879 for (int i = 0; i < num_components; i++)
880 components[i] = result;
881
882 SpvId type = get_fvec_type(ctx, bit_size, num_components);
883 return spirv_builder_const_composite(&ctx->builder, type, components,
884 num_components);
885 }
886
887 static SpvId
888 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
889 unsigned num_components, uint32_t value)
890 {
891 assert(bit_size == 32);
892
893 SpvId result = emit_uint_const(ctx, bit_size, value);
894 if (num_components == 1)
895 return result;
896
897 assert(num_components > 1);
898 SpvId components[num_components];
899 for (int i = 0; i < num_components; i++)
900 components[i] = result;
901
902 SpvId type = get_uvec_type(ctx, bit_size, num_components);
903 return spirv_builder_const_composite(&ctx->builder, type, components,
904 num_components);
905 }
906
907 static SpvId
908 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
909 unsigned num_components, int32_t value)
910 {
911 assert(bit_size == 32);
912
913 SpvId result = emit_int_const(ctx, bit_size, value);
914 if (num_components == 1)
915 return result;
916
917 assert(num_components > 1);
918 SpvId components[num_components];
919 for (int i = 0; i < num_components; i++)
920 components[i] = result;
921
922 SpvId type = get_ivec_type(ctx, bit_size, num_components);
923 return spirv_builder_const_composite(&ctx->builder, type, components,
924 num_components);
925 }
926
927 static inline unsigned
928 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
929 {
930 if (nir_op_infos[instr->op].input_sizes[src] > 0)
931 return nir_op_infos[instr->op].input_sizes[src];
932
933 if (instr->dest.dest.is_ssa)
934 return instr->dest.dest.ssa.num_components;
935 else
936 return instr->dest.dest.reg.reg->num_components;
937 }
938
939 static SpvId
940 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
941 {
942 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
943
944 unsigned num_components = alu_instr_src_components(alu, src);
945 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
946 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
947
948 if (bit_size == 1)
949 return raw_value;
950 else {
951 switch (nir_alu_type_get_base_type(type)) {
952 case nir_type_bool:
953 unreachable("bool should have bit-size 1");
954
955 case nir_type_int:
956 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
957
958 case nir_type_uint:
959 return raw_value;
960
961 case nir_type_float:
962 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
963
964 default:
965 unreachable("unknown nir_alu_type");
966 }
967 }
968 }
969
970 static SpvId
971 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
972 {
973 assert(!alu->dest.saturate);
974 return store_dest(ctx, &alu->dest.dest, result,
975 nir_op_infos[alu->op].output_type);
976 }
977
978 static SpvId
979 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
980 {
981 unsigned num_components = nir_dest_num_components(*dest);
982 unsigned bit_size = nir_dest_bit_size(*dest);
983
984 if (bit_size == 1)
985 return get_bvec_type(ctx, num_components);
986
987 switch (nir_alu_type_get_base_type(type)) {
988 case nir_type_bool:
989 unreachable("bool should have bit-size 1");
990
991 case nir_type_int:
992 return get_ivec_type(ctx, bit_size, num_components);
993
994 case nir_type_uint:
995 return get_uvec_type(ctx, bit_size, num_components);
996
997 case nir_type_float:
998 return get_fvec_type(ctx, bit_size, num_components);
999
1000 default:
1001 unreachable("unsupported nir_alu_type");
1002 }
1003 }
1004
1005 static void
1006 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
1007 {
1008 SpvId src[nir_op_infos[alu->op].num_inputs];
1009 unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
1010 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
1011 src[i] = get_alu_src(ctx, alu, i);
1012 in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
1013 }
1014
1015 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
1016 nir_op_infos[alu->op].output_type);
1017 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
1018 unsigned num_components = nir_dest_num_components(alu->dest.dest);
1019
1020 SpvId result = 0;
1021 switch (alu->op) {
1022 case nir_op_mov:
1023 assert(nir_op_infos[alu->op].num_inputs == 1);
1024 result = src[0];
1025 break;
1026
1027 #define UNOP(nir_op, spirv_op) \
1028 case nir_op: \
1029 assert(nir_op_infos[alu->op].num_inputs == 1); \
1030 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
1031 break;
1032
1033 UNOP(nir_op_ineg, SpvOpSNegate)
1034 UNOP(nir_op_fneg, SpvOpFNegate)
1035 UNOP(nir_op_fddx, SpvOpDPdx)
1036 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
1037 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
1038 UNOP(nir_op_fddy, SpvOpDPdy)
1039 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
1040 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
1041 UNOP(nir_op_f2i32, SpvOpConvertFToS)
1042 UNOP(nir_op_f2u32, SpvOpConvertFToU)
1043 UNOP(nir_op_i2f32, SpvOpConvertSToF)
1044 UNOP(nir_op_u2f32, SpvOpConvertUToF)
1045 #undef UNOP
1046
1047 case nir_op_inot:
1048 if (bit_size == 1)
1049 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
1050 else
1051 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
1052 break;
1053
1054 case nir_op_b2i32:
1055 assert(nir_op_infos[alu->op].num_inputs == 1);
1056 result = emit_select(ctx, dest_type, src[0],
1057 get_ivec_constant(ctx, 32, num_components, 1),
1058 get_ivec_constant(ctx, 32, num_components, 0));
1059 break;
1060
1061 case nir_op_b2f32:
1062 assert(nir_op_infos[alu->op].num_inputs == 1);
1063 result = emit_select(ctx, dest_type, src[0],
1064 get_fvec_constant(ctx, 32, num_components, 1),
1065 get_fvec_constant(ctx, 32, num_components, 0));
1066 break;
1067
1068 #define BUILTIN_UNOP(nir_op, spirv_op) \
1069 case nir_op: \
1070 assert(nir_op_infos[alu->op].num_inputs == 1); \
1071 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1072 break;
1073
1074 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1075 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1076 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1077 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1078 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1079 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1080 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1081 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1082 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1083 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1084 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1085 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1086 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1087 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1088 #undef BUILTIN_UNOP
1089
1090 case nir_op_frcp:
1091 assert(nir_op_infos[alu->op].num_inputs == 1);
1092 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1093 get_fvec_constant(ctx, bit_size, num_components, 1),
1094 src[0]);
1095 break;
1096
1097 case nir_op_f2b1:
1098 assert(nir_op_infos[alu->op].num_inputs == 1);
1099 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1100 get_fvec_constant(ctx,
1101 nir_src_bit_size(alu->src[0].src),
1102 num_components, 0));
1103 break;
1104 case nir_op_i2b1:
1105 assert(nir_op_infos[alu->op].num_inputs == 1);
1106 result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1107 get_ivec_constant(ctx,
1108 nir_src_bit_size(alu->src[0].src),
1109 num_components, 0));
1110 break;
1111
1112
1113 #define BINOP(nir_op, spirv_op) \
1114 case nir_op: \
1115 assert(nir_op_infos[alu->op].num_inputs == 2); \
1116 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1117 break;
1118
1119 BINOP(nir_op_iadd, SpvOpIAdd)
1120 BINOP(nir_op_isub, SpvOpISub)
1121 BINOP(nir_op_imul, SpvOpIMul)
1122 BINOP(nir_op_idiv, SpvOpSDiv)
1123 BINOP(nir_op_udiv, SpvOpUDiv)
1124 BINOP(nir_op_umod, SpvOpUMod)
1125 BINOP(nir_op_fadd, SpvOpFAdd)
1126 BINOP(nir_op_fsub, SpvOpFSub)
1127 BINOP(nir_op_fmul, SpvOpFMul)
1128 BINOP(nir_op_fdiv, SpvOpFDiv)
1129 BINOP(nir_op_fmod, SpvOpFMod)
1130 BINOP(nir_op_ilt, SpvOpSLessThan)
1131 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1132 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1133 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1134 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1135 BINOP(nir_op_feq, SpvOpFOrdEqual)
1136 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1137 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1138 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1139 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1140 #undef BINOP
1141
1142 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1143 case nir_op: \
1144 assert(nir_op_infos[alu->op].num_inputs == 2); \
1145 if (nir_src_bit_size(alu->src[0].src) == 1) \
1146 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1147 else \
1148 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1149 break;
1150
1151 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1152 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1153 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1154 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1155 #undef BINOP_LOG
1156
1157 #define BUILTIN_BINOP(nir_op, spirv_op) \
1158 case nir_op: \
1159 assert(nir_op_infos[alu->op].num_inputs == 2); \
1160 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1161 break;
1162
1163 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1164 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1165 #undef BUILTIN_BINOP
1166
1167 case nir_op_fdot2:
1168 case nir_op_fdot3:
1169 case nir_op_fdot4:
1170 assert(nir_op_infos[alu->op].num_inputs == 2);
1171 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1172 break;
1173
1174 case nir_op_fdph:
1175 unreachable("should already be lowered away");
1176
1177 case nir_op_seq:
1178 case nir_op_sne:
1179 case nir_op_slt:
1180 case nir_op_sge: {
1181 assert(nir_op_infos[alu->op].num_inputs == 2);
1182 int num_components = nir_dest_num_components(alu->dest.dest);
1183 SpvId bool_type = get_bvec_type(ctx, num_components);
1184
1185 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1186 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1187 if (num_components > 1) {
1188 SpvId zero_comps[num_components], one_comps[num_components];
1189 for (int i = 0; i < num_components; i++) {
1190 zero_comps[i] = zero;
1191 one_comps[i] = one;
1192 }
1193
1194 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1195 zero_comps, num_components);
1196 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1197 one_comps, num_components);
1198 }
1199
1200 SpvOp op;
1201 switch (alu->op) {
1202 case nir_op_seq: op = SpvOpFOrdEqual; break;
1203 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1204 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1205 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1206 default: unreachable("unexpected op");
1207 }
1208
1209 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1210 result = emit_select(ctx, dest_type, result, one, zero);
1211 }
1212 break;
1213
1214 case nir_op_flrp:
1215 assert(nir_op_infos[alu->op].num_inputs == 3);
1216 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1217 src[0], src[1], src[2]);
1218 break;
1219
1220 case nir_op_fcsel:
1221 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1222 get_bvec_type(ctx, num_components),
1223 src[0],
1224 get_fvec_constant(ctx,
1225 nir_src_bit_size(alu->src[0].src),
1226 num_components, 0));
1227 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1228 break;
1229
1230 case nir_op_bcsel:
1231 assert(nir_op_infos[alu->op].num_inputs == 3);
1232 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1233 break;
1234
1235 case nir_op_bany_fnequal2:
1236 case nir_op_bany_fnequal3:
1237 case nir_op_bany_fnequal4: {
1238 assert(nir_op_infos[alu->op].num_inputs == 2);
1239 assert(alu_instr_src_components(alu, 0) ==
1240 alu_instr_src_components(alu, 1));
1241 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1242 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1243 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1244 result = emit_binop(ctx, op,
1245 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1246 src[0], src[1]);
1247 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1248 break;
1249 }
1250
1251 case nir_op_ball_fequal2:
1252 case nir_op_ball_fequal3:
1253 case nir_op_ball_fequal4: {
1254 assert(nir_op_infos[alu->op].num_inputs == 2);
1255 assert(alu_instr_src_components(alu, 0) ==
1256 alu_instr_src_components(alu, 1));
1257 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1258 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1259 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1260 result = emit_binop(ctx, op,
1261 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1262 src[0], src[1]);
1263 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1264 break;
1265 }
1266
1267 case nir_op_bany_inequal2:
1268 case nir_op_bany_inequal3:
1269 case nir_op_bany_inequal4: {
1270 assert(nir_op_infos[alu->op].num_inputs == 2);
1271 assert(alu_instr_src_components(alu, 0) ==
1272 alu_instr_src_components(alu, 1));
1273 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1274 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1275 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1276 result = emit_binop(ctx, op,
1277 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1278 src[0], src[1]);
1279 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1280 break;
1281 }
1282
1283 case nir_op_ball_iequal2:
1284 case nir_op_ball_iequal3:
1285 case nir_op_ball_iequal4: {
1286 assert(nir_op_infos[alu->op].num_inputs == 2);
1287 assert(alu_instr_src_components(alu, 0) ==
1288 alu_instr_src_components(alu, 1));
1289 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1290 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1291 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1292 result = emit_binop(ctx, op,
1293 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1294 src[0], src[1]);
1295 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1296 break;
1297 }
1298
1299 case nir_op_vec2:
1300 case nir_op_vec3:
1301 case nir_op_vec4: {
1302 int num_inputs = nir_op_infos[alu->op].num_inputs;
1303 assert(2 <= num_inputs && num_inputs <= 4);
1304 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1305 src, num_inputs);
1306 }
1307 break;
1308
1309 default:
1310 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1311 nir_op_infos[alu->op].name);
1312
1313 unreachable("unsupported opcode");
1314 return;
1315 }
1316
1317 store_alu_result(ctx, alu, result);
1318 }
1319
1320 static void
1321 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1322 {
1323 unsigned bit_size = load_const->def.bit_size;
1324 unsigned num_components = load_const->def.num_components;
1325
1326 SpvId constant;
1327 if (num_components > 1) {
1328 SpvId components[num_components];
1329 SpvId type;
1330 if (bit_size == 1) {
1331 for (int i = 0; i < num_components; i++)
1332 components[i] = spirv_builder_const_bool(&ctx->builder,
1333 load_const->value[i].b);
1334
1335 type = get_bvec_type(ctx, num_components);
1336 } else {
1337 for (int i = 0; i < num_components; i++)
1338 components[i] = emit_uint_const(ctx, bit_size,
1339 load_const->value[i].u32);
1340
1341 type = get_uvec_type(ctx, bit_size, num_components);
1342 }
1343 constant = spirv_builder_const_composite(&ctx->builder, type,
1344 components, num_components);
1345 } else {
1346 assert(num_components == 1);
1347 if (bit_size == 1)
1348 constant = spirv_builder_const_bool(&ctx->builder,
1349 load_const->value[0].b);
1350 else
1351 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1352 }
1353
1354 store_ssa_def(ctx, &load_const->def, constant);
1355 }
1356
1357 static void
1358 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1359 {
1360 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1361 assert(const_block_index); // no dynamic indexing for now
1362 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1363
1364 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1365 if (const_offset) {
1366 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1367 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1368 SpvStorageClassUniform,
1369 uvec4_type);
1370
1371 unsigned idx = const_offset->u32;
1372 SpvId member = emit_uint_const(ctx, 32, 0);
1373 SpvId offset = emit_uint_const(ctx, 32, idx);
1374 SpvId offsets[] = { member, offset };
1375 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1376 ctx->ubos[0], offsets,
1377 ARRAY_SIZE(offsets));
1378 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1379
1380 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1381 unsigned num_components = nir_dest_num_components(intr->dest);
1382 if (num_components == 1) {
1383 uint32_t components[] = { 0 };
1384 result = spirv_builder_emit_composite_extract(&ctx->builder,
1385 type,
1386 result, components,
1387 1);
1388 } else if (num_components < 4) {
1389 SpvId constituents[num_components];
1390 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1391 for (uint32_t i = 0; i < num_components; ++i)
1392 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1393 uint_type,
1394 result, &i,
1395 1);
1396
1397 result = spirv_builder_emit_composite_construct(&ctx->builder,
1398 type,
1399 constituents,
1400 num_components);
1401 }
1402
1403 if (nir_dest_bit_size(intr->dest) == 1)
1404 result = uvec_to_bvec(ctx, result, num_components);
1405
1406 store_dest(ctx, &intr->dest, result, nir_type_uint);
1407 } else
1408 unreachable("uniform-addressing not yet supported");
1409 }
1410
1411 static void
1412 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1413 {
1414 assert(ctx->block_started);
1415 spirv_builder_emit_kill(&ctx->builder);
1416 /* discard is weird in NIR, so let's just create an unreachable block after
1417 it and hope that the vulkan driver will DCE any instructinos in it. */
1418 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1419 }
1420
1421 static void
1422 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1423 {
1424 SpvId ptr = get_src(ctx, intr->src);
1425
1426 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1427 SpvId result = spirv_builder_emit_load(&ctx->builder,
1428 get_glsl_type(ctx, var->type),
1429 ptr);
1430 unsigned num_components = nir_dest_num_components(intr->dest);
1431 unsigned bit_size = nir_dest_bit_size(intr->dest);
1432 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1433 store_dest(ctx, &intr->dest, result, nir_type_uint);
1434 }
1435
1436 static void
1437 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1438 {
1439 SpvId ptr = get_src(ctx, &intr->src[0]);
1440 SpvId src = get_src(ctx, &intr->src[1]);
1441
1442 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1443 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1444 SpvId result = emit_bitcast(ctx, type, src);
1445 spirv_builder_emit_store(&ctx->builder, ptr, result);
1446 }
1447
1448 static SpvId
1449 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1450 SpvStorageClass storage_class,
1451 const char *name, SpvBuiltIn builtin)
1452 {
1453 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1454 storage_class,
1455 var_type);
1456 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1457 storage_class);
1458 spirv_builder_emit_name(&ctx->builder, var, name);
1459 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1460
1461 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1462 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1463 return var;
1464 }
1465
1466 static void
1467 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1468 {
1469 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1470 if (!ctx->front_face_var)
1471 ctx->front_face_var = create_builtin_var(ctx, var_type,
1472 SpvStorageClassInput,
1473 "gl_FrontFacing",
1474 SpvBuiltInFrontFacing);
1475
1476 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1477 ctx->front_face_var);
1478 assert(1 == nir_dest_num_components(intr->dest));
1479 store_dest(ctx, &intr->dest, result, nir_type_bool);
1480 }
1481
1482 static void
1483 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1484 {
1485 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1486 if (!ctx->instance_id_var)
1487 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1488 SpvStorageClassInput,
1489 "gl_InstanceId",
1490 SpvBuiltInInstanceIndex);
1491
1492 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1493 ctx->instance_id_var);
1494 assert(1 == nir_dest_num_components(intr->dest));
1495 store_dest(ctx, &intr->dest, result, nir_type_uint);
1496 }
1497
1498 static void
1499 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1500 {
1501 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1502 if (!ctx->vertex_id_var)
1503 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1504 SpvStorageClassInput,
1505 "gl_VertexID",
1506 SpvBuiltInVertexIndex);
1507
1508 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1509 ctx->vertex_id_var);
1510 assert(1 == nir_dest_num_components(intr->dest));
1511 store_dest(ctx, &intr->dest, result, nir_type_uint);
1512 }
1513
1514 static void
1515 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1516 {
1517 switch (intr->intrinsic) {
1518 case nir_intrinsic_load_ubo:
1519 emit_load_ubo(ctx, intr);
1520 break;
1521
1522 case nir_intrinsic_discard:
1523 emit_discard(ctx, intr);
1524 break;
1525
1526 case nir_intrinsic_load_deref:
1527 emit_load_deref(ctx, intr);
1528 break;
1529
1530 case nir_intrinsic_store_deref:
1531 emit_store_deref(ctx, intr);
1532 break;
1533
1534 case nir_intrinsic_load_front_face:
1535 emit_load_front_face(ctx, intr);
1536 break;
1537
1538 case nir_intrinsic_load_instance_id:
1539 emit_load_instance_id(ctx, intr);
1540 break;
1541
1542 case nir_intrinsic_load_vertex_id:
1543 emit_load_vertex_id(ctx, intr);
1544 break;
1545
1546 default:
1547 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1548 nir_intrinsic_infos[intr->intrinsic].name);
1549 unreachable("unsupported intrinsic");
1550 }
1551 }
1552
1553 static void
1554 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1555 {
1556 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1557 undef->def.num_components);
1558
1559 store_ssa_def(ctx, &undef->def,
1560 spirv_builder_emit_undef(&ctx->builder, type));
1561 }
1562
1563 static SpvId
1564 get_src_float(struct ntv_context *ctx, nir_src *src)
1565 {
1566 SpvId def = get_src(ctx, src);
1567 unsigned num_components = nir_src_num_components(*src);
1568 unsigned bit_size = nir_src_bit_size(*src);
1569 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1570 }
1571
1572 static SpvId
1573 get_src_int(struct ntv_context *ctx, nir_src *src)
1574 {
1575 SpvId def = get_src(ctx, src);
1576 unsigned num_components = nir_src_num_components(*src);
1577 unsigned bit_size = nir_src_bit_size(*src);
1578 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1579 }
1580
1581 static void
1582 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1583 {
1584 assert(tex->op == nir_texop_tex ||
1585 tex->op == nir_texop_txb ||
1586 tex->op == nir_texop_txl ||
1587 tex->op == nir_texop_txd ||
1588 tex->op == nir_texop_txf ||
1589 tex->op == nir_texop_txf_ms ||
1590 tex->op == nir_texop_txs);
1591 assert(tex->texture_index == tex->sampler_index);
1592
1593 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1594 offset = 0, sample = 0;
1595 unsigned coord_components = 0;
1596 for (unsigned i = 0; i < tex->num_srcs; i++) {
1597 switch (tex->src[i].src_type) {
1598 case nir_tex_src_coord:
1599 if (tex->op == nir_texop_txf ||
1600 tex->op == nir_texop_txf_ms)
1601 coord = get_src_int(ctx, &tex->src[i].src);
1602 else
1603 coord = get_src_float(ctx, &tex->src[i].src);
1604 coord_components = nir_src_num_components(tex->src[i].src);
1605 break;
1606
1607 case nir_tex_src_projector:
1608 assert(nir_src_num_components(tex->src[i].src) == 1);
1609 proj = get_src_float(ctx, &tex->src[i].src);
1610 assert(proj != 0);
1611 break;
1612
1613 case nir_tex_src_offset:
1614 offset = get_src_int(ctx, &tex->src[i].src);
1615 break;
1616
1617 case nir_tex_src_bias:
1618 assert(tex->op == nir_texop_txb);
1619 bias = get_src_float(ctx, &tex->src[i].src);
1620 assert(bias != 0);
1621 break;
1622
1623 case nir_tex_src_lod:
1624 assert(nir_src_num_components(tex->src[i].src) == 1);
1625 if (tex->op == nir_texop_txf ||
1626 tex->op == nir_texop_txf_ms ||
1627 tex->op == nir_texop_txs)
1628 lod = get_src_int(ctx, &tex->src[i].src);
1629 else
1630 lod = get_src_float(ctx, &tex->src[i].src);
1631 assert(lod != 0);
1632 break;
1633
1634 case nir_tex_src_ms_index:
1635 assert(nir_src_num_components(tex->src[i].src) == 1);
1636 sample = get_src_int(ctx, &tex->src[i].src);
1637 break;
1638
1639 case nir_tex_src_comparator:
1640 assert(nir_src_num_components(tex->src[i].src) == 1);
1641 dref = get_src_float(ctx, &tex->src[i].src);
1642 assert(dref != 0);
1643 break;
1644
1645 case nir_tex_src_ddx:
1646 dx = get_src_float(ctx, &tex->src[i].src);
1647 assert(dx != 0);
1648 break;
1649
1650 case nir_tex_src_ddy:
1651 dy = get_src_float(ctx, &tex->src[i].src);
1652 assert(dy != 0);
1653 break;
1654
1655 default:
1656 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1657 unreachable("unknown texture source");
1658 }
1659 }
1660
1661 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1662 lod = emit_float_const(ctx, 32, 0.0f);
1663 assert(lod != 0);
1664 }
1665
1666 SpvId image_type = ctx->image_types[tex->texture_index];
1667 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1668 image_type);
1669
1670 assert(ctx->samplers_used & (1u << tex->texture_index));
1671 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1672 ctx->samplers[tex->texture_index]);
1673
1674 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1675
1676 if (tex->op == nir_texop_txs) {
1677 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1678 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1679 dest_type, image,
1680 lod);
1681 store_dest(ctx, &tex->dest, result, tex->dest_type);
1682 return;
1683 }
1684
1685 if (proj && coord_components > 0) {
1686 SpvId constituents[coord_components + 1];
1687 if (coord_components == 1)
1688 constituents[0] = coord;
1689 else {
1690 assert(coord_components > 1);
1691 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1692 for (uint32_t i = 0; i < coord_components; ++i)
1693 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1694 float_type,
1695 coord,
1696 &i, 1);
1697 }
1698
1699 constituents[coord_components++] = proj;
1700
1701 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1702 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1703 vec_type,
1704 constituents,
1705 coord_components);
1706 }
1707
1708 SpvId actual_dest_type = dest_type;
1709 if (dref)
1710 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1711
1712 SpvId result;
1713 if (tex->op == nir_texop_txf ||
1714 tex->op == nir_texop_txf_ms) {
1715 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1716 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1717 image, coord, lod, sample);
1718 } else {
1719 result = spirv_builder_emit_image_sample(&ctx->builder,
1720 actual_dest_type, load,
1721 coord,
1722 proj != 0,
1723 lod, bias, dref, dx, dy,
1724 offset);
1725 }
1726
1727 spirv_builder_emit_decoration(&ctx->builder, result,
1728 SpvDecorationRelaxedPrecision);
1729
1730 if (dref && nir_dest_num_components(tex->dest) > 1) {
1731 SpvId components[4] = { result, result, result, result };
1732 result = spirv_builder_emit_composite_construct(&ctx->builder,
1733 dest_type,
1734 components,
1735 4);
1736 }
1737
1738 store_dest(ctx, &tex->dest, result, tex->dest_type);
1739 }
1740
1741 static void
1742 start_block(struct ntv_context *ctx, SpvId label)
1743 {
1744 /* terminate previous block if needed */
1745 if (ctx->block_started)
1746 spirv_builder_emit_branch(&ctx->builder, label);
1747
1748 /* start new block */
1749 spirv_builder_label(&ctx->builder, label);
1750 ctx->block_started = true;
1751 }
1752
1753 static void
1754 branch(struct ntv_context *ctx, SpvId label)
1755 {
1756 assert(ctx->block_started);
1757 spirv_builder_emit_branch(&ctx->builder, label);
1758 ctx->block_started = false;
1759 }
1760
1761 static void
1762 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1763 SpvId else_id)
1764 {
1765 assert(ctx->block_started);
1766 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1767 then_id, else_id);
1768 ctx->block_started = false;
1769 }
1770
1771 static void
1772 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1773 {
1774 switch (jump->type) {
1775 case nir_jump_break:
1776 assert(ctx->loop_break);
1777 branch(ctx, ctx->loop_break);
1778 break;
1779
1780 case nir_jump_continue:
1781 assert(ctx->loop_cont);
1782 branch(ctx, ctx->loop_cont);
1783 break;
1784
1785 default:
1786 unreachable("Unsupported jump type\n");
1787 }
1788 }
1789
1790 static void
1791 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1792 {
1793 assert(deref->deref_type == nir_deref_type_var);
1794
1795 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1796 assert(he);
1797 SpvId result = (SpvId)(intptr_t)he->data;
1798 store_dest_raw(ctx, &deref->dest, result);
1799 }
1800
1801 static void
1802 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1803 {
1804 assert(deref->deref_type == nir_deref_type_array);
1805 nir_variable *var = nir_deref_instr_get_variable(deref);
1806
1807 SpvStorageClass storage_class;
1808 switch (var->data.mode) {
1809 case nir_var_shader_in:
1810 storage_class = SpvStorageClassInput;
1811 break;
1812
1813 case nir_var_shader_out:
1814 storage_class = SpvStorageClassOutput;
1815 break;
1816
1817 default:
1818 unreachable("Unsupported nir_variable_mode\n");
1819 }
1820
1821 SpvId index = get_src(ctx, &deref->arr.index);
1822
1823 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1824 storage_class,
1825 get_glsl_type(ctx, deref->type));
1826
1827 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1828 ptr_type,
1829 get_src(ctx, &deref->parent),
1830 &index, 1);
1831 /* uint is a bit of a lie here, it's really just an opaque type */
1832 store_dest(ctx, &deref->dest, result, nir_type_uint);
1833 }
1834
1835 static void
1836 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1837 {
1838 switch (deref->deref_type) {
1839 case nir_deref_type_var:
1840 emit_deref_var(ctx, deref);
1841 break;
1842
1843 case nir_deref_type_array:
1844 emit_deref_array(ctx, deref);
1845 break;
1846
1847 default:
1848 unreachable("unexpected deref_type");
1849 }
1850 }
1851
1852 static void
1853 emit_block(struct ntv_context *ctx, struct nir_block *block)
1854 {
1855 start_block(ctx, block_label(ctx, block));
1856 nir_foreach_instr(instr, block) {
1857 switch (instr->type) {
1858 case nir_instr_type_alu:
1859 emit_alu(ctx, nir_instr_as_alu(instr));
1860 break;
1861 case nir_instr_type_intrinsic:
1862 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1863 break;
1864 case nir_instr_type_load_const:
1865 emit_load_const(ctx, nir_instr_as_load_const(instr));
1866 break;
1867 case nir_instr_type_ssa_undef:
1868 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1869 break;
1870 case nir_instr_type_tex:
1871 emit_tex(ctx, nir_instr_as_tex(instr));
1872 break;
1873 case nir_instr_type_phi:
1874 unreachable("nir_instr_type_phi not supported");
1875 break;
1876 case nir_instr_type_jump:
1877 emit_jump(ctx, nir_instr_as_jump(instr));
1878 break;
1879 case nir_instr_type_call:
1880 unreachable("nir_instr_type_call not supported");
1881 break;
1882 case nir_instr_type_parallel_copy:
1883 unreachable("nir_instr_type_parallel_copy not supported");
1884 break;
1885 case nir_instr_type_deref:
1886 emit_deref(ctx, nir_instr_as_deref(instr));
1887 break;
1888 }
1889 }
1890 }
1891
1892 static void
1893 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1894
1895 static SpvId
1896 get_src_bool(struct ntv_context *ctx, nir_src *src)
1897 {
1898 assert(nir_src_bit_size(*src) == 1);
1899 return get_src(ctx, src);
1900 }
1901
1902 static void
1903 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1904 {
1905 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1906
1907 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1908 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1909 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1910 SpvId else_id = endif_id;
1911
1912 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1913 if (has_else) {
1914 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1915 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1916 }
1917
1918 /* create a header-block */
1919 start_block(ctx, header_id);
1920 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1921 SpvSelectionControlMaskNone);
1922 branch_conditional(ctx, condition, then_id, else_id);
1923
1924 emit_cf_list(ctx, &if_stmt->then_list);
1925
1926 if (has_else) {
1927 if (ctx->block_started)
1928 branch(ctx, endif_id);
1929
1930 emit_cf_list(ctx, &if_stmt->else_list);
1931 }
1932
1933 start_block(ctx, endif_id);
1934 }
1935
1936 static void
1937 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1938 {
1939 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1940 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1941 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1942 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1943
1944 /* create a header-block */
1945 start_block(ctx, header_id);
1946 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1947 branch(ctx, begin_id);
1948
1949 SpvId save_break = ctx->loop_break;
1950 SpvId save_cont = ctx->loop_cont;
1951 ctx->loop_break = break_id;
1952 ctx->loop_cont = cont_id;
1953
1954 emit_cf_list(ctx, &loop->body);
1955
1956 ctx->loop_break = save_break;
1957 ctx->loop_cont = save_cont;
1958
1959 branch(ctx, cont_id);
1960 start_block(ctx, cont_id);
1961 branch(ctx, header_id);
1962
1963 start_block(ctx, break_id);
1964 }
1965
1966 static void
1967 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1968 {
1969 foreach_list_typed(nir_cf_node, node, node, list) {
1970 switch (node->type) {
1971 case nir_cf_node_block:
1972 emit_block(ctx, nir_cf_node_as_block(node));
1973 break;
1974
1975 case nir_cf_node_if:
1976 emit_if(ctx, nir_cf_node_as_if(node));
1977 break;
1978
1979 case nir_cf_node_loop:
1980 emit_loop(ctx, nir_cf_node_as_loop(node));
1981 break;
1982
1983 case nir_cf_node_function:
1984 unreachable("nir_cf_node_function not supported");
1985 break;
1986 }
1987 }
1988 }
1989
1990 struct spirv_shader *
1991 nir_to_spirv(struct nir_shader *s)
1992 {
1993 struct spirv_shader *ret = NULL;
1994
1995 struct ntv_context ctx = {};
1996
1997 switch (s->info.stage) {
1998 case MESA_SHADER_VERTEX:
1999 case MESA_SHADER_FRAGMENT:
2000 case MESA_SHADER_COMPUTE:
2001 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
2002 break;
2003
2004 case MESA_SHADER_TESS_CTRL:
2005 case MESA_SHADER_TESS_EVAL:
2006 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
2007 break;
2008
2009 case MESA_SHADER_GEOMETRY:
2010 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
2011 break;
2012
2013 default:
2014 unreachable("invalid stage");
2015 }
2016
2017 // TODO: only enable when needed
2018 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2019 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
2020 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
2021 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
2022 }
2023
2024 ctx.stage = s->info.stage;
2025 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
2026 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
2027
2028 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
2029 SpvMemoryModelGLSL450);
2030
2031 SpvExecutionModel exec_model;
2032 switch (s->info.stage) {
2033 case MESA_SHADER_VERTEX:
2034 exec_model = SpvExecutionModelVertex;
2035 break;
2036 case MESA_SHADER_TESS_CTRL:
2037 exec_model = SpvExecutionModelTessellationControl;
2038 break;
2039 case MESA_SHADER_TESS_EVAL:
2040 exec_model = SpvExecutionModelTessellationEvaluation;
2041 break;
2042 case MESA_SHADER_GEOMETRY:
2043 exec_model = SpvExecutionModelGeometry;
2044 break;
2045 case MESA_SHADER_FRAGMENT:
2046 exec_model = SpvExecutionModelFragment;
2047 break;
2048 case MESA_SHADER_COMPUTE:
2049 exec_model = SpvExecutionModelGLCompute;
2050 break;
2051 default:
2052 unreachable("invalid stage");
2053 }
2054
2055 SpvId type_void = spirv_builder_type_void(&ctx.builder);
2056 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
2057 NULL, 0);
2058 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
2059 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2060
2061 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
2062 _mesa_key_pointer_equal);
2063
2064 nir_foreach_variable(var, &s->inputs)
2065 emit_input(&ctx, var);
2066
2067 nir_foreach_variable(var, &s->outputs)
2068 emit_output(&ctx, var);
2069
2070 nir_foreach_variable(var, &s->uniforms)
2071 emit_uniform(&ctx, var);
2072
2073 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2074 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2075 SpvExecutionModeOriginUpperLeft);
2076 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2077 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2078 SpvExecutionModeDepthReplacing);
2079 }
2080
2081
2082 spirv_builder_function(&ctx.builder, entry_point, type_void,
2083 SpvFunctionControlMaskNone,
2084 type_main);
2085
2086 nir_function_impl *entry = nir_shader_get_entrypoint(s);
2087 nir_metadata_require(entry, nir_metadata_block_index);
2088
2089 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
2090 if (!ctx.defs)
2091 goto fail;
2092 ctx.num_defs = entry->ssa_alloc;
2093
2094 nir_index_local_regs(entry);
2095 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
2096 if (!ctx.regs)
2097 goto fail;
2098 ctx.num_regs = entry->reg_alloc;
2099
2100 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
2101 if (!block_ids)
2102 goto fail;
2103
2104 for (int i = 0; i < entry->num_blocks; ++i)
2105 block_ids[i] = spirv_builder_new_id(&ctx.builder);
2106
2107 ctx.block_ids = block_ids;
2108 ctx.num_blocks = entry->num_blocks;
2109
2110 /* emit a block only for the variable declarations */
2111 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2112 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2113 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
2114 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2115 SpvStorageClassFunction,
2116 type);
2117 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2118 SpvStorageClassFunction);
2119
2120 ctx.regs[reg->index] = var;
2121 }
2122
2123 emit_cf_list(&ctx, &entry->body);
2124
2125 free(ctx.defs);
2126
2127 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2128 spirv_builder_function_end(&ctx.builder);
2129
2130 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2131 "main", ctx.entry_ifaces,
2132 ctx.num_entry_ifaces);
2133
2134 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2135
2136 ret = CALLOC_STRUCT(spirv_shader);
2137 if (!ret)
2138 goto fail;
2139
2140 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2141 if (!ret->words)
2142 goto fail;
2143
2144 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2145 assert(ret->num_words == num_words);
2146
2147 return ret;
2148
2149 fail:
2150
2151 if (ret)
2152 spirv_shader_delete(ret);
2153
2154 if (ctx.vars)
2155 _mesa_hash_table_destroy(ctx.vars, NULL);
2156
2157 return NULL;
2158 }
2159
2160 void
2161 spirv_shader_delete(struct spirv_shader *s)
2162 {
2163 FREE(s->words);
2164 FREE(s);
2165 }