zink: generically handle matrix types
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 /* this consistently maps slots to a zero-indexed value to avoid wasting slots */
33 static unsigned slot_pack_map[] = {
34 /* Position is builtin */
35 [VARYING_SLOT_POS] = UINT_MAX,
36 [VARYING_SLOT_COL0] = 0, /* input/output */
37 [VARYING_SLOT_COL1] = 1, /* input/output */
38 [VARYING_SLOT_FOGC] = 2, /* input/output */
39 /* TEX0-7 are deprecated, so we put them at the end of the range and hope nobody uses them all */
40 [VARYING_SLOT_TEX0] = VARYING_SLOT_VAR0 - 1, /* input/output */
41 [VARYING_SLOT_TEX1] = VARYING_SLOT_VAR0 - 2,
42 [VARYING_SLOT_TEX2] = VARYING_SLOT_VAR0 - 3,
43 [VARYING_SLOT_TEX3] = VARYING_SLOT_VAR0 - 4,
44 [VARYING_SLOT_TEX4] = VARYING_SLOT_VAR0 - 5,
45 [VARYING_SLOT_TEX5] = VARYING_SLOT_VAR0 - 6,
46 [VARYING_SLOT_TEX6] = VARYING_SLOT_VAR0 - 7,
47 [VARYING_SLOT_TEX7] = VARYING_SLOT_VAR0 - 8,
48
49 /* PointSize is builtin */
50 [VARYING_SLOT_PSIZ] = UINT_MAX,
51
52 [VARYING_SLOT_BFC0] = 3, /* output only */
53 [VARYING_SLOT_BFC1] = 4, /* output only */
54 [VARYING_SLOT_EDGE] = 5, /* output only */
55 [VARYING_SLOT_CLIP_VERTEX] = 6, /* output only */
56
57 /* ClipDistance is builtin */
58 [VARYING_SLOT_CLIP_DIST0] = UINT_MAX,
59 [VARYING_SLOT_CLIP_DIST1] = UINT_MAX,
60
61 /* CullDistance is builtin */
62 [VARYING_SLOT_CULL_DIST0] = UINT_MAX, /* input/output */
63 [VARYING_SLOT_CULL_DIST1] = UINT_MAX, /* never actually used */
64
65 /* PrimitiveId is builtin */
66 [VARYING_SLOT_PRIMITIVE_ID] = UINT_MAX,
67
68 /* Layer is builtin */
69 [VARYING_SLOT_LAYER] = UINT_MAX, /* input/output */
70
71 /* ViewportIndex is builtin */
72 [VARYING_SLOT_VIEWPORT] = UINT_MAX, /* input/output */
73
74 /* FrontFacing is builtin */
75 [VARYING_SLOT_FACE] = UINT_MAX,
76
77 /* PointCoord is builtin */
78 [VARYING_SLOT_PNTC] = UINT_MAX, /* input only */
79
80 /* TessLevelOuter is builtin */
81 [VARYING_SLOT_TESS_LEVEL_OUTER] = UINT_MAX,
82 /* TessLevelInner is builtin */
83 [VARYING_SLOT_TESS_LEVEL_INNER] = UINT_MAX,
84
85 [VARYING_SLOT_BOUNDING_BOX0] = 7, /* Only appears as TCS output. */
86 [VARYING_SLOT_BOUNDING_BOX1] = 8, /* Only appears as TCS output. */
87 [VARYING_SLOT_VIEW_INDEX] = 9, /* input/output */
88 [VARYING_SLOT_VIEWPORT_MASK] = 10, /* output only */
89 };
90 #define NTV_MIN_RESERVED_SLOTS 11
91
92 struct ntv_context {
93 void *mem_ctx;
94
95 struct spirv_builder builder;
96
97 SpvId GLSL_std_450;
98
99 gl_shader_stage stage;
100
101 SpvId ubos[128];
102 size_t num_ubos;
103 SpvId image_types[PIPE_MAX_SAMPLERS];
104 SpvId samplers[PIPE_MAX_SAMPLERS];
105 unsigned samplers_used : PIPE_MAX_SAMPLERS;
106 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
107 size_t num_entry_ifaces;
108
109 SpvId *defs;
110 size_t num_defs;
111
112 SpvId *regs;
113 size_t num_regs;
114
115 struct hash_table *vars; /* nir_variable -> SpvId */
116 struct hash_table *so_outputs; /* pipe_stream_output -> SpvId */
117 unsigned outputs[VARYING_SLOT_MAX];
118 const struct glsl_type *so_output_gl_types[VARYING_SLOT_MAX];
119 SpvId so_output_types[VARYING_SLOT_MAX];
120
121 const SpvId *block_ids;
122 size_t num_blocks;
123 bool block_started;
124 SpvId loop_break, loop_cont;
125
126 SpvId front_face_var, instance_id_var, vertex_id_var;
127 #ifndef NDEBUG
128 bool seen_texcoord[8]; //whether we've seen a VARYING_SLOT_TEX[n] this pass
129 #endif
130 };
131
132 static SpvId
133 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
134 unsigned num_components, float value);
135
136 static SpvId
137 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
138 unsigned num_components, uint32_t value);
139
140 static SpvId
141 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
142 unsigned num_components, int32_t value);
143
144 static SpvId
145 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
146
147 static SpvId
148 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
149 SpvId src0, SpvId src1);
150
151 static SpvId
152 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
153 SpvId src0, SpvId src1, SpvId src2);
154
155 static SpvId
156 get_bvec_type(struct ntv_context *ctx, int num_components)
157 {
158 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
159 if (num_components > 1)
160 return spirv_builder_type_vector(&ctx->builder, bool_type,
161 num_components);
162
163 assert(num_components == 1);
164 return bool_type;
165 }
166
167 static SpvId
168 block_label(struct ntv_context *ctx, nir_block *block)
169 {
170 assert(block->index < ctx->num_blocks);
171 return ctx->block_ids[block->index];
172 }
173
174 static SpvId
175 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
176 {
177 assert(bit_size == 32);
178 return spirv_builder_const_float(&ctx->builder, bit_size, value);
179 }
180
181 static SpvId
182 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
183 {
184 assert(bit_size == 32);
185 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
186 }
187
188 static SpvId
189 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
190 {
191 assert(bit_size == 32);
192 return spirv_builder_const_int(&ctx->builder, bit_size, value);
193 }
194
195 static SpvId
196 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
197 {
198 assert(bit_size == 32); // only 32-bit floats supported so far
199
200 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
201 if (num_components > 1)
202 return spirv_builder_type_vector(&ctx->builder, float_type,
203 num_components);
204
205 assert(num_components == 1);
206 return float_type;
207 }
208
209 static SpvId
210 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
211 {
212 assert(bit_size == 32); // only 32-bit ints supported so far
213
214 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
215 if (num_components > 1)
216 return spirv_builder_type_vector(&ctx->builder, int_type,
217 num_components);
218
219 assert(num_components == 1);
220 return int_type;
221 }
222
223 static SpvId
224 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
225 {
226 assert(bit_size == 32); // only 32-bit uints supported so far
227
228 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
229 if (num_components > 1)
230 return spirv_builder_type_vector(&ctx->builder, uint_type,
231 num_components);
232
233 assert(num_components == 1);
234 return uint_type;
235 }
236
237 static SpvId
238 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
239 {
240 unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
241 return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
242 }
243
244 static SpvId
245 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
246 {
247 switch (type) {
248 case GLSL_TYPE_BOOL:
249 return spirv_builder_type_bool(&ctx->builder);
250
251 case GLSL_TYPE_FLOAT:
252 return spirv_builder_type_float(&ctx->builder, 32);
253
254 case GLSL_TYPE_INT:
255 return spirv_builder_type_int(&ctx->builder, 32);
256
257 case GLSL_TYPE_UINT:
258 return spirv_builder_type_uint(&ctx->builder, 32);
259 /* TODO: handle more types */
260
261 default:
262 unreachable("unknown GLSL type");
263 }
264 }
265
266 static SpvId
267 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
268 {
269 assert(type);
270 if (glsl_type_is_scalar(type))
271 return get_glsl_basetype(ctx, glsl_get_base_type(type));
272
273 if (glsl_type_is_vector(type))
274 return spirv_builder_type_vector(&ctx->builder,
275 get_glsl_basetype(ctx, glsl_get_base_type(type)),
276 glsl_get_vector_elements(type));
277
278 if (glsl_type_is_array(type)) {
279 SpvId ret = spirv_builder_type_array(&ctx->builder,
280 get_glsl_type(ctx, glsl_get_array_element(type)),
281 emit_uint_const(ctx, 32, glsl_get_length(type)));
282 uint32_t stride = glsl_get_explicit_stride(type);
283 if (stride)
284 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
285 return ret;
286 }
287
288 if (glsl_type_is_matrix(type))
289 return spirv_builder_type_matrix(&ctx->builder,
290 spirv_builder_type_vector(&ctx->builder,
291 get_glsl_basetype(ctx, glsl_get_base_type(type)),
292 glsl_get_vector_elements(type)),
293 glsl_get_matrix_columns(type));
294
295 unreachable("we shouldn't get here, I think...");
296 }
297
298 static inline unsigned
299 handle_slot(struct ntv_context *ctx, unsigned slot)
300 {
301 unsigned orig = slot;
302 if (slot < VARYING_SLOT_VAR0) {
303 #ifndef NDEBUG
304 if (slot >= VARYING_SLOT_TEX0 && slot <= VARYING_SLOT_TEX7)
305 ctx->seen_texcoord[slot - VARYING_SLOT_TEX0] = true;
306 #endif
307 slot = slot_pack_map[slot];
308 if (slot == UINT_MAX)
309 debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(orig));
310 } else {
311 slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
312 assert(slot <= VARYING_SLOT_VAR0 - 8 ||
313 !ctx->seen_texcoord[VARYING_SLOT_VAR0 - slot - 1]);
314
315 }
316 assert(slot < VARYING_SLOT_VAR0);
317 return slot;
318 }
319
320 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
321 case VARYING_SLOT_##SLOT: \
322 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
323 break
324
325
326 static void
327 emit_input(struct ntv_context *ctx, struct nir_variable *var)
328 {
329 SpvId var_type = get_glsl_type(ctx, var->type);
330 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
331 SpvStorageClassInput,
332 var_type);
333 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
334 SpvStorageClassInput);
335
336 if (var->name)
337 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
338
339 if (ctx->stage == MESA_SHADER_FRAGMENT) {
340 unsigned slot = var->data.location;
341 switch (slot) {
342 HANDLE_EMIT_BUILTIN(POS, FragCoord);
343 HANDLE_EMIT_BUILTIN(PNTC, PointCoord);
344 HANDLE_EMIT_BUILTIN(LAYER, Layer);
345 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
346 HANDLE_EMIT_BUILTIN(CLIP_DIST0, ClipDistance);
347 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
348 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
349 HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
350
351 default:
352 slot = handle_slot(ctx, slot);
353 spirv_builder_emit_location(&ctx->builder, var_id, slot);
354 }
355 } else {
356 spirv_builder_emit_location(&ctx->builder, var_id,
357 var->data.driver_location);
358 }
359
360 if (var->data.location_frac)
361 spirv_builder_emit_component(&ctx->builder, var_id,
362 var->data.location_frac);
363
364 if (var->data.interpolation == INTERP_MODE_FLAT)
365 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
366
367 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
368
369 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
370 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
371 }
372
373 static void
374 emit_output(struct ntv_context *ctx, struct nir_variable *var)
375 {
376 SpvId var_type = get_glsl_type(ctx, var->type);
377 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
378 SpvStorageClassOutput,
379 var_type);
380 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
381 SpvStorageClassOutput);
382 if (var->name)
383 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
384
385
386 if (ctx->stage == MESA_SHADER_VERTEX) {
387 unsigned slot = var->data.location;
388 switch (slot) {
389 HANDLE_EMIT_BUILTIN(POS, Position);
390 HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
391 HANDLE_EMIT_BUILTIN(LAYER, Layer);
392 HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
393 HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
394 HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
395 HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
396 HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
397
398 case VARYING_SLOT_CLIP_DIST0:
399 assert(glsl_type_is_array(var->type));
400 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
401 /* this can be as large as 2x vec4, which requires 2 slots */
402 ctx->outputs[VARYING_SLOT_CLIP_DIST1] = var_id;
403 ctx->so_output_gl_types[VARYING_SLOT_CLIP_DIST1] = var->type;
404 ctx->so_output_types[VARYING_SLOT_CLIP_DIST1] = var_type;
405 break;
406
407 default:
408 slot = handle_slot(ctx, slot);
409 spirv_builder_emit_location(&ctx->builder, var_id, slot);
410 }
411 ctx->outputs[var->data.location] = var_id;
412 ctx->so_output_gl_types[var->data.location] = var->type;
413 ctx->so_output_types[var->data.location] = var_type;
414 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
415 if (var->data.location >= FRAG_RESULT_DATA0) {
416 spirv_builder_emit_location(&ctx->builder, var_id,
417 var->data.location - FRAG_RESULT_DATA0);
418 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
419 } else {
420 switch (var->data.location) {
421 case FRAG_RESULT_COLOR:
422 unreachable("gl_FragColor should be lowered by now");
423
424 case FRAG_RESULT_DEPTH:
425 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
426 break;
427
428 default:
429 spirv_builder_emit_location(&ctx->builder, var_id,
430 var->data.driver_location);
431 spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
432 }
433 }
434 }
435
436 if (var->data.location_frac)
437 spirv_builder_emit_component(&ctx->builder, var_id,
438 var->data.location_frac);
439
440 switch (var->data.interpolation) {
441 case INTERP_MODE_NONE:
442 case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
443 break;
444 case INTERP_MODE_FLAT:
445 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
446 break;
447 case INTERP_MODE_EXPLICIT:
448 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
449 break;
450 case INTERP_MODE_NOPERSPECTIVE:
451 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
452 break;
453 default:
454 unreachable("unknown interpolation value");
455 }
456
457 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
458
459 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
460 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
461 }
462
463 static SpvDim
464 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
465 {
466 *is_ms = false;
467 switch (gdim) {
468 case GLSL_SAMPLER_DIM_1D:
469 return SpvDim1D;
470 case GLSL_SAMPLER_DIM_2D:
471 return SpvDim2D;
472 case GLSL_SAMPLER_DIM_3D:
473 return SpvDim3D;
474 case GLSL_SAMPLER_DIM_CUBE:
475 return SpvDimCube;
476 case GLSL_SAMPLER_DIM_RECT:
477 return SpvDim2D;
478 case GLSL_SAMPLER_DIM_BUF:
479 return SpvDimBuffer;
480 case GLSL_SAMPLER_DIM_EXTERNAL:
481 return SpvDim2D; /* seems dodgy... */
482 case GLSL_SAMPLER_DIM_MS:
483 *is_ms = true;
484 return SpvDim2D;
485 default:
486 fprintf(stderr, "unknown sampler type %d\n", gdim);
487 break;
488 }
489 return SpvDim2D;
490 }
491
492 uint32_t
493 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
494 {
495 if (stage == MESA_SHADER_NONE ||
496 stage >= MESA_SHADER_COMPUTE) {
497 unreachable("not supported");
498 } else {
499 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
500 PIPE_MAX_SHADER_SAMPLER_VIEWS);
501
502 switch (type) {
503 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
504 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
505 return stage_offset + index;
506
507 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
508 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
509 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
510
511 default:
512 unreachable("unexpected type");
513 }
514 }
515 }
516
517 static void
518 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
519 {
520 const struct glsl_type *type = glsl_without_array(var->type);
521
522 bool is_ms;
523 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
524
525 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
526 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
527 dimension, false,
528 glsl_sampler_type_is_array(type),
529 is_ms, 1,
530 SpvImageFormatUnknown);
531
532 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
533 image_type);
534 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
535 SpvStorageClassUniformConstant,
536 sampled_type);
537
538 if (glsl_type_is_array(var->type)) {
539 for (int i = 0; i < glsl_get_length(var->type); ++i) {
540 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
541 SpvStorageClassUniformConstant);
542
543 if (var->name) {
544 char element_name[100];
545 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
546 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
547 }
548
549 int index = var->data.binding + i;
550 assert(!(ctx->samplers_used & (1 << index)));
551 assert(!ctx->image_types[index]);
552 ctx->image_types[index] = image_type;
553 ctx->samplers[index] = var_id;
554 ctx->samplers_used |= 1 << index;
555
556 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
557 var->data.descriptor_set);
558 int binding = zink_binding(ctx->stage,
559 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
560 var->data.binding + i);
561 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
562 }
563 } else {
564 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
565 SpvStorageClassUniformConstant);
566
567 if (var->name)
568 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
569
570 int index = var->data.binding;
571 assert(!(ctx->samplers_used & (1 << index)));
572 assert(!ctx->image_types[index]);
573 ctx->image_types[index] = image_type;
574 ctx->samplers[index] = var_id;
575 ctx->samplers_used |= 1 << index;
576
577 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
578 var->data.descriptor_set);
579 int binding = zink_binding(ctx->stage,
580 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
581 var->data.binding);
582 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
583 }
584 }
585
586 static void
587 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
588 {
589 uint32_t size = glsl_count_attribute_slots(var->type, false);
590 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
591 SpvId array_length = emit_uint_const(ctx, 32, size);
592 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
593 array_length);
594 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
595
596 // wrap UBO-array in a struct
597 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
598 if (var->name) {
599 char struct_name[100];
600 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
601 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
602 }
603
604 spirv_builder_emit_decoration(&ctx->builder, struct_type,
605 SpvDecorationBlock);
606 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
607
608
609 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
610 SpvStorageClassUniform,
611 struct_type);
612
613 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
614 SpvStorageClassUniform);
615 if (var->name)
616 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
617
618 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
619 ctx->ubos[ctx->num_ubos++] = var_id;
620
621 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
622 var->data.descriptor_set);
623 int binding = zink_binding(ctx->stage,
624 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
625 var->data.binding);
626 spirv_builder_emit_binding(&ctx->builder, var_id, binding);
627 }
628
629 static void
630 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
631 {
632 if (var->data.mode == nir_var_mem_ubo)
633 emit_ubo(ctx, var);
634 else {
635 assert(var->data.mode == nir_var_uniform);
636 if (glsl_type_is_sampler(glsl_without_array(var->type)))
637 emit_sampler(ctx, var);
638 }
639 }
640
641 static SpvId
642 get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_components)
643 {
644 if (bit_size == 1)
645 return get_bvec_type(ctx, num_components);
646 if (bit_size == 32)
647 return get_uvec_type(ctx, bit_size, num_components);
648 unreachable("unhandled register bit size");
649 return 0;
650 }
651
652 static SpvId
653 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
654 {
655 assert(ssa->index < ctx->num_defs);
656 assert(ctx->defs[ssa->index] != 0);
657 return ctx->defs[ssa->index];
658 }
659
660 static SpvId
661 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
662 {
663 assert(reg->index < ctx->num_regs);
664 assert(ctx->regs[reg->index] != 0);
665 return ctx->regs[reg->index];
666 }
667
668 static SpvId
669 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
670 {
671 assert(reg->reg);
672 assert(!reg->indirect);
673 assert(!reg->base_offset);
674
675 SpvId var = get_var_from_reg(ctx, reg->reg);
676 SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components);
677 return spirv_builder_emit_load(&ctx->builder, type, var);
678 }
679
680 static SpvId
681 get_src(struct ntv_context *ctx, nir_src *src)
682 {
683 if (src->is_ssa)
684 return get_src_ssa(ctx, src->ssa);
685 else
686 return get_src_reg(ctx, &src->reg);
687 }
688
689 static SpvId
690 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
691 {
692 assert(!alu->src[src].negate);
693 assert(!alu->src[src].abs);
694
695 SpvId def = get_src(ctx, &alu->src[src].src);
696
697 unsigned used_channels = 0;
698 bool need_swizzle = false;
699 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
700 if (!nir_alu_instr_channel_used(alu, src, i))
701 continue;
702
703 used_channels++;
704
705 if (alu->src[src].swizzle[i] != i)
706 need_swizzle = true;
707 }
708 assert(used_channels != 0);
709
710 unsigned live_channels = nir_src_num_components(alu->src[src].src);
711 if (used_channels != live_channels)
712 need_swizzle = true;
713
714 if (!need_swizzle)
715 return def;
716
717 int bit_size = nir_src_bit_size(alu->src[src].src);
718 assert(bit_size == 1 || bit_size == 32);
719
720 SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
721 spirv_builder_type_uint(&ctx->builder, bit_size);
722
723 if (used_channels == 1) {
724 uint32_t indices[] = { alu->src[src].swizzle[0] };
725 return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
726 def, indices,
727 ARRAY_SIZE(indices));
728 } else if (live_channels == 1) {
729 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
730 raw_type,
731 used_channels);
732
733 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
734 for (unsigned i = 0; i < used_channels; ++i)
735 constituents[i] = def;
736
737 return spirv_builder_emit_composite_construct(&ctx->builder,
738 raw_vec_type,
739 constituents,
740 used_channels);
741 } else {
742 SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
743 raw_type,
744 used_channels);
745
746 uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
747 size_t num_components = 0;
748 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
749 if (!nir_alu_instr_channel_used(alu, src, i))
750 continue;
751
752 components[num_components++] = alu->src[src].swizzle[i];
753 }
754
755 return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
756 def, def, components,
757 num_components);
758 }
759 }
760
761 static void
762 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
763 {
764 assert(result != 0);
765 assert(ssa->index < ctx->num_defs);
766 ctx->defs[ssa->index] = result;
767 }
768
769 static SpvId
770 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
771 SpvId if_true, SpvId if_false)
772 {
773 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
774 }
775
776 static SpvId
777 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
778 {
779 SpvId type = get_bvec_type(ctx, num_components);
780 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
781 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
782 }
783
784 static SpvId
785 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
786 {
787 return emit_unop(ctx, SpvOpBitcast, type, value);
788 }
789
790 static SpvId
791 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
792 unsigned num_components)
793 {
794 SpvId type = get_uvec_type(ctx, bit_size, num_components);
795 return emit_bitcast(ctx, type, value);
796 }
797
798 static SpvId
799 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
800 unsigned num_components)
801 {
802 SpvId type = get_ivec_type(ctx, bit_size, num_components);
803 return emit_bitcast(ctx, type, value);
804 }
805
806 static SpvId
807 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
808 unsigned num_components)
809 {
810 SpvId type = get_fvec_type(ctx, bit_size, num_components);
811 return emit_bitcast(ctx, type, value);
812 }
813
814 static void
815 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
816 {
817 SpvId var = get_var_from_reg(ctx, reg->reg);
818 assert(var);
819 spirv_builder_emit_store(&ctx->builder, var, result);
820 }
821
822 static void
823 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
824 {
825 if (dest->is_ssa)
826 store_ssa_def(ctx, &dest->ssa, result);
827 else
828 store_reg_def(ctx, &dest->reg, result);
829 }
830
831 static SpvId
832 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
833 {
834 unsigned num_components = nir_dest_num_components(*dest);
835 unsigned bit_size = nir_dest_bit_size(*dest);
836
837 if (bit_size != 1) {
838 switch (nir_alu_type_get_base_type(type)) {
839 case nir_type_bool:
840 assert("bool should have bit-size 1");
841
842 case nir_type_uint:
843 break; /* nothing to do! */
844
845 case nir_type_int:
846 case nir_type_float:
847 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
848 break;
849
850 default:
851 unreachable("unsupported nir_alu_type");
852 }
853 }
854
855 store_dest_raw(ctx, dest, result);
856 return result;
857 }
858
859 static SpvId
860 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
861 {
862 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
863 }
864
865 /* return the intended xfb output vec type based on base type and vector size */
866 static SpvId
867 get_output_type(struct ntv_context *ctx, unsigned register_index, unsigned num_components)
868 {
869 const struct glsl_type *out_type = ctx->so_output_gl_types[register_index];
870 enum glsl_base_type base_type = glsl_get_base_type(out_type);
871 if (base_type == GLSL_TYPE_ARRAY)
872 base_type = glsl_get_base_type(glsl_without_array(out_type));
873
874 switch (base_type) {
875 case GLSL_TYPE_BOOL:
876 return get_bvec_type(ctx, num_components);
877
878 case GLSL_TYPE_FLOAT:
879 return get_fvec_type(ctx, 32, num_components);
880
881 case GLSL_TYPE_INT:
882 return get_ivec_type(ctx, 32, num_components);
883
884 case GLSL_TYPE_UINT:
885 return get_uvec_type(ctx, 32, num_components);
886
887 default:
888 break;
889 }
890 unreachable("unknown type");
891 return 0;
892 }
893
894 /* for streamout create new outputs, as streamout can be done on individual components,
895 from complete outputs, so we just can't use the created packed outputs */
896 static void
897 emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
898 const struct zink_so_info *so_info)
899 {
900 for (unsigned i = 0; i < so_info->so_info.num_outputs; i++) {
901 struct pipe_stream_output so_output = so_info->so_info.output[i];
902 unsigned slot = so_info->so_info_slots[i];
903 SpvId out_type = get_output_type(ctx, slot, so_output.num_components);
904 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
905 SpvStorageClassOutput,
906 out_type);
907 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
908 SpvStorageClassOutput);
909 char name[10];
910
911 snprintf(name, 10, "xfb%d", i);
912 spirv_builder_emit_name(&ctx->builder, var_id, name);
913 spirv_builder_emit_offset(&ctx->builder, var_id, (so_output.dst_offset * 4));
914 spirv_builder_emit_xfb_buffer(&ctx->builder, var_id, so_output.output_buffer);
915 spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->so_info.stride[so_output.output_buffer] * 4);
916
917 /* output location is incremented by VARYING_SLOT_VAR0 for non-builtins in vtn,
918 * so we need to ensure that the new xfb location slot doesn't conflict with any previously-emitted
919 * outputs.
920 *
921 * if there's no previous outputs that take up user slots (VAR0+) then we can start right after the
922 * glsl builtin reserved slots, otherwise we start just after the adjusted user output slot
923 */
924 uint32_t location = NTV_MIN_RESERVED_SLOTS + i;
925 if (max_output_location >= VARYING_SLOT_VAR0)
926 location = max_output_location - VARYING_SLOT_VAR0 + 1 + i;
927 assert(location < VARYING_SLOT_VAR0);
928 assert(location <= VARYING_SLOT_VAR0 - 8 ||
929 !ctx->seen_texcoord[VARYING_SLOT_VAR0 - location - 1]);
930 spirv_builder_emit_location(&ctx->builder, var_id, location);
931
932 /* note: gl_ClipDistance[4] can the 0-indexed member of VARYING_SLOT_CLIP_DIST1 here,
933 * so this is still the 0 component
934 */
935 if (so_output.start_component)
936 spirv_builder_emit_component(&ctx->builder, var_id, so_output.start_component);
937
938 uint32_t *key = ralloc_size(ctx->mem_ctx, sizeof(uint32_t));
939 *key = (uint32_t)so_output.register_index << 2 | so_output.start_component;
940 _mesa_hash_table_insert(ctx->so_outputs, key, (void *)(intptr_t)var_id);
941
942 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
943 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
944 }
945 }
946
947 static void
948 emit_so_outputs(struct ntv_context *ctx,
949 const struct zink_so_info *so_info)
950 {
951 SpvId loaded_outputs[VARYING_SLOT_MAX] = {};
952 for (unsigned i = 0; i < so_info->so_info.num_outputs; i++) {
953 uint32_t components[NIR_MAX_VEC_COMPONENTS];
954 unsigned slot = so_info->so_info_slots[i];
955 struct pipe_stream_output so_output = so_info->so_info.output[i];
956 uint32_t so_key = (uint32_t) so_output.register_index << 2 | so_output.start_component;
957 struct hash_entry *he = _mesa_hash_table_search(ctx->so_outputs, &so_key);
958 assert(he);
959 SpvId so_output_var_id = (SpvId)(intptr_t)he->data;
960
961 SpvId type = get_output_type(ctx, slot, so_output.num_components);
962 SpvId output = ctx->outputs[slot];
963 SpvId output_type = ctx->so_output_types[slot];
964 const struct glsl_type *out_type = ctx->so_output_gl_types[slot];
965
966 if (!loaded_outputs[slot])
967 loaded_outputs[slot] = spirv_builder_emit_load(&ctx->builder, output_type, output);
968 SpvId src = loaded_outputs[slot];
969
970 SpvId result;
971
972 for (unsigned c = 0; c < so_output.num_components; c++) {
973 components[c] = so_output.start_component + c;
974 /* this is the second half of a 2 * vec4 array */
975 if (ctx->stage == MESA_SHADER_VERTEX && slot == VARYING_SLOT_CLIP_DIST1)
976 components[c] += 4;
977 }
978
979 /* if we're emitting a scalar or the type we're emitting matches the output's original type and we're
980 * emitting the same number of components, then we can skip any sort of conversion here
981 */
982 if (glsl_type_is_scalar(out_type) || (type == output_type && glsl_get_length(out_type) == so_output.num_components))
983 result = src;
984 else {
985 /* OpCompositeExtract can only extract scalars for our use here */
986 if (so_output.num_components == 1) {
987 result = spirv_builder_emit_composite_extract(&ctx->builder, type, src, components, so_output.num_components);
988 } else if (glsl_type_is_vector(out_type)) {
989 /* OpVectorShuffle can select vector members into a differently-sized vector */
990 result = spirv_builder_emit_vector_shuffle(&ctx->builder, type,
991 src, src,
992 components, so_output.num_components);
993 result = emit_unop(ctx, SpvOpBitcast, type, result);
994 } else {
995 /* for arrays, we need to manually extract each desired member
996 * and re-pack them into the desired output type
997 */
998 for (unsigned c = 0; c < so_output.num_components; c++) {
999 uint32_t member[] = { so_output.start_component + c };
1000 SpvId base_type = get_glsl_type(ctx, glsl_without_array(out_type));
1001
1002 if (ctx->stage == MESA_SHADER_VERTEX && slot == VARYING_SLOT_CLIP_DIST1)
1003 member[0] += 4;
1004 components[c] = spirv_builder_emit_composite_extract(&ctx->builder, base_type, src, member, 1);
1005 }
1006 result = spirv_builder_emit_composite_construct(&ctx->builder, type, components, so_output.num_components);
1007 }
1008 }
1009
1010 spirv_builder_emit_store(&ctx->builder, so_output_var_id, result);
1011 }
1012 }
1013
1014 static SpvId
1015 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
1016 SpvId src0, SpvId src1)
1017 {
1018 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
1019 }
1020
1021 static SpvId
1022 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
1023 SpvId src0, SpvId src1, SpvId src2)
1024 {
1025 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
1026 }
1027
1028 static SpvId
1029 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1030 SpvId src)
1031 {
1032 SpvId args[] = { src };
1033 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1034 op, args, ARRAY_SIZE(args));
1035 }
1036
1037 static SpvId
1038 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1039 SpvId src0, SpvId src1)
1040 {
1041 SpvId args[] = { src0, src1 };
1042 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1043 op, args, ARRAY_SIZE(args));
1044 }
1045
1046 static SpvId
1047 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1048 SpvId src0, SpvId src1, SpvId src2)
1049 {
1050 SpvId args[] = { src0, src1, src2 };
1051 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1052 op, args, ARRAY_SIZE(args));
1053 }
1054
1055 static SpvId
1056 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
1057 unsigned num_components, float value)
1058 {
1059 assert(bit_size == 32);
1060
1061 SpvId result = emit_float_const(ctx, bit_size, value);
1062 if (num_components == 1)
1063 return result;
1064
1065 assert(num_components > 1);
1066 SpvId components[num_components];
1067 for (int i = 0; i < num_components; i++)
1068 components[i] = result;
1069
1070 SpvId type = get_fvec_type(ctx, bit_size, num_components);
1071 return spirv_builder_const_composite(&ctx->builder, type, components,
1072 num_components);
1073 }
1074
1075 static SpvId
1076 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
1077 unsigned num_components, uint32_t value)
1078 {
1079 assert(bit_size == 32);
1080
1081 SpvId result = emit_uint_const(ctx, bit_size, value);
1082 if (num_components == 1)
1083 return result;
1084
1085 assert(num_components > 1);
1086 SpvId components[num_components];
1087 for (int i = 0; i < num_components; i++)
1088 components[i] = result;
1089
1090 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1091 return spirv_builder_const_composite(&ctx->builder, type, components,
1092 num_components);
1093 }
1094
1095 static SpvId
1096 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
1097 unsigned num_components, int32_t value)
1098 {
1099 assert(bit_size == 32);
1100
1101 SpvId result = emit_int_const(ctx, bit_size, value);
1102 if (num_components == 1)
1103 return result;
1104
1105 assert(num_components > 1);
1106 SpvId components[num_components];
1107 for (int i = 0; i < num_components; i++)
1108 components[i] = result;
1109
1110 SpvId type = get_ivec_type(ctx, bit_size, num_components);
1111 return spirv_builder_const_composite(&ctx->builder, type, components,
1112 num_components);
1113 }
1114
1115 static inline unsigned
1116 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
1117 {
1118 if (nir_op_infos[instr->op].input_sizes[src] > 0)
1119 return nir_op_infos[instr->op].input_sizes[src];
1120
1121 if (instr->dest.dest.is_ssa)
1122 return instr->dest.dest.ssa.num_components;
1123 else
1124 return instr->dest.dest.reg.reg->num_components;
1125 }
1126
1127 static SpvId
1128 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
1129 {
1130 SpvId raw_value = get_alu_src_raw(ctx, alu, src);
1131
1132 unsigned num_components = alu_instr_src_components(alu, src);
1133 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
1134 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
1135
1136 if (bit_size == 1)
1137 return raw_value;
1138 else {
1139 switch (nir_alu_type_get_base_type(type)) {
1140 case nir_type_bool:
1141 unreachable("bool should have bit-size 1");
1142
1143 case nir_type_int:
1144 return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
1145
1146 case nir_type_uint:
1147 return raw_value;
1148
1149 case nir_type_float:
1150 return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
1151
1152 default:
1153 unreachable("unknown nir_alu_type");
1154 }
1155 }
1156 }
1157
1158 static SpvId
1159 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
1160 {
1161 assert(!alu->dest.saturate);
1162 return store_dest(ctx, &alu->dest.dest, result,
1163 nir_op_infos[alu->op].output_type);
1164 }
1165
1166 static SpvId
1167 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
1168 {
1169 unsigned num_components = nir_dest_num_components(*dest);
1170 unsigned bit_size = nir_dest_bit_size(*dest);
1171
1172 if (bit_size == 1)
1173 return get_bvec_type(ctx, num_components);
1174
1175 switch (nir_alu_type_get_base_type(type)) {
1176 case nir_type_bool:
1177 unreachable("bool should have bit-size 1");
1178
1179 case nir_type_int:
1180 return get_ivec_type(ctx, bit_size, num_components);
1181
1182 case nir_type_uint:
1183 return get_uvec_type(ctx, bit_size, num_components);
1184
1185 case nir_type_float:
1186 return get_fvec_type(ctx, bit_size, num_components);
1187
1188 default:
1189 unreachable("unsupported nir_alu_type");
1190 }
1191 }
1192
1193 static void
1194 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
1195 {
1196 SpvId src[nir_op_infos[alu->op].num_inputs];
1197 unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
1198 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
1199 src[i] = get_alu_src(ctx, alu, i);
1200 in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
1201 }
1202
1203 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
1204 nir_op_infos[alu->op].output_type);
1205 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
1206 unsigned num_components = nir_dest_num_components(alu->dest.dest);
1207
1208 SpvId result = 0;
1209 switch (alu->op) {
1210 case nir_op_mov:
1211 assert(nir_op_infos[alu->op].num_inputs == 1);
1212 result = src[0];
1213 break;
1214
1215 #define UNOP(nir_op, spirv_op) \
1216 case nir_op: \
1217 assert(nir_op_infos[alu->op].num_inputs == 1); \
1218 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
1219 break;
1220
1221 UNOP(nir_op_ineg, SpvOpSNegate)
1222 UNOP(nir_op_fneg, SpvOpFNegate)
1223 UNOP(nir_op_fddx, SpvOpDPdx)
1224 UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
1225 UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
1226 UNOP(nir_op_fddy, SpvOpDPdy)
1227 UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
1228 UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
1229 UNOP(nir_op_f2i32, SpvOpConvertFToS)
1230 UNOP(nir_op_f2u32, SpvOpConvertFToU)
1231 UNOP(nir_op_i2f32, SpvOpConvertSToF)
1232 UNOP(nir_op_u2f32, SpvOpConvertUToF)
1233 UNOP(nir_op_bitfield_reverse, SpvOpBitReverse)
1234 #undef UNOP
1235
1236 case nir_op_inot:
1237 if (bit_size == 1)
1238 result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
1239 else
1240 result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
1241 break;
1242
1243 case nir_op_b2i32:
1244 assert(nir_op_infos[alu->op].num_inputs == 1);
1245 result = emit_select(ctx, dest_type, src[0],
1246 get_ivec_constant(ctx, 32, num_components, 1),
1247 get_ivec_constant(ctx, 32, num_components, 0));
1248 break;
1249
1250 case nir_op_b2f32:
1251 assert(nir_op_infos[alu->op].num_inputs == 1);
1252 result = emit_select(ctx, dest_type, src[0],
1253 get_fvec_constant(ctx, 32, num_components, 1),
1254 get_fvec_constant(ctx, 32, num_components, 0));
1255 break;
1256
1257 #define BUILTIN_UNOP(nir_op, spirv_op) \
1258 case nir_op: \
1259 assert(nir_op_infos[alu->op].num_inputs == 1); \
1260 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1261 break;
1262
1263 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1264 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1265 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1266 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1267 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1268 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1269 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1270 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1271 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1272 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1273 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1274 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1275 BUILTIN_UNOP(nir_op_isign, GLSLstd450SSign)
1276 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1277 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1278 #undef BUILTIN_UNOP
1279
1280 case nir_op_frcp:
1281 assert(nir_op_infos[alu->op].num_inputs == 1);
1282 result = emit_binop(ctx, SpvOpFDiv, dest_type,
1283 get_fvec_constant(ctx, bit_size, num_components, 1),
1284 src[0]);
1285 break;
1286
1287 case nir_op_f2b1:
1288 assert(nir_op_infos[alu->op].num_inputs == 1);
1289 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1290 get_fvec_constant(ctx,
1291 nir_src_bit_size(alu->src[0].src),
1292 num_components, 0));
1293 break;
1294 case nir_op_i2b1:
1295 assert(nir_op_infos[alu->op].num_inputs == 1);
1296 result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1297 get_ivec_constant(ctx,
1298 nir_src_bit_size(alu->src[0].src),
1299 num_components, 0));
1300 break;
1301
1302
1303 #define BINOP(nir_op, spirv_op) \
1304 case nir_op: \
1305 assert(nir_op_infos[alu->op].num_inputs == 2); \
1306 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1307 break;
1308
1309 BINOP(nir_op_iadd, SpvOpIAdd)
1310 BINOP(nir_op_isub, SpvOpISub)
1311 BINOP(nir_op_imul, SpvOpIMul)
1312 BINOP(nir_op_idiv, SpvOpSDiv)
1313 BINOP(nir_op_udiv, SpvOpUDiv)
1314 BINOP(nir_op_umod, SpvOpUMod)
1315 BINOP(nir_op_fadd, SpvOpFAdd)
1316 BINOP(nir_op_fsub, SpvOpFSub)
1317 BINOP(nir_op_fmul, SpvOpFMul)
1318 BINOP(nir_op_fdiv, SpvOpFDiv)
1319 BINOP(nir_op_fmod, SpvOpFMod)
1320 BINOP(nir_op_ilt, SpvOpSLessThan)
1321 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1322 BINOP(nir_op_ult, SpvOpULessThan)
1323 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1324 BINOP(nir_op_flt, SpvOpFOrdLessThan)
1325 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1326 BINOP(nir_op_feq, SpvOpFOrdEqual)
1327 BINOP(nir_op_fneu, SpvOpFUnordNotEqual)
1328 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1329 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1330 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1331 BINOP(nir_op_ixor, SpvOpBitwiseXor)
1332 #undef BINOP
1333
1334 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1335 case nir_op: \
1336 assert(nir_op_infos[alu->op].num_inputs == 2); \
1337 if (nir_src_bit_size(alu->src[0].src) == 1) \
1338 result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1339 else \
1340 result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1341 break;
1342
1343 BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1344 BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1345 BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1346 BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1347 #undef BINOP_LOG
1348
1349 #define BUILTIN_BINOP(nir_op, spirv_op) \
1350 case nir_op: \
1351 assert(nir_op_infos[alu->op].num_inputs == 2); \
1352 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1353 break;
1354
1355 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1356 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1357 BUILTIN_BINOP(nir_op_imin, GLSLstd450SMin)
1358 BUILTIN_BINOP(nir_op_imax, GLSLstd450SMax)
1359 BUILTIN_BINOP(nir_op_umin, GLSLstd450UMin)
1360 BUILTIN_BINOP(nir_op_umax, GLSLstd450UMax)
1361 #undef BUILTIN_BINOP
1362
1363 case nir_op_fdot2:
1364 case nir_op_fdot3:
1365 case nir_op_fdot4:
1366 assert(nir_op_infos[alu->op].num_inputs == 2);
1367 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1368 break;
1369
1370 case nir_op_fdph:
1371 unreachable("should already be lowered away");
1372
1373 case nir_op_seq:
1374 case nir_op_sne:
1375 case nir_op_slt:
1376 case nir_op_sge: {
1377 assert(nir_op_infos[alu->op].num_inputs == 2);
1378 int num_components = nir_dest_num_components(alu->dest.dest);
1379 SpvId bool_type = get_bvec_type(ctx, num_components);
1380
1381 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1382 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1383 if (num_components > 1) {
1384 SpvId zero_comps[num_components], one_comps[num_components];
1385 for (int i = 0; i < num_components; i++) {
1386 zero_comps[i] = zero;
1387 one_comps[i] = one;
1388 }
1389
1390 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1391 zero_comps, num_components);
1392 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1393 one_comps, num_components);
1394 }
1395
1396 SpvOp op;
1397 switch (alu->op) {
1398 case nir_op_seq: op = SpvOpFOrdEqual; break;
1399 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1400 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1401 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1402 default: unreachable("unexpected op");
1403 }
1404
1405 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1406 result = emit_select(ctx, dest_type, result, one, zero);
1407 }
1408 break;
1409
1410 case nir_op_flrp:
1411 assert(nir_op_infos[alu->op].num_inputs == 3);
1412 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1413 src[0], src[1], src[2]);
1414 break;
1415
1416 case nir_op_fcsel:
1417 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1418 get_bvec_type(ctx, num_components),
1419 src[0],
1420 get_fvec_constant(ctx,
1421 nir_src_bit_size(alu->src[0].src),
1422 num_components, 0));
1423 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1424 break;
1425
1426 case nir_op_bcsel:
1427 assert(nir_op_infos[alu->op].num_inputs == 3);
1428 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1429 break;
1430
1431 case nir_op_bany_fnequal2:
1432 case nir_op_bany_fnequal3:
1433 case nir_op_bany_fnequal4: {
1434 assert(nir_op_infos[alu->op].num_inputs == 2);
1435 assert(alu_instr_src_components(alu, 0) ==
1436 alu_instr_src_components(alu, 1));
1437 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1438 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1439 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1440 result = emit_binop(ctx, op,
1441 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1442 src[0], src[1]);
1443 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1444 break;
1445 }
1446
1447 case nir_op_ball_fequal2:
1448 case nir_op_ball_fequal3:
1449 case nir_op_ball_fequal4: {
1450 assert(nir_op_infos[alu->op].num_inputs == 2);
1451 assert(alu_instr_src_components(alu, 0) ==
1452 alu_instr_src_components(alu, 1));
1453 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1454 /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1455 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1456 result = emit_binop(ctx, op,
1457 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1458 src[0], src[1]);
1459 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1460 break;
1461 }
1462
1463 case nir_op_bany_inequal2:
1464 case nir_op_bany_inequal3:
1465 case nir_op_bany_inequal4: {
1466 assert(nir_op_infos[alu->op].num_inputs == 2);
1467 assert(alu_instr_src_components(alu, 0) ==
1468 alu_instr_src_components(alu, 1));
1469 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1470 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1471 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1472 result = emit_binop(ctx, op,
1473 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1474 src[0], src[1]);
1475 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1476 break;
1477 }
1478
1479 case nir_op_ball_iequal2:
1480 case nir_op_ball_iequal3:
1481 case nir_op_ball_iequal4: {
1482 assert(nir_op_infos[alu->op].num_inputs == 2);
1483 assert(alu_instr_src_components(alu, 0) ==
1484 alu_instr_src_components(alu, 1));
1485 assert(in_bit_sizes[0] == in_bit_sizes[1]);
1486 /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1487 SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1488 result = emit_binop(ctx, op,
1489 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1490 src[0], src[1]);
1491 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1492 break;
1493 }
1494
1495 case nir_op_vec2:
1496 case nir_op_vec3:
1497 case nir_op_vec4: {
1498 int num_inputs = nir_op_infos[alu->op].num_inputs;
1499 assert(2 <= num_inputs && num_inputs <= 4);
1500 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1501 src, num_inputs);
1502 }
1503 break;
1504
1505 default:
1506 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1507 nir_op_infos[alu->op].name);
1508
1509 unreachable("unsupported opcode");
1510 return;
1511 }
1512
1513 store_alu_result(ctx, alu, result);
1514 }
1515
1516 static void
1517 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1518 {
1519 unsigned bit_size = load_const->def.bit_size;
1520 unsigned num_components = load_const->def.num_components;
1521
1522 SpvId constant;
1523 if (num_components > 1) {
1524 SpvId components[num_components];
1525 SpvId type = get_vec_from_bit_size(ctx, bit_size, num_components);
1526 if (bit_size == 1) {
1527 for (int i = 0; i < num_components; i++)
1528 components[i] = spirv_builder_const_bool(&ctx->builder,
1529 load_const->value[i].b);
1530
1531 } else {
1532 for (int i = 0; i < num_components; i++)
1533 components[i] = emit_uint_const(ctx, bit_size,
1534 load_const->value[i].u32);
1535
1536 }
1537 constant = spirv_builder_const_composite(&ctx->builder, type,
1538 components, num_components);
1539 } else {
1540 assert(num_components == 1);
1541 if (bit_size == 1)
1542 constant = spirv_builder_const_bool(&ctx->builder,
1543 load_const->value[0].b);
1544 else
1545 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1546 }
1547
1548 store_ssa_def(ctx, &load_const->def, constant);
1549 }
1550
1551 static void
1552 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1553 {
1554 ASSERTED nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1555 assert(const_block_index); // no dynamic indexing for now
1556 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1557
1558 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1559 if (const_offset) {
1560 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1561 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1562 SpvStorageClassUniform,
1563 uvec4_type);
1564
1565 unsigned idx = const_offset->u32;
1566 SpvId member = emit_uint_const(ctx, 32, 0);
1567 SpvId offset = emit_uint_const(ctx, 32, idx);
1568 SpvId offsets[] = { member, offset };
1569 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1570 ctx->ubos[0], offsets,
1571 ARRAY_SIZE(offsets));
1572 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1573
1574 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1575 unsigned num_components = nir_dest_num_components(intr->dest);
1576 if (num_components == 1) {
1577 uint32_t components[] = { 0 };
1578 result = spirv_builder_emit_composite_extract(&ctx->builder,
1579 type,
1580 result, components,
1581 1);
1582 } else if (num_components < 4) {
1583 SpvId constituents[num_components];
1584 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1585 for (uint32_t i = 0; i < num_components; ++i)
1586 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1587 uint_type,
1588 result, &i,
1589 1);
1590
1591 result = spirv_builder_emit_composite_construct(&ctx->builder,
1592 type,
1593 constituents,
1594 num_components);
1595 }
1596
1597 if (nir_dest_bit_size(intr->dest) == 1)
1598 result = uvec_to_bvec(ctx, result, num_components);
1599
1600 store_dest(ctx, &intr->dest, result, nir_type_uint);
1601 } else
1602 unreachable("uniform-addressing not yet supported");
1603 }
1604
1605 static void
1606 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1607 {
1608 assert(ctx->block_started);
1609 spirv_builder_emit_kill(&ctx->builder);
1610 /* discard is weird in NIR, so let's just create an unreachable block after
1611 it and hope that the vulkan driver will DCE any instructinos in it. */
1612 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1613 }
1614
1615 static void
1616 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1617 {
1618 SpvId ptr = get_src(ctx, intr->src);
1619
1620 SpvId result = spirv_builder_emit_load(&ctx->builder,
1621 get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type),
1622 ptr);
1623 unsigned num_components = nir_dest_num_components(intr->dest);
1624 unsigned bit_size = nir_dest_bit_size(intr->dest);
1625 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1626 store_dest(ctx, &intr->dest, result, nir_type_uint);
1627 }
1628
1629 static void
1630 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1631 {
1632 SpvId ptr = get_src(ctx, &intr->src[0]);
1633 SpvId src = get_src(ctx, &intr->src[1]);
1634
1635 SpvId type = get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type);
1636 SpvId result = emit_bitcast(ctx, type, src);
1637 spirv_builder_emit_store(&ctx->builder, ptr, result);
1638 }
1639
1640 static SpvId
1641 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1642 SpvStorageClass storage_class,
1643 const char *name, SpvBuiltIn builtin)
1644 {
1645 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1646 storage_class,
1647 var_type);
1648 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1649 storage_class);
1650 spirv_builder_emit_name(&ctx->builder, var, name);
1651 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1652
1653 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1654 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1655 return var;
1656 }
1657
1658 static void
1659 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1660 {
1661 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1662 if (!ctx->front_face_var)
1663 ctx->front_face_var = create_builtin_var(ctx, var_type,
1664 SpvStorageClassInput,
1665 "gl_FrontFacing",
1666 SpvBuiltInFrontFacing);
1667
1668 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1669 ctx->front_face_var);
1670 assert(1 == nir_dest_num_components(intr->dest));
1671 store_dest(ctx, &intr->dest, result, nir_type_bool);
1672 }
1673
1674 static void
1675 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1676 {
1677 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1678 if (!ctx->instance_id_var)
1679 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1680 SpvStorageClassInput,
1681 "gl_InstanceId",
1682 SpvBuiltInInstanceIndex);
1683
1684 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1685 ctx->instance_id_var);
1686 assert(1 == nir_dest_num_components(intr->dest));
1687 store_dest(ctx, &intr->dest, result, nir_type_uint);
1688 }
1689
1690 static void
1691 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1692 {
1693 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1694 if (!ctx->vertex_id_var)
1695 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1696 SpvStorageClassInput,
1697 "gl_VertexID",
1698 SpvBuiltInVertexIndex);
1699
1700 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1701 ctx->vertex_id_var);
1702 assert(1 == nir_dest_num_components(intr->dest));
1703 store_dest(ctx, &intr->dest, result, nir_type_uint);
1704 }
1705
1706 static void
1707 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1708 {
1709 switch (intr->intrinsic) {
1710 case nir_intrinsic_load_ubo:
1711 emit_load_ubo(ctx, intr);
1712 break;
1713
1714 case nir_intrinsic_discard:
1715 emit_discard(ctx, intr);
1716 break;
1717
1718 case nir_intrinsic_load_deref:
1719 emit_load_deref(ctx, intr);
1720 break;
1721
1722 case nir_intrinsic_store_deref:
1723 emit_store_deref(ctx, intr);
1724 break;
1725
1726 case nir_intrinsic_load_front_face:
1727 emit_load_front_face(ctx, intr);
1728 break;
1729
1730 case nir_intrinsic_load_instance_id:
1731 emit_load_instance_id(ctx, intr);
1732 break;
1733
1734 case nir_intrinsic_load_vertex_id:
1735 emit_load_vertex_id(ctx, intr);
1736 break;
1737
1738 default:
1739 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1740 nir_intrinsic_infos[intr->intrinsic].name);
1741 unreachable("unsupported intrinsic");
1742 }
1743 }
1744
1745 static void
1746 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1747 {
1748 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1749 undef->def.num_components);
1750
1751 store_ssa_def(ctx, &undef->def,
1752 spirv_builder_emit_undef(&ctx->builder, type));
1753 }
1754
1755 static SpvId
1756 get_src_float(struct ntv_context *ctx, nir_src *src)
1757 {
1758 SpvId def = get_src(ctx, src);
1759 unsigned num_components = nir_src_num_components(*src);
1760 unsigned bit_size = nir_src_bit_size(*src);
1761 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1762 }
1763
1764 static SpvId
1765 get_src_int(struct ntv_context *ctx, nir_src *src)
1766 {
1767 SpvId def = get_src(ctx, src);
1768 unsigned num_components = nir_src_num_components(*src);
1769 unsigned bit_size = nir_src_bit_size(*src);
1770 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1771 }
1772
1773 static inline bool
1774 tex_instr_is_lod_allowed(nir_tex_instr *tex)
1775 {
1776 /* This can only be used with an OpTypeImage that has a Dim operand of 1D, 2D, 3D, or Cube
1777 * - SPIR-V: 3.14. Image Operands
1778 */
1779
1780 return (tex->sampler_dim == GLSL_SAMPLER_DIM_1D ||
1781 tex->sampler_dim == GLSL_SAMPLER_DIM_2D ||
1782 tex->sampler_dim == GLSL_SAMPLER_DIM_3D ||
1783 tex->sampler_dim == GLSL_SAMPLER_DIM_CUBE);
1784 }
1785
1786 static SpvId
1787 pad_coord_vector(struct ntv_context *ctx, SpvId orig, unsigned old_size, unsigned new_size)
1788 {
1789 SpvId int_type = spirv_builder_type_int(&ctx->builder, 32);
1790 SpvId type = get_ivec_type(ctx, 32, new_size);
1791 SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
1792 SpvId zero = emit_int_const(ctx, 32, 0);
1793 assert(new_size < NIR_MAX_VEC_COMPONENTS);
1794
1795 if (old_size == 1)
1796 constituents[0] = orig;
1797 else {
1798 for (unsigned i = 0; i < old_size; i++)
1799 constituents[i] = spirv_builder_emit_vector_extract(&ctx->builder, int_type, orig, i);
1800 }
1801
1802 for (unsigned i = old_size; i < new_size; i++)
1803 constituents[i] = zero;
1804
1805 return spirv_builder_emit_composite_construct(&ctx->builder, type,
1806 constituents, new_size);
1807 }
1808
1809 static void
1810 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1811 {
1812 assert(tex->op == nir_texop_tex ||
1813 tex->op == nir_texop_txb ||
1814 tex->op == nir_texop_txl ||
1815 tex->op == nir_texop_txd ||
1816 tex->op == nir_texop_txf ||
1817 tex->op == nir_texop_txf_ms ||
1818 tex->op == nir_texop_txs);
1819 assert(tex->texture_index == tex->sampler_index);
1820
1821 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1822 offset = 0, sample = 0;
1823 unsigned coord_components = 0, coord_bitsize = 0, offset_components = 0;
1824 for (unsigned i = 0; i < tex->num_srcs; i++) {
1825 switch (tex->src[i].src_type) {
1826 case nir_tex_src_coord:
1827 if (tex->op == nir_texop_txf ||
1828 tex->op == nir_texop_txf_ms)
1829 coord = get_src_int(ctx, &tex->src[i].src);
1830 else
1831 coord = get_src_float(ctx, &tex->src[i].src);
1832 coord_components = nir_src_num_components(tex->src[i].src);
1833 coord_bitsize = nir_src_bit_size(tex->src[i].src);
1834 break;
1835
1836 case nir_tex_src_projector:
1837 assert(nir_src_num_components(tex->src[i].src) == 1);
1838 proj = get_src_float(ctx, &tex->src[i].src);
1839 assert(proj != 0);
1840 break;
1841
1842 case nir_tex_src_offset:
1843 offset = get_src_int(ctx, &tex->src[i].src);
1844 offset_components = nir_src_num_components(tex->src[i].src);
1845 break;
1846
1847 case nir_tex_src_bias:
1848 assert(tex->op == nir_texop_txb);
1849 bias = get_src_float(ctx, &tex->src[i].src);
1850 assert(bias != 0);
1851 break;
1852
1853 case nir_tex_src_lod:
1854 assert(nir_src_num_components(tex->src[i].src) == 1);
1855 if (tex->op == nir_texop_txf ||
1856 tex->op == nir_texop_txf_ms ||
1857 tex->op == nir_texop_txs)
1858 lod = get_src_int(ctx, &tex->src[i].src);
1859 else
1860 lod = get_src_float(ctx, &tex->src[i].src);
1861 assert(lod != 0);
1862 break;
1863
1864 case nir_tex_src_ms_index:
1865 assert(nir_src_num_components(tex->src[i].src) == 1);
1866 sample = get_src_int(ctx, &tex->src[i].src);
1867 break;
1868
1869 case nir_tex_src_comparator:
1870 assert(nir_src_num_components(tex->src[i].src) == 1);
1871 dref = get_src_float(ctx, &tex->src[i].src);
1872 assert(dref != 0);
1873 break;
1874
1875 case nir_tex_src_ddx:
1876 dx = get_src_float(ctx, &tex->src[i].src);
1877 assert(dx != 0);
1878 break;
1879
1880 case nir_tex_src_ddy:
1881 dy = get_src_float(ctx, &tex->src[i].src);
1882 assert(dy != 0);
1883 break;
1884
1885 default:
1886 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1887 unreachable("unknown texture source");
1888 }
1889 }
1890
1891 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1892 lod = emit_float_const(ctx, 32, 0.0f);
1893 assert(lod != 0);
1894 }
1895
1896 SpvId image_type = ctx->image_types[tex->texture_index];
1897 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1898 image_type);
1899
1900 assert(ctx->samplers_used & (1u << tex->texture_index));
1901 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1902 ctx->samplers[tex->texture_index]);
1903
1904 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1905
1906 if (!tex_instr_is_lod_allowed(tex))
1907 lod = 0;
1908 if (tex->op == nir_texop_txs) {
1909 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1910 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1911 dest_type, image,
1912 lod);
1913 store_dest(ctx, &tex->dest, result, tex->dest_type);
1914 return;
1915 }
1916
1917 if (proj && coord_components > 0) {
1918 SpvId constituents[coord_components + 1];
1919 if (coord_components == 1)
1920 constituents[0] = coord;
1921 else {
1922 assert(coord_components > 1);
1923 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1924 for (uint32_t i = 0; i < coord_components; ++i)
1925 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1926 float_type,
1927 coord,
1928 &i, 1);
1929 }
1930
1931 constituents[coord_components++] = proj;
1932
1933 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1934 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1935 vec_type,
1936 constituents,
1937 coord_components);
1938 }
1939
1940 SpvId actual_dest_type = dest_type;
1941 if (dref)
1942 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1943
1944 SpvId result;
1945 if (tex->op == nir_texop_txf ||
1946 tex->op == nir_texop_txf_ms) {
1947 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1948 if (offset) {
1949 /* SPIRV requires matched length vectors for OpIAdd, so if a shader
1950 * uses vecs of differing sizes we need to make a new vec padded with zeroes
1951 * to mimic how GLSL does this implicitly
1952 */
1953 if (offset_components > coord_components)
1954 coord = pad_coord_vector(ctx, coord, coord_components, offset_components);
1955 else if (coord_components > offset_components)
1956 offset = pad_coord_vector(ctx, offset, offset_components, coord_components);
1957 coord = emit_binop(ctx, SpvOpIAdd,
1958 get_ivec_type(ctx, coord_bitsize, coord_components),
1959 coord, offset);
1960 }
1961 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1962 image, coord, lod, sample);
1963 } else {
1964 result = spirv_builder_emit_image_sample(&ctx->builder,
1965 actual_dest_type, load,
1966 coord,
1967 proj != 0,
1968 lod, bias, dref, dx, dy,
1969 offset);
1970 }
1971
1972 spirv_builder_emit_decoration(&ctx->builder, result,
1973 SpvDecorationRelaxedPrecision);
1974
1975 if (dref && nir_dest_num_components(tex->dest) > 1) {
1976 SpvId components[4] = { result, result, result, result };
1977 result = spirv_builder_emit_composite_construct(&ctx->builder,
1978 dest_type,
1979 components,
1980 4);
1981 }
1982
1983 store_dest(ctx, &tex->dest, result, tex->dest_type);
1984 }
1985
1986 static void
1987 start_block(struct ntv_context *ctx, SpvId label)
1988 {
1989 /* terminate previous block if needed */
1990 if (ctx->block_started)
1991 spirv_builder_emit_branch(&ctx->builder, label);
1992
1993 /* start new block */
1994 spirv_builder_label(&ctx->builder, label);
1995 ctx->block_started = true;
1996 }
1997
1998 static void
1999 branch(struct ntv_context *ctx, SpvId label)
2000 {
2001 assert(ctx->block_started);
2002 spirv_builder_emit_branch(&ctx->builder, label);
2003 ctx->block_started = false;
2004 }
2005
2006 static void
2007 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
2008 SpvId else_id)
2009 {
2010 assert(ctx->block_started);
2011 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
2012 then_id, else_id);
2013 ctx->block_started = false;
2014 }
2015
2016 static void
2017 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
2018 {
2019 switch (jump->type) {
2020 case nir_jump_break:
2021 assert(ctx->loop_break);
2022 branch(ctx, ctx->loop_break);
2023 break;
2024
2025 case nir_jump_continue:
2026 assert(ctx->loop_cont);
2027 branch(ctx, ctx->loop_cont);
2028 break;
2029
2030 default:
2031 unreachable("Unsupported jump type\n");
2032 }
2033 }
2034
2035 static void
2036 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
2037 {
2038 assert(deref->deref_type == nir_deref_type_var);
2039
2040 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
2041 assert(he);
2042 SpvId result = (SpvId)(intptr_t)he->data;
2043 store_dest_raw(ctx, &deref->dest, result);
2044 }
2045
2046 static void
2047 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
2048 {
2049 assert(deref->deref_type == nir_deref_type_array);
2050 nir_variable *var = nir_deref_instr_get_variable(deref);
2051
2052 SpvStorageClass storage_class;
2053 switch (var->data.mode) {
2054 case nir_var_shader_in:
2055 storage_class = SpvStorageClassInput;
2056 break;
2057
2058 case nir_var_shader_out:
2059 storage_class = SpvStorageClassOutput;
2060 break;
2061
2062 default:
2063 unreachable("Unsupported nir_variable_mode\n");
2064 }
2065
2066 SpvId index = get_src(ctx, &deref->arr.index);
2067
2068 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
2069 storage_class,
2070 get_glsl_type(ctx, deref->type));
2071
2072 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
2073 ptr_type,
2074 get_src(ctx, &deref->parent),
2075 &index, 1);
2076 /* uint is a bit of a lie here, it's really just an opaque type */
2077 store_dest(ctx, &deref->dest, result, nir_type_uint);
2078 }
2079
2080 static void
2081 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
2082 {
2083 switch (deref->deref_type) {
2084 case nir_deref_type_var:
2085 emit_deref_var(ctx, deref);
2086 break;
2087
2088 case nir_deref_type_array:
2089 emit_deref_array(ctx, deref);
2090 break;
2091
2092 default:
2093 unreachable("unexpected deref_type");
2094 }
2095 }
2096
2097 static void
2098 emit_block(struct ntv_context *ctx, struct nir_block *block)
2099 {
2100 start_block(ctx, block_label(ctx, block));
2101 nir_foreach_instr(instr, block) {
2102 switch (instr->type) {
2103 case nir_instr_type_alu:
2104 emit_alu(ctx, nir_instr_as_alu(instr));
2105 break;
2106 case nir_instr_type_intrinsic:
2107 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
2108 break;
2109 case nir_instr_type_load_const:
2110 emit_load_const(ctx, nir_instr_as_load_const(instr));
2111 break;
2112 case nir_instr_type_ssa_undef:
2113 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
2114 break;
2115 case nir_instr_type_tex:
2116 emit_tex(ctx, nir_instr_as_tex(instr));
2117 break;
2118 case nir_instr_type_phi:
2119 unreachable("nir_instr_type_phi not supported");
2120 break;
2121 case nir_instr_type_jump:
2122 emit_jump(ctx, nir_instr_as_jump(instr));
2123 break;
2124 case nir_instr_type_call:
2125 unreachable("nir_instr_type_call not supported");
2126 break;
2127 case nir_instr_type_parallel_copy:
2128 unreachable("nir_instr_type_parallel_copy not supported");
2129 break;
2130 case nir_instr_type_deref:
2131 emit_deref(ctx, nir_instr_as_deref(instr));
2132 break;
2133 }
2134 }
2135 }
2136
2137 static void
2138 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
2139
2140 static SpvId
2141 get_src_bool(struct ntv_context *ctx, nir_src *src)
2142 {
2143 assert(nir_src_bit_size(*src) == 1);
2144 return get_src(ctx, src);
2145 }
2146
2147 static void
2148 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
2149 {
2150 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
2151
2152 SpvId header_id = spirv_builder_new_id(&ctx->builder);
2153 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
2154 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
2155 SpvId else_id = endif_id;
2156
2157 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
2158 if (has_else) {
2159 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
2160 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
2161 }
2162
2163 /* create a header-block */
2164 start_block(ctx, header_id);
2165 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
2166 SpvSelectionControlMaskNone);
2167 branch_conditional(ctx, condition, then_id, else_id);
2168
2169 emit_cf_list(ctx, &if_stmt->then_list);
2170
2171 if (has_else) {
2172 if (ctx->block_started)
2173 branch(ctx, endif_id);
2174
2175 emit_cf_list(ctx, &if_stmt->else_list);
2176 }
2177
2178 start_block(ctx, endif_id);
2179 }
2180
2181 static void
2182 emit_loop(struct ntv_context *ctx, nir_loop *loop)
2183 {
2184 SpvId header_id = spirv_builder_new_id(&ctx->builder);
2185 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
2186 SpvId break_id = spirv_builder_new_id(&ctx->builder);
2187 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
2188
2189 /* create a header-block */
2190 start_block(ctx, header_id);
2191 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
2192 branch(ctx, begin_id);
2193
2194 SpvId save_break = ctx->loop_break;
2195 SpvId save_cont = ctx->loop_cont;
2196 ctx->loop_break = break_id;
2197 ctx->loop_cont = cont_id;
2198
2199 emit_cf_list(ctx, &loop->body);
2200
2201 ctx->loop_break = save_break;
2202 ctx->loop_cont = save_cont;
2203
2204 /* loop->body may have already ended our block */
2205 if (ctx->block_started)
2206 branch(ctx, cont_id);
2207 start_block(ctx, cont_id);
2208 branch(ctx, header_id);
2209
2210 start_block(ctx, break_id);
2211 }
2212
2213 static void
2214 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
2215 {
2216 foreach_list_typed(nir_cf_node, node, node, list) {
2217 switch (node->type) {
2218 case nir_cf_node_block:
2219 emit_block(ctx, nir_cf_node_as_block(node));
2220 break;
2221
2222 case nir_cf_node_if:
2223 emit_if(ctx, nir_cf_node_as_if(node));
2224 break;
2225
2226 case nir_cf_node_loop:
2227 emit_loop(ctx, nir_cf_node_as_loop(node));
2228 break;
2229
2230 case nir_cf_node_function:
2231 unreachable("nir_cf_node_function not supported");
2232 break;
2233 }
2234 }
2235 }
2236
2237 struct spirv_shader *
2238 nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info)
2239 {
2240 struct spirv_shader *ret = NULL;
2241
2242 struct ntv_context ctx = {};
2243 ctx.mem_ctx = ralloc_context(NULL);
2244 ctx.builder.mem_ctx = ctx.mem_ctx;
2245
2246 switch (s->info.stage) {
2247 case MESA_SHADER_VERTEX:
2248 case MESA_SHADER_FRAGMENT:
2249 case MESA_SHADER_COMPUTE:
2250 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
2251 break;
2252
2253 case MESA_SHADER_TESS_CTRL:
2254 case MESA_SHADER_TESS_EVAL:
2255 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
2256 break;
2257
2258 case MESA_SHADER_GEOMETRY:
2259 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
2260 break;
2261
2262 default:
2263 unreachable("invalid stage");
2264 }
2265
2266 // TODO: only enable when needed
2267 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2268 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
2269 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
2270 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
2271 }
2272
2273 ctx.stage = s->info.stage;
2274 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
2275 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
2276
2277 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
2278 SpvMemoryModelGLSL450);
2279
2280 SpvExecutionModel exec_model;
2281 switch (s->info.stage) {
2282 case MESA_SHADER_VERTEX:
2283 exec_model = SpvExecutionModelVertex;
2284 break;
2285 case MESA_SHADER_TESS_CTRL:
2286 exec_model = SpvExecutionModelTessellationControl;
2287 break;
2288 case MESA_SHADER_TESS_EVAL:
2289 exec_model = SpvExecutionModelTessellationEvaluation;
2290 break;
2291 case MESA_SHADER_GEOMETRY:
2292 exec_model = SpvExecutionModelGeometry;
2293 break;
2294 case MESA_SHADER_FRAGMENT:
2295 exec_model = SpvExecutionModelFragment;
2296 break;
2297 case MESA_SHADER_COMPUTE:
2298 exec_model = SpvExecutionModelGLCompute;
2299 break;
2300 default:
2301 unreachable("invalid stage");
2302 }
2303
2304 SpvId type_void = spirv_builder_type_void(&ctx.builder);
2305 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
2306 NULL, 0);
2307 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
2308 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2309
2310 ctx.vars = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_pointer,
2311 _mesa_key_pointer_equal);
2312
2313 ctx.so_outputs = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_u32,
2314 _mesa_key_u32_equal);
2315
2316 nir_foreach_shader_in_variable(var, s)
2317 emit_input(&ctx, var);
2318
2319 nir_foreach_shader_out_variable(var, s)
2320 emit_output(&ctx, var);
2321
2322 if (so_info)
2323 emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info);
2324 nir_foreach_variable_with_modes(var, s, nir_var_uniform |
2325 nir_var_mem_ubo |
2326 nir_var_mem_ssbo)
2327 emit_uniform(&ctx, var);
2328
2329 if (s->info.stage == MESA_SHADER_FRAGMENT) {
2330 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2331 SpvExecutionModeOriginUpperLeft);
2332 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2333 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2334 SpvExecutionModeDepthReplacing);
2335 }
2336
2337 if (so_info && so_info->so_info.num_outputs) {
2338 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTransformFeedback);
2339 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2340 SpvExecutionModeXfb);
2341 }
2342
2343 spirv_builder_function(&ctx.builder, entry_point, type_void,
2344 SpvFunctionControlMaskNone,
2345 type_main);
2346
2347 nir_function_impl *entry = nir_shader_get_entrypoint(s);
2348 nir_metadata_require(entry, nir_metadata_block_index);
2349
2350 ctx.defs = ralloc_array_size(ctx.mem_ctx,
2351 sizeof(SpvId), entry->ssa_alloc);
2352 if (!ctx.defs)
2353 goto fail;
2354 ctx.num_defs = entry->ssa_alloc;
2355
2356 nir_index_local_regs(entry);
2357 ctx.regs = ralloc_array_size(ctx.mem_ctx,
2358 sizeof(SpvId), entry->reg_alloc);
2359 if (!ctx.regs)
2360 goto fail;
2361 ctx.num_regs = entry->reg_alloc;
2362
2363 SpvId *block_ids = ralloc_array_size(ctx.mem_ctx,
2364 sizeof(SpvId), entry->num_blocks);
2365 if (!block_ids)
2366 goto fail;
2367
2368 for (int i = 0; i < entry->num_blocks; ++i)
2369 block_ids[i] = spirv_builder_new_id(&ctx.builder);
2370
2371 ctx.block_ids = block_ids;
2372 ctx.num_blocks = entry->num_blocks;
2373
2374 /* emit a block only for the variable declarations */
2375 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2376 foreach_list_typed(nir_register, reg, node, &entry->registers) {
2377 SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components);
2378 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2379 SpvStorageClassFunction,
2380 type);
2381 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2382 SpvStorageClassFunction);
2383
2384 ctx.regs[reg->index] = var;
2385 }
2386
2387 emit_cf_list(&ctx, &entry->body);
2388
2389 if (so_info)
2390 emit_so_outputs(&ctx, so_info);
2391
2392 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2393 spirv_builder_function_end(&ctx.builder);
2394
2395 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2396 "main", ctx.entry_ifaces,
2397 ctx.num_entry_ifaces);
2398
2399 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2400
2401 ret = CALLOC_STRUCT(spirv_shader);
2402 if (!ret)
2403 goto fail;
2404
2405 ret->words = MALLOC(sizeof(uint32_t) * num_words);
2406 if (!ret->words)
2407 goto fail;
2408
2409 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2410 assert(ret->num_words == num_words);
2411
2412 ralloc_free(ctx.mem_ctx);
2413
2414 return ret;
2415
2416 fail:
2417 ralloc_free(ctx.mem_ctx);
2418
2419 if (ret)
2420 spirv_shader_delete(ret);
2421
2422 return NULL;
2423 }
2424
2425 void
2426 spirv_shader_delete(struct spirv_shader *s)
2427 {
2428 FREE(s->words);
2429 FREE(s);
2430 }