zink: do not lower bools to float
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59 };
60
61 static SpvId
62 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
63 unsigned num_components, float value);
64
65 static SpvId
66 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
67 unsigned num_components, uint32_t value);
68
69 static SpvId
70 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
71
72 static SpvId
73 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
74 SpvId src0, SpvId src1);
75
76 static SpvId
77 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
78 SpvId src0, SpvId src1, SpvId src2);
79
80 static SpvId
81 get_bvec_type(struct ntv_context *ctx, int num_components)
82 {
83 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
84 if (num_components > 1)
85 return spirv_builder_type_vector(&ctx->builder, bool_type,
86 num_components);
87
88 assert(num_components == 1);
89 return bool_type;
90 }
91
92 static SpvId
93 block_label(struct ntv_context *ctx, nir_block *block)
94 {
95 assert(block->index < ctx->num_blocks);
96 return ctx->block_ids[block->index];
97 }
98
99 static SpvId
100 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
101 {
102 assert(bit_size == 32);
103 return spirv_builder_const_float(&ctx->builder, bit_size, value);
104 }
105
106 static SpvId
107 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
108 {
109 assert(bit_size == 32);
110 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
111 }
112
113 static SpvId
114 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
115 {
116 assert(bit_size == 32); // only 32-bit floats supported so far
117
118 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
119 if (num_components > 1)
120 return spirv_builder_type_vector(&ctx->builder, float_type,
121 num_components);
122
123 assert(num_components == 1);
124 return float_type;
125 }
126
127 static SpvId
128 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
129 {
130 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
131
132 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
133 if (num_components > 1)
134 return spirv_builder_type_vector(&ctx->builder, int_type,
135 num_components);
136
137 assert(num_components == 1);
138 return int_type;
139 }
140
141 static SpvId
142 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
143 {
144 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
145
146 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
147 if (num_components > 1)
148 return spirv_builder_type_vector(&ctx->builder, uint_type,
149 num_components);
150
151 assert(num_components == 1);
152 return uint_type;
153 }
154
155 static SpvId
156 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
157 {
158 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
159 nir_dest_num_components(*dest));
160 }
161
162 static SpvId
163 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
164 {
165 switch (type) {
166 case GLSL_TYPE_FLOAT:
167 return spirv_builder_type_float(&ctx->builder, 32);
168
169 case GLSL_TYPE_INT:
170 return spirv_builder_type_int(&ctx->builder, 32);
171
172 case GLSL_TYPE_UINT:
173 return spirv_builder_type_uint(&ctx->builder, 32);
174 /* TODO: handle more types */
175
176 default:
177 unreachable("unknown GLSL type");
178 }
179 }
180
181 static SpvId
182 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
183 {
184 assert(type);
185 if (glsl_type_is_scalar(type))
186 return get_glsl_basetype(ctx, glsl_get_base_type(type));
187
188 if (glsl_type_is_vector(type))
189 return spirv_builder_type_vector(&ctx->builder,
190 get_glsl_basetype(ctx, glsl_get_base_type(type)),
191 glsl_get_vector_elements(type));
192
193 if (glsl_type_is_array(type)) {
194 SpvId ret = spirv_builder_type_array(&ctx->builder,
195 get_glsl_type(ctx, glsl_get_array_element(type)),
196 emit_uint_const(ctx, 32, glsl_get_length(type)));
197 uint32_t stride = glsl_get_explicit_stride(type);
198 if (stride)
199 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
200 return ret;
201 }
202
203
204 unreachable("we shouldn't get here, I think...");
205 }
206
207 static void
208 emit_input(struct ntv_context *ctx, struct nir_variable *var)
209 {
210 SpvId var_type = get_glsl_type(ctx, var->type);
211 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
212 SpvStorageClassInput,
213 var_type);
214 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
215 SpvStorageClassInput);
216
217 if (var->name)
218 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
219
220 if (ctx->stage == MESA_SHADER_FRAGMENT) {
221 if (var->data.location >= VARYING_SLOT_VAR0 ||
222 (var->data.location >= VARYING_SLOT_COL0 &&
223 var->data.location <= VARYING_SLOT_TEX7)) {
224 spirv_builder_emit_location(&ctx->builder, var_id,
225 ctx->var_location++);
226 } else {
227 switch (var->data.location) {
228 case VARYING_SLOT_POS:
229 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
230 break;
231
232 case VARYING_SLOT_PNTC:
233 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
234 break;
235
236 default:
237 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
238 unreachable("unexpected varying slot");
239 }
240 }
241 } else {
242 spirv_builder_emit_location(&ctx->builder, var_id,
243 var->data.driver_location);
244 }
245
246 if (var->data.location_frac)
247 spirv_builder_emit_component(&ctx->builder, var_id,
248 var->data.location_frac);
249
250 if (var->data.interpolation == INTERP_MODE_FLAT)
251 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
252
253 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
254
255 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
256 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
257 }
258
259 static void
260 emit_output(struct ntv_context *ctx, struct nir_variable *var)
261 {
262 SpvId var_type = get_glsl_type(ctx, var->type);
263 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
264 SpvStorageClassOutput,
265 var_type);
266 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
267 SpvStorageClassOutput);
268 if (var->name)
269 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
270
271
272 if (ctx->stage == MESA_SHADER_VERTEX) {
273 if (var->data.location >= VARYING_SLOT_VAR0 ||
274 (var->data.location >= VARYING_SLOT_COL0 &&
275 var->data.location <= VARYING_SLOT_TEX7)) {
276 spirv_builder_emit_location(&ctx->builder, var_id,
277 ctx->var_location++);
278 } else {
279 switch (var->data.location) {
280 case VARYING_SLOT_POS:
281 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
282 break;
283
284 case VARYING_SLOT_PSIZ:
285 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
286 break;
287
288 case VARYING_SLOT_CLIP_DIST0:
289 assert(glsl_type_is_array(var->type));
290 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
291 break;
292
293 default:
294 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
295 unreachable("unexpected varying slot");
296 }
297 }
298 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
299 if (var->data.location >= FRAG_RESULT_DATA0)
300 spirv_builder_emit_location(&ctx->builder, var_id,
301 var->data.location - FRAG_RESULT_DATA0);
302 else {
303 switch (var->data.location) {
304 case FRAG_RESULT_COLOR:
305 spirv_builder_emit_location(&ctx->builder, var_id, 0);
306 break;
307
308 case FRAG_RESULT_DEPTH:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
310 break;
311
312 default:
313 spirv_builder_emit_location(&ctx->builder, var_id,
314 var->data.driver_location);
315 }
316 }
317 }
318
319 if (var->data.location_frac)
320 spirv_builder_emit_component(&ctx->builder, var_id,
321 var->data.location_frac);
322
323 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
324
325 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
326 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
327 }
328
329 static SpvDim
330 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
331 {
332 *is_ms = false;
333 switch (gdim) {
334 case GLSL_SAMPLER_DIM_1D:
335 return SpvDim1D;
336 case GLSL_SAMPLER_DIM_2D:
337 return SpvDim2D;
338 case GLSL_SAMPLER_DIM_RECT:
339 return SpvDimRect;
340 case GLSL_SAMPLER_DIM_CUBE:
341 return SpvDimCube;
342 case GLSL_SAMPLER_DIM_3D:
343 return SpvDim3D;
344 case GLSL_SAMPLER_DIM_MS:
345 *is_ms = true;
346 return SpvDim2D;
347 default:
348 fprintf(stderr, "unknown sampler type %d\n", gdim);
349 break;
350 }
351 return SpvDim2D;
352 }
353
354 static void
355 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
356 {
357 bool is_ms;
358 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
359 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
360 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
361 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
362 SpvImageFormatUnknown);
363
364 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
365 image_type);
366 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
367 SpvStorageClassUniformConstant,
368 sampled_type);
369 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
370 SpvStorageClassUniformConstant);
371
372 if (var->name)
373 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
374
375 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
376 ctx->samplers[ctx->num_samplers++] = var_id;
377
378 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
379 var->data.descriptor_set);
380 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
381 }
382
383 static void
384 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
385 {
386 uint32_t size = glsl_count_attribute_slots(var->type, false);
387 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
388 SpvId array_length = emit_uint_const(ctx, 32, size);
389 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
390 array_length);
391 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
392
393 // wrap UBO-array in a struct
394 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
395 if (var->name) {
396 char struct_name[100];
397 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
398 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
399 }
400
401 spirv_builder_emit_decoration(&ctx->builder, struct_type,
402 SpvDecorationBlock);
403 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
404
405
406 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
407 SpvStorageClassUniform,
408 struct_type);
409
410 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
411 SpvStorageClassUniform);
412 if (var->name)
413 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
414
415 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
416 ctx->ubos[ctx->num_ubos++] = var_id;
417
418 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
419 var->data.descriptor_set);
420 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
421 }
422
423 static void
424 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
425 {
426 if (var->data.mode == nir_var_mem_ubo)
427 emit_ubo(ctx, var);
428 else {
429 assert(var->data.mode == nir_var_uniform);
430 if (glsl_type_is_sampler(var->type))
431 emit_sampler(ctx, var);
432 }
433 }
434
435 static SpvId
436 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
437 {
438 assert(ssa->index < ctx->num_defs);
439 assert(ctx->defs[ssa->index] != 0);
440 return ctx->defs[ssa->index];
441 }
442
443 static SpvId
444 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
445 {
446 assert(reg->index < ctx->num_regs);
447 assert(ctx->regs[reg->index] != 0);
448 return ctx->regs[reg->index];
449 }
450
451 static SpvId
452 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
453 {
454 assert(reg->reg);
455 assert(!reg->indirect);
456 assert(!reg->base_offset);
457
458 SpvId var = get_var_from_reg(ctx, reg->reg);
459 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
460 return spirv_builder_emit_load(&ctx->builder, type, var);
461 }
462
463 static SpvId
464 get_src_uint(struct ntv_context *ctx, nir_src *src)
465 {
466 if (src->is_ssa)
467 return get_src_uint_ssa(ctx, src->ssa);
468 else
469 return get_src_uint_reg(ctx, &src->reg);
470 }
471
472 static SpvId
473 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
474 {
475 assert(!alu->src[src].negate);
476 assert(!alu->src[src].abs);
477
478 SpvId def = get_src_uint(ctx, &alu->src[src].src);
479
480 unsigned used_channels = 0;
481 bool need_swizzle = false;
482 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
483 if (!nir_alu_instr_channel_used(alu, src, i))
484 continue;
485
486 used_channels++;
487
488 if (alu->src[src].swizzle[i] != i)
489 need_swizzle = true;
490 }
491 assert(used_channels != 0);
492
493 unsigned live_channels = nir_src_num_components(alu->src[src].src);
494 if (used_channels != live_channels)
495 need_swizzle = true;
496
497 if (!need_swizzle)
498 return def;
499
500 int bit_size = nir_src_bit_size(alu->src[src].src);
501 assert(bit_size == 1 || bit_size == 32);
502
503 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
504 if (used_channels == 1) {
505 uint32_t indices[] = { alu->src[src].swizzle[0] };
506 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
507 def, indices,
508 ARRAY_SIZE(indices));
509 } else if (live_channels == 1) {
510 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
511 used_channels);
512
513 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
514 for (unsigned i = 0; i < used_channels; ++i)
515 constituents[i] = def;
516
517 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
518 constituents,
519 used_channels);
520 } else {
521 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
522 used_channels);
523
524 uint32_t components[NIR_MAX_VEC_COMPONENTS];
525 size_t num_components = 0;
526 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
527 if (!nir_alu_instr_channel_used(alu, src, i))
528 continue;
529
530 components[num_components++] = alu->src[src].swizzle[i];
531 }
532
533 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
534 def, def, components, num_components);
535 }
536 }
537
538 static void
539 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
540 {
541 assert(result != 0);
542 assert(ssa->index < ctx->num_defs);
543 ctx->defs[ssa->index] = result;
544 }
545
546 static SpvId
547 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
548 SpvId if_true, SpvId if_false)
549 {
550 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
551 }
552
553 static SpvId
554 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
555 {
556 SpvId otype = get_uvec_type(ctx, 32, num_components);
557 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
558 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
559 return emit_select(ctx, otype, value, one, zero);
560 }
561
562 static SpvId
563 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
564 {
565 SpvId type = get_bvec_type(ctx, num_components);
566 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
567 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
568 }
569
570 static SpvId
571 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
572 {
573 return emit_unop(ctx, SpvOpBitcast, type, value);
574 }
575
576 static SpvId
577 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
578 unsigned num_components)
579 {
580 SpvId type = get_uvec_type(ctx, bit_size, num_components);
581 return emit_bitcast(ctx, type, value);
582 }
583
584 static SpvId
585 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
586 unsigned num_components)
587 {
588 SpvId type = get_ivec_type(ctx, bit_size, num_components);
589 return emit_bitcast(ctx, type, value);
590 }
591
592 static SpvId
593 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
594 unsigned num_components)
595 {
596 SpvId type = get_fvec_type(ctx, bit_size, num_components);
597 return emit_bitcast(ctx, type, value);
598 }
599
600 static void
601 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
602 {
603 SpvId var = get_var_from_reg(ctx, reg->reg);
604 assert(var);
605 spirv_builder_emit_store(&ctx->builder, var, result);
606 }
607
608 static void
609 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
610 {
611 if (dest->is_ssa)
612 store_ssa_def_uint(ctx, &dest->ssa, result);
613 else
614 store_reg_def(ctx, &dest->reg, result);
615 }
616
617 static void
618 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
619 {
620 unsigned num_components = nir_dest_num_components(*dest);
621 unsigned bit_size = nir_dest_bit_size(*dest);
622
623 switch (nir_alu_type_get_base_type(type)) {
624 case nir_type_bool:
625 assert(bit_size == 1);
626 result = bvec_to_uvec(ctx, result, num_components);
627 break;
628
629 case nir_type_uint:
630 break; /* nothing to do! */
631
632 case nir_type_int:
633 case nir_type_float:
634 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
635 break;
636
637 default:
638 unreachable("unsupported nir_alu_type");
639 }
640
641 store_dest_uint(ctx, dest, result);
642 }
643
644 static SpvId
645 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
646 {
647 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
648 }
649
650 static SpvId
651 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
652 SpvId src0, SpvId src1)
653 {
654 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
655 }
656
657 static SpvId
658 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
659 SpvId src0, SpvId src1, SpvId src2)
660 {
661 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
662 }
663
664 static SpvId
665 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
666 SpvId src)
667 {
668 SpvId args[] = { src };
669 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
670 op, args, ARRAY_SIZE(args));
671 }
672
673 static SpvId
674 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
675 SpvId src0, SpvId src1)
676 {
677 SpvId args[] = { src0, src1 };
678 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
679 op, args, ARRAY_SIZE(args));
680 }
681
682 static SpvId
683 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
684 unsigned num_components, float value)
685 {
686 assert(bit_size == 32);
687
688 SpvId result = emit_float_const(ctx, bit_size, value);
689 if (num_components == 1)
690 return result;
691
692 assert(num_components > 1);
693 SpvId components[num_components];
694 for (int i = 0; i < num_components; i++)
695 components[i] = result;
696
697 SpvId type = get_fvec_type(ctx, bit_size, num_components);
698 return spirv_builder_const_composite(&ctx->builder, type, components,
699 num_components);
700 }
701
702 static SpvId
703 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
704 unsigned num_components, uint32_t value)
705 {
706 assert(bit_size == 32);
707
708 SpvId result = emit_uint_const(ctx, bit_size, value);
709 if (num_components == 1)
710 return result;
711
712 assert(num_components > 1);
713 SpvId components[num_components];
714 for (int i = 0; i < num_components; i++)
715 components[i] = result;
716
717 SpvId type = get_uvec_type(ctx, bit_size, num_components);
718 return spirv_builder_const_composite(&ctx->builder, type, components,
719 num_components);
720 }
721
722 static inline unsigned
723 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
724 {
725 if (nir_op_infos[instr->op].input_sizes[src] > 0)
726 return nir_op_infos[instr->op].input_sizes[src];
727
728 if (instr->dest.dest.is_ssa)
729 return instr->dest.dest.ssa.num_components;
730 else
731 return instr->dest.dest.reg.reg->num_components;
732 }
733
734 static SpvId
735 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
736 {
737 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
738
739 unsigned num_components = alu_instr_src_components(alu, src);
740 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
741 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
742
743 switch (nir_alu_type_get_base_type(type)) {
744 case nir_type_bool:
745 assert(bit_size == 1);
746 return uvec_to_bvec(ctx, uint_value, num_components);
747
748 case nir_type_int:
749 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
750
751 case nir_type_uint:
752 return uint_value;
753
754 case nir_type_float:
755 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
756
757 default:
758 unreachable("unknown nir_alu_type");
759 }
760 }
761
762 static void
763 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
764 {
765 assert(!alu->dest.saturate);
766 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
767 }
768
769 static SpvId
770 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
771 {
772 unsigned num_components = nir_dest_num_components(*dest);
773 unsigned bit_size = nir_dest_bit_size(*dest);
774
775 switch (nir_alu_type_get_base_type(type)) {
776 case nir_type_bool:
777 return get_bvec_type(ctx, num_components);
778
779 case nir_type_int:
780 return get_ivec_type(ctx, bit_size, num_components);
781
782 case nir_type_uint:
783 return get_uvec_type(ctx, bit_size, num_components);
784
785 case nir_type_float:
786 return get_fvec_type(ctx, bit_size, num_components);
787
788 default:
789 unreachable("unsupported nir_alu_type");
790 }
791 }
792
793 static void
794 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
795 {
796 SpvId src[nir_op_infos[alu->op].num_inputs];
797 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
798 src[i] = get_alu_src(ctx, alu, i);
799
800 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
801 nir_op_infos[alu->op].output_type);
802 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
803 unsigned num_components = nir_dest_num_components(alu->dest.dest);
804
805 SpvId result = 0;
806 switch (alu->op) {
807 case nir_op_mov:
808 assert(nir_op_infos[alu->op].num_inputs == 1);
809 result = src[0];
810 break;
811
812 #define UNOP(nir_op, spirv_op) \
813 case nir_op: \
814 assert(nir_op_infos[alu->op].num_inputs == 1); \
815 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
816 break;
817
818 UNOP(nir_op_ineg, SpvOpSNegate)
819 UNOP(nir_op_fneg, SpvOpFNegate)
820 UNOP(nir_op_fddx, SpvOpDPdx)
821 UNOP(nir_op_fddy, SpvOpDPdy)
822 UNOP(nir_op_f2i32, SpvOpConvertFToS)
823 UNOP(nir_op_f2u32, SpvOpConvertFToU)
824 UNOP(nir_op_i2f32, SpvOpConvertSToF)
825 UNOP(nir_op_u2f32, SpvOpConvertUToF)
826 UNOP(nir_op_inot, SpvOpNot)
827 #undef UNOP
828
829 case nir_op_b2i32:
830 assert(nir_op_infos[alu->op].num_inputs == 1);
831 result = emit_select(ctx, dest_type, src[0],
832 get_uvec_constant(ctx, 32, num_components, 1),
833 get_uvec_constant(ctx, 32, num_components, 0));
834 break;
835
836 case nir_op_b2f32:
837 assert(nir_op_infos[alu->op].num_inputs == 1);
838 result = emit_select(ctx, dest_type, src[0],
839 get_fvec_constant(ctx, 32, num_components, 1),
840 get_fvec_constant(ctx, 32, num_components, 0));
841 break;
842
843 #define BUILTIN_UNOP(nir_op, spirv_op) \
844 case nir_op: \
845 assert(nir_op_infos[alu->op].num_inputs == 1); \
846 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
847 break;
848
849 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
850 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
851 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
852 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
853 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
854 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
855 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
856 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
857 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
858 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
859 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
860 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
861 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
862 #undef BUILTIN_UNOP
863
864 case nir_op_frcp:
865 assert(nir_op_infos[alu->op].num_inputs == 1);
866 result = emit_binop(ctx, SpvOpFDiv, dest_type,
867 get_fvec_constant(ctx, bit_size, num_components, 1),
868 src[0]);
869 break;
870
871 case nir_op_f2b1:
872 assert(nir_op_infos[alu->op].num_inputs == 1);
873 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
874 get_fvec_constant(ctx,
875 nir_src_bit_size(alu->src[0].src),
876 num_components, 0));
877 break;
878
879
880 #define BINOP(nir_op, spirv_op) \
881 case nir_op: \
882 assert(nir_op_infos[alu->op].num_inputs == 2); \
883 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
884 break;
885
886 BINOP(nir_op_iadd, SpvOpIAdd)
887 BINOP(nir_op_isub, SpvOpISub)
888 BINOP(nir_op_imul, SpvOpIMul)
889 BINOP(nir_op_idiv, SpvOpSDiv)
890 BINOP(nir_op_udiv, SpvOpUDiv)
891 BINOP(nir_op_fadd, SpvOpFAdd)
892 BINOP(nir_op_fsub, SpvOpFSub)
893 BINOP(nir_op_fmul, SpvOpFMul)
894 BINOP(nir_op_fdiv, SpvOpFDiv)
895 BINOP(nir_op_fmod, SpvOpFMod)
896 BINOP(nir_op_ilt, SpvOpSLessThan)
897 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
898 BINOP(nir_op_ieq, SpvOpIEqual)
899 BINOP(nir_op_ine, SpvOpINotEqual)
900 BINOP(nir_op_flt, SpvOpFOrdLessThan)
901 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
902 BINOP(nir_op_feq, SpvOpFOrdEqual)
903 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
904 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
905 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
906 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
907 BINOP(nir_op_iand, SpvOpBitwiseAnd)
908 BINOP(nir_op_ior, SpvOpBitwiseOr)
909 #undef BINOP
910
911 #define BUILTIN_BINOP(nir_op, spirv_op) \
912 case nir_op: \
913 assert(nir_op_infos[alu->op].num_inputs == 2); \
914 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
915 break;
916
917 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
918 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
919 #undef BUILTIN_BINOP
920
921 case nir_op_fdot2:
922 case nir_op_fdot3:
923 case nir_op_fdot4:
924 assert(nir_op_infos[alu->op].num_inputs == 2);
925 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
926 break;
927
928 case nir_op_seq:
929 case nir_op_sne:
930 case nir_op_slt:
931 case nir_op_sge: {
932 assert(nir_op_infos[alu->op].num_inputs == 2);
933 int num_components = nir_dest_num_components(alu->dest.dest);
934 SpvId bool_type = get_bvec_type(ctx, num_components);
935
936 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
937 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
938 if (num_components > 1) {
939 SpvId zero_comps[num_components], one_comps[num_components];
940 for (int i = 0; i < num_components; i++) {
941 zero_comps[i] = zero;
942 one_comps[i] = one;
943 }
944
945 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
946 zero_comps, num_components);
947 one = spirv_builder_const_composite(&ctx->builder, dest_type,
948 one_comps, num_components);
949 }
950
951 SpvOp op;
952 switch (alu->op) {
953 case nir_op_seq: op = SpvOpFOrdEqual; break;
954 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
955 case nir_op_slt: op = SpvOpFOrdLessThan; break;
956 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
957 default: unreachable("unexpected op");
958 }
959
960 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
961 result = emit_select(ctx, dest_type, result, one, zero);
962 }
963 break;
964
965 case nir_op_fcsel:
966 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
967 get_bvec_type(ctx, num_components),
968 src[0],
969 get_fvec_constant(ctx,
970 nir_src_bit_size(alu->src[0].src),
971 num_components, 0));
972 result = emit_select(ctx, dest_type, result, src[1], src[2]);
973 break;
974
975 case nir_op_bcsel:
976 assert(nir_op_infos[alu->op].num_inputs == 3);
977 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
978 break;
979
980 case nir_op_vec2:
981 case nir_op_vec3:
982 case nir_op_vec4: {
983 int num_inputs = nir_op_infos[alu->op].num_inputs;
984 assert(2 <= num_inputs && num_inputs <= 4);
985 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
986 src, num_inputs);
987 }
988 break;
989
990 default:
991 fprintf(stderr, "emit_alu: not implemented (%s)\n",
992 nir_op_infos[alu->op].name);
993
994 unreachable("unsupported opcode");
995 return;
996 }
997
998 store_alu_result(ctx, alu, result);
999 }
1000
1001 static void
1002 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1003 {
1004 uint32_t values[NIR_MAX_VEC_COMPONENTS];
1005 for (int i = 0; i < load_const->def.num_components; ++i)
1006 values[i] = load_const->value[i].u32;
1007
1008 unsigned bit_size = load_const->def.bit_size;
1009 unsigned num_components = load_const->def.num_components;
1010
1011 SpvId constant;
1012 if (num_components > 1) {
1013 SpvId components[num_components];
1014 for (int i = 0; i < num_components; i++)
1015 components[i] = emit_uint_const(ctx, bit_size, values[i]);
1016
1017 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1018 constant = spirv_builder_const_composite(&ctx->builder, type,
1019 components, num_components);
1020 } else {
1021 assert(num_components == 1);
1022 constant = emit_uint_const(ctx, bit_size, values[0]);
1023 }
1024
1025 store_ssa_def_uint(ctx, &load_const->def, constant);
1026 }
1027
1028 static void
1029 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1030 {
1031 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1032 assert(const_block_index); // no dynamic indexing for now
1033 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1034
1035 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1036 if (const_offset) {
1037 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1038 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1039 SpvStorageClassUniform,
1040 uvec4_type);
1041
1042 unsigned idx = const_offset->u32;
1043 SpvId member = emit_uint_const(ctx, 32, 0);
1044 SpvId offset = emit_uint_const(ctx, 32, idx);
1045 SpvId offsets[] = { member, offset };
1046 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1047 ctx->ubos[0], offsets,
1048 ARRAY_SIZE(offsets));
1049 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1050
1051 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1052 unsigned num_components = nir_dest_num_components(intr->dest);
1053 if (num_components == 1) {
1054 uint32_t components[] = { 0 };
1055 result = spirv_builder_emit_composite_extract(&ctx->builder,
1056 type,
1057 result, components,
1058 1);
1059 } else if (num_components < 4) {
1060 SpvId constituents[num_components];
1061 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1062 for (uint32_t i = 0; i < num_components; ++i)
1063 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1064 uint_type,
1065 result, &i,
1066 1);
1067
1068 result = spirv_builder_emit_composite_construct(&ctx->builder,
1069 type,
1070 constituents,
1071 num_components);
1072 }
1073
1074 store_dest_uint(ctx, &intr->dest, result);
1075 } else
1076 unreachable("uniform-addressing not yet supported");
1077 }
1078
1079 static void
1080 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1081 {
1082 assert(ctx->block_started);
1083 spirv_builder_emit_kill(&ctx->builder);
1084 /* discard is weird in NIR, so let's just create an unreachable block after
1085 it and hope that the vulkan driver will DCE any instructinos in it. */
1086 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1087 }
1088
1089 static void
1090 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1091 {
1092 /* uint is a bit of a lie here; it's really just a pointer */
1093 SpvId ptr = get_src_uint(ctx, intr->src);
1094
1095 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1096 SpvId result = spirv_builder_emit_load(&ctx->builder,
1097 get_glsl_type(ctx, var->type),
1098 ptr);
1099 unsigned num_components = nir_dest_num_components(intr->dest);
1100 unsigned bit_size = nir_dest_bit_size(intr->dest);
1101 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1102 store_dest_uint(ctx, &intr->dest, result);
1103 }
1104
1105 static void
1106 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1107 {
1108 /* uint is a bit of a lie here; it's really just a pointer */
1109 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1110 SpvId src = get_src_uint(ctx, &intr->src[1]);
1111
1112 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1113 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1114 SpvId result = emit_bitcast(ctx, type, src);
1115 spirv_builder_emit_store(&ctx->builder, ptr, result);
1116 }
1117
1118 static void
1119 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1120 {
1121 switch (intr->intrinsic) {
1122 case nir_intrinsic_load_ubo:
1123 emit_load_ubo(ctx, intr);
1124 break;
1125
1126 case nir_intrinsic_discard:
1127 emit_discard(ctx, intr);
1128 break;
1129
1130 case nir_intrinsic_load_deref:
1131 emit_load_deref(ctx, intr);
1132 break;
1133
1134 case nir_intrinsic_store_deref:
1135 emit_store_deref(ctx, intr);
1136 break;
1137
1138 default:
1139 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1140 nir_intrinsic_infos[intr->intrinsic].name);
1141 unreachable("unsupported intrinsic");
1142 }
1143 }
1144
1145 static void
1146 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1147 {
1148 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1149 undef->def.num_components);
1150
1151 store_ssa_def_uint(ctx, &undef->def,
1152 spirv_builder_emit_undef(&ctx->builder, type));
1153 }
1154
1155 static SpvId
1156 get_src_float(struct ntv_context *ctx, nir_src *src)
1157 {
1158 SpvId def = get_src_uint(ctx, src);
1159 unsigned num_components = nir_src_num_components(*src);
1160 unsigned bit_size = nir_src_bit_size(*src);
1161 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1162 }
1163
1164 static void
1165 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1166 {
1167 assert(tex->op == nir_texop_tex ||
1168 tex->op == nir_texop_txb ||
1169 tex->op == nir_texop_txl);
1170 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1171 assert(tex->texture_index == tex->sampler_index);
1172
1173 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1174 unsigned coord_components;
1175 for (unsigned i = 0; i < tex->num_srcs; i++) {
1176 switch (tex->src[i].src_type) {
1177 case nir_tex_src_coord:
1178 coord = get_src_float(ctx, &tex->src[i].src);
1179 coord_components = nir_src_num_components(tex->src[i].src);
1180 break;
1181
1182 case nir_tex_src_projector:
1183 assert(nir_src_num_components(tex->src[i].src) == 1);
1184 proj = get_src_float(ctx, &tex->src[i].src);
1185 assert(proj != 0);
1186 break;
1187
1188 case nir_tex_src_bias:
1189 assert(tex->op == nir_texop_txb);
1190 bias = get_src_float(ctx, &tex->src[i].src);
1191 assert(bias != 0);
1192 break;
1193
1194 case nir_tex_src_lod:
1195 assert(nir_src_num_components(tex->src[i].src) == 1);
1196 lod = get_src_float(ctx, &tex->src[i].src);
1197 assert(lod != 0);
1198 break;
1199
1200 case nir_tex_src_comparator:
1201 assert(nir_src_num_components(tex->src[i].src) == 1);
1202 dref = get_src_float(ctx, &tex->src[i].src);
1203 assert(dref != 0);
1204 break;
1205
1206 default:
1207 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1208 unreachable("unknown texture source");
1209 }
1210 }
1211
1212 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1213 lod = emit_float_const(ctx, 32, 0.0f);
1214 assert(lod != 0);
1215 }
1216
1217 bool is_ms;
1218 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1219 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1220 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1221 dimension, false, tex->is_array, is_ms, 1,
1222 SpvImageFormatUnknown);
1223 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1224 image_type);
1225
1226 assert(tex->texture_index < ctx->num_samplers);
1227 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1228 ctx->samplers[tex->texture_index]);
1229
1230 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1231
1232 if (proj) {
1233 SpvId constituents[coord_components + 1];
1234 if (coord_components == 1)
1235 constituents[0] = coord;
1236 else {
1237 assert(coord_components > 1);
1238 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1239 for (uint32_t i = 0; i < coord_components; ++i)
1240 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1241 float_type,
1242 coord,
1243 &i, 1);
1244 }
1245
1246 constituents[coord_components++] = proj;
1247
1248 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1249 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1250 vec_type,
1251 constituents,
1252 coord_components);
1253 }
1254
1255 SpvId actual_dest_type = dest_type;
1256 if (dref)
1257 actual_dest_type = float_type;
1258
1259 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1260 actual_dest_type, load,
1261 coord,
1262 proj != 0,
1263 lod, bias, dref);
1264 spirv_builder_emit_decoration(&ctx->builder, result,
1265 SpvDecorationRelaxedPrecision);
1266
1267 if (dref) {
1268 SpvId components[4] = { result, result, result, result };
1269 result = spirv_builder_emit_composite_construct(&ctx->builder,
1270 dest_type,
1271 components,
1272 4);
1273 }
1274
1275 store_dest(ctx, &tex->dest, result, tex->dest_type);
1276 }
1277
1278 static void
1279 start_block(struct ntv_context *ctx, SpvId label)
1280 {
1281 /* terminate previous block if needed */
1282 if (ctx->block_started)
1283 spirv_builder_emit_branch(&ctx->builder, label);
1284
1285 /* start new block */
1286 spirv_builder_label(&ctx->builder, label);
1287 ctx->block_started = true;
1288 }
1289
1290 static void
1291 branch(struct ntv_context *ctx, SpvId label)
1292 {
1293 assert(ctx->block_started);
1294 spirv_builder_emit_branch(&ctx->builder, label);
1295 ctx->block_started = false;
1296 }
1297
1298 static void
1299 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1300 SpvId else_id)
1301 {
1302 assert(ctx->block_started);
1303 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1304 then_id, else_id);
1305 ctx->block_started = false;
1306 }
1307
1308 static void
1309 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1310 {
1311 switch (jump->type) {
1312 case nir_jump_break:
1313 assert(ctx->loop_break);
1314 branch(ctx, ctx->loop_break);
1315 break;
1316
1317 case nir_jump_continue:
1318 assert(ctx->loop_cont);
1319 branch(ctx, ctx->loop_cont);
1320 break;
1321
1322 default:
1323 unreachable("Unsupported jump type\n");
1324 }
1325 }
1326
1327 static void
1328 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1329 {
1330 assert(deref->deref_type == nir_deref_type_var);
1331
1332 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1333 assert(he);
1334 SpvId result = (SpvId)(intptr_t)he->data;
1335 /* uint is a bit of a lie here, it's really just an opaque type */
1336 store_dest_uint(ctx, &deref->dest, result);
1337 }
1338
1339 static void
1340 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1341 {
1342 assert(deref->deref_type == nir_deref_type_array);
1343 nir_variable *var = nir_deref_instr_get_variable(deref);
1344
1345 SpvStorageClass storage_class;
1346 switch (var->data.mode) {
1347 case nir_var_shader_in:
1348 storage_class = SpvStorageClassInput;
1349 break;
1350
1351 case nir_var_shader_out:
1352 storage_class = SpvStorageClassOutput;
1353 break;
1354
1355 default:
1356 unreachable("Unsupported nir_variable_mode\n");
1357 }
1358
1359 SpvId index = get_src_uint(ctx, &deref->arr.index);
1360
1361 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1362 storage_class,
1363 get_glsl_type(ctx, deref->type));
1364
1365 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1366 ptr_type,
1367 get_src_uint(ctx, &deref->parent),
1368 &index, 1);
1369 /* uint is a bit of a lie here, it's really just an opaque type */
1370 store_dest_uint(ctx, &deref->dest, result);
1371 }
1372
1373 static void
1374 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1375 {
1376 switch (deref->deref_type) {
1377 case nir_deref_type_var:
1378 emit_deref_var(ctx, deref);
1379 break;
1380
1381 case nir_deref_type_array:
1382 emit_deref_array(ctx, deref);
1383 break;
1384
1385 default:
1386 unreachable("unexpected deref_type");
1387 }
1388 }
1389
1390 static void
1391 emit_block(struct ntv_context *ctx, struct nir_block *block)
1392 {
1393 start_block(ctx, block_label(ctx, block));
1394 nir_foreach_instr(instr, block) {
1395 switch (instr->type) {
1396 case nir_instr_type_alu:
1397 emit_alu(ctx, nir_instr_as_alu(instr));
1398 break;
1399 case nir_instr_type_intrinsic:
1400 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1401 break;
1402 case nir_instr_type_load_const:
1403 emit_load_const(ctx, nir_instr_as_load_const(instr));
1404 break;
1405 case nir_instr_type_ssa_undef:
1406 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1407 break;
1408 case nir_instr_type_tex:
1409 emit_tex(ctx, nir_instr_as_tex(instr));
1410 break;
1411 case nir_instr_type_phi:
1412 unreachable("nir_instr_type_phi not supported");
1413 break;
1414 case nir_instr_type_jump:
1415 emit_jump(ctx, nir_instr_as_jump(instr));
1416 break;
1417 case nir_instr_type_call:
1418 unreachable("nir_instr_type_call not supported");
1419 break;
1420 case nir_instr_type_parallel_copy:
1421 unreachable("nir_instr_type_parallel_copy not supported");
1422 break;
1423 case nir_instr_type_deref:
1424 emit_deref(ctx, nir_instr_as_deref(instr));
1425 break;
1426 }
1427 }
1428 }
1429
1430 static void
1431 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1432
1433 static SpvId
1434 get_src_bool(struct ntv_context *ctx, nir_src *src)
1435 {
1436 SpvId def = get_src_uint(ctx, src);
1437 assert(nir_src_bit_size(*src) == 1);
1438 unsigned num_components = nir_src_num_components(*src);
1439 return uvec_to_bvec(ctx, def, num_components);
1440 }
1441
1442 static void
1443 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1444 {
1445 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1446
1447 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1448 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1449 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1450 SpvId else_id = endif_id;
1451
1452 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1453 if (has_else) {
1454 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1455 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1456 }
1457
1458 /* create a header-block */
1459 start_block(ctx, header_id);
1460 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1461 SpvSelectionControlMaskNone);
1462 branch_conditional(ctx, condition, then_id, else_id);
1463
1464 emit_cf_list(ctx, &if_stmt->then_list);
1465
1466 if (has_else) {
1467 if (ctx->block_started)
1468 branch(ctx, endif_id);
1469
1470 emit_cf_list(ctx, &if_stmt->else_list);
1471 }
1472
1473 start_block(ctx, endif_id);
1474 }
1475
1476 static void
1477 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1478 {
1479 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1480 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1481 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1482 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1483
1484 /* create a header-block */
1485 start_block(ctx, header_id);
1486 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1487 branch(ctx, begin_id);
1488
1489 SpvId save_break = ctx->loop_break;
1490 SpvId save_cont = ctx->loop_cont;
1491 ctx->loop_break = break_id;
1492 ctx->loop_cont = cont_id;
1493
1494 emit_cf_list(ctx, &loop->body);
1495
1496 ctx->loop_break = save_break;
1497 ctx->loop_cont = save_cont;
1498
1499 branch(ctx, cont_id);
1500 start_block(ctx, cont_id);
1501 branch(ctx, header_id);
1502
1503 start_block(ctx, break_id);
1504 }
1505
1506 static void
1507 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1508 {
1509 foreach_list_typed(nir_cf_node, node, node, list) {
1510 switch (node->type) {
1511 case nir_cf_node_block:
1512 emit_block(ctx, nir_cf_node_as_block(node));
1513 break;
1514
1515 case nir_cf_node_if:
1516 emit_if(ctx, nir_cf_node_as_if(node));
1517 break;
1518
1519 case nir_cf_node_loop:
1520 emit_loop(ctx, nir_cf_node_as_loop(node));
1521 break;
1522
1523 case nir_cf_node_function:
1524 unreachable("nir_cf_node_function not supported");
1525 break;
1526 }
1527 }
1528 }
1529
1530 struct spirv_shader *
1531 nir_to_spirv(struct nir_shader *s)
1532 {
1533 struct spirv_shader *ret = NULL;
1534
1535 struct ntv_context ctx = {};
1536
1537 switch (s->info.stage) {
1538 case MESA_SHADER_VERTEX:
1539 case MESA_SHADER_FRAGMENT:
1540 case MESA_SHADER_COMPUTE:
1541 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1542 break;
1543
1544 case MESA_SHADER_TESS_CTRL:
1545 case MESA_SHADER_TESS_EVAL:
1546 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1547 break;
1548
1549 case MESA_SHADER_GEOMETRY:
1550 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1551 break;
1552
1553 default:
1554 unreachable("invalid stage");
1555 }
1556
1557 // TODO: only enable when needed
1558 if (s->info.stage == MESA_SHADER_FRAGMENT)
1559 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1560
1561 ctx.stage = s->info.stage;
1562 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1563 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1564
1565 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1566 SpvMemoryModelGLSL450);
1567
1568 SpvExecutionModel exec_model;
1569 switch (s->info.stage) {
1570 case MESA_SHADER_VERTEX:
1571 exec_model = SpvExecutionModelVertex;
1572 break;
1573 case MESA_SHADER_TESS_CTRL:
1574 exec_model = SpvExecutionModelTessellationControl;
1575 break;
1576 case MESA_SHADER_TESS_EVAL:
1577 exec_model = SpvExecutionModelTessellationEvaluation;
1578 break;
1579 case MESA_SHADER_GEOMETRY:
1580 exec_model = SpvExecutionModelGeometry;
1581 break;
1582 case MESA_SHADER_FRAGMENT:
1583 exec_model = SpvExecutionModelFragment;
1584 break;
1585 case MESA_SHADER_COMPUTE:
1586 exec_model = SpvExecutionModelGLCompute;
1587 break;
1588 default:
1589 unreachable("invalid stage");
1590 }
1591
1592 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1593 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1594 NULL, 0);
1595 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1596 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1597
1598 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1599 _mesa_key_pointer_equal);
1600
1601 nir_foreach_variable(var, &s->inputs)
1602 emit_input(&ctx, var);
1603
1604 nir_foreach_variable(var, &s->outputs)
1605 emit_output(&ctx, var);
1606
1607 nir_foreach_variable(var, &s->uniforms)
1608 emit_uniform(&ctx, var);
1609
1610 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1611 "main", ctx.entry_ifaces,
1612 ctx.num_entry_ifaces);
1613 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1614 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1615 SpvExecutionModeOriginUpperLeft);
1616 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1617 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1618 SpvExecutionModeDepthReplacing);
1619 }
1620
1621
1622 spirv_builder_function(&ctx.builder, entry_point, type_void,
1623 SpvFunctionControlMaskNone,
1624 type_main);
1625
1626 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1627 nir_metadata_require(entry, nir_metadata_block_index);
1628
1629 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1630 if (!ctx.defs)
1631 goto fail;
1632 ctx.num_defs = entry->ssa_alloc;
1633
1634 nir_index_local_regs(entry);
1635 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1636 if (!ctx.regs)
1637 goto fail;
1638 ctx.num_regs = entry->reg_alloc;
1639
1640 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1641 if (!block_ids)
1642 goto fail;
1643
1644 for (int i = 0; i < entry->num_blocks; ++i)
1645 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1646
1647 ctx.block_ids = block_ids;
1648 ctx.num_blocks = entry->num_blocks;
1649
1650 /* emit a block only for the variable declarations */
1651 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1652 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1653 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1654 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1655 SpvStorageClassFunction,
1656 type);
1657 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1658 SpvStorageClassFunction);
1659
1660 ctx.regs[reg->index] = var;
1661 }
1662
1663 emit_cf_list(&ctx, &entry->body);
1664
1665 free(ctx.defs);
1666
1667 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1668 spirv_builder_function_end(&ctx.builder);
1669
1670 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1671
1672 ret = CALLOC_STRUCT(spirv_shader);
1673 if (!ret)
1674 goto fail;
1675
1676 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1677 if (!ret->words)
1678 goto fail;
1679
1680 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1681 assert(ret->num_words == num_words);
1682
1683 return ret;
1684
1685 fail:
1686
1687 if (ret)
1688 spirv_shader_delete(ret);
1689
1690 if (ctx.vars)
1691 _mesa_hash_table_destroy(ctx.vars, NULL);
1692
1693 return NULL;
1694 }
1695
1696 void
1697 spirv_shader_delete(struct spirv_shader *s)
1698 {
1699 FREE(s->words);
1700 FREE(s);
1701 }