34bb14379c4b2b963703c04d557a3790f0473e2d
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId image_types[PIPE_MAX_SAMPLERS];
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var, instance_id_var, vertex_id_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
172 nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0)
238 spirv_builder_emit_location(&ctx->builder, var_id,
239 var->data.location -
240 VARYING_SLOT_VAR0 +
241 VARYING_SLOT_TEX0);
242 else if ((var->data.location >= VARYING_SLOT_COL0 &&
243 var->data.location <= VARYING_SLOT_TEX7) ||
244 var->data.location == VARYING_SLOT_BFC0 ||
245 var->data.location == VARYING_SLOT_BFC1) {
246 spirv_builder_emit_location(&ctx->builder, var_id,
247 var->data.location);
248 } else {
249 switch (var->data.location) {
250 case VARYING_SLOT_POS:
251 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
252 break;
253
254 case VARYING_SLOT_PNTC:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
256 break;
257
258 default:
259 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
260 unreachable("unexpected varying slot");
261 }
262 }
263 } else {
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267
268 if (var->data.location_frac)
269 spirv_builder_emit_component(&ctx->builder, var_id,
270 var->data.location_frac);
271
272 if (var->data.interpolation == INTERP_MODE_FLAT)
273 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
274
275 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
276
277 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
278 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
279 }
280
281 static void
282 emit_output(struct ntv_context *ctx, struct nir_variable *var)
283 {
284 SpvId var_type = get_glsl_type(ctx, var->type);
285 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
286 SpvStorageClassOutput,
287 var_type);
288 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
289 SpvStorageClassOutput);
290 if (var->name)
291 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
292
293
294 if (ctx->stage == MESA_SHADER_VERTEX) {
295 if (var->data.location >= VARYING_SLOT_VAR0)
296 spirv_builder_emit_location(&ctx->builder, var_id,
297 var->data.location -
298 VARYING_SLOT_VAR0 +
299 VARYING_SLOT_TEX0);
300 else if ((var->data.location >= VARYING_SLOT_COL0 &&
301 var->data.location <= VARYING_SLOT_TEX7) ||
302 var->data.location == VARYING_SLOT_BFC0 ||
303 var->data.location == VARYING_SLOT_BFC1) {
304 spirv_builder_emit_location(&ctx->builder, var_id,
305 var->data.location);
306 } else {
307 switch (var->data.location) {
308 case VARYING_SLOT_POS:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
310 break;
311
312 case VARYING_SLOT_PSIZ:
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
314 break;
315
316 case VARYING_SLOT_CLIP_DIST0:
317 assert(glsl_type_is_array(var->type));
318 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
319 break;
320
321 default:
322 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
323 unreachable("unexpected varying slot");
324 }
325 }
326 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
327 if (var->data.location >= FRAG_RESULT_DATA0)
328 spirv_builder_emit_location(&ctx->builder, var_id,
329 var->data.location - FRAG_RESULT_DATA0);
330 else {
331 switch (var->data.location) {
332 case FRAG_RESULT_COLOR:
333 spirv_builder_emit_location(&ctx->builder, var_id, 0);
334 break;
335
336 case FRAG_RESULT_DEPTH:
337 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
338 break;
339
340 default:
341 spirv_builder_emit_location(&ctx->builder, var_id,
342 var->data.driver_location);
343 }
344 }
345 }
346
347 if (var->data.location_frac)
348 spirv_builder_emit_component(&ctx->builder, var_id,
349 var->data.location_frac);
350
351 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
352
353 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
354 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
355 }
356
357 static SpvDim
358 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
359 {
360 *is_ms = false;
361 switch (gdim) {
362 case GLSL_SAMPLER_DIM_1D:
363 return SpvDim1D;
364 case GLSL_SAMPLER_DIM_2D:
365 return SpvDim2D;
366 case GLSL_SAMPLER_DIM_3D:
367 return SpvDim3D;
368 case GLSL_SAMPLER_DIM_CUBE:
369 return SpvDimCube;
370 case GLSL_SAMPLER_DIM_RECT:
371 return SpvDimRect;
372 case GLSL_SAMPLER_DIM_BUF:
373 return SpvDimBuffer;
374 case GLSL_SAMPLER_DIM_EXTERNAL:
375 return SpvDim2D; /* seems dodgy... */
376 case GLSL_SAMPLER_DIM_MS:
377 *is_ms = true;
378 return SpvDim2D;
379 default:
380 fprintf(stderr, "unknown sampler type %d\n", gdim);
381 break;
382 }
383 return SpvDim2D;
384 }
385
386 static void
387 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
388 {
389 const struct glsl_type *type = glsl_without_array(var->type);
390
391 bool is_ms;
392 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
393
394 SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
395 SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
396 dimension, false,
397 glsl_sampler_type_is_array(type),
398 is_ms, 1,
399 SpvImageFormatUnknown);
400
401 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
402 image_type);
403 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
404 SpvStorageClassUniformConstant,
405 sampled_type);
406
407 if (glsl_type_is_array(var->type)) {
408 for (int i = 0; i < glsl_get_length(var->type); ++i) {
409 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
410 SpvStorageClassUniformConstant);
411
412 if (var->name) {
413 char element_name[100];
414 snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
415 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
416 }
417
418 assert(ctx->num_samplers < ARRAY_SIZE(ctx->image_types));
419 ctx->image_types[ctx->num_samplers] = image_type;
420
421 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
422 ctx->samplers[ctx->num_samplers++] = var_id;
423
424 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
425 var->data.descriptor_set);
426 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
427 }
428 } else {
429 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
430 SpvStorageClassUniformConstant);
431
432 if (var->name)
433 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
434
435 assert(ctx->num_samplers < ARRAY_SIZE(ctx->image_types));
436 ctx->image_types[ctx->num_samplers] = image_type;
437
438 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
439 ctx->samplers[ctx->num_samplers++] = var_id;
440
441 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
442 var->data.descriptor_set);
443 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
444 }
445 }
446
447 static void
448 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
449 {
450 uint32_t size = glsl_count_attribute_slots(var->type, false);
451 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
452 SpvId array_length = emit_uint_const(ctx, 32, size);
453 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
454 array_length);
455 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
456
457 // wrap UBO-array in a struct
458 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
459 if (var->name) {
460 char struct_name[100];
461 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
462 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
463 }
464
465 spirv_builder_emit_decoration(&ctx->builder, struct_type,
466 SpvDecorationBlock);
467 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
468
469
470 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
471 SpvStorageClassUniform,
472 struct_type);
473
474 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
475 SpvStorageClassUniform);
476 if (var->name)
477 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
478
479 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
480 ctx->ubos[ctx->num_ubos++] = var_id;
481
482 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
483 var->data.descriptor_set);
484 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
485 }
486
487 static void
488 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
489 {
490 if (var->data.mode == nir_var_mem_ubo)
491 emit_ubo(ctx, var);
492 else {
493 assert(var->data.mode == nir_var_uniform);
494 if (glsl_type_is_sampler(glsl_without_array(var->type)))
495 emit_sampler(ctx, var);
496 }
497 }
498
499 static SpvId
500 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
501 {
502 assert(ssa->index < ctx->num_defs);
503 assert(ctx->defs[ssa->index] != 0);
504 return ctx->defs[ssa->index];
505 }
506
507 static SpvId
508 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
509 {
510 assert(reg->index < ctx->num_regs);
511 assert(ctx->regs[reg->index] != 0);
512 return ctx->regs[reg->index];
513 }
514
515 static SpvId
516 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
517 {
518 assert(reg->reg);
519 assert(!reg->indirect);
520 assert(!reg->base_offset);
521
522 SpvId var = get_var_from_reg(ctx, reg->reg);
523 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
524 return spirv_builder_emit_load(&ctx->builder, type, var);
525 }
526
527 static SpvId
528 get_src_uint(struct ntv_context *ctx, nir_src *src)
529 {
530 if (src->is_ssa)
531 return get_src_uint_ssa(ctx, src->ssa);
532 else
533 return get_src_uint_reg(ctx, &src->reg);
534 }
535
536 static SpvId
537 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
538 {
539 assert(!alu->src[src].negate);
540 assert(!alu->src[src].abs);
541
542 SpvId def = get_src_uint(ctx, &alu->src[src].src);
543
544 unsigned used_channels = 0;
545 bool need_swizzle = false;
546 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
547 if (!nir_alu_instr_channel_used(alu, src, i))
548 continue;
549
550 used_channels++;
551
552 if (alu->src[src].swizzle[i] != i)
553 need_swizzle = true;
554 }
555 assert(used_channels != 0);
556
557 unsigned live_channels = nir_src_num_components(alu->src[src].src);
558 if (used_channels != live_channels)
559 need_swizzle = true;
560
561 if (!need_swizzle)
562 return def;
563
564 int bit_size = nir_src_bit_size(alu->src[src].src);
565 assert(bit_size == 1 || bit_size == 32);
566
567 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
568 if (used_channels == 1) {
569 uint32_t indices[] = { alu->src[src].swizzle[0] };
570 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
571 def, indices,
572 ARRAY_SIZE(indices));
573 } else if (live_channels == 1) {
574 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
575 used_channels);
576
577 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
578 for (unsigned i = 0; i < used_channels; ++i)
579 constituents[i] = def;
580
581 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
582 constituents,
583 used_channels);
584 } else {
585 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
586 used_channels);
587
588 uint32_t components[NIR_MAX_VEC_COMPONENTS];
589 size_t num_components = 0;
590 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
591 if (!nir_alu_instr_channel_used(alu, src, i))
592 continue;
593
594 components[num_components++] = alu->src[src].swizzle[i];
595 }
596
597 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
598 def, def, components, num_components);
599 }
600 }
601
602 static void
603 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
604 {
605 assert(result != 0);
606 assert(ssa->index < ctx->num_defs);
607 ctx->defs[ssa->index] = result;
608 }
609
610 static SpvId
611 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
612 SpvId if_true, SpvId if_false)
613 {
614 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
615 }
616
617 static SpvId
618 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
619 {
620 SpvId otype = get_uvec_type(ctx, 32, num_components);
621 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
622 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
623 return emit_select(ctx, otype, value, one, zero);
624 }
625
626 static SpvId
627 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
628 {
629 SpvId type = get_bvec_type(ctx, num_components);
630 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
631 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
632 }
633
634 static SpvId
635 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
636 {
637 return emit_unop(ctx, SpvOpBitcast, type, value);
638 }
639
640 static SpvId
641 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
642 unsigned num_components)
643 {
644 SpvId type = get_uvec_type(ctx, bit_size, num_components);
645 return emit_bitcast(ctx, type, value);
646 }
647
648 static SpvId
649 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
650 unsigned num_components)
651 {
652 SpvId type = get_ivec_type(ctx, bit_size, num_components);
653 return emit_bitcast(ctx, type, value);
654 }
655
656 static SpvId
657 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
658 unsigned num_components)
659 {
660 SpvId type = get_fvec_type(ctx, bit_size, num_components);
661 return emit_bitcast(ctx, type, value);
662 }
663
664 static void
665 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
666 {
667 SpvId var = get_var_from_reg(ctx, reg->reg);
668 assert(var);
669 spirv_builder_emit_store(&ctx->builder, var, result);
670 }
671
672 static void
673 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
674 {
675 if (dest->is_ssa)
676 store_ssa_def_uint(ctx, &dest->ssa, result);
677 else
678 store_reg_def(ctx, &dest->reg, result);
679 }
680
681 static void
682 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
683 {
684 unsigned num_components = nir_dest_num_components(*dest);
685 unsigned bit_size = nir_dest_bit_size(*dest);
686
687 switch (nir_alu_type_get_base_type(type)) {
688 case nir_type_bool:
689 assert(bit_size == 1);
690 result = bvec_to_uvec(ctx, result, num_components);
691 break;
692
693 case nir_type_uint:
694 break; /* nothing to do! */
695
696 case nir_type_int:
697 case nir_type_float:
698 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
699 break;
700
701 default:
702 unreachable("unsupported nir_alu_type");
703 }
704
705 store_dest_uint(ctx, dest, result);
706 }
707
708 static SpvId
709 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
710 {
711 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
712 }
713
714 static SpvId
715 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
716 SpvId src0, SpvId src1)
717 {
718 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
719 }
720
721 static SpvId
722 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
723 SpvId src0, SpvId src1, SpvId src2)
724 {
725 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
726 }
727
728 static SpvId
729 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
730 SpvId src)
731 {
732 SpvId args[] = { src };
733 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
734 op, args, ARRAY_SIZE(args));
735 }
736
737 static SpvId
738 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
739 SpvId src0, SpvId src1)
740 {
741 SpvId args[] = { src0, src1 };
742 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
743 op, args, ARRAY_SIZE(args));
744 }
745
746 static SpvId
747 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
748 SpvId src0, SpvId src1, SpvId src2)
749 {
750 SpvId args[] = { src0, src1, src2 };
751 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
752 op, args, ARRAY_SIZE(args));
753 }
754
755 static SpvId
756 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
757 unsigned num_components, float value)
758 {
759 assert(bit_size == 32);
760
761 SpvId result = emit_float_const(ctx, bit_size, value);
762 if (num_components == 1)
763 return result;
764
765 assert(num_components > 1);
766 SpvId components[num_components];
767 for (int i = 0; i < num_components; i++)
768 components[i] = result;
769
770 SpvId type = get_fvec_type(ctx, bit_size, num_components);
771 return spirv_builder_const_composite(&ctx->builder, type, components,
772 num_components);
773 }
774
775 static SpvId
776 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
777 unsigned num_components, uint32_t value)
778 {
779 assert(bit_size == 32);
780
781 SpvId result = emit_uint_const(ctx, bit_size, value);
782 if (num_components == 1)
783 return result;
784
785 assert(num_components > 1);
786 SpvId components[num_components];
787 for (int i = 0; i < num_components; i++)
788 components[i] = result;
789
790 SpvId type = get_uvec_type(ctx, bit_size, num_components);
791 return spirv_builder_const_composite(&ctx->builder, type, components,
792 num_components);
793 }
794
795 static SpvId
796 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
797 unsigned num_components, int32_t value)
798 {
799 assert(bit_size == 32);
800
801 SpvId result = emit_int_const(ctx, bit_size, value);
802 if (num_components == 1)
803 return result;
804
805 assert(num_components > 1);
806 SpvId components[num_components];
807 for (int i = 0; i < num_components; i++)
808 components[i] = result;
809
810 SpvId type = get_ivec_type(ctx, bit_size, num_components);
811 return spirv_builder_const_composite(&ctx->builder, type, components,
812 num_components);
813 }
814
815 static inline unsigned
816 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
817 {
818 if (nir_op_infos[instr->op].input_sizes[src] > 0)
819 return nir_op_infos[instr->op].input_sizes[src];
820
821 if (instr->dest.dest.is_ssa)
822 return instr->dest.dest.ssa.num_components;
823 else
824 return instr->dest.dest.reg.reg->num_components;
825 }
826
827 static SpvId
828 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
829 {
830 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
831
832 unsigned num_components = alu_instr_src_components(alu, src);
833 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
834 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
835
836 switch (nir_alu_type_get_base_type(type)) {
837 case nir_type_bool:
838 assert(bit_size == 1);
839 return uvec_to_bvec(ctx, uint_value, num_components);
840
841 case nir_type_int:
842 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
843
844 case nir_type_uint:
845 return uint_value;
846
847 case nir_type_float:
848 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
849
850 default:
851 unreachable("unknown nir_alu_type");
852 }
853 }
854
855 static void
856 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
857 {
858 assert(!alu->dest.saturate);
859 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
860 }
861
862 static SpvId
863 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
864 {
865 unsigned num_components = nir_dest_num_components(*dest);
866 unsigned bit_size = nir_dest_bit_size(*dest);
867
868 switch (nir_alu_type_get_base_type(type)) {
869 case nir_type_bool:
870 return get_bvec_type(ctx, num_components);
871
872 case nir_type_int:
873 return get_ivec_type(ctx, bit_size, num_components);
874
875 case nir_type_uint:
876 return get_uvec_type(ctx, bit_size, num_components);
877
878 case nir_type_float:
879 return get_fvec_type(ctx, bit_size, num_components);
880
881 default:
882 unreachable("unsupported nir_alu_type");
883 }
884 }
885
886 static void
887 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
888 {
889 SpvId src[nir_op_infos[alu->op].num_inputs];
890 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
891 src[i] = get_alu_src(ctx, alu, i);
892
893 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
894 nir_op_infos[alu->op].output_type);
895 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
896 unsigned num_components = nir_dest_num_components(alu->dest.dest);
897
898 SpvId result = 0;
899 switch (alu->op) {
900 case nir_op_mov:
901 assert(nir_op_infos[alu->op].num_inputs == 1);
902 result = src[0];
903 break;
904
905 #define UNOP(nir_op, spirv_op) \
906 case nir_op: \
907 assert(nir_op_infos[alu->op].num_inputs == 1); \
908 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
909 break;
910
911 UNOP(nir_op_ineg, SpvOpSNegate)
912 UNOP(nir_op_fneg, SpvOpFNegate)
913 UNOP(nir_op_fddx, SpvOpDPdx)
914 UNOP(nir_op_fddy, SpvOpDPdy)
915 UNOP(nir_op_f2i32, SpvOpConvertFToS)
916 UNOP(nir_op_f2u32, SpvOpConvertFToU)
917 UNOP(nir_op_i2f32, SpvOpConvertSToF)
918 UNOP(nir_op_u2f32, SpvOpConvertUToF)
919 UNOP(nir_op_inot, SpvOpNot)
920 #undef UNOP
921
922 case nir_op_b2i32:
923 assert(nir_op_infos[alu->op].num_inputs == 1);
924 result = emit_select(ctx, dest_type, src[0],
925 get_ivec_constant(ctx, 32, num_components, 1),
926 get_ivec_constant(ctx, 32, num_components, 0));
927 break;
928
929 case nir_op_b2f32:
930 assert(nir_op_infos[alu->op].num_inputs == 1);
931 result = emit_select(ctx, dest_type, src[0],
932 get_fvec_constant(ctx, 32, num_components, 1),
933 get_fvec_constant(ctx, 32, num_components, 0));
934 break;
935
936 #define BUILTIN_UNOP(nir_op, spirv_op) \
937 case nir_op: \
938 assert(nir_op_infos[alu->op].num_inputs == 1); \
939 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
940 break;
941
942 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
943 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
944 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
945 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
946 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
947 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
948 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
949 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
950 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
951 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
952 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
953 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
954 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
955 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
956 #undef BUILTIN_UNOP
957
958 case nir_op_frcp:
959 assert(nir_op_infos[alu->op].num_inputs == 1);
960 result = emit_binop(ctx, SpvOpFDiv, dest_type,
961 get_fvec_constant(ctx, bit_size, num_components, 1),
962 src[0]);
963 break;
964
965 case nir_op_f2b1:
966 assert(nir_op_infos[alu->op].num_inputs == 1);
967 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
968 get_fvec_constant(ctx,
969 nir_src_bit_size(alu->src[0].src),
970 num_components, 0));
971 break;
972
973
974 #define BINOP(nir_op, spirv_op) \
975 case nir_op: \
976 assert(nir_op_infos[alu->op].num_inputs == 2); \
977 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
978 break;
979
980 BINOP(nir_op_iadd, SpvOpIAdd)
981 BINOP(nir_op_isub, SpvOpISub)
982 BINOP(nir_op_imul, SpvOpIMul)
983 BINOP(nir_op_idiv, SpvOpSDiv)
984 BINOP(nir_op_udiv, SpvOpUDiv)
985 BINOP(nir_op_umod, SpvOpUMod)
986 BINOP(nir_op_fadd, SpvOpFAdd)
987 BINOP(nir_op_fsub, SpvOpFSub)
988 BINOP(nir_op_fmul, SpvOpFMul)
989 BINOP(nir_op_fdiv, SpvOpFDiv)
990 BINOP(nir_op_fmod, SpvOpFMod)
991 BINOP(nir_op_ilt, SpvOpSLessThan)
992 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
993 BINOP(nir_op_ieq, SpvOpIEqual)
994 BINOP(nir_op_ine, SpvOpINotEqual)
995 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
996 BINOP(nir_op_flt, SpvOpFOrdLessThan)
997 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
998 BINOP(nir_op_feq, SpvOpFOrdEqual)
999 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
1000 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1001 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1002 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1003 BINOP(nir_op_iand, SpvOpBitwiseAnd)
1004 BINOP(nir_op_ior, SpvOpBitwiseOr)
1005 #undef BINOP
1006
1007 #define BUILTIN_BINOP(nir_op, spirv_op) \
1008 case nir_op: \
1009 assert(nir_op_infos[alu->op].num_inputs == 2); \
1010 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1011 break;
1012
1013 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1014 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1015 #undef BUILTIN_BINOP
1016
1017 case nir_op_fdot2:
1018 case nir_op_fdot3:
1019 case nir_op_fdot4:
1020 assert(nir_op_infos[alu->op].num_inputs == 2);
1021 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1022 break;
1023
1024 case nir_op_seq:
1025 case nir_op_sne:
1026 case nir_op_slt:
1027 case nir_op_sge: {
1028 assert(nir_op_infos[alu->op].num_inputs == 2);
1029 int num_components = nir_dest_num_components(alu->dest.dest);
1030 SpvId bool_type = get_bvec_type(ctx, num_components);
1031
1032 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1033 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1034 if (num_components > 1) {
1035 SpvId zero_comps[num_components], one_comps[num_components];
1036 for (int i = 0; i < num_components; i++) {
1037 zero_comps[i] = zero;
1038 one_comps[i] = one;
1039 }
1040
1041 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1042 zero_comps, num_components);
1043 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1044 one_comps, num_components);
1045 }
1046
1047 SpvOp op;
1048 switch (alu->op) {
1049 case nir_op_seq: op = SpvOpFOrdEqual; break;
1050 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1051 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1052 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1053 default: unreachable("unexpected op");
1054 }
1055
1056 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1057 result = emit_select(ctx, dest_type, result, one, zero);
1058 }
1059 break;
1060
1061 case nir_op_flrp:
1062 assert(nir_op_infos[alu->op].num_inputs == 3);
1063 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1064 src[0], src[1], src[2]);
1065 break;
1066
1067 case nir_op_fcsel:
1068 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1069 get_bvec_type(ctx, num_components),
1070 src[0],
1071 get_fvec_constant(ctx,
1072 nir_src_bit_size(alu->src[0].src),
1073 num_components, 0));
1074 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1075 break;
1076
1077 case nir_op_bcsel:
1078 assert(nir_op_infos[alu->op].num_inputs == 3);
1079 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1080 break;
1081
1082 case nir_op_bany_fnequal2:
1083 case nir_op_bany_fnequal3:
1084 case nir_op_bany_fnequal4:
1085 assert(nir_op_infos[alu->op].num_inputs == 2);
1086 assert(alu_instr_src_components(alu, 0) ==
1087 alu_instr_src_components(alu, 1));
1088 result = emit_binop(ctx, SpvOpFOrdNotEqual,
1089 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1090 src[0], src[1]);
1091 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1092 break;
1093
1094 case nir_op_ball_fequal2:
1095 case nir_op_ball_fequal3:
1096 case nir_op_ball_fequal4:
1097 assert(nir_op_infos[alu->op].num_inputs == 2);
1098 assert(alu_instr_src_components(alu, 0) ==
1099 alu_instr_src_components(alu, 1));
1100 result = emit_binop(ctx, SpvOpFOrdEqual,
1101 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1102 src[0], src[1]);
1103 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1104 break;
1105
1106 case nir_op_bany_inequal2:
1107 case nir_op_bany_inequal3:
1108 case nir_op_bany_inequal4:
1109 assert(nir_op_infos[alu->op].num_inputs == 2);
1110 assert(alu_instr_src_components(alu, 0) ==
1111 alu_instr_src_components(alu, 1));
1112 result = emit_binop(ctx, SpvOpINotEqual,
1113 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1114 src[0], src[1]);
1115 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1116 break;
1117
1118 case nir_op_ball_iequal2:
1119 case nir_op_ball_iequal3:
1120 case nir_op_ball_iequal4:
1121 assert(nir_op_infos[alu->op].num_inputs == 2);
1122 assert(alu_instr_src_components(alu, 0) ==
1123 alu_instr_src_components(alu, 1));
1124 result = emit_binop(ctx, SpvOpIEqual,
1125 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1126 src[0], src[1]);
1127 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1128 break;
1129
1130 case nir_op_vec2:
1131 case nir_op_vec3:
1132 case nir_op_vec4: {
1133 int num_inputs = nir_op_infos[alu->op].num_inputs;
1134 assert(2 <= num_inputs && num_inputs <= 4);
1135 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1136 src, num_inputs);
1137 }
1138 break;
1139
1140 default:
1141 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1142 nir_op_infos[alu->op].name);
1143
1144 unreachable("unsupported opcode");
1145 return;
1146 }
1147
1148 store_alu_result(ctx, alu, result);
1149 }
1150
1151 static void
1152 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1153 {
1154 unsigned bit_size = load_const->def.bit_size;
1155 unsigned num_components = load_const->def.num_components;
1156
1157 SpvId constant;
1158 if (num_components > 1) {
1159 SpvId components[num_components];
1160 SpvId type;
1161 if (bit_size == 1) {
1162 for (int i = 0; i < num_components; i++)
1163 components[i] = spirv_builder_const_bool(&ctx->builder,
1164 load_const->value[i].b);
1165
1166 type = get_bvec_type(ctx, num_components);
1167 } else {
1168 for (int i = 0; i < num_components; i++)
1169 components[i] = emit_uint_const(ctx, bit_size,
1170 load_const->value[i].u32);
1171
1172 type = get_uvec_type(ctx, bit_size, num_components);
1173 }
1174 constant = spirv_builder_const_composite(&ctx->builder, type,
1175 components, num_components);
1176 } else {
1177 assert(num_components == 1);
1178 if (bit_size == 1)
1179 constant = spirv_builder_const_bool(&ctx->builder,
1180 load_const->value[0].b);
1181 else
1182 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1183 }
1184
1185 if (bit_size == 1)
1186 constant = bvec_to_uvec(ctx, constant, num_components);
1187
1188 store_ssa_def_uint(ctx, &load_const->def, constant);
1189 }
1190
1191 static void
1192 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1193 {
1194 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1195 assert(const_block_index); // no dynamic indexing for now
1196 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1197
1198 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1199 if (const_offset) {
1200 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1201 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1202 SpvStorageClassUniform,
1203 uvec4_type);
1204
1205 unsigned idx = const_offset->u32;
1206 SpvId member = emit_uint_const(ctx, 32, 0);
1207 SpvId offset = emit_uint_const(ctx, 32, idx);
1208 SpvId offsets[] = { member, offset };
1209 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1210 ctx->ubos[0], offsets,
1211 ARRAY_SIZE(offsets));
1212 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1213
1214 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1215 unsigned num_components = nir_dest_num_components(intr->dest);
1216 if (num_components == 1) {
1217 uint32_t components[] = { 0 };
1218 result = spirv_builder_emit_composite_extract(&ctx->builder,
1219 type,
1220 result, components,
1221 1);
1222 } else if (num_components < 4) {
1223 SpvId constituents[num_components];
1224 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1225 for (uint32_t i = 0; i < num_components; ++i)
1226 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1227 uint_type,
1228 result, &i,
1229 1);
1230
1231 result = spirv_builder_emit_composite_construct(&ctx->builder,
1232 type,
1233 constituents,
1234 num_components);
1235 }
1236
1237 store_dest_uint(ctx, &intr->dest, result);
1238 } else
1239 unreachable("uniform-addressing not yet supported");
1240 }
1241
1242 static void
1243 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1244 {
1245 assert(ctx->block_started);
1246 spirv_builder_emit_kill(&ctx->builder);
1247 /* discard is weird in NIR, so let's just create an unreachable block after
1248 it and hope that the vulkan driver will DCE any instructinos in it. */
1249 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1250 }
1251
1252 static void
1253 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1254 {
1255 /* uint is a bit of a lie here; it's really just a pointer */
1256 SpvId ptr = get_src_uint(ctx, intr->src);
1257
1258 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1259 SpvId result = spirv_builder_emit_load(&ctx->builder,
1260 get_glsl_type(ctx, var->type),
1261 ptr);
1262 unsigned num_components = nir_dest_num_components(intr->dest);
1263 unsigned bit_size = nir_dest_bit_size(intr->dest);
1264 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1265 store_dest_uint(ctx, &intr->dest, result);
1266 }
1267
1268 static void
1269 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1270 {
1271 /* uint is a bit of a lie here; it's really just a pointer */
1272 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1273 SpvId src = get_src_uint(ctx, &intr->src[1]);
1274
1275 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1276 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1277 SpvId result = emit_bitcast(ctx, type, src);
1278 spirv_builder_emit_store(&ctx->builder, ptr, result);
1279 }
1280
1281 static SpvId
1282 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1283 SpvStorageClass storage_class,
1284 const char *name, SpvBuiltIn builtin)
1285 {
1286 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1287 storage_class,
1288 var_type);
1289 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1290 storage_class);
1291 spirv_builder_emit_name(&ctx->builder, var, name);
1292 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1293
1294 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1295 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1296 return var;
1297 }
1298
1299 static void
1300 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1301 {
1302 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1303 if (!ctx->front_face_var)
1304 ctx->front_face_var = create_builtin_var(ctx, var_type,
1305 SpvStorageClassInput,
1306 "gl_FrontFacing",
1307 SpvBuiltInFrontFacing);
1308
1309 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1310 ctx->front_face_var);
1311 assert(1 == nir_dest_num_components(intr->dest));
1312 result = bvec_to_uvec(ctx, result, 1);
1313 store_dest_uint(ctx, &intr->dest, result);
1314 }
1315
1316 static void
1317 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1318 {
1319 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1320 if (!ctx->instance_id_var)
1321 ctx->instance_id_var = create_builtin_var(ctx, var_type,
1322 SpvStorageClassInput,
1323 "gl_InstanceId",
1324 SpvBuiltInInstanceIndex);
1325
1326 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1327 ctx->instance_id_var);
1328 assert(1 == nir_dest_num_components(intr->dest));
1329 store_dest_uint(ctx, &intr->dest, result);
1330 }
1331
1332 static void
1333 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1334 {
1335 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1336 if (!ctx->vertex_id_var)
1337 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1338 SpvStorageClassInput,
1339 "gl_VertexID",
1340 SpvBuiltInVertexIndex);
1341
1342 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1343 ctx->vertex_id_var);
1344 assert(1 == nir_dest_num_components(intr->dest));
1345 store_dest_uint(ctx, &intr->dest, result);
1346 }
1347
1348 static void
1349 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1350 {
1351 switch (intr->intrinsic) {
1352 case nir_intrinsic_load_ubo:
1353 emit_load_ubo(ctx, intr);
1354 break;
1355
1356 case nir_intrinsic_discard:
1357 emit_discard(ctx, intr);
1358 break;
1359
1360 case nir_intrinsic_load_deref:
1361 emit_load_deref(ctx, intr);
1362 break;
1363
1364 case nir_intrinsic_store_deref:
1365 emit_store_deref(ctx, intr);
1366 break;
1367
1368 case nir_intrinsic_load_front_face:
1369 emit_load_front_face(ctx, intr);
1370 break;
1371
1372 case nir_intrinsic_load_instance_id:
1373 emit_load_instance_id(ctx, intr);
1374 break;
1375
1376 case nir_intrinsic_load_vertex_id:
1377 emit_load_vertex_id(ctx, intr);
1378 break;
1379
1380 default:
1381 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1382 nir_intrinsic_infos[intr->intrinsic].name);
1383 unreachable("unsupported intrinsic");
1384 }
1385 }
1386
1387 static void
1388 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1389 {
1390 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1391 undef->def.num_components);
1392
1393 store_ssa_def_uint(ctx, &undef->def,
1394 spirv_builder_emit_undef(&ctx->builder, type));
1395 }
1396
1397 static SpvId
1398 get_src_float(struct ntv_context *ctx, nir_src *src)
1399 {
1400 SpvId def = get_src_uint(ctx, src);
1401 unsigned num_components = nir_src_num_components(*src);
1402 unsigned bit_size = nir_src_bit_size(*src);
1403 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1404 }
1405
1406 static SpvId
1407 get_src_int(struct ntv_context *ctx, nir_src *src)
1408 {
1409 SpvId def = get_src_uint(ctx, src);
1410 unsigned num_components = nir_src_num_components(*src);
1411 unsigned bit_size = nir_src_bit_size(*src);
1412 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1413 }
1414
1415 static void
1416 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1417 {
1418 assert(tex->op == nir_texop_tex ||
1419 tex->op == nir_texop_txb ||
1420 tex->op == nir_texop_txl ||
1421 tex->op == nir_texop_txd ||
1422 tex->op == nir_texop_txf ||
1423 tex->op == nir_texop_txs);
1424 assert(tex->texture_index == tex->sampler_index);
1425
1426 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1427 offset = 0;
1428 unsigned coord_components = 0;
1429 for (unsigned i = 0; i < tex->num_srcs; i++) {
1430 switch (tex->src[i].src_type) {
1431 case nir_tex_src_coord:
1432 if (tex->op == nir_texop_txf)
1433 coord = get_src_int(ctx, &tex->src[i].src);
1434 else
1435 coord = get_src_float(ctx, &tex->src[i].src);
1436 coord_components = nir_src_num_components(tex->src[i].src);
1437 break;
1438
1439 case nir_tex_src_projector:
1440 assert(nir_src_num_components(tex->src[i].src) == 1);
1441 proj = get_src_float(ctx, &tex->src[i].src);
1442 assert(proj != 0);
1443 break;
1444
1445 case nir_tex_src_offset:
1446 offset = get_src_int(ctx, &tex->src[i].src);
1447 break;
1448
1449 case nir_tex_src_bias:
1450 assert(tex->op == nir_texop_txb);
1451 bias = get_src_float(ctx, &tex->src[i].src);
1452 assert(bias != 0);
1453 break;
1454
1455 case nir_tex_src_lod:
1456 assert(nir_src_num_components(tex->src[i].src) == 1);
1457 if (tex->op == nir_texop_txf ||
1458 tex->op == nir_texop_txs)
1459 lod = get_src_int(ctx, &tex->src[i].src);
1460 else
1461 lod = get_src_float(ctx, &tex->src[i].src);
1462 assert(lod != 0);
1463 break;
1464
1465 case nir_tex_src_comparator:
1466 assert(nir_src_num_components(tex->src[i].src) == 1);
1467 dref = get_src_float(ctx, &tex->src[i].src);
1468 assert(dref != 0);
1469 break;
1470
1471 case nir_tex_src_ddx:
1472 dx = get_src_float(ctx, &tex->src[i].src);
1473 assert(dx != 0);
1474 break;
1475
1476 case nir_tex_src_ddy:
1477 dy = get_src_float(ctx, &tex->src[i].src);
1478 assert(dy != 0);
1479 break;
1480
1481 default:
1482 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1483 unreachable("unknown texture source");
1484 }
1485 }
1486
1487 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1488 lod = emit_float_const(ctx, 32, 0.0f);
1489 assert(lod != 0);
1490 }
1491
1492 SpvId image_type = ctx->image_types[tex->texture_index];
1493 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1494 image_type);
1495
1496 assert(tex->texture_index < ctx->num_samplers);
1497 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1498 ctx->samplers[tex->texture_index]);
1499
1500 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1501
1502 if (tex->op == nir_texop_txs) {
1503 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1504 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1505 dest_type, image,
1506 lod);
1507 store_dest(ctx, &tex->dest, result, tex->dest_type);
1508 return;
1509 }
1510
1511 if (proj && coord_components > 0) {
1512 SpvId constituents[coord_components + 1];
1513 if (coord_components == 1)
1514 constituents[0] = coord;
1515 else {
1516 assert(coord_components > 1);
1517 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1518 for (uint32_t i = 0; i < coord_components; ++i)
1519 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1520 float_type,
1521 coord,
1522 &i, 1);
1523 }
1524
1525 constituents[coord_components++] = proj;
1526
1527 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1528 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1529 vec_type,
1530 constituents,
1531 coord_components);
1532 }
1533
1534 SpvId actual_dest_type = dest_type;
1535 if (dref)
1536 actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1537
1538 SpvId result;
1539 if (tex->op == nir_texop_txf) {
1540 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1541 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1542 image, coord, lod);
1543 } else {
1544 result = spirv_builder_emit_image_sample(&ctx->builder,
1545 actual_dest_type, load,
1546 coord,
1547 proj != 0,
1548 lod, bias, dref, dx, dy,
1549 offset);
1550 }
1551
1552 spirv_builder_emit_decoration(&ctx->builder, result,
1553 SpvDecorationRelaxedPrecision);
1554
1555 if (dref && nir_dest_num_components(tex->dest) > 1) {
1556 SpvId components[4] = { result, result, result, result };
1557 result = spirv_builder_emit_composite_construct(&ctx->builder,
1558 dest_type,
1559 components,
1560 4);
1561 }
1562
1563 store_dest(ctx, &tex->dest, result, tex->dest_type);
1564 }
1565
1566 static void
1567 start_block(struct ntv_context *ctx, SpvId label)
1568 {
1569 /* terminate previous block if needed */
1570 if (ctx->block_started)
1571 spirv_builder_emit_branch(&ctx->builder, label);
1572
1573 /* start new block */
1574 spirv_builder_label(&ctx->builder, label);
1575 ctx->block_started = true;
1576 }
1577
1578 static void
1579 branch(struct ntv_context *ctx, SpvId label)
1580 {
1581 assert(ctx->block_started);
1582 spirv_builder_emit_branch(&ctx->builder, label);
1583 ctx->block_started = false;
1584 }
1585
1586 static void
1587 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1588 SpvId else_id)
1589 {
1590 assert(ctx->block_started);
1591 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1592 then_id, else_id);
1593 ctx->block_started = false;
1594 }
1595
1596 static void
1597 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1598 {
1599 switch (jump->type) {
1600 case nir_jump_break:
1601 assert(ctx->loop_break);
1602 branch(ctx, ctx->loop_break);
1603 break;
1604
1605 case nir_jump_continue:
1606 assert(ctx->loop_cont);
1607 branch(ctx, ctx->loop_cont);
1608 break;
1609
1610 default:
1611 unreachable("Unsupported jump type\n");
1612 }
1613 }
1614
1615 static void
1616 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1617 {
1618 assert(deref->deref_type == nir_deref_type_var);
1619
1620 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1621 assert(he);
1622 SpvId result = (SpvId)(intptr_t)he->data;
1623 /* uint is a bit of a lie here, it's really just an opaque type */
1624 store_dest_uint(ctx, &deref->dest, result);
1625 }
1626
1627 static void
1628 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1629 {
1630 assert(deref->deref_type == nir_deref_type_array);
1631 nir_variable *var = nir_deref_instr_get_variable(deref);
1632
1633 SpvStorageClass storage_class;
1634 switch (var->data.mode) {
1635 case nir_var_shader_in:
1636 storage_class = SpvStorageClassInput;
1637 break;
1638
1639 case nir_var_shader_out:
1640 storage_class = SpvStorageClassOutput;
1641 break;
1642
1643 default:
1644 unreachable("Unsupported nir_variable_mode\n");
1645 }
1646
1647 SpvId index = get_src_uint(ctx, &deref->arr.index);
1648
1649 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1650 storage_class,
1651 get_glsl_type(ctx, deref->type));
1652
1653 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1654 ptr_type,
1655 get_src_uint(ctx, &deref->parent),
1656 &index, 1);
1657 /* uint is a bit of a lie here, it's really just an opaque type */
1658 store_dest_uint(ctx, &deref->dest, result);
1659 }
1660
1661 static void
1662 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1663 {
1664 switch (deref->deref_type) {
1665 case nir_deref_type_var:
1666 emit_deref_var(ctx, deref);
1667 break;
1668
1669 case nir_deref_type_array:
1670 emit_deref_array(ctx, deref);
1671 break;
1672
1673 default:
1674 unreachable("unexpected deref_type");
1675 }
1676 }
1677
1678 static void
1679 emit_block(struct ntv_context *ctx, struct nir_block *block)
1680 {
1681 start_block(ctx, block_label(ctx, block));
1682 nir_foreach_instr(instr, block) {
1683 switch (instr->type) {
1684 case nir_instr_type_alu:
1685 emit_alu(ctx, nir_instr_as_alu(instr));
1686 break;
1687 case nir_instr_type_intrinsic:
1688 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1689 break;
1690 case nir_instr_type_load_const:
1691 emit_load_const(ctx, nir_instr_as_load_const(instr));
1692 break;
1693 case nir_instr_type_ssa_undef:
1694 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1695 break;
1696 case nir_instr_type_tex:
1697 emit_tex(ctx, nir_instr_as_tex(instr));
1698 break;
1699 case nir_instr_type_phi:
1700 unreachable("nir_instr_type_phi not supported");
1701 break;
1702 case nir_instr_type_jump:
1703 emit_jump(ctx, nir_instr_as_jump(instr));
1704 break;
1705 case nir_instr_type_call:
1706 unreachable("nir_instr_type_call not supported");
1707 break;
1708 case nir_instr_type_parallel_copy:
1709 unreachable("nir_instr_type_parallel_copy not supported");
1710 break;
1711 case nir_instr_type_deref:
1712 emit_deref(ctx, nir_instr_as_deref(instr));
1713 break;
1714 }
1715 }
1716 }
1717
1718 static void
1719 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1720
1721 static SpvId
1722 get_src_bool(struct ntv_context *ctx, nir_src *src)
1723 {
1724 SpvId def = get_src_uint(ctx, src);
1725 assert(nir_src_bit_size(*src) == 1);
1726 unsigned num_components = nir_src_num_components(*src);
1727 return uvec_to_bvec(ctx, def, num_components);
1728 }
1729
1730 static void
1731 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1732 {
1733 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1734
1735 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1736 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1737 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1738 SpvId else_id = endif_id;
1739
1740 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1741 if (has_else) {
1742 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1743 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1744 }
1745
1746 /* create a header-block */
1747 start_block(ctx, header_id);
1748 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1749 SpvSelectionControlMaskNone);
1750 branch_conditional(ctx, condition, then_id, else_id);
1751
1752 emit_cf_list(ctx, &if_stmt->then_list);
1753
1754 if (has_else) {
1755 if (ctx->block_started)
1756 branch(ctx, endif_id);
1757
1758 emit_cf_list(ctx, &if_stmt->else_list);
1759 }
1760
1761 start_block(ctx, endif_id);
1762 }
1763
1764 static void
1765 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1766 {
1767 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1768 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1769 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1770 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1771
1772 /* create a header-block */
1773 start_block(ctx, header_id);
1774 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1775 branch(ctx, begin_id);
1776
1777 SpvId save_break = ctx->loop_break;
1778 SpvId save_cont = ctx->loop_cont;
1779 ctx->loop_break = break_id;
1780 ctx->loop_cont = cont_id;
1781
1782 emit_cf_list(ctx, &loop->body);
1783
1784 ctx->loop_break = save_break;
1785 ctx->loop_cont = save_cont;
1786
1787 branch(ctx, cont_id);
1788 start_block(ctx, cont_id);
1789 branch(ctx, header_id);
1790
1791 start_block(ctx, break_id);
1792 }
1793
1794 static void
1795 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1796 {
1797 foreach_list_typed(nir_cf_node, node, node, list) {
1798 switch (node->type) {
1799 case nir_cf_node_block:
1800 emit_block(ctx, nir_cf_node_as_block(node));
1801 break;
1802
1803 case nir_cf_node_if:
1804 emit_if(ctx, nir_cf_node_as_if(node));
1805 break;
1806
1807 case nir_cf_node_loop:
1808 emit_loop(ctx, nir_cf_node_as_loop(node));
1809 break;
1810
1811 case nir_cf_node_function:
1812 unreachable("nir_cf_node_function not supported");
1813 break;
1814 }
1815 }
1816 }
1817
1818 struct spirv_shader *
1819 nir_to_spirv(struct nir_shader *s)
1820 {
1821 struct spirv_shader *ret = NULL;
1822
1823 struct ntv_context ctx = {};
1824
1825 switch (s->info.stage) {
1826 case MESA_SHADER_VERTEX:
1827 case MESA_SHADER_FRAGMENT:
1828 case MESA_SHADER_COMPUTE:
1829 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1830 break;
1831
1832 case MESA_SHADER_TESS_CTRL:
1833 case MESA_SHADER_TESS_EVAL:
1834 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1835 break;
1836
1837 case MESA_SHADER_GEOMETRY:
1838 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1839 break;
1840
1841 default:
1842 unreachable("invalid stage");
1843 }
1844
1845 // TODO: only enable when needed
1846 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1847 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1848 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1849 }
1850
1851 ctx.stage = s->info.stage;
1852 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1853 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1854
1855 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1856 SpvMemoryModelGLSL450);
1857
1858 SpvExecutionModel exec_model;
1859 switch (s->info.stage) {
1860 case MESA_SHADER_VERTEX:
1861 exec_model = SpvExecutionModelVertex;
1862 break;
1863 case MESA_SHADER_TESS_CTRL:
1864 exec_model = SpvExecutionModelTessellationControl;
1865 break;
1866 case MESA_SHADER_TESS_EVAL:
1867 exec_model = SpvExecutionModelTessellationEvaluation;
1868 break;
1869 case MESA_SHADER_GEOMETRY:
1870 exec_model = SpvExecutionModelGeometry;
1871 break;
1872 case MESA_SHADER_FRAGMENT:
1873 exec_model = SpvExecutionModelFragment;
1874 break;
1875 case MESA_SHADER_COMPUTE:
1876 exec_model = SpvExecutionModelGLCompute;
1877 break;
1878 default:
1879 unreachable("invalid stage");
1880 }
1881
1882 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1883 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1884 NULL, 0);
1885 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1886 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1887
1888 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1889 _mesa_key_pointer_equal);
1890
1891 nir_foreach_variable(var, &s->inputs)
1892 emit_input(&ctx, var);
1893
1894 nir_foreach_variable(var, &s->outputs)
1895 emit_output(&ctx, var);
1896
1897 nir_foreach_variable(var, &s->uniforms)
1898 emit_uniform(&ctx, var);
1899
1900 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1901 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1902 SpvExecutionModeOriginUpperLeft);
1903 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1904 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1905 SpvExecutionModeDepthReplacing);
1906 }
1907
1908
1909 spirv_builder_function(&ctx.builder, entry_point, type_void,
1910 SpvFunctionControlMaskNone,
1911 type_main);
1912
1913 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1914 nir_metadata_require(entry, nir_metadata_block_index);
1915
1916 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1917 if (!ctx.defs)
1918 goto fail;
1919 ctx.num_defs = entry->ssa_alloc;
1920
1921 nir_index_local_regs(entry);
1922 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1923 if (!ctx.regs)
1924 goto fail;
1925 ctx.num_regs = entry->reg_alloc;
1926
1927 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1928 if (!block_ids)
1929 goto fail;
1930
1931 for (int i = 0; i < entry->num_blocks; ++i)
1932 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1933
1934 ctx.block_ids = block_ids;
1935 ctx.num_blocks = entry->num_blocks;
1936
1937 /* emit a block only for the variable declarations */
1938 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1939 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1940 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1941 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1942 SpvStorageClassFunction,
1943 type);
1944 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1945 SpvStorageClassFunction);
1946
1947 ctx.regs[reg->index] = var;
1948 }
1949
1950 emit_cf_list(&ctx, &entry->body);
1951
1952 free(ctx.defs);
1953
1954 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1955 spirv_builder_function_end(&ctx.builder);
1956
1957 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1958 "main", ctx.entry_ifaces,
1959 ctx.num_entry_ifaces);
1960
1961 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1962
1963 ret = CALLOC_STRUCT(spirv_shader);
1964 if (!ret)
1965 goto fail;
1966
1967 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1968 if (!ret->words)
1969 goto fail;
1970
1971 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1972 assert(ret->num_words == num_words);
1973
1974 return ret;
1975
1976 fail:
1977
1978 if (ret)
1979 spirv_shader_delete(ret);
1980
1981 if (ctx.vars)
1982 _mesa_hash_table_destroy(ctx.vars, NULL);
1983
1984 return NULL;
1985 }
1986
1987 void
1988 spirv_shader_delete(struct spirv_shader *s)
1989 {
1990 FREE(s->words);
1991 FREE(s);
1992 }