zink/spirv: inline get_uvec_constant into emit_load_const
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59 };
60
61 static SpvId
62 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
63 const float values[]);
64
65 static SpvId
66 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
67 const uint32_t values[]);
68
69 static SpvId
70 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
71
72 static SpvId
73 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
74 SpvId src0, SpvId src1);
75
76 static SpvId
77 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
78 SpvId src0, SpvId src1, SpvId src2);
79
80 static SpvId
81 get_bvec_type(struct ntv_context *ctx, int num_components)
82 {
83 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
84 if (num_components > 1)
85 return spirv_builder_type_vector(&ctx->builder, bool_type,
86 num_components);
87
88 assert(num_components == 1);
89 return bool_type;
90 }
91
92 static SpvId
93 block_label(struct ntv_context *ctx, nir_block *block)
94 {
95 assert(block->index < ctx->num_blocks);
96 return ctx->block_ids[block->index];
97 }
98
99 static SpvId
100 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
101 {
102 assert(bit_size == 32);
103 return spirv_builder_const_float(&ctx->builder, bit_size, value);
104 }
105
106 static SpvId
107 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
108 {
109 assert(bit_size == 32);
110 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
111 }
112
113 static SpvId
114 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
115 {
116 assert(bit_size == 32); // only 32-bit floats supported so far
117
118 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
119 if (num_components > 1)
120 return spirv_builder_type_vector(&ctx->builder, float_type,
121 num_components);
122
123 assert(num_components == 1);
124 return float_type;
125 }
126
127 static SpvId
128 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
129 {
130 assert(bit_size == 32); // only 32-bit ints supported so far
131
132 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
133 if (num_components > 1)
134 return spirv_builder_type_vector(&ctx->builder, int_type,
135 num_components);
136
137 assert(num_components == 1);
138 return int_type;
139 }
140
141 static SpvId
142 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
143 {
144 assert(bit_size == 32); // only 32-bit uints supported so far
145
146 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
147 if (num_components > 1)
148 return spirv_builder_type_vector(&ctx->builder, uint_type,
149 num_components);
150
151 assert(num_components == 1);
152 return uint_type;
153 }
154
155 static SpvId
156 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
157 {
158 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
159 nir_dest_num_components(*dest));
160 }
161
162 static SpvId
163 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
164 {
165 switch (type) {
166 case GLSL_TYPE_FLOAT:
167 return spirv_builder_type_float(&ctx->builder, 32);
168
169 case GLSL_TYPE_INT:
170 return spirv_builder_type_int(&ctx->builder, 32);
171
172 case GLSL_TYPE_UINT:
173 return spirv_builder_type_uint(&ctx->builder, 32);
174 /* TODO: handle more types */
175
176 default:
177 unreachable("unknown GLSL type");
178 }
179 }
180
181 static SpvId
182 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
183 {
184 assert(type);
185 if (glsl_type_is_scalar(type))
186 return get_glsl_basetype(ctx, glsl_get_base_type(type));
187
188 if (glsl_type_is_vector(type))
189 return spirv_builder_type_vector(&ctx->builder,
190 get_glsl_basetype(ctx, glsl_get_base_type(type)),
191 glsl_get_vector_elements(type));
192
193 if (glsl_type_is_array(type)) {
194 SpvId ret = spirv_builder_type_array(&ctx->builder,
195 get_glsl_type(ctx, glsl_get_array_element(type)),
196 emit_uint_const(ctx, 32, glsl_get_length(type)));
197 uint32_t stride = glsl_get_explicit_stride(type);
198 if (stride)
199 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
200 return ret;
201 }
202
203
204 unreachable("we shouldn't get here, I think...");
205 }
206
207 static void
208 emit_input(struct ntv_context *ctx, struct nir_variable *var)
209 {
210 SpvId var_type = get_glsl_type(ctx, var->type);
211 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
212 SpvStorageClassInput,
213 var_type);
214 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
215 SpvStorageClassInput);
216
217 if (var->name)
218 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
219
220 if (ctx->stage == MESA_SHADER_FRAGMENT) {
221 if (var->data.location >= VARYING_SLOT_VAR0 ||
222 (var->data.location >= VARYING_SLOT_COL0 &&
223 var->data.location <= VARYING_SLOT_TEX7)) {
224 spirv_builder_emit_location(&ctx->builder, var_id,
225 ctx->var_location++);
226 } else {
227 switch (var->data.location) {
228 case VARYING_SLOT_POS:
229 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
230 break;
231
232 case VARYING_SLOT_PNTC:
233 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
234 break;
235
236 default:
237 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
238 unreachable("unexpected varying slot");
239 }
240 }
241 } else {
242 spirv_builder_emit_location(&ctx->builder, var_id,
243 var->data.driver_location);
244 }
245
246 if (var->data.location_frac)
247 spirv_builder_emit_component(&ctx->builder, var_id,
248 var->data.location_frac);
249
250 if (var->data.interpolation == INTERP_MODE_FLAT)
251 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
252
253 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
254
255 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
256 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
257 }
258
259 static void
260 emit_output(struct ntv_context *ctx, struct nir_variable *var)
261 {
262 SpvId var_type = get_glsl_type(ctx, var->type);
263 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
264 SpvStorageClassOutput,
265 var_type);
266 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
267 SpvStorageClassOutput);
268 if (var->name)
269 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
270
271
272 if (ctx->stage == MESA_SHADER_VERTEX) {
273 if (var->data.location >= VARYING_SLOT_VAR0 ||
274 (var->data.location >= VARYING_SLOT_COL0 &&
275 var->data.location <= VARYING_SLOT_TEX7)) {
276 spirv_builder_emit_location(&ctx->builder, var_id,
277 ctx->var_location++);
278 } else {
279 switch (var->data.location) {
280 case VARYING_SLOT_POS:
281 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
282 break;
283
284 case VARYING_SLOT_PSIZ:
285 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
286 break;
287
288 case VARYING_SLOT_CLIP_DIST0:
289 assert(glsl_type_is_array(var->type));
290 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
291 break;
292
293 default:
294 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
295 unreachable("unexpected varying slot");
296 }
297 }
298 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
299 if (var->data.location >= FRAG_RESULT_DATA0)
300 spirv_builder_emit_location(&ctx->builder, var_id,
301 var->data.location - FRAG_RESULT_DATA0);
302 else {
303 switch (var->data.location) {
304 case FRAG_RESULT_COLOR:
305 spirv_builder_emit_location(&ctx->builder, var_id, 0);
306 break;
307
308 case FRAG_RESULT_DEPTH:
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
310 break;
311
312 default:
313 spirv_builder_emit_location(&ctx->builder, var_id,
314 var->data.driver_location);
315 }
316 }
317 }
318
319 if (var->data.location_frac)
320 spirv_builder_emit_component(&ctx->builder, var_id,
321 var->data.location_frac);
322
323 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
324
325 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
326 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
327 }
328
329 static SpvDim
330 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
331 {
332 *is_ms = false;
333 switch (gdim) {
334 case GLSL_SAMPLER_DIM_1D:
335 return SpvDim1D;
336 case GLSL_SAMPLER_DIM_2D:
337 return SpvDim2D;
338 case GLSL_SAMPLER_DIM_RECT:
339 return SpvDimRect;
340 case GLSL_SAMPLER_DIM_CUBE:
341 return SpvDimCube;
342 case GLSL_SAMPLER_DIM_3D:
343 return SpvDim3D;
344 case GLSL_SAMPLER_DIM_MS:
345 *is_ms = true;
346 return SpvDim2D;
347 default:
348 fprintf(stderr, "unknown sampler type %d\n", gdim);
349 break;
350 }
351 return SpvDim2D;
352 }
353
354 static void
355 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
356 {
357 bool is_ms;
358 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
359 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
360 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
361 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
362 SpvImageFormatUnknown);
363
364 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
365 image_type);
366 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
367 SpvStorageClassUniformConstant,
368 sampled_type);
369 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
370 SpvStorageClassUniformConstant);
371
372 if (var->name)
373 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
374
375 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
376 ctx->samplers[ctx->num_samplers++] = var_id;
377
378 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
379 var->data.descriptor_set);
380 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
381 }
382
383 static void
384 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
385 {
386 uint32_t size = glsl_count_attribute_slots(var->type, false);
387 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
388 SpvId array_length = emit_uint_const(ctx, 32, size);
389 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
390 array_length);
391 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
392
393 // wrap UBO-array in a struct
394 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
395 if (var->name) {
396 char struct_name[100];
397 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
398 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
399 }
400
401 spirv_builder_emit_decoration(&ctx->builder, struct_type,
402 SpvDecorationBlock);
403 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
404
405
406 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
407 SpvStorageClassUniform,
408 struct_type);
409
410 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
411 SpvStorageClassUniform);
412 if (var->name)
413 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
414
415 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
416 ctx->ubos[ctx->num_ubos++] = var_id;
417
418 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
419 var->data.descriptor_set);
420 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
421 }
422
423 static void
424 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
425 {
426 if (var->data.mode == nir_var_mem_ubo)
427 emit_ubo(ctx, var);
428 else {
429 assert(var->data.mode == nir_var_uniform);
430 if (glsl_type_is_sampler(var->type))
431 emit_sampler(ctx, var);
432 }
433 }
434
435 static SpvId
436 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
437 {
438 assert(ssa->index < ctx->num_defs);
439 assert(ctx->defs[ssa->index] != 0);
440 return ctx->defs[ssa->index];
441 }
442
443 static SpvId
444 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
445 {
446 assert(reg->index < ctx->num_regs);
447 assert(ctx->regs[reg->index] != 0);
448 return ctx->regs[reg->index];
449 }
450
451 static SpvId
452 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
453 {
454 assert(reg->reg);
455 assert(!reg->indirect);
456 assert(!reg->base_offset);
457
458 SpvId var = get_var_from_reg(ctx, reg->reg);
459 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
460 return spirv_builder_emit_load(&ctx->builder, type, var);
461 }
462
463 static SpvId
464 get_src_uint(struct ntv_context *ctx, nir_src *src)
465 {
466 if (src->is_ssa)
467 return get_src_uint_ssa(ctx, src->ssa);
468 else
469 return get_src_uint_reg(ctx, &src->reg);
470 }
471
472 static SpvId
473 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
474 {
475 assert(!alu->src[src].negate);
476 assert(!alu->src[src].abs);
477
478 SpvId def = get_src_uint(ctx, &alu->src[src].src);
479
480 unsigned used_channels = 0;
481 bool need_swizzle = false;
482 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
483 if (!nir_alu_instr_channel_used(alu, src, i))
484 continue;
485
486 used_channels++;
487
488 if (alu->src[src].swizzle[i] != i)
489 need_swizzle = true;
490 }
491 assert(used_channels != 0);
492
493 unsigned live_channels = nir_src_num_components(alu->src[src].src);
494 if (used_channels != live_channels)
495 need_swizzle = true;
496
497 if (!need_swizzle)
498 return def;
499
500 int bit_size = nir_src_bit_size(alu->src[src].src);
501 assert(bit_size == 32);
502
503 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
504 if (used_channels == 1) {
505 uint32_t indices[] = { alu->src[src].swizzle[0] };
506 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
507 def, indices,
508 ARRAY_SIZE(indices));
509 } else if (live_channels == 1) {
510 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
511 used_channels);
512
513 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
514 for (unsigned i = 0; i < used_channels; ++i)
515 constituents[i] = def;
516
517 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
518 constituents,
519 used_channels);
520 } else {
521 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
522 used_channels);
523
524 uint32_t components[NIR_MAX_VEC_COMPONENTS];
525 size_t num_components = 0;
526 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
527 if (!nir_alu_instr_channel_used(alu, src, i))
528 continue;
529
530 components[num_components++] = alu->src[src].swizzle[i];
531 }
532
533 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
534 def, def, components, num_components);
535 }
536 }
537
538 static void
539 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
540 {
541 assert(result != 0);
542 assert(ssa->index < ctx->num_defs);
543 ctx->defs[ssa->index] = result;
544 }
545
546 static SpvId
547 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
548 SpvId if_true, SpvId if_false)
549 {
550 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
551 }
552
553 static SpvId
554 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
555 {
556 SpvId otype = get_uvec_type(ctx, 32, num_components);
557 uint32_t zeros[4] = { 0, 0, 0, 0 };
558 uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
559 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
560 SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
561 return emit_select(ctx, otype, value, one, zero);
562 }
563
564 static SpvId
565 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
566 {
567 SpvId type = get_bvec_type(ctx, num_components);
568
569 uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
570 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
571
572 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
573 }
574
575 static SpvId
576 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
577 {
578 return emit_unop(ctx, SpvOpBitcast, type, value);
579 }
580
581 static SpvId
582 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
583 unsigned num_components)
584 {
585 SpvId type = get_uvec_type(ctx, bit_size, num_components);
586 return emit_bitcast(ctx, type, value);
587 }
588
589 static SpvId
590 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
591 unsigned num_components)
592 {
593 SpvId type = get_ivec_type(ctx, bit_size, num_components);
594 return emit_bitcast(ctx, type, value);
595 }
596
597 static SpvId
598 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
599 unsigned num_components)
600 {
601 SpvId type = get_fvec_type(ctx, bit_size, num_components);
602 return emit_bitcast(ctx, type, value);
603 }
604
605 static void
606 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
607 {
608 SpvId var = get_var_from_reg(ctx, reg->reg);
609 assert(var);
610 spirv_builder_emit_store(&ctx->builder, var, result);
611 }
612
613 static void
614 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
615 {
616 if (dest->is_ssa)
617 store_ssa_def_uint(ctx, &dest->ssa, result);
618 else
619 store_reg_def(ctx, &dest->reg, result);
620 }
621
622 static void
623 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
624 {
625 unsigned num_components = nir_dest_num_components(*dest);
626 unsigned bit_size = nir_dest_bit_size(*dest);
627
628 switch (nir_alu_type_get_base_type(type)) {
629 case nir_type_bool:
630 assert(bit_size == 1);
631 result = bvec_to_uvec(ctx, result, num_components);
632 break;
633
634 case nir_type_uint:
635 break; /* nothing to do! */
636
637 case nir_type_int:
638 case nir_type_float:
639 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
640 break;
641
642 default:
643 unreachable("unsupported nir_alu_type");
644 }
645
646 store_dest_uint(ctx, dest, result);
647 }
648
649 static SpvId
650 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
651 {
652 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
653 }
654
655 static SpvId
656 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
657 SpvId src0, SpvId src1)
658 {
659 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
660 }
661
662 static SpvId
663 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
664 SpvId src0, SpvId src1, SpvId src2)
665 {
666 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
667 }
668
669 static SpvId
670 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
671 SpvId src)
672 {
673 SpvId args[] = { src };
674 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
675 op, args, ARRAY_SIZE(args));
676 }
677
678 static SpvId
679 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
680 SpvId src0, SpvId src1)
681 {
682 SpvId args[] = { src0, src1 };
683 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
684 op, args, ARRAY_SIZE(args));
685 }
686
687 static SpvId
688 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
689 const float values[])
690 {
691 assert(bit_size == 32);
692
693 if (num_components > 1) {
694 SpvId components[num_components];
695 for (int i = 0; i < num_components; i++)
696 components[i] = emit_float_const(ctx, bit_size, values[i]);
697
698 SpvId type = get_fvec_type(ctx, bit_size, num_components);
699 return spirv_builder_const_composite(&ctx->builder, type, components,
700 num_components);
701 }
702
703 assert(num_components == 1);
704 return emit_float_const(ctx, bit_size, values[0]);
705 }
706
707 static SpvId
708 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
709 const uint32_t values[])
710 {
711 assert(bit_size == 32);
712
713 if (num_components > 1) {
714 SpvId components[num_components];
715 for (int i = 0; i < num_components; i++)
716 components[i] = emit_uint_const(ctx, bit_size, values[i]);
717
718 SpvId type = get_uvec_type(ctx, bit_size, num_components);
719 return spirv_builder_const_composite(&ctx->builder, type, components,
720 num_components);
721 }
722
723 assert(num_components == 1);
724 return emit_uint_const(ctx, bit_size, values[0]);
725 }
726
727 static inline unsigned
728 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
729 {
730 if (nir_op_infos[instr->op].input_sizes[src] > 0)
731 return nir_op_infos[instr->op].input_sizes[src];
732
733 if (instr->dest.dest.is_ssa)
734 return instr->dest.dest.ssa.num_components;
735 else
736 return instr->dest.dest.reg.reg->num_components;
737 }
738
739 static SpvId
740 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
741 {
742 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
743
744 unsigned num_components = alu_instr_src_components(alu, src);
745 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
746 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
747
748 switch (nir_alu_type_get_base_type(type)) {
749 case nir_type_bool:
750 assert(bit_size == 1);
751 return uvec_to_bvec(ctx, uint_value, num_components);
752
753 case nir_type_int:
754 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
755
756 case nir_type_uint:
757 return uint_value;
758
759 case nir_type_float:
760 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
761
762 default:
763 unreachable("unknown nir_alu_type");
764 }
765 }
766
767 static void
768 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
769 {
770 assert(!alu->dest.saturate);
771 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
772 }
773
774 static SpvId
775 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
776 {
777 unsigned num_components = nir_dest_num_components(*dest);
778 unsigned bit_size = nir_dest_bit_size(*dest);
779
780 switch (nir_alu_type_get_base_type(type)) {
781 case nir_type_bool:
782 return get_bvec_type(ctx, num_components);
783
784 case nir_type_int:
785 return get_ivec_type(ctx, bit_size, num_components);
786
787 case nir_type_uint:
788 return get_uvec_type(ctx, bit_size, num_components);
789
790 case nir_type_float:
791 return get_fvec_type(ctx, bit_size, num_components);
792
793 default:
794 unreachable("unsupported nir_alu_type");
795 }
796 }
797
798 static void
799 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
800 {
801 SpvId src[nir_op_infos[alu->op].num_inputs];
802 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
803 src[i] = get_alu_src(ctx, alu, i);
804
805 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
806 nir_op_infos[alu->op].output_type);
807 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
808 unsigned num_components = nir_dest_num_components(alu->dest.dest);
809
810 SpvId result = 0;
811 switch (alu->op) {
812 case nir_op_mov:
813 assert(nir_op_infos[alu->op].num_inputs == 1);
814 result = src[0];
815 break;
816
817 #define UNOP(nir_op, spirv_op) \
818 case nir_op: \
819 assert(nir_op_infos[alu->op].num_inputs == 1); \
820 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
821 break;
822
823 UNOP(nir_op_ineg, SpvOpSNegate)
824 UNOP(nir_op_fneg, SpvOpFNegate)
825 UNOP(nir_op_fddx, SpvOpDPdx)
826 UNOP(nir_op_fddy, SpvOpDPdy)
827 UNOP(nir_op_f2i32, SpvOpConvertFToS)
828 UNOP(nir_op_f2u32, SpvOpConvertFToU)
829 UNOP(nir_op_i2f32, SpvOpConvertSToF)
830 UNOP(nir_op_u2f32, SpvOpConvertUToF)
831 UNOP(nir_op_inot, SpvOpNot)
832 #undef UNOP
833
834 case nir_op_b2i32:
835 assert(nir_op_infos[alu->op].num_inputs == 1);
836 result = bvec_to_uvec(ctx, src[0], num_components);
837 break;
838
839 #define BUILTIN_UNOP(nir_op, spirv_op) \
840 case nir_op: \
841 assert(nir_op_infos[alu->op].num_inputs == 1); \
842 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
843 break;
844
845 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
846 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
847 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
848 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
849 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
850 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
851 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
852 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
853 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
854 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
855 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
856 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
857 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
858 #undef BUILTIN_UNOP
859
860 case nir_op_frcp: {
861 assert(nir_op_infos[alu->op].num_inputs == 1);
862 float one[4] = { 1, 1, 1, 1 };
863 src[1] = src[0];
864 src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
865 result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
866 }
867 break;
868
869 case nir_op_f2b1: {
870 assert(nir_op_infos[alu->op].num_inputs == 1);
871 float values[NIR_MAX_VEC_COMPONENTS] = { 0 };
872 SpvId zero = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
873 num_components, values);
874 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0], zero);
875 } break;
876
877
878 #define BINOP(nir_op, spirv_op) \
879 case nir_op: \
880 assert(nir_op_infos[alu->op].num_inputs == 2); \
881 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
882 break;
883
884 BINOP(nir_op_iadd, SpvOpIAdd)
885 BINOP(nir_op_isub, SpvOpISub)
886 BINOP(nir_op_imul, SpvOpIMul)
887 BINOP(nir_op_idiv, SpvOpSDiv)
888 BINOP(nir_op_udiv, SpvOpUDiv)
889 BINOP(nir_op_fadd, SpvOpFAdd)
890 BINOP(nir_op_fsub, SpvOpFSub)
891 BINOP(nir_op_fmul, SpvOpFMul)
892 BINOP(nir_op_fdiv, SpvOpFDiv)
893 BINOP(nir_op_fmod, SpvOpFMod)
894 BINOP(nir_op_ilt, SpvOpSLessThan)
895 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
896 BINOP(nir_op_ieq, SpvOpIEqual)
897 BINOP(nir_op_ine, SpvOpINotEqual)
898 BINOP(nir_op_flt, SpvOpFOrdLessThan)
899 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
900 BINOP(nir_op_feq, SpvOpFOrdEqual)
901 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
902 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
903 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
904 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
905 BINOP(nir_op_iand, SpvOpBitwiseAnd)
906 BINOP(nir_op_ior, SpvOpBitwiseOr)
907 #undef BINOP
908
909 #define BUILTIN_BINOP(nir_op, spirv_op) \
910 case nir_op: \
911 assert(nir_op_infos[alu->op].num_inputs == 2); \
912 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
913 break;
914
915 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
916 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
917 #undef BUILTIN_BINOP
918
919 case nir_op_fdot2:
920 case nir_op_fdot3:
921 case nir_op_fdot4:
922 assert(nir_op_infos[alu->op].num_inputs == 2);
923 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
924 break;
925
926 case nir_op_seq:
927 case nir_op_sne:
928 case nir_op_slt:
929 case nir_op_sge: {
930 assert(nir_op_infos[alu->op].num_inputs == 2);
931 int num_components = nir_dest_num_components(alu->dest.dest);
932 SpvId bool_type = get_bvec_type(ctx, num_components);
933
934 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
935 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
936 if (num_components > 1) {
937 SpvId zero_comps[num_components], one_comps[num_components];
938 for (int i = 0; i < num_components; i++) {
939 zero_comps[i] = zero;
940 one_comps[i] = one;
941 }
942
943 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
944 zero_comps, num_components);
945 one = spirv_builder_const_composite(&ctx->builder, dest_type,
946 one_comps, num_components);
947 }
948
949 SpvOp op;
950 switch (alu->op) {
951 case nir_op_seq: op = SpvOpFOrdEqual; break;
952 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
953 case nir_op_slt: op = SpvOpFOrdLessThan; break;
954 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
955 default: unreachable("unexpected op");
956 }
957
958 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
959 result = emit_select(ctx, dest_type, result, one, zero);
960 }
961 break;
962
963 case nir_op_fcsel: {
964 assert(nir_op_infos[alu->op].num_inputs == 3);
965 int num_components = nir_dest_num_components(alu->dest.dest);
966 SpvId bool_type = get_bvec_type(ctx, num_components);
967
968 float zero[4] = { 0, 0, 0, 0 };
969 SpvId cmp = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
970 num_components, zero);
971
972 result = emit_binop(ctx, SpvOpFOrdGreaterThan, bool_type, src[0], cmp);
973 result = emit_select(ctx, dest_type, result, src[1], src[2]);
974 }
975 break;
976
977 case nir_op_bcsel:
978 assert(nir_op_infos[alu->op].num_inputs == 3);
979 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
980 break;
981
982 case nir_op_vec2:
983 case nir_op_vec3:
984 case nir_op_vec4: {
985 int num_inputs = nir_op_infos[alu->op].num_inputs;
986 assert(2 <= num_inputs && num_inputs <= 4);
987 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
988 src, num_inputs);
989 }
990 break;
991
992 default:
993 fprintf(stderr, "emit_alu: not implemented (%s)\n",
994 nir_op_infos[alu->op].name);
995
996 unreachable("unsupported opcode");
997 return;
998 }
999
1000 store_alu_result(ctx, alu, result);
1001 }
1002
1003 static void
1004 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1005 {
1006 uint32_t values[NIR_MAX_VEC_COMPONENTS];
1007 for (int i = 0; i < load_const->def.num_components; ++i)
1008 values[i] = load_const->value[i].u32;
1009
1010 unsigned bit_size = load_const->def.bit_size;
1011 unsigned num_components = load_const->def.num_components;
1012
1013 SpvId constant;
1014 if (num_components > 1) {
1015 SpvId components[num_components];
1016 for (int i = 0; i < num_components; i++)
1017 components[i] = emit_uint_const(ctx, bit_size, values[i]);
1018
1019 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1020 constant = spirv_builder_const_composite(&ctx->builder, type,
1021 components, num_components);
1022 } else {
1023 assert(num_components == 1);
1024 constant = emit_uint_const(ctx, bit_size, values[0]);
1025 }
1026
1027 store_ssa_def_uint(ctx, &load_const->def, constant);
1028 }
1029
1030 static void
1031 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1032 {
1033 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1034 assert(const_block_index); // no dynamic indexing for now
1035 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1036
1037 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1038 if (const_offset) {
1039 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1040 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1041 SpvStorageClassUniform,
1042 uvec4_type);
1043
1044 unsigned idx = const_offset->u32;
1045 SpvId member = emit_uint_const(ctx, 32, 0);
1046 SpvId offset = emit_uint_const(ctx, 32, idx);
1047 SpvId offsets[] = { member, offset };
1048 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1049 ctx->ubos[0], offsets,
1050 ARRAY_SIZE(offsets));
1051 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1052
1053 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1054 unsigned num_components = nir_dest_num_components(intr->dest);
1055 if (num_components == 1) {
1056 uint32_t components[] = { 0 };
1057 result = spirv_builder_emit_composite_extract(&ctx->builder,
1058 type,
1059 result, components,
1060 1);
1061 } else if (num_components < 4) {
1062 SpvId constituents[num_components];
1063 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1064 for (uint32_t i = 0; i < num_components; ++i)
1065 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1066 uint_type,
1067 result, &i,
1068 1);
1069
1070 result = spirv_builder_emit_composite_construct(&ctx->builder,
1071 type,
1072 constituents,
1073 num_components);
1074 }
1075
1076 store_dest_uint(ctx, &intr->dest, result);
1077 } else
1078 unreachable("uniform-addressing not yet supported");
1079 }
1080
1081 static void
1082 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1083 {
1084 assert(ctx->block_started);
1085 spirv_builder_emit_kill(&ctx->builder);
1086 /* discard is weird in NIR, so let's just create an unreachable block after
1087 it and hope that the vulkan driver will DCE any instructinos in it. */
1088 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1089 }
1090
1091 static void
1092 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1093 {
1094 /* uint is a bit of a lie here; it's really just a pointer */
1095 SpvId ptr = get_src_uint(ctx, intr->src);
1096
1097 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1098 SpvId result = spirv_builder_emit_load(&ctx->builder,
1099 get_glsl_type(ctx, var->type),
1100 ptr);
1101 unsigned num_components = nir_dest_num_components(intr->dest);
1102 unsigned bit_size = nir_dest_bit_size(intr->dest);
1103 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1104 store_dest_uint(ctx, &intr->dest, result);
1105 }
1106
1107 static void
1108 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1109 {
1110 /* uint is a bit of a lie here; it's really just a pointer */
1111 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1112 SpvId src = get_src_uint(ctx, &intr->src[1]);
1113
1114 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1115 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1116 SpvId result = emit_bitcast(ctx, type, src);
1117 spirv_builder_emit_store(&ctx->builder, ptr, result);
1118 }
1119
1120 static void
1121 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1122 {
1123 switch (intr->intrinsic) {
1124 case nir_intrinsic_load_ubo:
1125 emit_load_ubo(ctx, intr);
1126 break;
1127
1128 case nir_intrinsic_discard:
1129 emit_discard(ctx, intr);
1130 break;
1131
1132 case nir_intrinsic_load_deref:
1133 emit_load_deref(ctx, intr);
1134 break;
1135
1136 case nir_intrinsic_store_deref:
1137 emit_store_deref(ctx, intr);
1138 break;
1139
1140 default:
1141 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1142 nir_intrinsic_infos[intr->intrinsic].name);
1143 unreachable("unsupported intrinsic");
1144 }
1145 }
1146
1147 static void
1148 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1149 {
1150 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1151 undef->def.num_components);
1152
1153 store_ssa_def_uint(ctx, &undef->def,
1154 spirv_builder_emit_undef(&ctx->builder, type));
1155 }
1156
1157 static SpvId
1158 get_src_float(struct ntv_context *ctx, nir_src *src)
1159 {
1160 SpvId def = get_src_uint(ctx, src);
1161 unsigned num_components = nir_src_num_components(*src);
1162 unsigned bit_size = nir_src_bit_size(*src);
1163 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1164 }
1165
1166 static void
1167 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1168 {
1169 assert(tex->op == nir_texop_tex ||
1170 tex->op == nir_texop_txb ||
1171 tex->op == nir_texop_txl);
1172 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1173 assert(tex->texture_index == tex->sampler_index);
1174
1175 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1176 unsigned coord_components;
1177 for (unsigned i = 0; i < tex->num_srcs; i++) {
1178 switch (tex->src[i].src_type) {
1179 case nir_tex_src_coord:
1180 coord = get_src_float(ctx, &tex->src[i].src);
1181 coord_components = nir_src_num_components(tex->src[i].src);
1182 break;
1183
1184 case nir_tex_src_projector:
1185 assert(nir_src_num_components(tex->src[i].src) == 1);
1186 proj = get_src_float(ctx, &tex->src[i].src);
1187 assert(proj != 0);
1188 break;
1189
1190 case nir_tex_src_bias:
1191 assert(tex->op == nir_texop_txb);
1192 bias = get_src_float(ctx, &tex->src[i].src);
1193 assert(bias != 0);
1194 break;
1195
1196 case nir_tex_src_lod:
1197 assert(nir_src_num_components(tex->src[i].src) == 1);
1198 lod = get_src_float(ctx, &tex->src[i].src);
1199 assert(lod != 0);
1200 break;
1201
1202 case nir_tex_src_comparator:
1203 assert(nir_src_num_components(tex->src[i].src) == 1);
1204 dref = get_src_float(ctx, &tex->src[i].src);
1205 assert(dref != 0);
1206 break;
1207
1208 default:
1209 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1210 unreachable("unknown texture source");
1211 }
1212 }
1213
1214 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1215 lod = emit_float_const(ctx, 32, 0.0f);
1216 assert(lod != 0);
1217 }
1218
1219 bool is_ms;
1220 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1221 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1222 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1223 dimension, false, tex->is_array, is_ms, 1,
1224 SpvImageFormatUnknown);
1225 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1226 image_type);
1227
1228 assert(tex->texture_index < ctx->num_samplers);
1229 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1230 ctx->samplers[tex->texture_index]);
1231
1232 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1233
1234 if (proj) {
1235 SpvId constituents[coord_components + 1];
1236 if (coord_components == 1)
1237 constituents[0] = coord;
1238 else {
1239 assert(coord_components > 1);
1240 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1241 for (uint32_t i = 0; i < coord_components; ++i)
1242 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1243 float_type,
1244 coord,
1245 &i, 1);
1246 }
1247
1248 constituents[coord_components++] = proj;
1249
1250 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1251 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1252 vec_type,
1253 constituents,
1254 coord_components);
1255 }
1256
1257 SpvId actual_dest_type = dest_type;
1258 if (dref)
1259 actual_dest_type = float_type;
1260
1261 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1262 actual_dest_type, load,
1263 coord,
1264 proj != 0,
1265 lod, bias, dref);
1266 spirv_builder_emit_decoration(&ctx->builder, result,
1267 SpvDecorationRelaxedPrecision);
1268
1269 if (dref) {
1270 SpvId components[4] = { result, result, result, result };
1271 result = spirv_builder_emit_composite_construct(&ctx->builder,
1272 dest_type,
1273 components,
1274 4);
1275 }
1276
1277 store_dest(ctx, &tex->dest, result, tex->dest_type);
1278 }
1279
1280 static void
1281 start_block(struct ntv_context *ctx, SpvId label)
1282 {
1283 /* terminate previous block if needed */
1284 if (ctx->block_started)
1285 spirv_builder_emit_branch(&ctx->builder, label);
1286
1287 /* start new block */
1288 spirv_builder_label(&ctx->builder, label);
1289 ctx->block_started = true;
1290 }
1291
1292 static void
1293 branch(struct ntv_context *ctx, SpvId label)
1294 {
1295 assert(ctx->block_started);
1296 spirv_builder_emit_branch(&ctx->builder, label);
1297 ctx->block_started = false;
1298 }
1299
1300 static void
1301 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1302 SpvId else_id)
1303 {
1304 assert(ctx->block_started);
1305 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1306 then_id, else_id);
1307 ctx->block_started = false;
1308 }
1309
1310 static void
1311 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1312 {
1313 switch (jump->type) {
1314 case nir_jump_break:
1315 assert(ctx->loop_break);
1316 branch(ctx, ctx->loop_break);
1317 break;
1318
1319 case nir_jump_continue:
1320 assert(ctx->loop_cont);
1321 branch(ctx, ctx->loop_cont);
1322 break;
1323
1324 default:
1325 unreachable("Unsupported jump type\n");
1326 }
1327 }
1328
1329 static void
1330 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1331 {
1332 assert(deref->deref_type == nir_deref_type_var);
1333
1334 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1335 assert(he);
1336 SpvId result = (SpvId)(intptr_t)he->data;
1337 /* uint is a bit of a lie here, it's really just an opaque type */
1338 store_dest_uint(ctx, &deref->dest, result);
1339 }
1340
1341 static void
1342 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1343 {
1344 assert(deref->deref_type == nir_deref_type_array);
1345 nir_variable *var = nir_deref_instr_get_variable(deref);
1346
1347 SpvStorageClass storage_class;
1348 switch (var->data.mode) {
1349 case nir_var_shader_in:
1350 storage_class = SpvStorageClassInput;
1351 break;
1352
1353 case nir_var_shader_out:
1354 storage_class = SpvStorageClassOutput;
1355 break;
1356
1357 default:
1358 unreachable("Unsupported nir_variable_mode\n");
1359 }
1360
1361 SpvId index = get_src_uint(ctx, &deref->arr.index);
1362
1363 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1364 storage_class,
1365 get_glsl_type(ctx, deref->type));
1366
1367 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1368 ptr_type,
1369 get_src_uint(ctx, &deref->parent),
1370 &index, 1);
1371 /* uint is a bit of a lie here, it's really just an opaque type */
1372 store_dest_uint(ctx, &deref->dest, result);
1373 }
1374
1375 static void
1376 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1377 {
1378 switch (deref->deref_type) {
1379 case nir_deref_type_var:
1380 emit_deref_var(ctx, deref);
1381 break;
1382
1383 case nir_deref_type_array:
1384 emit_deref_array(ctx, deref);
1385 break;
1386
1387 default:
1388 unreachable("unexpected deref_type");
1389 }
1390 }
1391
1392 static void
1393 emit_block(struct ntv_context *ctx, struct nir_block *block)
1394 {
1395 start_block(ctx, block_label(ctx, block));
1396 nir_foreach_instr(instr, block) {
1397 switch (instr->type) {
1398 case nir_instr_type_alu:
1399 emit_alu(ctx, nir_instr_as_alu(instr));
1400 break;
1401 case nir_instr_type_intrinsic:
1402 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1403 break;
1404 case nir_instr_type_load_const:
1405 emit_load_const(ctx, nir_instr_as_load_const(instr));
1406 break;
1407 case nir_instr_type_ssa_undef:
1408 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1409 break;
1410 case nir_instr_type_tex:
1411 emit_tex(ctx, nir_instr_as_tex(instr));
1412 break;
1413 case nir_instr_type_phi:
1414 unreachable("nir_instr_type_phi not supported");
1415 break;
1416 case nir_instr_type_jump:
1417 emit_jump(ctx, nir_instr_as_jump(instr));
1418 break;
1419 case nir_instr_type_call:
1420 unreachable("nir_instr_type_call not supported");
1421 break;
1422 case nir_instr_type_parallel_copy:
1423 unreachable("nir_instr_type_parallel_copy not supported");
1424 break;
1425 case nir_instr_type_deref:
1426 emit_deref(ctx, nir_instr_as_deref(instr));
1427 break;
1428 }
1429 }
1430 }
1431
1432 static void
1433 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1434
1435 static SpvId
1436 get_src_bool(struct ntv_context *ctx, nir_src *src)
1437 {
1438 SpvId def = get_src_uint(ctx, src);
1439 assert(nir_src_bit_size(*src) == 32);
1440 unsigned num_components = nir_src_num_components(*src);
1441 return uvec_to_bvec(ctx, def, num_components);
1442 }
1443
1444 static void
1445 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1446 {
1447 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1448
1449 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1450 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1451 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1452 SpvId else_id = endif_id;
1453
1454 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1455 if (has_else) {
1456 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1457 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1458 }
1459
1460 /* create a header-block */
1461 start_block(ctx, header_id);
1462 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1463 SpvSelectionControlMaskNone);
1464 branch_conditional(ctx, condition, then_id, else_id);
1465
1466 emit_cf_list(ctx, &if_stmt->then_list);
1467
1468 if (has_else) {
1469 if (ctx->block_started)
1470 branch(ctx, endif_id);
1471
1472 emit_cf_list(ctx, &if_stmt->else_list);
1473 }
1474
1475 start_block(ctx, endif_id);
1476 }
1477
1478 static void
1479 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1480 {
1481 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1482 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1483 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1484 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1485
1486 /* create a header-block */
1487 start_block(ctx, header_id);
1488 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1489 branch(ctx, begin_id);
1490
1491 SpvId save_break = ctx->loop_break;
1492 SpvId save_cont = ctx->loop_cont;
1493 ctx->loop_break = break_id;
1494 ctx->loop_cont = cont_id;
1495
1496 emit_cf_list(ctx, &loop->body);
1497
1498 ctx->loop_break = save_break;
1499 ctx->loop_cont = save_cont;
1500
1501 branch(ctx, cont_id);
1502 start_block(ctx, cont_id);
1503 branch(ctx, header_id);
1504
1505 start_block(ctx, break_id);
1506 }
1507
1508 static void
1509 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1510 {
1511 foreach_list_typed(nir_cf_node, node, node, list) {
1512 switch (node->type) {
1513 case nir_cf_node_block:
1514 emit_block(ctx, nir_cf_node_as_block(node));
1515 break;
1516
1517 case nir_cf_node_if:
1518 emit_if(ctx, nir_cf_node_as_if(node));
1519 break;
1520
1521 case nir_cf_node_loop:
1522 emit_loop(ctx, nir_cf_node_as_loop(node));
1523 break;
1524
1525 case nir_cf_node_function:
1526 unreachable("nir_cf_node_function not supported");
1527 break;
1528 }
1529 }
1530 }
1531
1532 struct spirv_shader *
1533 nir_to_spirv(struct nir_shader *s)
1534 {
1535 struct spirv_shader *ret = NULL;
1536
1537 struct ntv_context ctx = {};
1538
1539 switch (s->info.stage) {
1540 case MESA_SHADER_VERTEX:
1541 case MESA_SHADER_FRAGMENT:
1542 case MESA_SHADER_COMPUTE:
1543 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1544 break;
1545
1546 case MESA_SHADER_TESS_CTRL:
1547 case MESA_SHADER_TESS_EVAL:
1548 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1549 break;
1550
1551 case MESA_SHADER_GEOMETRY:
1552 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1553 break;
1554
1555 default:
1556 unreachable("invalid stage");
1557 }
1558
1559 // TODO: only enable when needed
1560 if (s->info.stage == MESA_SHADER_FRAGMENT)
1561 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1562
1563 ctx.stage = s->info.stage;
1564 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1565 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1566
1567 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1568 SpvMemoryModelGLSL450);
1569
1570 SpvExecutionModel exec_model;
1571 switch (s->info.stage) {
1572 case MESA_SHADER_VERTEX:
1573 exec_model = SpvExecutionModelVertex;
1574 break;
1575 case MESA_SHADER_TESS_CTRL:
1576 exec_model = SpvExecutionModelTessellationControl;
1577 break;
1578 case MESA_SHADER_TESS_EVAL:
1579 exec_model = SpvExecutionModelTessellationEvaluation;
1580 break;
1581 case MESA_SHADER_GEOMETRY:
1582 exec_model = SpvExecutionModelGeometry;
1583 break;
1584 case MESA_SHADER_FRAGMENT:
1585 exec_model = SpvExecutionModelFragment;
1586 break;
1587 case MESA_SHADER_COMPUTE:
1588 exec_model = SpvExecutionModelGLCompute;
1589 break;
1590 default:
1591 unreachable("invalid stage");
1592 }
1593
1594 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1595 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1596 NULL, 0);
1597 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1598 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1599
1600 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1601 _mesa_key_pointer_equal);
1602
1603 nir_foreach_variable(var, &s->inputs)
1604 emit_input(&ctx, var);
1605
1606 nir_foreach_variable(var, &s->outputs)
1607 emit_output(&ctx, var);
1608
1609 nir_foreach_variable(var, &s->uniforms)
1610 emit_uniform(&ctx, var);
1611
1612 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1613 "main", ctx.entry_ifaces,
1614 ctx.num_entry_ifaces);
1615 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1616 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1617 SpvExecutionModeOriginUpperLeft);
1618 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1619 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1620 SpvExecutionModeDepthReplacing);
1621 }
1622
1623
1624 spirv_builder_function(&ctx.builder, entry_point, type_void,
1625 SpvFunctionControlMaskNone,
1626 type_main);
1627
1628 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1629 nir_metadata_require(entry, nir_metadata_block_index);
1630
1631 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1632 if (!ctx.defs)
1633 goto fail;
1634 ctx.num_defs = entry->ssa_alloc;
1635
1636 nir_index_local_regs(entry);
1637 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1638 if (!ctx.regs)
1639 goto fail;
1640 ctx.num_regs = entry->reg_alloc;
1641
1642 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1643 if (!block_ids)
1644 goto fail;
1645
1646 for (int i = 0; i < entry->num_blocks; ++i)
1647 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1648
1649 ctx.block_ids = block_ids;
1650 ctx.num_blocks = entry->num_blocks;
1651
1652 /* emit a block only for the variable declarations */
1653 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1654 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1655 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1656 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1657 SpvStorageClassFunction,
1658 type);
1659 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1660 SpvStorageClassFunction);
1661
1662 ctx.regs[reg->index] = var;
1663 }
1664
1665 emit_cf_list(&ctx, &entry->body);
1666
1667 free(ctx.defs);
1668
1669 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1670 spirv_builder_function_end(&ctx.builder);
1671
1672 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1673
1674 ret = CALLOC_STRUCT(spirv_shader);
1675 if (!ret)
1676 goto fail;
1677
1678 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1679 if (!ret->words)
1680 goto fail;
1681
1682 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1683 assert(ret->num_words == num_words);
1684
1685 return ret;
1686
1687 fail:
1688
1689 if (ret)
1690 spirv_shader_delete(ret);
1691
1692 if (ctx.vars)
1693 _mesa_hash_table_destroy(ctx.vars, NULL);
1694
1695 return NULL;
1696 }
1697
1698 void
1699 spirv_shader_delete(struct spirv_shader *s)
1700 {
1701 FREE(s->words);
1702 FREE(s);
1703 }