zink/spirv: implement load_front_face
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59
60 SpvId front_face_var;
61 };
62
63 static SpvId
64 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
65 unsigned num_components, float value);
66
67 static SpvId
68 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
69 unsigned num_components, uint32_t value);
70
71 static SpvId
72 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
73 unsigned num_components, int32_t value);
74
75 static SpvId
76 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
77
78 static SpvId
79 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
80 SpvId src0, SpvId src1);
81
82 static SpvId
83 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
84 SpvId src0, SpvId src1, SpvId src2);
85
86 static SpvId
87 get_bvec_type(struct ntv_context *ctx, int num_components)
88 {
89 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
90 if (num_components > 1)
91 return spirv_builder_type_vector(&ctx->builder, bool_type,
92 num_components);
93
94 assert(num_components == 1);
95 return bool_type;
96 }
97
98 static SpvId
99 block_label(struct ntv_context *ctx, nir_block *block)
100 {
101 assert(block->index < ctx->num_blocks);
102 return ctx->block_ids[block->index];
103 }
104
105 static SpvId
106 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
107 {
108 assert(bit_size == 32);
109 return spirv_builder_const_float(&ctx->builder, bit_size, value);
110 }
111
112 static SpvId
113 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
114 {
115 assert(bit_size == 32);
116 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
117 }
118
119 static SpvId
120 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
121 {
122 assert(bit_size == 32);
123 return spirv_builder_const_int(&ctx->builder, bit_size, value);
124 }
125
126 static SpvId
127 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
128 {
129 assert(bit_size == 32); // only 32-bit floats supported so far
130
131 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
132 if (num_components > 1)
133 return spirv_builder_type_vector(&ctx->builder, float_type,
134 num_components);
135
136 assert(num_components == 1);
137 return float_type;
138 }
139
140 static SpvId
141 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
144
145 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
146 if (num_components > 1)
147 return spirv_builder_type_vector(&ctx->builder, int_type,
148 num_components);
149
150 assert(num_components == 1);
151 return int_type;
152 }
153
154 static SpvId
155 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
158
159 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
160 if (num_components > 1)
161 return spirv_builder_type_vector(&ctx->builder, uint_type,
162 num_components);
163
164 assert(num_components == 1);
165 return uint_type;
166 }
167
168 static SpvId
169 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
170 {
171 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
172 nir_dest_num_components(*dest));
173 }
174
175 static SpvId
176 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
177 {
178 switch (type) {
179 case GLSL_TYPE_BOOL:
180 return spirv_builder_type_bool(&ctx->builder);
181
182 case GLSL_TYPE_FLOAT:
183 return spirv_builder_type_float(&ctx->builder, 32);
184
185 case GLSL_TYPE_INT:
186 return spirv_builder_type_int(&ctx->builder, 32);
187
188 case GLSL_TYPE_UINT:
189 return spirv_builder_type_uint(&ctx->builder, 32);
190 /* TODO: handle more types */
191
192 default:
193 unreachable("unknown GLSL type");
194 }
195 }
196
197 static SpvId
198 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
199 {
200 assert(type);
201 if (glsl_type_is_scalar(type))
202 return get_glsl_basetype(ctx, glsl_get_base_type(type));
203
204 if (glsl_type_is_vector(type))
205 return spirv_builder_type_vector(&ctx->builder,
206 get_glsl_basetype(ctx, glsl_get_base_type(type)),
207 glsl_get_vector_elements(type));
208
209 if (glsl_type_is_array(type)) {
210 SpvId ret = spirv_builder_type_array(&ctx->builder,
211 get_glsl_type(ctx, glsl_get_array_element(type)),
212 emit_uint_const(ctx, 32, glsl_get_length(type)));
213 uint32_t stride = glsl_get_explicit_stride(type);
214 if (stride)
215 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
216 return ret;
217 }
218
219
220 unreachable("we shouldn't get here, I think...");
221 }
222
223 static void
224 emit_input(struct ntv_context *ctx, struct nir_variable *var)
225 {
226 SpvId var_type = get_glsl_type(ctx, var->type);
227 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
228 SpvStorageClassInput,
229 var_type);
230 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
231 SpvStorageClassInput);
232
233 if (var->name)
234 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
235
236 if (ctx->stage == MESA_SHADER_FRAGMENT) {
237 if (var->data.location >= VARYING_SLOT_VAR0 ||
238 (var->data.location >= VARYING_SLOT_COL0 &&
239 var->data.location <= VARYING_SLOT_TEX7)) {
240 spirv_builder_emit_location(&ctx->builder, var_id,
241 ctx->var_location++);
242 } else {
243 switch (var->data.location) {
244 case VARYING_SLOT_POS:
245 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
246 break;
247
248 case VARYING_SLOT_PNTC:
249 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
250 break;
251
252 default:
253 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
254 unreachable("unexpected varying slot");
255 }
256 }
257 } else {
258 spirv_builder_emit_location(&ctx->builder, var_id,
259 var->data.driver_location);
260 }
261
262 if (var->data.location_frac)
263 spirv_builder_emit_component(&ctx->builder, var_id,
264 var->data.location_frac);
265
266 if (var->data.interpolation == INTERP_MODE_FLAT)
267 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
268
269 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
270
271 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
272 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
273 }
274
275 static void
276 emit_output(struct ntv_context *ctx, struct nir_variable *var)
277 {
278 SpvId var_type = get_glsl_type(ctx, var->type);
279 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
280 SpvStorageClassOutput,
281 var_type);
282 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
283 SpvStorageClassOutput);
284 if (var->name)
285 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
286
287
288 if (ctx->stage == MESA_SHADER_VERTEX) {
289 if (var->data.location >= VARYING_SLOT_VAR0 ||
290 (var->data.location >= VARYING_SLOT_COL0 &&
291 var->data.location <= VARYING_SLOT_TEX7)) {
292 spirv_builder_emit_location(&ctx->builder, var_id,
293 ctx->var_location++);
294 } else {
295 switch (var->data.location) {
296 case VARYING_SLOT_POS:
297 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
298 break;
299
300 case VARYING_SLOT_PSIZ:
301 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
302 break;
303
304 case VARYING_SLOT_CLIP_DIST0:
305 assert(glsl_type_is_array(var->type));
306 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
307 break;
308
309 default:
310 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
311 unreachable("unexpected varying slot");
312 }
313 }
314 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
315 if (var->data.location >= FRAG_RESULT_DATA0)
316 spirv_builder_emit_location(&ctx->builder, var_id,
317 var->data.location - FRAG_RESULT_DATA0);
318 else {
319 switch (var->data.location) {
320 case FRAG_RESULT_COLOR:
321 spirv_builder_emit_location(&ctx->builder, var_id, 0);
322 break;
323
324 case FRAG_RESULT_DEPTH:
325 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
326 break;
327
328 default:
329 spirv_builder_emit_location(&ctx->builder, var_id,
330 var->data.driver_location);
331 }
332 }
333 }
334
335 if (var->data.location_frac)
336 spirv_builder_emit_component(&ctx->builder, var_id,
337 var->data.location_frac);
338
339 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
340
341 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
342 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
343 }
344
345 static SpvDim
346 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
347 {
348 *is_ms = false;
349 switch (gdim) {
350 case GLSL_SAMPLER_DIM_1D:
351 return SpvDim1D;
352 case GLSL_SAMPLER_DIM_2D:
353 return SpvDim2D;
354 case GLSL_SAMPLER_DIM_RECT:
355 return SpvDimRect;
356 case GLSL_SAMPLER_DIM_CUBE:
357 return SpvDimCube;
358 case GLSL_SAMPLER_DIM_3D:
359 return SpvDim3D;
360 case GLSL_SAMPLER_DIM_MS:
361 *is_ms = true;
362 return SpvDim2D;
363 default:
364 fprintf(stderr, "unknown sampler type %d\n", gdim);
365 break;
366 }
367 return SpvDim2D;
368 }
369
370 static void
371 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
372 {
373 bool is_ms;
374 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
375 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
376 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
377 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
378 SpvImageFormatUnknown);
379
380 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
381 image_type);
382 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
383 SpvStorageClassUniformConstant,
384 sampled_type);
385 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
386 SpvStorageClassUniformConstant);
387
388 if (var->name)
389 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
390
391 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
392 ctx->samplers[ctx->num_samplers++] = var_id;
393
394 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
395 var->data.descriptor_set);
396 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
397 }
398
399 static void
400 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
401 {
402 uint32_t size = glsl_count_attribute_slots(var->type, false);
403 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
404 SpvId array_length = emit_uint_const(ctx, 32, size);
405 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
406 array_length);
407 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
408
409 // wrap UBO-array in a struct
410 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
411 if (var->name) {
412 char struct_name[100];
413 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
414 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
415 }
416
417 spirv_builder_emit_decoration(&ctx->builder, struct_type,
418 SpvDecorationBlock);
419 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
420
421
422 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
423 SpvStorageClassUniform,
424 struct_type);
425
426 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
427 SpvStorageClassUniform);
428 if (var->name)
429 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
430
431 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
432 ctx->ubos[ctx->num_ubos++] = var_id;
433
434 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
435 var->data.descriptor_set);
436 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
437 }
438
439 static void
440 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
441 {
442 if (var->data.mode == nir_var_mem_ubo)
443 emit_ubo(ctx, var);
444 else {
445 assert(var->data.mode == nir_var_uniform);
446 if (glsl_type_is_sampler(var->type))
447 emit_sampler(ctx, var);
448 }
449 }
450
451 static SpvId
452 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
453 {
454 assert(ssa->index < ctx->num_defs);
455 assert(ctx->defs[ssa->index] != 0);
456 return ctx->defs[ssa->index];
457 }
458
459 static SpvId
460 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
461 {
462 assert(reg->index < ctx->num_regs);
463 assert(ctx->regs[reg->index] != 0);
464 return ctx->regs[reg->index];
465 }
466
467 static SpvId
468 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
469 {
470 assert(reg->reg);
471 assert(!reg->indirect);
472 assert(!reg->base_offset);
473
474 SpvId var = get_var_from_reg(ctx, reg->reg);
475 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
476 return spirv_builder_emit_load(&ctx->builder, type, var);
477 }
478
479 static SpvId
480 get_src_uint(struct ntv_context *ctx, nir_src *src)
481 {
482 if (src->is_ssa)
483 return get_src_uint_ssa(ctx, src->ssa);
484 else
485 return get_src_uint_reg(ctx, &src->reg);
486 }
487
488 static SpvId
489 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
490 {
491 assert(!alu->src[src].negate);
492 assert(!alu->src[src].abs);
493
494 SpvId def = get_src_uint(ctx, &alu->src[src].src);
495
496 unsigned used_channels = 0;
497 bool need_swizzle = false;
498 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
499 if (!nir_alu_instr_channel_used(alu, src, i))
500 continue;
501
502 used_channels++;
503
504 if (alu->src[src].swizzle[i] != i)
505 need_swizzle = true;
506 }
507 assert(used_channels != 0);
508
509 unsigned live_channels = nir_src_num_components(alu->src[src].src);
510 if (used_channels != live_channels)
511 need_swizzle = true;
512
513 if (!need_swizzle)
514 return def;
515
516 int bit_size = nir_src_bit_size(alu->src[src].src);
517 assert(bit_size == 1 || bit_size == 32);
518
519 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
520 if (used_channels == 1) {
521 uint32_t indices[] = { alu->src[src].swizzle[0] };
522 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
523 def, indices,
524 ARRAY_SIZE(indices));
525 } else if (live_channels == 1) {
526 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
527 used_channels);
528
529 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
530 for (unsigned i = 0; i < used_channels; ++i)
531 constituents[i] = def;
532
533 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
534 constituents,
535 used_channels);
536 } else {
537 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
538 used_channels);
539
540 uint32_t components[NIR_MAX_VEC_COMPONENTS];
541 size_t num_components = 0;
542 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
543 if (!nir_alu_instr_channel_used(alu, src, i))
544 continue;
545
546 components[num_components++] = alu->src[src].swizzle[i];
547 }
548
549 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
550 def, def, components, num_components);
551 }
552 }
553
554 static void
555 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
556 {
557 assert(result != 0);
558 assert(ssa->index < ctx->num_defs);
559 ctx->defs[ssa->index] = result;
560 }
561
562 static SpvId
563 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
564 SpvId if_true, SpvId if_false)
565 {
566 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
567 }
568
569 static SpvId
570 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
571 {
572 SpvId otype = get_uvec_type(ctx, 32, num_components);
573 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
574 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
575 return emit_select(ctx, otype, value, one, zero);
576 }
577
578 static SpvId
579 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
580 {
581 SpvId type = get_bvec_type(ctx, num_components);
582 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
583 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
584 }
585
586 static SpvId
587 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
588 {
589 return emit_unop(ctx, SpvOpBitcast, type, value);
590 }
591
592 static SpvId
593 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
594 unsigned num_components)
595 {
596 SpvId type = get_uvec_type(ctx, bit_size, num_components);
597 return emit_bitcast(ctx, type, value);
598 }
599
600 static SpvId
601 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
602 unsigned num_components)
603 {
604 SpvId type = get_ivec_type(ctx, bit_size, num_components);
605 return emit_bitcast(ctx, type, value);
606 }
607
608 static SpvId
609 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
610 unsigned num_components)
611 {
612 SpvId type = get_fvec_type(ctx, bit_size, num_components);
613 return emit_bitcast(ctx, type, value);
614 }
615
616 static void
617 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
618 {
619 SpvId var = get_var_from_reg(ctx, reg->reg);
620 assert(var);
621 spirv_builder_emit_store(&ctx->builder, var, result);
622 }
623
624 static void
625 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
626 {
627 if (dest->is_ssa)
628 store_ssa_def_uint(ctx, &dest->ssa, result);
629 else
630 store_reg_def(ctx, &dest->reg, result);
631 }
632
633 static void
634 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
635 {
636 unsigned num_components = nir_dest_num_components(*dest);
637 unsigned bit_size = nir_dest_bit_size(*dest);
638
639 switch (nir_alu_type_get_base_type(type)) {
640 case nir_type_bool:
641 assert(bit_size == 1);
642 result = bvec_to_uvec(ctx, result, num_components);
643 break;
644
645 case nir_type_uint:
646 break; /* nothing to do! */
647
648 case nir_type_int:
649 case nir_type_float:
650 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
651 break;
652
653 default:
654 unreachable("unsupported nir_alu_type");
655 }
656
657 store_dest_uint(ctx, dest, result);
658 }
659
660 static SpvId
661 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
662 {
663 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
664 }
665
666 static SpvId
667 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
668 SpvId src0, SpvId src1)
669 {
670 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
671 }
672
673 static SpvId
674 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
675 SpvId src0, SpvId src1, SpvId src2)
676 {
677 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
678 }
679
680 static SpvId
681 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
682 SpvId src)
683 {
684 SpvId args[] = { src };
685 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
686 op, args, ARRAY_SIZE(args));
687 }
688
689 static SpvId
690 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
691 SpvId src0, SpvId src1)
692 {
693 SpvId args[] = { src0, src1 };
694 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
695 op, args, ARRAY_SIZE(args));
696 }
697
698 static SpvId
699 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
700 unsigned num_components, float value)
701 {
702 assert(bit_size == 32);
703
704 SpvId result = emit_float_const(ctx, bit_size, value);
705 if (num_components == 1)
706 return result;
707
708 assert(num_components > 1);
709 SpvId components[num_components];
710 for (int i = 0; i < num_components; i++)
711 components[i] = result;
712
713 SpvId type = get_fvec_type(ctx, bit_size, num_components);
714 return spirv_builder_const_composite(&ctx->builder, type, components,
715 num_components);
716 }
717
718 static SpvId
719 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
720 unsigned num_components, uint32_t value)
721 {
722 assert(bit_size == 32);
723
724 SpvId result = emit_uint_const(ctx, bit_size, value);
725 if (num_components == 1)
726 return result;
727
728 assert(num_components > 1);
729 SpvId components[num_components];
730 for (int i = 0; i < num_components; i++)
731 components[i] = result;
732
733 SpvId type = get_uvec_type(ctx, bit_size, num_components);
734 return spirv_builder_const_composite(&ctx->builder, type, components,
735 num_components);
736 }
737
738 static SpvId
739 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
740 unsigned num_components, int32_t value)
741 {
742 assert(bit_size == 32);
743
744 SpvId result = emit_int_const(ctx, bit_size, value);
745 if (num_components == 1)
746 return result;
747
748 assert(num_components > 1);
749 SpvId components[num_components];
750 for (int i = 0; i < num_components; i++)
751 components[i] = result;
752
753 SpvId type = get_ivec_type(ctx, bit_size, num_components);
754 return spirv_builder_const_composite(&ctx->builder, type, components,
755 num_components);
756 }
757
758 static inline unsigned
759 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
760 {
761 if (nir_op_infos[instr->op].input_sizes[src] > 0)
762 return nir_op_infos[instr->op].input_sizes[src];
763
764 if (instr->dest.dest.is_ssa)
765 return instr->dest.dest.ssa.num_components;
766 else
767 return instr->dest.dest.reg.reg->num_components;
768 }
769
770 static SpvId
771 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
772 {
773 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
774
775 unsigned num_components = alu_instr_src_components(alu, src);
776 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
777 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
778
779 switch (nir_alu_type_get_base_type(type)) {
780 case nir_type_bool:
781 assert(bit_size == 1);
782 return uvec_to_bvec(ctx, uint_value, num_components);
783
784 case nir_type_int:
785 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
786
787 case nir_type_uint:
788 return uint_value;
789
790 case nir_type_float:
791 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
792
793 default:
794 unreachable("unknown nir_alu_type");
795 }
796 }
797
798 static void
799 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
800 {
801 assert(!alu->dest.saturate);
802 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
803 }
804
805 static SpvId
806 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
807 {
808 unsigned num_components = nir_dest_num_components(*dest);
809 unsigned bit_size = nir_dest_bit_size(*dest);
810
811 switch (nir_alu_type_get_base_type(type)) {
812 case nir_type_bool:
813 return get_bvec_type(ctx, num_components);
814
815 case nir_type_int:
816 return get_ivec_type(ctx, bit_size, num_components);
817
818 case nir_type_uint:
819 return get_uvec_type(ctx, bit_size, num_components);
820
821 case nir_type_float:
822 return get_fvec_type(ctx, bit_size, num_components);
823
824 default:
825 unreachable("unsupported nir_alu_type");
826 }
827 }
828
829 static void
830 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
831 {
832 SpvId src[nir_op_infos[alu->op].num_inputs];
833 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
834 src[i] = get_alu_src(ctx, alu, i);
835
836 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
837 nir_op_infos[alu->op].output_type);
838 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
839 unsigned num_components = nir_dest_num_components(alu->dest.dest);
840
841 SpvId result = 0;
842 switch (alu->op) {
843 case nir_op_mov:
844 assert(nir_op_infos[alu->op].num_inputs == 1);
845 result = src[0];
846 break;
847
848 #define UNOP(nir_op, spirv_op) \
849 case nir_op: \
850 assert(nir_op_infos[alu->op].num_inputs == 1); \
851 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
852 break;
853
854 UNOP(nir_op_ineg, SpvOpSNegate)
855 UNOP(nir_op_fneg, SpvOpFNegate)
856 UNOP(nir_op_fddx, SpvOpDPdx)
857 UNOP(nir_op_fddy, SpvOpDPdy)
858 UNOP(nir_op_f2i32, SpvOpConvertFToS)
859 UNOP(nir_op_f2u32, SpvOpConvertFToU)
860 UNOP(nir_op_i2f32, SpvOpConvertSToF)
861 UNOP(nir_op_u2f32, SpvOpConvertUToF)
862 UNOP(nir_op_inot, SpvOpNot)
863 #undef UNOP
864
865 case nir_op_b2i32:
866 assert(nir_op_infos[alu->op].num_inputs == 1);
867 result = emit_select(ctx, dest_type, src[0],
868 get_ivec_constant(ctx, 32, num_components, 1),
869 get_ivec_constant(ctx, 32, num_components, 0));
870 break;
871
872 case nir_op_b2f32:
873 assert(nir_op_infos[alu->op].num_inputs == 1);
874 result = emit_select(ctx, dest_type, src[0],
875 get_fvec_constant(ctx, 32, num_components, 1),
876 get_fvec_constant(ctx, 32, num_components, 0));
877 break;
878
879 #define BUILTIN_UNOP(nir_op, spirv_op) \
880 case nir_op: \
881 assert(nir_op_infos[alu->op].num_inputs == 1); \
882 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
883 break;
884
885 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
886 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
887 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
888 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
889 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
890 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
891 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
892 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
893 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
894 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
895 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
896 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
897 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
898 #undef BUILTIN_UNOP
899
900 case nir_op_frcp:
901 assert(nir_op_infos[alu->op].num_inputs == 1);
902 result = emit_binop(ctx, SpvOpFDiv, dest_type,
903 get_fvec_constant(ctx, bit_size, num_components, 1),
904 src[0]);
905 break;
906
907 case nir_op_f2b1:
908 assert(nir_op_infos[alu->op].num_inputs == 1);
909 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
910 get_fvec_constant(ctx,
911 nir_src_bit_size(alu->src[0].src),
912 num_components, 0));
913 break;
914
915
916 #define BINOP(nir_op, spirv_op) \
917 case nir_op: \
918 assert(nir_op_infos[alu->op].num_inputs == 2); \
919 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
920 break;
921
922 BINOP(nir_op_iadd, SpvOpIAdd)
923 BINOP(nir_op_isub, SpvOpISub)
924 BINOP(nir_op_imul, SpvOpIMul)
925 BINOP(nir_op_idiv, SpvOpSDiv)
926 BINOP(nir_op_udiv, SpvOpUDiv)
927 BINOP(nir_op_fadd, SpvOpFAdd)
928 BINOP(nir_op_fsub, SpvOpFSub)
929 BINOP(nir_op_fmul, SpvOpFMul)
930 BINOP(nir_op_fdiv, SpvOpFDiv)
931 BINOP(nir_op_fmod, SpvOpFMod)
932 BINOP(nir_op_ilt, SpvOpSLessThan)
933 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
934 BINOP(nir_op_ieq, SpvOpIEqual)
935 BINOP(nir_op_ine, SpvOpINotEqual)
936 BINOP(nir_op_flt, SpvOpFOrdLessThan)
937 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
938 BINOP(nir_op_feq, SpvOpFOrdEqual)
939 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
940 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
941 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
942 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
943 BINOP(nir_op_iand, SpvOpBitwiseAnd)
944 BINOP(nir_op_ior, SpvOpBitwiseOr)
945 #undef BINOP
946
947 #define BUILTIN_BINOP(nir_op, spirv_op) \
948 case nir_op: \
949 assert(nir_op_infos[alu->op].num_inputs == 2); \
950 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
951 break;
952
953 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
954 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
955 #undef BUILTIN_BINOP
956
957 case nir_op_fdot2:
958 case nir_op_fdot3:
959 case nir_op_fdot4:
960 assert(nir_op_infos[alu->op].num_inputs == 2);
961 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
962 break;
963
964 case nir_op_seq:
965 case nir_op_sne:
966 case nir_op_slt:
967 case nir_op_sge: {
968 assert(nir_op_infos[alu->op].num_inputs == 2);
969 int num_components = nir_dest_num_components(alu->dest.dest);
970 SpvId bool_type = get_bvec_type(ctx, num_components);
971
972 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
973 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
974 if (num_components > 1) {
975 SpvId zero_comps[num_components], one_comps[num_components];
976 for (int i = 0; i < num_components; i++) {
977 zero_comps[i] = zero;
978 one_comps[i] = one;
979 }
980
981 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
982 zero_comps, num_components);
983 one = spirv_builder_const_composite(&ctx->builder, dest_type,
984 one_comps, num_components);
985 }
986
987 SpvOp op;
988 switch (alu->op) {
989 case nir_op_seq: op = SpvOpFOrdEqual; break;
990 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
991 case nir_op_slt: op = SpvOpFOrdLessThan; break;
992 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
993 default: unreachable("unexpected op");
994 }
995
996 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
997 result = emit_select(ctx, dest_type, result, one, zero);
998 }
999 break;
1000
1001 case nir_op_fcsel:
1002 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1003 get_bvec_type(ctx, num_components),
1004 src[0],
1005 get_fvec_constant(ctx,
1006 nir_src_bit_size(alu->src[0].src),
1007 num_components, 0));
1008 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1009 break;
1010
1011 case nir_op_bcsel:
1012 assert(nir_op_infos[alu->op].num_inputs == 3);
1013 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1014 break;
1015
1016 case nir_op_vec2:
1017 case nir_op_vec3:
1018 case nir_op_vec4: {
1019 int num_inputs = nir_op_infos[alu->op].num_inputs;
1020 assert(2 <= num_inputs && num_inputs <= 4);
1021 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1022 src, num_inputs);
1023 }
1024 break;
1025
1026 default:
1027 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1028 nir_op_infos[alu->op].name);
1029
1030 unreachable("unsupported opcode");
1031 return;
1032 }
1033
1034 store_alu_result(ctx, alu, result);
1035 }
1036
1037 static void
1038 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1039 {
1040 uint32_t values[NIR_MAX_VEC_COMPONENTS];
1041 for (int i = 0; i < load_const->def.num_components; ++i)
1042 values[i] = load_const->value[i].u32;
1043
1044 unsigned bit_size = load_const->def.bit_size;
1045 unsigned num_components = load_const->def.num_components;
1046
1047 SpvId constant;
1048 if (num_components > 1) {
1049 SpvId components[num_components];
1050 for (int i = 0; i < num_components; i++)
1051 components[i] = emit_uint_const(ctx, bit_size, values[i]);
1052
1053 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1054 constant = spirv_builder_const_composite(&ctx->builder, type,
1055 components, num_components);
1056 } else {
1057 assert(num_components == 1);
1058 constant = emit_uint_const(ctx, bit_size, values[0]);
1059 }
1060
1061 store_ssa_def_uint(ctx, &load_const->def, constant);
1062 }
1063
1064 static void
1065 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1066 {
1067 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1068 assert(const_block_index); // no dynamic indexing for now
1069 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1070
1071 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1072 if (const_offset) {
1073 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1074 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1075 SpvStorageClassUniform,
1076 uvec4_type);
1077
1078 unsigned idx = const_offset->u32;
1079 SpvId member = emit_uint_const(ctx, 32, 0);
1080 SpvId offset = emit_uint_const(ctx, 32, idx);
1081 SpvId offsets[] = { member, offset };
1082 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1083 ctx->ubos[0], offsets,
1084 ARRAY_SIZE(offsets));
1085 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1086
1087 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1088 unsigned num_components = nir_dest_num_components(intr->dest);
1089 if (num_components == 1) {
1090 uint32_t components[] = { 0 };
1091 result = spirv_builder_emit_composite_extract(&ctx->builder,
1092 type,
1093 result, components,
1094 1);
1095 } else if (num_components < 4) {
1096 SpvId constituents[num_components];
1097 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1098 for (uint32_t i = 0; i < num_components; ++i)
1099 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1100 uint_type,
1101 result, &i,
1102 1);
1103
1104 result = spirv_builder_emit_composite_construct(&ctx->builder,
1105 type,
1106 constituents,
1107 num_components);
1108 }
1109
1110 store_dest_uint(ctx, &intr->dest, result);
1111 } else
1112 unreachable("uniform-addressing not yet supported");
1113 }
1114
1115 static void
1116 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1117 {
1118 assert(ctx->block_started);
1119 spirv_builder_emit_kill(&ctx->builder);
1120 /* discard is weird in NIR, so let's just create an unreachable block after
1121 it and hope that the vulkan driver will DCE any instructinos in it. */
1122 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1123 }
1124
1125 static void
1126 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1127 {
1128 /* uint is a bit of a lie here; it's really just a pointer */
1129 SpvId ptr = get_src_uint(ctx, intr->src);
1130
1131 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1132 SpvId result = spirv_builder_emit_load(&ctx->builder,
1133 get_glsl_type(ctx, var->type),
1134 ptr);
1135 unsigned num_components = nir_dest_num_components(intr->dest);
1136 unsigned bit_size = nir_dest_bit_size(intr->dest);
1137 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1138 store_dest_uint(ctx, &intr->dest, result);
1139 }
1140
1141 static void
1142 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1143 {
1144 /* uint is a bit of a lie here; it's really just a pointer */
1145 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1146 SpvId src = get_src_uint(ctx, &intr->src[1]);
1147
1148 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1149 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1150 SpvId result = emit_bitcast(ctx, type, src);
1151 spirv_builder_emit_store(&ctx->builder, ptr, result);
1152 }
1153
1154 static void
1155 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1156 {
1157 SpvId var_type = get_glsl_type(ctx, glsl_bool_type());
1158 if (!ctx->front_face_var) {
1159 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1160 SpvStorageClassInput,
1161 var_type);
1162 ctx->front_face_var = spirv_builder_emit_var(&ctx->builder,
1163 pointer_type,
1164 SpvStorageClassInput);
1165 spirv_builder_emit_name(&ctx->builder, ctx->front_face_var,
1166 "gl_FrontFacing");
1167 spirv_builder_emit_builtin(&ctx->builder, ctx->front_face_var,
1168 SpvBuiltInFrontFacing);
1169
1170 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1171 ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->front_face_var;
1172 }
1173
1174 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1175 ctx->front_face_var);
1176 assert(1 == nir_dest_num_components(intr->dest));
1177 result = bvec_to_uvec(ctx, result, 1);
1178 store_dest_uint(ctx, &intr->dest, result);
1179 }
1180
1181 static void
1182 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1183 {
1184 switch (intr->intrinsic) {
1185 case nir_intrinsic_load_ubo:
1186 emit_load_ubo(ctx, intr);
1187 break;
1188
1189 case nir_intrinsic_discard:
1190 emit_discard(ctx, intr);
1191 break;
1192
1193 case nir_intrinsic_load_deref:
1194 emit_load_deref(ctx, intr);
1195 break;
1196
1197 case nir_intrinsic_store_deref:
1198 emit_store_deref(ctx, intr);
1199 break;
1200
1201 case nir_intrinsic_load_front_face:
1202 emit_load_front_face(ctx, intr);
1203 break;
1204
1205 default:
1206 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1207 nir_intrinsic_infos[intr->intrinsic].name);
1208 unreachable("unsupported intrinsic");
1209 }
1210 }
1211
1212 static void
1213 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1214 {
1215 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1216 undef->def.num_components);
1217
1218 store_ssa_def_uint(ctx, &undef->def,
1219 spirv_builder_emit_undef(&ctx->builder, type));
1220 }
1221
1222 static SpvId
1223 get_src_float(struct ntv_context *ctx, nir_src *src)
1224 {
1225 SpvId def = get_src_uint(ctx, src);
1226 unsigned num_components = nir_src_num_components(*src);
1227 unsigned bit_size = nir_src_bit_size(*src);
1228 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1229 }
1230
1231 static void
1232 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1233 {
1234 assert(tex->op == nir_texop_tex ||
1235 tex->op == nir_texop_txb ||
1236 tex->op == nir_texop_txl);
1237 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1238 assert(tex->texture_index == tex->sampler_index);
1239
1240 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1241 unsigned coord_components;
1242 for (unsigned i = 0; i < tex->num_srcs; i++) {
1243 switch (tex->src[i].src_type) {
1244 case nir_tex_src_coord:
1245 coord = get_src_float(ctx, &tex->src[i].src);
1246 coord_components = nir_src_num_components(tex->src[i].src);
1247 break;
1248
1249 case nir_tex_src_projector:
1250 assert(nir_src_num_components(tex->src[i].src) == 1);
1251 proj = get_src_float(ctx, &tex->src[i].src);
1252 assert(proj != 0);
1253 break;
1254
1255 case nir_tex_src_bias:
1256 assert(tex->op == nir_texop_txb);
1257 bias = get_src_float(ctx, &tex->src[i].src);
1258 assert(bias != 0);
1259 break;
1260
1261 case nir_tex_src_lod:
1262 assert(nir_src_num_components(tex->src[i].src) == 1);
1263 lod = get_src_float(ctx, &tex->src[i].src);
1264 assert(lod != 0);
1265 break;
1266
1267 case nir_tex_src_comparator:
1268 assert(nir_src_num_components(tex->src[i].src) == 1);
1269 dref = get_src_float(ctx, &tex->src[i].src);
1270 assert(dref != 0);
1271 break;
1272
1273 default:
1274 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1275 unreachable("unknown texture source");
1276 }
1277 }
1278
1279 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1280 lod = emit_float_const(ctx, 32, 0.0f);
1281 assert(lod != 0);
1282 }
1283
1284 bool is_ms;
1285 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1286 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1287 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1288 dimension, false, tex->is_array, is_ms, 1,
1289 SpvImageFormatUnknown);
1290 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1291 image_type);
1292
1293 assert(tex->texture_index < ctx->num_samplers);
1294 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1295 ctx->samplers[tex->texture_index]);
1296
1297 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1298
1299 if (proj) {
1300 SpvId constituents[coord_components + 1];
1301 if (coord_components == 1)
1302 constituents[0] = coord;
1303 else {
1304 assert(coord_components > 1);
1305 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1306 for (uint32_t i = 0; i < coord_components; ++i)
1307 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1308 float_type,
1309 coord,
1310 &i, 1);
1311 }
1312
1313 constituents[coord_components++] = proj;
1314
1315 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1316 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1317 vec_type,
1318 constituents,
1319 coord_components);
1320 }
1321
1322 SpvId actual_dest_type = dest_type;
1323 if (dref)
1324 actual_dest_type = float_type;
1325
1326 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1327 actual_dest_type, load,
1328 coord,
1329 proj != 0,
1330 lod, bias, dref);
1331 spirv_builder_emit_decoration(&ctx->builder, result,
1332 SpvDecorationRelaxedPrecision);
1333
1334 if (dref) {
1335 SpvId components[4] = { result, result, result, result };
1336 result = spirv_builder_emit_composite_construct(&ctx->builder,
1337 dest_type,
1338 components,
1339 4);
1340 }
1341
1342 store_dest(ctx, &tex->dest, result, tex->dest_type);
1343 }
1344
1345 static void
1346 start_block(struct ntv_context *ctx, SpvId label)
1347 {
1348 /* terminate previous block if needed */
1349 if (ctx->block_started)
1350 spirv_builder_emit_branch(&ctx->builder, label);
1351
1352 /* start new block */
1353 spirv_builder_label(&ctx->builder, label);
1354 ctx->block_started = true;
1355 }
1356
1357 static void
1358 branch(struct ntv_context *ctx, SpvId label)
1359 {
1360 assert(ctx->block_started);
1361 spirv_builder_emit_branch(&ctx->builder, label);
1362 ctx->block_started = false;
1363 }
1364
1365 static void
1366 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1367 SpvId else_id)
1368 {
1369 assert(ctx->block_started);
1370 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1371 then_id, else_id);
1372 ctx->block_started = false;
1373 }
1374
1375 static void
1376 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1377 {
1378 switch (jump->type) {
1379 case nir_jump_break:
1380 assert(ctx->loop_break);
1381 branch(ctx, ctx->loop_break);
1382 break;
1383
1384 case nir_jump_continue:
1385 assert(ctx->loop_cont);
1386 branch(ctx, ctx->loop_cont);
1387 break;
1388
1389 default:
1390 unreachable("Unsupported jump type\n");
1391 }
1392 }
1393
1394 static void
1395 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1396 {
1397 assert(deref->deref_type == nir_deref_type_var);
1398
1399 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1400 assert(he);
1401 SpvId result = (SpvId)(intptr_t)he->data;
1402 /* uint is a bit of a lie here, it's really just an opaque type */
1403 store_dest_uint(ctx, &deref->dest, result);
1404 }
1405
1406 static void
1407 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1408 {
1409 assert(deref->deref_type == nir_deref_type_array);
1410 nir_variable *var = nir_deref_instr_get_variable(deref);
1411
1412 SpvStorageClass storage_class;
1413 switch (var->data.mode) {
1414 case nir_var_shader_in:
1415 storage_class = SpvStorageClassInput;
1416 break;
1417
1418 case nir_var_shader_out:
1419 storage_class = SpvStorageClassOutput;
1420 break;
1421
1422 default:
1423 unreachable("Unsupported nir_variable_mode\n");
1424 }
1425
1426 SpvId index = get_src_uint(ctx, &deref->arr.index);
1427
1428 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1429 storage_class,
1430 get_glsl_type(ctx, deref->type));
1431
1432 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1433 ptr_type,
1434 get_src_uint(ctx, &deref->parent),
1435 &index, 1);
1436 /* uint is a bit of a lie here, it's really just an opaque type */
1437 store_dest_uint(ctx, &deref->dest, result);
1438 }
1439
1440 static void
1441 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1442 {
1443 switch (deref->deref_type) {
1444 case nir_deref_type_var:
1445 emit_deref_var(ctx, deref);
1446 break;
1447
1448 case nir_deref_type_array:
1449 emit_deref_array(ctx, deref);
1450 break;
1451
1452 default:
1453 unreachable("unexpected deref_type");
1454 }
1455 }
1456
1457 static void
1458 emit_block(struct ntv_context *ctx, struct nir_block *block)
1459 {
1460 start_block(ctx, block_label(ctx, block));
1461 nir_foreach_instr(instr, block) {
1462 switch (instr->type) {
1463 case nir_instr_type_alu:
1464 emit_alu(ctx, nir_instr_as_alu(instr));
1465 break;
1466 case nir_instr_type_intrinsic:
1467 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1468 break;
1469 case nir_instr_type_load_const:
1470 emit_load_const(ctx, nir_instr_as_load_const(instr));
1471 break;
1472 case nir_instr_type_ssa_undef:
1473 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1474 break;
1475 case nir_instr_type_tex:
1476 emit_tex(ctx, nir_instr_as_tex(instr));
1477 break;
1478 case nir_instr_type_phi:
1479 unreachable("nir_instr_type_phi not supported");
1480 break;
1481 case nir_instr_type_jump:
1482 emit_jump(ctx, nir_instr_as_jump(instr));
1483 break;
1484 case nir_instr_type_call:
1485 unreachable("nir_instr_type_call not supported");
1486 break;
1487 case nir_instr_type_parallel_copy:
1488 unreachable("nir_instr_type_parallel_copy not supported");
1489 break;
1490 case nir_instr_type_deref:
1491 emit_deref(ctx, nir_instr_as_deref(instr));
1492 break;
1493 }
1494 }
1495 }
1496
1497 static void
1498 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1499
1500 static SpvId
1501 get_src_bool(struct ntv_context *ctx, nir_src *src)
1502 {
1503 SpvId def = get_src_uint(ctx, src);
1504 assert(nir_src_bit_size(*src) == 1);
1505 unsigned num_components = nir_src_num_components(*src);
1506 return uvec_to_bvec(ctx, def, num_components);
1507 }
1508
1509 static void
1510 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1511 {
1512 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1513
1514 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1515 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1516 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1517 SpvId else_id = endif_id;
1518
1519 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1520 if (has_else) {
1521 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1522 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1523 }
1524
1525 /* create a header-block */
1526 start_block(ctx, header_id);
1527 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1528 SpvSelectionControlMaskNone);
1529 branch_conditional(ctx, condition, then_id, else_id);
1530
1531 emit_cf_list(ctx, &if_stmt->then_list);
1532
1533 if (has_else) {
1534 if (ctx->block_started)
1535 branch(ctx, endif_id);
1536
1537 emit_cf_list(ctx, &if_stmt->else_list);
1538 }
1539
1540 start_block(ctx, endif_id);
1541 }
1542
1543 static void
1544 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1545 {
1546 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1547 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1548 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1549 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1550
1551 /* create a header-block */
1552 start_block(ctx, header_id);
1553 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1554 branch(ctx, begin_id);
1555
1556 SpvId save_break = ctx->loop_break;
1557 SpvId save_cont = ctx->loop_cont;
1558 ctx->loop_break = break_id;
1559 ctx->loop_cont = cont_id;
1560
1561 emit_cf_list(ctx, &loop->body);
1562
1563 ctx->loop_break = save_break;
1564 ctx->loop_cont = save_cont;
1565
1566 branch(ctx, cont_id);
1567 start_block(ctx, cont_id);
1568 branch(ctx, header_id);
1569
1570 start_block(ctx, break_id);
1571 }
1572
1573 static void
1574 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1575 {
1576 foreach_list_typed(nir_cf_node, node, node, list) {
1577 switch (node->type) {
1578 case nir_cf_node_block:
1579 emit_block(ctx, nir_cf_node_as_block(node));
1580 break;
1581
1582 case nir_cf_node_if:
1583 emit_if(ctx, nir_cf_node_as_if(node));
1584 break;
1585
1586 case nir_cf_node_loop:
1587 emit_loop(ctx, nir_cf_node_as_loop(node));
1588 break;
1589
1590 case nir_cf_node_function:
1591 unreachable("nir_cf_node_function not supported");
1592 break;
1593 }
1594 }
1595 }
1596
1597 struct spirv_shader *
1598 nir_to_spirv(struct nir_shader *s)
1599 {
1600 struct spirv_shader *ret = NULL;
1601
1602 struct ntv_context ctx = {};
1603
1604 switch (s->info.stage) {
1605 case MESA_SHADER_VERTEX:
1606 case MESA_SHADER_FRAGMENT:
1607 case MESA_SHADER_COMPUTE:
1608 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1609 break;
1610
1611 case MESA_SHADER_TESS_CTRL:
1612 case MESA_SHADER_TESS_EVAL:
1613 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1614 break;
1615
1616 case MESA_SHADER_GEOMETRY:
1617 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1618 break;
1619
1620 default:
1621 unreachable("invalid stage");
1622 }
1623
1624 // TODO: only enable when needed
1625 if (s->info.stage == MESA_SHADER_FRAGMENT)
1626 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1627
1628 ctx.stage = s->info.stage;
1629 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1630 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1631
1632 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1633 SpvMemoryModelGLSL450);
1634
1635 SpvExecutionModel exec_model;
1636 switch (s->info.stage) {
1637 case MESA_SHADER_VERTEX:
1638 exec_model = SpvExecutionModelVertex;
1639 break;
1640 case MESA_SHADER_TESS_CTRL:
1641 exec_model = SpvExecutionModelTessellationControl;
1642 break;
1643 case MESA_SHADER_TESS_EVAL:
1644 exec_model = SpvExecutionModelTessellationEvaluation;
1645 break;
1646 case MESA_SHADER_GEOMETRY:
1647 exec_model = SpvExecutionModelGeometry;
1648 break;
1649 case MESA_SHADER_FRAGMENT:
1650 exec_model = SpvExecutionModelFragment;
1651 break;
1652 case MESA_SHADER_COMPUTE:
1653 exec_model = SpvExecutionModelGLCompute;
1654 break;
1655 default:
1656 unreachable("invalid stage");
1657 }
1658
1659 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1660 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1661 NULL, 0);
1662 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1663 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1664
1665 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1666 _mesa_key_pointer_equal);
1667
1668 nir_foreach_variable(var, &s->inputs)
1669 emit_input(&ctx, var);
1670
1671 nir_foreach_variable(var, &s->outputs)
1672 emit_output(&ctx, var);
1673
1674 nir_foreach_variable(var, &s->uniforms)
1675 emit_uniform(&ctx, var);
1676
1677 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1678 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1679 SpvExecutionModeOriginUpperLeft);
1680 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1681 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1682 SpvExecutionModeDepthReplacing);
1683 }
1684
1685
1686 spirv_builder_function(&ctx.builder, entry_point, type_void,
1687 SpvFunctionControlMaskNone,
1688 type_main);
1689
1690 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1691 nir_metadata_require(entry, nir_metadata_block_index);
1692
1693 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1694 if (!ctx.defs)
1695 goto fail;
1696 ctx.num_defs = entry->ssa_alloc;
1697
1698 nir_index_local_regs(entry);
1699 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1700 if (!ctx.regs)
1701 goto fail;
1702 ctx.num_regs = entry->reg_alloc;
1703
1704 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1705 if (!block_ids)
1706 goto fail;
1707
1708 for (int i = 0; i < entry->num_blocks; ++i)
1709 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1710
1711 ctx.block_ids = block_ids;
1712 ctx.num_blocks = entry->num_blocks;
1713
1714 /* emit a block only for the variable declarations */
1715 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1716 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1717 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1718 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1719 SpvStorageClassFunction,
1720 type);
1721 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1722 SpvStorageClassFunction);
1723
1724 ctx.regs[reg->index] = var;
1725 }
1726
1727 emit_cf_list(&ctx, &entry->body);
1728
1729 free(ctx.defs);
1730
1731 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1732 spirv_builder_function_end(&ctx.builder);
1733
1734 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1735 "main", ctx.entry_ifaces,
1736 ctx.num_entry_ifaces);
1737
1738 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1739
1740 ret = CALLOC_STRUCT(spirv_shader);
1741 if (!ret)
1742 goto fail;
1743
1744 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1745 if (!ret->words)
1746 goto fail;
1747
1748 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1749 assert(ret->num_words == num_words);
1750
1751 return ret;
1752
1753 fail:
1754
1755 if (ret)
1756 spirv_shader_delete(ret);
1757
1758 if (ctx.vars)
1759 _mesa_hash_table_destroy(ctx.vars, NULL);
1760
1761 return NULL;
1762 }
1763
1764 void
1765 spirv_shader_delete(struct spirv_shader *s)
1766 {
1767 FREE(s->words);
1768 FREE(s);
1769 }