zink/spirv: fixup b2i32
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59 };
60
61 static SpvId
62 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
63 unsigned num_components, float value);
64
65 static SpvId
66 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
67 unsigned num_components, uint32_t value);
68
69 static SpvId
70 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
71 unsigned num_components, int32_t value);
72
73 static SpvId
74 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
75
76 static SpvId
77 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
78 SpvId src0, SpvId src1);
79
80 static SpvId
81 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
82 SpvId src0, SpvId src1, SpvId src2);
83
84 static SpvId
85 get_bvec_type(struct ntv_context *ctx, int num_components)
86 {
87 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
88 if (num_components > 1)
89 return spirv_builder_type_vector(&ctx->builder, bool_type,
90 num_components);
91
92 assert(num_components == 1);
93 return bool_type;
94 }
95
96 static SpvId
97 block_label(struct ntv_context *ctx, nir_block *block)
98 {
99 assert(block->index < ctx->num_blocks);
100 return ctx->block_ids[block->index];
101 }
102
103 static SpvId
104 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
105 {
106 assert(bit_size == 32);
107 return spirv_builder_const_float(&ctx->builder, bit_size, value);
108 }
109
110 static SpvId
111 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
112 {
113 assert(bit_size == 32);
114 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
115 }
116
117 static SpvId
118 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
119 {
120 assert(bit_size == 32);
121 return spirv_builder_const_int(&ctx->builder, bit_size, value);
122 }
123
124 static SpvId
125 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
126 {
127 assert(bit_size == 32); // only 32-bit floats supported so far
128
129 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
130 if (num_components > 1)
131 return spirv_builder_type_vector(&ctx->builder, float_type,
132 num_components);
133
134 assert(num_components == 1);
135 return float_type;
136 }
137
138 static SpvId
139 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
140 {
141 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
142
143 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
144 if (num_components > 1)
145 return spirv_builder_type_vector(&ctx->builder, int_type,
146 num_components);
147
148 assert(num_components == 1);
149 return int_type;
150 }
151
152 static SpvId
153 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
154 {
155 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
156
157 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
158 if (num_components > 1)
159 return spirv_builder_type_vector(&ctx->builder, uint_type,
160 num_components);
161
162 assert(num_components == 1);
163 return uint_type;
164 }
165
166 static SpvId
167 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
168 {
169 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
170 nir_dest_num_components(*dest));
171 }
172
173 static SpvId
174 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
175 {
176 switch (type) {
177 case GLSL_TYPE_FLOAT:
178 return spirv_builder_type_float(&ctx->builder, 32);
179
180 case GLSL_TYPE_INT:
181 return spirv_builder_type_int(&ctx->builder, 32);
182
183 case GLSL_TYPE_UINT:
184 return spirv_builder_type_uint(&ctx->builder, 32);
185 /* TODO: handle more types */
186
187 default:
188 unreachable("unknown GLSL type");
189 }
190 }
191
192 static SpvId
193 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
194 {
195 assert(type);
196 if (glsl_type_is_scalar(type))
197 return get_glsl_basetype(ctx, glsl_get_base_type(type));
198
199 if (glsl_type_is_vector(type))
200 return spirv_builder_type_vector(&ctx->builder,
201 get_glsl_basetype(ctx, glsl_get_base_type(type)),
202 glsl_get_vector_elements(type));
203
204 if (glsl_type_is_array(type)) {
205 SpvId ret = spirv_builder_type_array(&ctx->builder,
206 get_glsl_type(ctx, glsl_get_array_element(type)),
207 emit_uint_const(ctx, 32, glsl_get_length(type)));
208 uint32_t stride = glsl_get_explicit_stride(type);
209 if (stride)
210 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
211 return ret;
212 }
213
214
215 unreachable("we shouldn't get here, I think...");
216 }
217
218 static void
219 emit_input(struct ntv_context *ctx, struct nir_variable *var)
220 {
221 SpvId var_type = get_glsl_type(ctx, var->type);
222 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
223 SpvStorageClassInput,
224 var_type);
225 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
226 SpvStorageClassInput);
227
228 if (var->name)
229 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
230
231 if (ctx->stage == MESA_SHADER_FRAGMENT) {
232 if (var->data.location >= VARYING_SLOT_VAR0 ||
233 (var->data.location >= VARYING_SLOT_COL0 &&
234 var->data.location <= VARYING_SLOT_TEX7)) {
235 spirv_builder_emit_location(&ctx->builder, var_id,
236 ctx->var_location++);
237 } else {
238 switch (var->data.location) {
239 case VARYING_SLOT_POS:
240 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
241 break;
242
243 case VARYING_SLOT_PNTC:
244 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
245 break;
246
247 default:
248 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
249 unreachable("unexpected varying slot");
250 }
251 }
252 } else {
253 spirv_builder_emit_location(&ctx->builder, var_id,
254 var->data.driver_location);
255 }
256
257 if (var->data.location_frac)
258 spirv_builder_emit_component(&ctx->builder, var_id,
259 var->data.location_frac);
260
261 if (var->data.interpolation == INTERP_MODE_FLAT)
262 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
263
264 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
265
266 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
267 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
268 }
269
270 static void
271 emit_output(struct ntv_context *ctx, struct nir_variable *var)
272 {
273 SpvId var_type = get_glsl_type(ctx, var->type);
274 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
275 SpvStorageClassOutput,
276 var_type);
277 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
278 SpvStorageClassOutput);
279 if (var->name)
280 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
281
282
283 if (ctx->stage == MESA_SHADER_VERTEX) {
284 if (var->data.location >= VARYING_SLOT_VAR0 ||
285 (var->data.location >= VARYING_SLOT_COL0 &&
286 var->data.location <= VARYING_SLOT_TEX7)) {
287 spirv_builder_emit_location(&ctx->builder, var_id,
288 ctx->var_location++);
289 } else {
290 switch (var->data.location) {
291 case VARYING_SLOT_POS:
292 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
293 break;
294
295 case VARYING_SLOT_PSIZ:
296 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
297 break;
298
299 case VARYING_SLOT_CLIP_DIST0:
300 assert(glsl_type_is_array(var->type));
301 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
302 break;
303
304 default:
305 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
306 unreachable("unexpected varying slot");
307 }
308 }
309 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
310 if (var->data.location >= FRAG_RESULT_DATA0)
311 spirv_builder_emit_location(&ctx->builder, var_id,
312 var->data.location - FRAG_RESULT_DATA0);
313 else {
314 switch (var->data.location) {
315 case FRAG_RESULT_COLOR:
316 spirv_builder_emit_location(&ctx->builder, var_id, 0);
317 break;
318
319 case FRAG_RESULT_DEPTH:
320 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
321 break;
322
323 default:
324 spirv_builder_emit_location(&ctx->builder, var_id,
325 var->data.driver_location);
326 }
327 }
328 }
329
330 if (var->data.location_frac)
331 spirv_builder_emit_component(&ctx->builder, var_id,
332 var->data.location_frac);
333
334 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
335
336 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
337 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
338 }
339
340 static SpvDim
341 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
342 {
343 *is_ms = false;
344 switch (gdim) {
345 case GLSL_SAMPLER_DIM_1D:
346 return SpvDim1D;
347 case GLSL_SAMPLER_DIM_2D:
348 return SpvDim2D;
349 case GLSL_SAMPLER_DIM_RECT:
350 return SpvDimRect;
351 case GLSL_SAMPLER_DIM_CUBE:
352 return SpvDimCube;
353 case GLSL_SAMPLER_DIM_3D:
354 return SpvDim3D;
355 case GLSL_SAMPLER_DIM_MS:
356 *is_ms = true;
357 return SpvDim2D;
358 default:
359 fprintf(stderr, "unknown sampler type %d\n", gdim);
360 break;
361 }
362 return SpvDim2D;
363 }
364
365 static void
366 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
367 {
368 bool is_ms;
369 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
370 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
371 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
372 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
373 SpvImageFormatUnknown);
374
375 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
376 image_type);
377 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
378 SpvStorageClassUniformConstant,
379 sampled_type);
380 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
381 SpvStorageClassUniformConstant);
382
383 if (var->name)
384 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
385
386 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
387 ctx->samplers[ctx->num_samplers++] = var_id;
388
389 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
390 var->data.descriptor_set);
391 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
392 }
393
394 static void
395 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
396 {
397 uint32_t size = glsl_count_attribute_slots(var->type, false);
398 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
399 SpvId array_length = emit_uint_const(ctx, 32, size);
400 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
401 array_length);
402 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
403
404 // wrap UBO-array in a struct
405 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
406 if (var->name) {
407 char struct_name[100];
408 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
409 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
410 }
411
412 spirv_builder_emit_decoration(&ctx->builder, struct_type,
413 SpvDecorationBlock);
414 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
415
416
417 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
418 SpvStorageClassUniform,
419 struct_type);
420
421 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
422 SpvStorageClassUniform);
423 if (var->name)
424 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
425
426 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
427 ctx->ubos[ctx->num_ubos++] = var_id;
428
429 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
430 var->data.descriptor_set);
431 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
432 }
433
434 static void
435 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
436 {
437 if (var->data.mode == nir_var_mem_ubo)
438 emit_ubo(ctx, var);
439 else {
440 assert(var->data.mode == nir_var_uniform);
441 if (glsl_type_is_sampler(var->type))
442 emit_sampler(ctx, var);
443 }
444 }
445
446 static SpvId
447 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
448 {
449 assert(ssa->index < ctx->num_defs);
450 assert(ctx->defs[ssa->index] != 0);
451 return ctx->defs[ssa->index];
452 }
453
454 static SpvId
455 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
456 {
457 assert(reg->index < ctx->num_regs);
458 assert(ctx->regs[reg->index] != 0);
459 return ctx->regs[reg->index];
460 }
461
462 static SpvId
463 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
464 {
465 assert(reg->reg);
466 assert(!reg->indirect);
467 assert(!reg->base_offset);
468
469 SpvId var = get_var_from_reg(ctx, reg->reg);
470 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
471 return spirv_builder_emit_load(&ctx->builder, type, var);
472 }
473
474 static SpvId
475 get_src_uint(struct ntv_context *ctx, nir_src *src)
476 {
477 if (src->is_ssa)
478 return get_src_uint_ssa(ctx, src->ssa);
479 else
480 return get_src_uint_reg(ctx, &src->reg);
481 }
482
483 static SpvId
484 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
485 {
486 assert(!alu->src[src].negate);
487 assert(!alu->src[src].abs);
488
489 SpvId def = get_src_uint(ctx, &alu->src[src].src);
490
491 unsigned used_channels = 0;
492 bool need_swizzle = false;
493 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
494 if (!nir_alu_instr_channel_used(alu, src, i))
495 continue;
496
497 used_channels++;
498
499 if (alu->src[src].swizzle[i] != i)
500 need_swizzle = true;
501 }
502 assert(used_channels != 0);
503
504 unsigned live_channels = nir_src_num_components(alu->src[src].src);
505 if (used_channels != live_channels)
506 need_swizzle = true;
507
508 if (!need_swizzle)
509 return def;
510
511 int bit_size = nir_src_bit_size(alu->src[src].src);
512 assert(bit_size == 1 || bit_size == 32);
513
514 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
515 if (used_channels == 1) {
516 uint32_t indices[] = { alu->src[src].swizzle[0] };
517 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
518 def, indices,
519 ARRAY_SIZE(indices));
520 } else if (live_channels == 1) {
521 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
522 used_channels);
523
524 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
525 for (unsigned i = 0; i < used_channels; ++i)
526 constituents[i] = def;
527
528 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
529 constituents,
530 used_channels);
531 } else {
532 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
533 used_channels);
534
535 uint32_t components[NIR_MAX_VEC_COMPONENTS];
536 size_t num_components = 0;
537 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
538 if (!nir_alu_instr_channel_used(alu, src, i))
539 continue;
540
541 components[num_components++] = alu->src[src].swizzle[i];
542 }
543
544 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
545 def, def, components, num_components);
546 }
547 }
548
549 static void
550 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
551 {
552 assert(result != 0);
553 assert(ssa->index < ctx->num_defs);
554 ctx->defs[ssa->index] = result;
555 }
556
557 static SpvId
558 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
559 SpvId if_true, SpvId if_false)
560 {
561 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
562 }
563
564 static SpvId
565 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
566 {
567 SpvId otype = get_uvec_type(ctx, 32, num_components);
568 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
569 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
570 return emit_select(ctx, otype, value, one, zero);
571 }
572
573 static SpvId
574 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
575 {
576 SpvId type = get_bvec_type(ctx, num_components);
577 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
578 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
579 }
580
581 static SpvId
582 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
583 {
584 return emit_unop(ctx, SpvOpBitcast, type, value);
585 }
586
587 static SpvId
588 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
589 unsigned num_components)
590 {
591 SpvId type = get_uvec_type(ctx, bit_size, num_components);
592 return emit_bitcast(ctx, type, value);
593 }
594
595 static SpvId
596 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
597 unsigned num_components)
598 {
599 SpvId type = get_ivec_type(ctx, bit_size, num_components);
600 return emit_bitcast(ctx, type, value);
601 }
602
603 static SpvId
604 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
605 unsigned num_components)
606 {
607 SpvId type = get_fvec_type(ctx, bit_size, num_components);
608 return emit_bitcast(ctx, type, value);
609 }
610
611 static void
612 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
613 {
614 SpvId var = get_var_from_reg(ctx, reg->reg);
615 assert(var);
616 spirv_builder_emit_store(&ctx->builder, var, result);
617 }
618
619 static void
620 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
621 {
622 if (dest->is_ssa)
623 store_ssa_def_uint(ctx, &dest->ssa, result);
624 else
625 store_reg_def(ctx, &dest->reg, result);
626 }
627
628 static void
629 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
630 {
631 unsigned num_components = nir_dest_num_components(*dest);
632 unsigned bit_size = nir_dest_bit_size(*dest);
633
634 switch (nir_alu_type_get_base_type(type)) {
635 case nir_type_bool:
636 assert(bit_size == 1);
637 result = bvec_to_uvec(ctx, result, num_components);
638 break;
639
640 case nir_type_uint:
641 break; /* nothing to do! */
642
643 case nir_type_int:
644 case nir_type_float:
645 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
646 break;
647
648 default:
649 unreachable("unsupported nir_alu_type");
650 }
651
652 store_dest_uint(ctx, dest, result);
653 }
654
655 static SpvId
656 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
657 {
658 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
659 }
660
661 static SpvId
662 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
663 SpvId src0, SpvId src1)
664 {
665 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
666 }
667
668 static SpvId
669 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
670 SpvId src0, SpvId src1, SpvId src2)
671 {
672 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
673 }
674
675 static SpvId
676 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
677 SpvId src)
678 {
679 SpvId args[] = { src };
680 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
681 op, args, ARRAY_SIZE(args));
682 }
683
684 static SpvId
685 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
686 SpvId src0, SpvId src1)
687 {
688 SpvId args[] = { src0, src1 };
689 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
690 op, args, ARRAY_SIZE(args));
691 }
692
693 static SpvId
694 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
695 unsigned num_components, float value)
696 {
697 assert(bit_size == 32);
698
699 SpvId result = emit_float_const(ctx, bit_size, value);
700 if (num_components == 1)
701 return result;
702
703 assert(num_components > 1);
704 SpvId components[num_components];
705 for (int i = 0; i < num_components; i++)
706 components[i] = result;
707
708 SpvId type = get_fvec_type(ctx, bit_size, num_components);
709 return spirv_builder_const_composite(&ctx->builder, type, components,
710 num_components);
711 }
712
713 static SpvId
714 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
715 unsigned num_components, uint32_t value)
716 {
717 assert(bit_size == 32);
718
719 SpvId result = emit_uint_const(ctx, bit_size, value);
720 if (num_components == 1)
721 return result;
722
723 assert(num_components > 1);
724 SpvId components[num_components];
725 for (int i = 0; i < num_components; i++)
726 components[i] = result;
727
728 SpvId type = get_uvec_type(ctx, bit_size, num_components);
729 return spirv_builder_const_composite(&ctx->builder, type, components,
730 num_components);
731 }
732
733 static SpvId
734 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
735 unsigned num_components, int32_t value)
736 {
737 assert(bit_size == 32);
738
739 SpvId result = emit_int_const(ctx, bit_size, value);
740 if (num_components == 1)
741 return result;
742
743 assert(num_components > 1);
744 SpvId components[num_components];
745 for (int i = 0; i < num_components; i++)
746 components[i] = result;
747
748 SpvId type = get_ivec_type(ctx, bit_size, num_components);
749 return spirv_builder_const_composite(&ctx->builder, type, components,
750 num_components);
751 }
752
753 static inline unsigned
754 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
755 {
756 if (nir_op_infos[instr->op].input_sizes[src] > 0)
757 return nir_op_infos[instr->op].input_sizes[src];
758
759 if (instr->dest.dest.is_ssa)
760 return instr->dest.dest.ssa.num_components;
761 else
762 return instr->dest.dest.reg.reg->num_components;
763 }
764
765 static SpvId
766 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
767 {
768 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
769
770 unsigned num_components = alu_instr_src_components(alu, src);
771 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
772 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
773
774 switch (nir_alu_type_get_base_type(type)) {
775 case nir_type_bool:
776 assert(bit_size == 1);
777 return uvec_to_bvec(ctx, uint_value, num_components);
778
779 case nir_type_int:
780 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
781
782 case nir_type_uint:
783 return uint_value;
784
785 case nir_type_float:
786 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
787
788 default:
789 unreachable("unknown nir_alu_type");
790 }
791 }
792
793 static void
794 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
795 {
796 assert(!alu->dest.saturate);
797 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
798 }
799
800 static SpvId
801 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
802 {
803 unsigned num_components = nir_dest_num_components(*dest);
804 unsigned bit_size = nir_dest_bit_size(*dest);
805
806 switch (nir_alu_type_get_base_type(type)) {
807 case nir_type_bool:
808 return get_bvec_type(ctx, num_components);
809
810 case nir_type_int:
811 return get_ivec_type(ctx, bit_size, num_components);
812
813 case nir_type_uint:
814 return get_uvec_type(ctx, bit_size, num_components);
815
816 case nir_type_float:
817 return get_fvec_type(ctx, bit_size, num_components);
818
819 default:
820 unreachable("unsupported nir_alu_type");
821 }
822 }
823
824 static void
825 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
826 {
827 SpvId src[nir_op_infos[alu->op].num_inputs];
828 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
829 src[i] = get_alu_src(ctx, alu, i);
830
831 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
832 nir_op_infos[alu->op].output_type);
833 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
834 unsigned num_components = nir_dest_num_components(alu->dest.dest);
835
836 SpvId result = 0;
837 switch (alu->op) {
838 case nir_op_mov:
839 assert(nir_op_infos[alu->op].num_inputs == 1);
840 result = src[0];
841 break;
842
843 #define UNOP(nir_op, spirv_op) \
844 case nir_op: \
845 assert(nir_op_infos[alu->op].num_inputs == 1); \
846 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
847 break;
848
849 UNOP(nir_op_ineg, SpvOpSNegate)
850 UNOP(nir_op_fneg, SpvOpFNegate)
851 UNOP(nir_op_fddx, SpvOpDPdx)
852 UNOP(nir_op_fddy, SpvOpDPdy)
853 UNOP(nir_op_f2i32, SpvOpConvertFToS)
854 UNOP(nir_op_f2u32, SpvOpConvertFToU)
855 UNOP(nir_op_i2f32, SpvOpConvertSToF)
856 UNOP(nir_op_u2f32, SpvOpConvertUToF)
857 UNOP(nir_op_inot, SpvOpNot)
858 #undef UNOP
859
860 case nir_op_b2i32:
861 assert(nir_op_infos[alu->op].num_inputs == 1);
862 result = emit_select(ctx, dest_type, src[0],
863 get_ivec_constant(ctx, 32, num_components, 1),
864 get_ivec_constant(ctx, 32, num_components, 0));
865 break;
866
867 case nir_op_b2f32:
868 assert(nir_op_infos[alu->op].num_inputs == 1);
869 result = emit_select(ctx, dest_type, src[0],
870 get_fvec_constant(ctx, 32, num_components, 1),
871 get_fvec_constant(ctx, 32, num_components, 0));
872 break;
873
874 #define BUILTIN_UNOP(nir_op, spirv_op) \
875 case nir_op: \
876 assert(nir_op_infos[alu->op].num_inputs == 1); \
877 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
878 break;
879
880 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
881 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
882 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
883 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
884 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
885 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
886 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
887 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
888 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
889 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
890 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
891 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
892 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
893 #undef BUILTIN_UNOP
894
895 case nir_op_frcp:
896 assert(nir_op_infos[alu->op].num_inputs == 1);
897 result = emit_binop(ctx, SpvOpFDiv, dest_type,
898 get_fvec_constant(ctx, bit_size, num_components, 1),
899 src[0]);
900 break;
901
902 case nir_op_f2b1:
903 assert(nir_op_infos[alu->op].num_inputs == 1);
904 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
905 get_fvec_constant(ctx,
906 nir_src_bit_size(alu->src[0].src),
907 num_components, 0));
908 break;
909
910
911 #define BINOP(nir_op, spirv_op) \
912 case nir_op: \
913 assert(nir_op_infos[alu->op].num_inputs == 2); \
914 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
915 break;
916
917 BINOP(nir_op_iadd, SpvOpIAdd)
918 BINOP(nir_op_isub, SpvOpISub)
919 BINOP(nir_op_imul, SpvOpIMul)
920 BINOP(nir_op_idiv, SpvOpSDiv)
921 BINOP(nir_op_udiv, SpvOpUDiv)
922 BINOP(nir_op_fadd, SpvOpFAdd)
923 BINOP(nir_op_fsub, SpvOpFSub)
924 BINOP(nir_op_fmul, SpvOpFMul)
925 BINOP(nir_op_fdiv, SpvOpFDiv)
926 BINOP(nir_op_fmod, SpvOpFMod)
927 BINOP(nir_op_ilt, SpvOpSLessThan)
928 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
929 BINOP(nir_op_ieq, SpvOpIEqual)
930 BINOP(nir_op_ine, SpvOpINotEqual)
931 BINOP(nir_op_flt, SpvOpFOrdLessThan)
932 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
933 BINOP(nir_op_feq, SpvOpFOrdEqual)
934 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
935 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
936 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
937 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
938 BINOP(nir_op_iand, SpvOpBitwiseAnd)
939 BINOP(nir_op_ior, SpvOpBitwiseOr)
940 #undef BINOP
941
942 #define BUILTIN_BINOP(nir_op, spirv_op) \
943 case nir_op: \
944 assert(nir_op_infos[alu->op].num_inputs == 2); \
945 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
946 break;
947
948 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
949 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
950 #undef BUILTIN_BINOP
951
952 case nir_op_fdot2:
953 case nir_op_fdot3:
954 case nir_op_fdot4:
955 assert(nir_op_infos[alu->op].num_inputs == 2);
956 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
957 break;
958
959 case nir_op_seq:
960 case nir_op_sne:
961 case nir_op_slt:
962 case nir_op_sge: {
963 assert(nir_op_infos[alu->op].num_inputs == 2);
964 int num_components = nir_dest_num_components(alu->dest.dest);
965 SpvId bool_type = get_bvec_type(ctx, num_components);
966
967 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
968 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
969 if (num_components > 1) {
970 SpvId zero_comps[num_components], one_comps[num_components];
971 for (int i = 0; i < num_components; i++) {
972 zero_comps[i] = zero;
973 one_comps[i] = one;
974 }
975
976 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
977 zero_comps, num_components);
978 one = spirv_builder_const_composite(&ctx->builder, dest_type,
979 one_comps, num_components);
980 }
981
982 SpvOp op;
983 switch (alu->op) {
984 case nir_op_seq: op = SpvOpFOrdEqual; break;
985 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
986 case nir_op_slt: op = SpvOpFOrdLessThan; break;
987 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
988 default: unreachable("unexpected op");
989 }
990
991 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
992 result = emit_select(ctx, dest_type, result, one, zero);
993 }
994 break;
995
996 case nir_op_fcsel:
997 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
998 get_bvec_type(ctx, num_components),
999 src[0],
1000 get_fvec_constant(ctx,
1001 nir_src_bit_size(alu->src[0].src),
1002 num_components, 0));
1003 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1004 break;
1005
1006 case nir_op_bcsel:
1007 assert(nir_op_infos[alu->op].num_inputs == 3);
1008 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1009 break;
1010
1011 case nir_op_vec2:
1012 case nir_op_vec3:
1013 case nir_op_vec4: {
1014 int num_inputs = nir_op_infos[alu->op].num_inputs;
1015 assert(2 <= num_inputs && num_inputs <= 4);
1016 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1017 src, num_inputs);
1018 }
1019 break;
1020
1021 default:
1022 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1023 nir_op_infos[alu->op].name);
1024
1025 unreachable("unsupported opcode");
1026 return;
1027 }
1028
1029 store_alu_result(ctx, alu, result);
1030 }
1031
1032 static void
1033 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1034 {
1035 uint32_t values[NIR_MAX_VEC_COMPONENTS];
1036 for (int i = 0; i < load_const->def.num_components; ++i)
1037 values[i] = load_const->value[i].u32;
1038
1039 unsigned bit_size = load_const->def.bit_size;
1040 unsigned num_components = load_const->def.num_components;
1041
1042 SpvId constant;
1043 if (num_components > 1) {
1044 SpvId components[num_components];
1045 for (int i = 0; i < num_components; i++)
1046 components[i] = emit_uint_const(ctx, bit_size, values[i]);
1047
1048 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1049 constant = spirv_builder_const_composite(&ctx->builder, type,
1050 components, num_components);
1051 } else {
1052 assert(num_components == 1);
1053 constant = emit_uint_const(ctx, bit_size, values[0]);
1054 }
1055
1056 store_ssa_def_uint(ctx, &load_const->def, constant);
1057 }
1058
1059 static void
1060 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1061 {
1062 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1063 assert(const_block_index); // no dynamic indexing for now
1064 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1065
1066 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1067 if (const_offset) {
1068 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1069 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1070 SpvStorageClassUniform,
1071 uvec4_type);
1072
1073 unsigned idx = const_offset->u32;
1074 SpvId member = emit_uint_const(ctx, 32, 0);
1075 SpvId offset = emit_uint_const(ctx, 32, idx);
1076 SpvId offsets[] = { member, offset };
1077 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1078 ctx->ubos[0], offsets,
1079 ARRAY_SIZE(offsets));
1080 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1081
1082 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1083 unsigned num_components = nir_dest_num_components(intr->dest);
1084 if (num_components == 1) {
1085 uint32_t components[] = { 0 };
1086 result = spirv_builder_emit_composite_extract(&ctx->builder,
1087 type,
1088 result, components,
1089 1);
1090 } else if (num_components < 4) {
1091 SpvId constituents[num_components];
1092 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1093 for (uint32_t i = 0; i < num_components; ++i)
1094 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1095 uint_type,
1096 result, &i,
1097 1);
1098
1099 result = spirv_builder_emit_composite_construct(&ctx->builder,
1100 type,
1101 constituents,
1102 num_components);
1103 }
1104
1105 store_dest_uint(ctx, &intr->dest, result);
1106 } else
1107 unreachable("uniform-addressing not yet supported");
1108 }
1109
1110 static void
1111 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1112 {
1113 assert(ctx->block_started);
1114 spirv_builder_emit_kill(&ctx->builder);
1115 /* discard is weird in NIR, so let's just create an unreachable block after
1116 it and hope that the vulkan driver will DCE any instructinos in it. */
1117 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1118 }
1119
1120 static void
1121 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1122 {
1123 /* uint is a bit of a lie here; it's really just a pointer */
1124 SpvId ptr = get_src_uint(ctx, intr->src);
1125
1126 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1127 SpvId result = spirv_builder_emit_load(&ctx->builder,
1128 get_glsl_type(ctx, var->type),
1129 ptr);
1130 unsigned num_components = nir_dest_num_components(intr->dest);
1131 unsigned bit_size = nir_dest_bit_size(intr->dest);
1132 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1133 store_dest_uint(ctx, &intr->dest, result);
1134 }
1135
1136 static void
1137 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1138 {
1139 /* uint is a bit of a lie here; it's really just a pointer */
1140 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1141 SpvId src = get_src_uint(ctx, &intr->src[1]);
1142
1143 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1144 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1145 SpvId result = emit_bitcast(ctx, type, src);
1146 spirv_builder_emit_store(&ctx->builder, ptr, result);
1147 }
1148
1149 static void
1150 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1151 {
1152 switch (intr->intrinsic) {
1153 case nir_intrinsic_load_ubo:
1154 emit_load_ubo(ctx, intr);
1155 break;
1156
1157 case nir_intrinsic_discard:
1158 emit_discard(ctx, intr);
1159 break;
1160
1161 case nir_intrinsic_load_deref:
1162 emit_load_deref(ctx, intr);
1163 break;
1164
1165 case nir_intrinsic_store_deref:
1166 emit_store_deref(ctx, intr);
1167 break;
1168
1169 default:
1170 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1171 nir_intrinsic_infos[intr->intrinsic].name);
1172 unreachable("unsupported intrinsic");
1173 }
1174 }
1175
1176 static void
1177 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1178 {
1179 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1180 undef->def.num_components);
1181
1182 store_ssa_def_uint(ctx, &undef->def,
1183 spirv_builder_emit_undef(&ctx->builder, type));
1184 }
1185
1186 static SpvId
1187 get_src_float(struct ntv_context *ctx, nir_src *src)
1188 {
1189 SpvId def = get_src_uint(ctx, src);
1190 unsigned num_components = nir_src_num_components(*src);
1191 unsigned bit_size = nir_src_bit_size(*src);
1192 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1193 }
1194
1195 static void
1196 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1197 {
1198 assert(tex->op == nir_texop_tex ||
1199 tex->op == nir_texop_txb ||
1200 tex->op == nir_texop_txl);
1201 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1202 assert(tex->texture_index == tex->sampler_index);
1203
1204 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1205 unsigned coord_components;
1206 for (unsigned i = 0; i < tex->num_srcs; i++) {
1207 switch (tex->src[i].src_type) {
1208 case nir_tex_src_coord:
1209 coord = get_src_float(ctx, &tex->src[i].src);
1210 coord_components = nir_src_num_components(tex->src[i].src);
1211 break;
1212
1213 case nir_tex_src_projector:
1214 assert(nir_src_num_components(tex->src[i].src) == 1);
1215 proj = get_src_float(ctx, &tex->src[i].src);
1216 assert(proj != 0);
1217 break;
1218
1219 case nir_tex_src_bias:
1220 assert(tex->op == nir_texop_txb);
1221 bias = get_src_float(ctx, &tex->src[i].src);
1222 assert(bias != 0);
1223 break;
1224
1225 case nir_tex_src_lod:
1226 assert(nir_src_num_components(tex->src[i].src) == 1);
1227 lod = get_src_float(ctx, &tex->src[i].src);
1228 assert(lod != 0);
1229 break;
1230
1231 case nir_tex_src_comparator:
1232 assert(nir_src_num_components(tex->src[i].src) == 1);
1233 dref = get_src_float(ctx, &tex->src[i].src);
1234 assert(dref != 0);
1235 break;
1236
1237 default:
1238 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1239 unreachable("unknown texture source");
1240 }
1241 }
1242
1243 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1244 lod = emit_float_const(ctx, 32, 0.0f);
1245 assert(lod != 0);
1246 }
1247
1248 bool is_ms;
1249 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1250 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1251 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1252 dimension, false, tex->is_array, is_ms, 1,
1253 SpvImageFormatUnknown);
1254 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1255 image_type);
1256
1257 assert(tex->texture_index < ctx->num_samplers);
1258 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1259 ctx->samplers[tex->texture_index]);
1260
1261 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1262
1263 if (proj) {
1264 SpvId constituents[coord_components + 1];
1265 if (coord_components == 1)
1266 constituents[0] = coord;
1267 else {
1268 assert(coord_components > 1);
1269 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1270 for (uint32_t i = 0; i < coord_components; ++i)
1271 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1272 float_type,
1273 coord,
1274 &i, 1);
1275 }
1276
1277 constituents[coord_components++] = proj;
1278
1279 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1280 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1281 vec_type,
1282 constituents,
1283 coord_components);
1284 }
1285
1286 SpvId actual_dest_type = dest_type;
1287 if (dref)
1288 actual_dest_type = float_type;
1289
1290 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1291 actual_dest_type, load,
1292 coord,
1293 proj != 0,
1294 lod, bias, dref);
1295 spirv_builder_emit_decoration(&ctx->builder, result,
1296 SpvDecorationRelaxedPrecision);
1297
1298 if (dref) {
1299 SpvId components[4] = { result, result, result, result };
1300 result = spirv_builder_emit_composite_construct(&ctx->builder,
1301 dest_type,
1302 components,
1303 4);
1304 }
1305
1306 store_dest(ctx, &tex->dest, result, tex->dest_type);
1307 }
1308
1309 static void
1310 start_block(struct ntv_context *ctx, SpvId label)
1311 {
1312 /* terminate previous block if needed */
1313 if (ctx->block_started)
1314 spirv_builder_emit_branch(&ctx->builder, label);
1315
1316 /* start new block */
1317 spirv_builder_label(&ctx->builder, label);
1318 ctx->block_started = true;
1319 }
1320
1321 static void
1322 branch(struct ntv_context *ctx, SpvId label)
1323 {
1324 assert(ctx->block_started);
1325 spirv_builder_emit_branch(&ctx->builder, label);
1326 ctx->block_started = false;
1327 }
1328
1329 static void
1330 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1331 SpvId else_id)
1332 {
1333 assert(ctx->block_started);
1334 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1335 then_id, else_id);
1336 ctx->block_started = false;
1337 }
1338
1339 static void
1340 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1341 {
1342 switch (jump->type) {
1343 case nir_jump_break:
1344 assert(ctx->loop_break);
1345 branch(ctx, ctx->loop_break);
1346 break;
1347
1348 case nir_jump_continue:
1349 assert(ctx->loop_cont);
1350 branch(ctx, ctx->loop_cont);
1351 break;
1352
1353 default:
1354 unreachable("Unsupported jump type\n");
1355 }
1356 }
1357
1358 static void
1359 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1360 {
1361 assert(deref->deref_type == nir_deref_type_var);
1362
1363 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1364 assert(he);
1365 SpvId result = (SpvId)(intptr_t)he->data;
1366 /* uint is a bit of a lie here, it's really just an opaque type */
1367 store_dest_uint(ctx, &deref->dest, result);
1368 }
1369
1370 static void
1371 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1372 {
1373 assert(deref->deref_type == nir_deref_type_array);
1374 nir_variable *var = nir_deref_instr_get_variable(deref);
1375
1376 SpvStorageClass storage_class;
1377 switch (var->data.mode) {
1378 case nir_var_shader_in:
1379 storage_class = SpvStorageClassInput;
1380 break;
1381
1382 case nir_var_shader_out:
1383 storage_class = SpvStorageClassOutput;
1384 break;
1385
1386 default:
1387 unreachable("Unsupported nir_variable_mode\n");
1388 }
1389
1390 SpvId index = get_src_uint(ctx, &deref->arr.index);
1391
1392 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1393 storage_class,
1394 get_glsl_type(ctx, deref->type));
1395
1396 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1397 ptr_type,
1398 get_src_uint(ctx, &deref->parent),
1399 &index, 1);
1400 /* uint is a bit of a lie here, it's really just an opaque type */
1401 store_dest_uint(ctx, &deref->dest, result);
1402 }
1403
1404 static void
1405 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1406 {
1407 switch (deref->deref_type) {
1408 case nir_deref_type_var:
1409 emit_deref_var(ctx, deref);
1410 break;
1411
1412 case nir_deref_type_array:
1413 emit_deref_array(ctx, deref);
1414 break;
1415
1416 default:
1417 unreachable("unexpected deref_type");
1418 }
1419 }
1420
1421 static void
1422 emit_block(struct ntv_context *ctx, struct nir_block *block)
1423 {
1424 start_block(ctx, block_label(ctx, block));
1425 nir_foreach_instr(instr, block) {
1426 switch (instr->type) {
1427 case nir_instr_type_alu:
1428 emit_alu(ctx, nir_instr_as_alu(instr));
1429 break;
1430 case nir_instr_type_intrinsic:
1431 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1432 break;
1433 case nir_instr_type_load_const:
1434 emit_load_const(ctx, nir_instr_as_load_const(instr));
1435 break;
1436 case nir_instr_type_ssa_undef:
1437 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1438 break;
1439 case nir_instr_type_tex:
1440 emit_tex(ctx, nir_instr_as_tex(instr));
1441 break;
1442 case nir_instr_type_phi:
1443 unreachable("nir_instr_type_phi not supported");
1444 break;
1445 case nir_instr_type_jump:
1446 emit_jump(ctx, nir_instr_as_jump(instr));
1447 break;
1448 case nir_instr_type_call:
1449 unreachable("nir_instr_type_call not supported");
1450 break;
1451 case nir_instr_type_parallel_copy:
1452 unreachable("nir_instr_type_parallel_copy not supported");
1453 break;
1454 case nir_instr_type_deref:
1455 emit_deref(ctx, nir_instr_as_deref(instr));
1456 break;
1457 }
1458 }
1459 }
1460
1461 static void
1462 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1463
1464 static SpvId
1465 get_src_bool(struct ntv_context *ctx, nir_src *src)
1466 {
1467 SpvId def = get_src_uint(ctx, src);
1468 assert(nir_src_bit_size(*src) == 1);
1469 unsigned num_components = nir_src_num_components(*src);
1470 return uvec_to_bvec(ctx, def, num_components);
1471 }
1472
1473 static void
1474 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1475 {
1476 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1477
1478 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1479 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1480 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1481 SpvId else_id = endif_id;
1482
1483 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1484 if (has_else) {
1485 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1486 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1487 }
1488
1489 /* create a header-block */
1490 start_block(ctx, header_id);
1491 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1492 SpvSelectionControlMaskNone);
1493 branch_conditional(ctx, condition, then_id, else_id);
1494
1495 emit_cf_list(ctx, &if_stmt->then_list);
1496
1497 if (has_else) {
1498 if (ctx->block_started)
1499 branch(ctx, endif_id);
1500
1501 emit_cf_list(ctx, &if_stmt->else_list);
1502 }
1503
1504 start_block(ctx, endif_id);
1505 }
1506
1507 static void
1508 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1509 {
1510 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1511 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1512 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1513 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1514
1515 /* create a header-block */
1516 start_block(ctx, header_id);
1517 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1518 branch(ctx, begin_id);
1519
1520 SpvId save_break = ctx->loop_break;
1521 SpvId save_cont = ctx->loop_cont;
1522 ctx->loop_break = break_id;
1523 ctx->loop_cont = cont_id;
1524
1525 emit_cf_list(ctx, &loop->body);
1526
1527 ctx->loop_break = save_break;
1528 ctx->loop_cont = save_cont;
1529
1530 branch(ctx, cont_id);
1531 start_block(ctx, cont_id);
1532 branch(ctx, header_id);
1533
1534 start_block(ctx, break_id);
1535 }
1536
1537 static void
1538 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1539 {
1540 foreach_list_typed(nir_cf_node, node, node, list) {
1541 switch (node->type) {
1542 case nir_cf_node_block:
1543 emit_block(ctx, nir_cf_node_as_block(node));
1544 break;
1545
1546 case nir_cf_node_if:
1547 emit_if(ctx, nir_cf_node_as_if(node));
1548 break;
1549
1550 case nir_cf_node_loop:
1551 emit_loop(ctx, nir_cf_node_as_loop(node));
1552 break;
1553
1554 case nir_cf_node_function:
1555 unreachable("nir_cf_node_function not supported");
1556 break;
1557 }
1558 }
1559 }
1560
1561 struct spirv_shader *
1562 nir_to_spirv(struct nir_shader *s)
1563 {
1564 struct spirv_shader *ret = NULL;
1565
1566 struct ntv_context ctx = {};
1567
1568 switch (s->info.stage) {
1569 case MESA_SHADER_VERTEX:
1570 case MESA_SHADER_FRAGMENT:
1571 case MESA_SHADER_COMPUTE:
1572 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1573 break;
1574
1575 case MESA_SHADER_TESS_CTRL:
1576 case MESA_SHADER_TESS_EVAL:
1577 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1578 break;
1579
1580 case MESA_SHADER_GEOMETRY:
1581 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1582 break;
1583
1584 default:
1585 unreachable("invalid stage");
1586 }
1587
1588 // TODO: only enable when needed
1589 if (s->info.stage == MESA_SHADER_FRAGMENT)
1590 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1591
1592 ctx.stage = s->info.stage;
1593 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1594 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1595
1596 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1597 SpvMemoryModelGLSL450);
1598
1599 SpvExecutionModel exec_model;
1600 switch (s->info.stage) {
1601 case MESA_SHADER_VERTEX:
1602 exec_model = SpvExecutionModelVertex;
1603 break;
1604 case MESA_SHADER_TESS_CTRL:
1605 exec_model = SpvExecutionModelTessellationControl;
1606 break;
1607 case MESA_SHADER_TESS_EVAL:
1608 exec_model = SpvExecutionModelTessellationEvaluation;
1609 break;
1610 case MESA_SHADER_GEOMETRY:
1611 exec_model = SpvExecutionModelGeometry;
1612 break;
1613 case MESA_SHADER_FRAGMENT:
1614 exec_model = SpvExecutionModelFragment;
1615 break;
1616 case MESA_SHADER_COMPUTE:
1617 exec_model = SpvExecutionModelGLCompute;
1618 break;
1619 default:
1620 unreachable("invalid stage");
1621 }
1622
1623 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1624 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1625 NULL, 0);
1626 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1627 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1628
1629 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1630 _mesa_key_pointer_equal);
1631
1632 nir_foreach_variable(var, &s->inputs)
1633 emit_input(&ctx, var);
1634
1635 nir_foreach_variable(var, &s->outputs)
1636 emit_output(&ctx, var);
1637
1638 nir_foreach_variable(var, &s->uniforms)
1639 emit_uniform(&ctx, var);
1640
1641 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1642 "main", ctx.entry_ifaces,
1643 ctx.num_entry_ifaces);
1644 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1645 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1646 SpvExecutionModeOriginUpperLeft);
1647 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1648 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1649 SpvExecutionModeDepthReplacing);
1650 }
1651
1652
1653 spirv_builder_function(&ctx.builder, entry_point, type_void,
1654 SpvFunctionControlMaskNone,
1655 type_main);
1656
1657 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1658 nir_metadata_require(entry, nir_metadata_block_index);
1659
1660 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1661 if (!ctx.defs)
1662 goto fail;
1663 ctx.num_defs = entry->ssa_alloc;
1664
1665 nir_index_local_regs(entry);
1666 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1667 if (!ctx.regs)
1668 goto fail;
1669 ctx.num_regs = entry->reg_alloc;
1670
1671 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1672 if (!block_ids)
1673 goto fail;
1674
1675 for (int i = 0; i < entry->num_blocks; ++i)
1676 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1677
1678 ctx.block_ids = block_ids;
1679 ctx.num_blocks = entry->num_blocks;
1680
1681 /* emit a block only for the variable declarations */
1682 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1683 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1684 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1685 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1686 SpvStorageClassFunction,
1687 type);
1688 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1689 SpvStorageClassFunction);
1690
1691 ctx.regs[reg->index] = var;
1692 }
1693
1694 emit_cf_list(&ctx, &entry->body);
1695
1696 free(ctx.defs);
1697
1698 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1699 spirv_builder_function_end(&ctx.builder);
1700
1701 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1702
1703 ret = CALLOC_STRUCT(spirv_shader);
1704 if (!ret)
1705 goto fail;
1706
1707 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1708 if (!ret->words)
1709 goto fail;
1710
1711 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1712 assert(ret->num_words == num_words);
1713
1714 return ret;
1715
1716 fail:
1717
1718 if (ret)
1719 spirv_shader_delete(ret);
1720
1721 if (ctx.vars)
1722 _mesa_hash_table_destroy(ctx.vars, NULL);
1723
1724 return NULL;
1725 }
1726
1727 void
1728 spirv_shader_delete(struct spirv_shader *s)
1729 {
1730 FREE(s->words);
1731 FREE(s);
1732 }