zink/spirv: implement emit_float_const helper
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59 };
60
61 static SpvId
62 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
63 const float values[]);
64
65 static SpvId
66 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
67 const uint32_t values[]);
68
69 static SpvId
70 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
71
72 static SpvId
73 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
74 SpvId src0, SpvId src1);
75
76 static SpvId
77 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
78 SpvId src0, SpvId src1, SpvId src2);
79
80 static SpvId
81 get_bvec_type(struct ntv_context *ctx, int num_components)
82 {
83 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
84 if (num_components > 1)
85 return spirv_builder_type_vector(&ctx->builder, bool_type,
86 num_components);
87
88 assert(num_components == 1);
89 return bool_type;
90 }
91
92 static SpvId
93 block_label(struct ntv_context *ctx, nir_block *block)
94 {
95 assert(block->index < ctx->num_blocks);
96 return ctx->block_ids[block->index];
97 }
98
99 static SpvId
100 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
101 {
102 assert(bit_size == 32); // only 32-bit floats supported so far
103
104 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
105 if (num_components > 1)
106 return spirv_builder_type_vector(&ctx->builder, float_type,
107 num_components);
108
109 assert(num_components == 1);
110 return float_type;
111 }
112
113 static SpvId
114 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
115 {
116 assert(bit_size == 32); // only 32-bit ints supported so far
117
118 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
119 if (num_components > 1)
120 return spirv_builder_type_vector(&ctx->builder, int_type,
121 num_components);
122
123 assert(num_components == 1);
124 return int_type;
125 }
126
127 static SpvId
128 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
129 {
130 assert(bit_size == 32); // only 32-bit uints supported so far
131
132 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
133 if (num_components > 1)
134 return spirv_builder_type_vector(&ctx->builder, uint_type,
135 num_components);
136
137 assert(num_components == 1);
138 return uint_type;
139 }
140
141 static SpvId
142 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
143 {
144 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
145 nir_dest_num_components(*dest));
146 }
147
148 static SpvId
149 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
150 {
151 switch (type) {
152 case GLSL_TYPE_FLOAT:
153 return spirv_builder_type_float(&ctx->builder, 32);
154
155 case GLSL_TYPE_INT:
156 return spirv_builder_type_int(&ctx->builder, 32);
157
158 case GLSL_TYPE_UINT:
159 return spirv_builder_type_uint(&ctx->builder, 32);
160 /* TODO: handle more types */
161
162 default:
163 unreachable("unknown GLSL type");
164 }
165 }
166
167 static SpvId
168 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
169 {
170 assert(type);
171 if (glsl_type_is_scalar(type))
172 return get_glsl_basetype(ctx, glsl_get_base_type(type));
173
174 if (glsl_type_is_vector(type))
175 return spirv_builder_type_vector(&ctx->builder,
176 get_glsl_basetype(ctx, glsl_get_base_type(type)),
177 glsl_get_vector_elements(type));
178
179 if (glsl_type_is_array(type)) {
180 SpvId ret = spirv_builder_type_array(&ctx->builder,
181 get_glsl_type(ctx, glsl_get_array_element(type)),
182 spirv_builder_const_uint(&ctx->builder, 32, glsl_get_length(type)));
183 uint32_t stride = glsl_get_explicit_stride(type);
184 if (stride)
185 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
186 return ret;
187 }
188
189
190 unreachable("we shouldn't get here, I think...");
191 }
192
193 static void
194 emit_input(struct ntv_context *ctx, struct nir_variable *var)
195 {
196 SpvId var_type = get_glsl_type(ctx, var->type);
197 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
198 SpvStorageClassInput,
199 var_type);
200 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
201 SpvStorageClassInput);
202
203 if (var->name)
204 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
205
206 if (ctx->stage == MESA_SHADER_FRAGMENT) {
207 if (var->data.location >= VARYING_SLOT_VAR0 ||
208 (var->data.location >= VARYING_SLOT_COL0 &&
209 var->data.location <= VARYING_SLOT_TEX7)) {
210 spirv_builder_emit_location(&ctx->builder, var_id,
211 ctx->var_location++);
212 } else {
213 switch (var->data.location) {
214 case VARYING_SLOT_POS:
215 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
216 break;
217
218 case VARYING_SLOT_PNTC:
219 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
220 break;
221
222 default:
223 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
224 unreachable("unexpected varying slot");
225 }
226 }
227 } else {
228 spirv_builder_emit_location(&ctx->builder, var_id,
229 var->data.driver_location);
230 }
231
232 if (var->data.location_frac)
233 spirv_builder_emit_component(&ctx->builder, var_id,
234 var->data.location_frac);
235
236 if (var->data.interpolation == INTERP_MODE_FLAT)
237 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
238
239 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
240
241 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
242 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
243 }
244
245 static void
246 emit_output(struct ntv_context *ctx, struct nir_variable *var)
247 {
248 SpvId var_type = get_glsl_type(ctx, var->type);
249 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
250 SpvStorageClassOutput,
251 var_type);
252 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
253 SpvStorageClassOutput);
254 if (var->name)
255 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
256
257
258 if (ctx->stage == MESA_SHADER_VERTEX) {
259 if (var->data.location >= VARYING_SLOT_VAR0 ||
260 (var->data.location >= VARYING_SLOT_COL0 &&
261 var->data.location <= VARYING_SLOT_TEX7)) {
262 spirv_builder_emit_location(&ctx->builder, var_id,
263 ctx->var_location++);
264 } else {
265 switch (var->data.location) {
266 case VARYING_SLOT_POS:
267 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
268 break;
269
270 case VARYING_SLOT_PSIZ:
271 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
272 break;
273
274 case VARYING_SLOT_CLIP_DIST0:
275 assert(glsl_type_is_array(var->type));
276 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
277 break;
278
279 default:
280 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
281 unreachable("unexpected varying slot");
282 }
283 }
284 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
285 if (var->data.location >= FRAG_RESULT_DATA0)
286 spirv_builder_emit_location(&ctx->builder, var_id,
287 var->data.location - FRAG_RESULT_DATA0);
288 else {
289 switch (var->data.location) {
290 case FRAG_RESULT_COLOR:
291 spirv_builder_emit_location(&ctx->builder, var_id, 0);
292 break;
293
294 case FRAG_RESULT_DEPTH:
295 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
296 break;
297
298 default:
299 spirv_builder_emit_location(&ctx->builder, var_id,
300 var->data.driver_location);
301 }
302 }
303 }
304
305 if (var->data.location_frac)
306 spirv_builder_emit_component(&ctx->builder, var_id,
307 var->data.location_frac);
308
309 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
310
311 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
312 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
313 }
314
315 static SpvDim
316 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
317 {
318 *is_ms = false;
319 switch (gdim) {
320 case GLSL_SAMPLER_DIM_1D:
321 return SpvDim1D;
322 case GLSL_SAMPLER_DIM_2D:
323 return SpvDim2D;
324 case GLSL_SAMPLER_DIM_RECT:
325 return SpvDimRect;
326 case GLSL_SAMPLER_DIM_CUBE:
327 return SpvDimCube;
328 case GLSL_SAMPLER_DIM_3D:
329 return SpvDim3D;
330 case GLSL_SAMPLER_DIM_MS:
331 *is_ms = true;
332 return SpvDim2D;
333 default:
334 fprintf(stderr, "unknown sampler type %d\n", gdim);
335 break;
336 }
337 return SpvDim2D;
338 }
339
340 static void
341 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
342 {
343 bool is_ms;
344 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
345 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
346 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
347 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
348 SpvImageFormatUnknown);
349
350 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
351 image_type);
352 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
353 SpvStorageClassUniformConstant,
354 sampled_type);
355 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
356 SpvStorageClassUniformConstant);
357
358 if (var->name)
359 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
360
361 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
362 ctx->samplers[ctx->num_samplers++] = var_id;
363
364 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
365 var->data.descriptor_set);
366 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
367 }
368
369 static void
370 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
371 {
372 uint32_t size = glsl_count_attribute_slots(var->type, false);
373 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
374 SpvId array_length = spirv_builder_const_uint(&ctx->builder, 32, size);
375 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
376 array_length);
377 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
378
379 // wrap UBO-array in a struct
380 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
381 if (var->name) {
382 char struct_name[100];
383 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
384 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
385 }
386
387 spirv_builder_emit_decoration(&ctx->builder, struct_type,
388 SpvDecorationBlock);
389 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
390
391
392 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
393 SpvStorageClassUniform,
394 struct_type);
395
396 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
397 SpvStorageClassUniform);
398 if (var->name)
399 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
400
401 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
402 ctx->ubos[ctx->num_ubos++] = var_id;
403
404 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
405 var->data.descriptor_set);
406 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
407 }
408
409 static void
410 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
411 {
412 if (var->data.mode == nir_var_mem_ubo)
413 emit_ubo(ctx, var);
414 else {
415 assert(var->data.mode == nir_var_uniform);
416 if (glsl_type_is_sampler(var->type))
417 emit_sampler(ctx, var);
418 }
419 }
420
421 static SpvId
422 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
423 {
424 assert(ssa->index < ctx->num_defs);
425 assert(ctx->defs[ssa->index] != 0);
426 return ctx->defs[ssa->index];
427 }
428
429 static SpvId
430 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
431 {
432 assert(reg->index < ctx->num_regs);
433 assert(ctx->regs[reg->index] != 0);
434 return ctx->regs[reg->index];
435 }
436
437 static SpvId
438 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
439 {
440 assert(reg->reg);
441 assert(!reg->indirect);
442 assert(!reg->base_offset);
443
444 SpvId var = get_var_from_reg(ctx, reg->reg);
445 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
446 return spirv_builder_emit_load(&ctx->builder, type, var);
447 }
448
449 static SpvId
450 get_src_uint(struct ntv_context *ctx, nir_src *src)
451 {
452 if (src->is_ssa)
453 return get_src_uint_ssa(ctx, src->ssa);
454 else
455 return get_src_uint_reg(ctx, &src->reg);
456 }
457
458 static SpvId
459 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
460 {
461 assert(!alu->src[src].negate);
462 assert(!alu->src[src].abs);
463
464 SpvId def = get_src_uint(ctx, &alu->src[src].src);
465
466 unsigned used_channels = 0;
467 bool need_swizzle = false;
468 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
469 if (!nir_alu_instr_channel_used(alu, src, i))
470 continue;
471
472 used_channels++;
473
474 if (alu->src[src].swizzle[i] != i)
475 need_swizzle = true;
476 }
477 assert(used_channels != 0);
478
479 unsigned live_channels = nir_src_num_components(alu->src[src].src);
480 if (used_channels != live_channels)
481 need_swizzle = true;
482
483 if (!need_swizzle)
484 return def;
485
486 int bit_size = nir_src_bit_size(alu->src[src].src);
487 assert(bit_size == 32);
488
489 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
490 if (used_channels == 1) {
491 uint32_t indices[] = { alu->src[src].swizzle[0] };
492 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
493 def, indices,
494 ARRAY_SIZE(indices));
495 } else if (live_channels == 1) {
496 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
497 used_channels);
498
499 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
500 for (unsigned i = 0; i < used_channels; ++i)
501 constituents[i] = def;
502
503 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
504 constituents,
505 used_channels);
506 } else {
507 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
508 used_channels);
509
510 uint32_t components[NIR_MAX_VEC_COMPONENTS];
511 size_t num_components = 0;
512 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
513 if (!nir_alu_instr_channel_used(alu, src, i))
514 continue;
515
516 components[num_components++] = alu->src[src].swizzle[i];
517 }
518
519 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
520 def, def, components, num_components);
521 }
522 }
523
524 static void
525 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
526 {
527 assert(result != 0);
528 assert(ssa->index < ctx->num_defs);
529 ctx->defs[ssa->index] = result;
530 }
531
532 static SpvId
533 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
534 SpvId if_true, SpvId if_false)
535 {
536 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
537 }
538
539 static SpvId
540 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
541 {
542 SpvId otype = get_uvec_type(ctx, 32, num_components);
543 uint32_t zeros[4] = { 0, 0, 0, 0 };
544 uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
545 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
546 SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
547 return emit_select(ctx, otype, value, one, zero);
548 }
549
550 static SpvId
551 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
552 {
553 SpvId type = get_bvec_type(ctx, num_components);
554
555 uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
556 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
557
558 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
559 }
560
561 static SpvId
562 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
563 unsigned num_components)
564 {
565 SpvId type = get_uvec_type(ctx, bit_size, num_components);
566 return emit_unop(ctx, SpvOpBitcast, type, value);
567 }
568
569 static SpvId
570 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
571 unsigned num_components)
572 {
573 SpvId type = get_ivec_type(ctx, bit_size, num_components);
574 return emit_unop(ctx, SpvOpBitcast, type, value);
575 }
576
577 static SpvId
578 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
579 unsigned num_components)
580 {
581 SpvId type = get_fvec_type(ctx, bit_size, num_components);
582 return emit_unop(ctx, SpvOpBitcast, type, value);
583 }
584
585 static void
586 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
587 {
588 SpvId var = get_var_from_reg(ctx, reg->reg);
589 assert(var);
590 spirv_builder_emit_store(&ctx->builder, var, result);
591 }
592
593 static void
594 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
595 {
596 if (dest->is_ssa)
597 store_ssa_def_uint(ctx, &dest->ssa, result);
598 else
599 store_reg_def(ctx, &dest->reg, result);
600 }
601
602 static void
603 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
604 {
605 unsigned num_components = nir_dest_num_components(*dest);
606 unsigned bit_size = nir_dest_bit_size(*dest);
607
608 switch (nir_alu_type_get_base_type(type)) {
609 case nir_type_bool:
610 assert(bit_size == 1);
611 result = bvec_to_uvec(ctx, result, num_components);
612 break;
613
614 case nir_type_uint:
615 break; /* nothing to do! */
616
617 case nir_type_int:
618 case nir_type_float:
619 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
620 break;
621
622 default:
623 unreachable("unsupported nir_alu_type");
624 }
625
626 store_dest_uint(ctx, dest, result);
627 }
628
629 static SpvId
630 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
631 {
632 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
633 }
634
635 static SpvId
636 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
637 SpvId src0, SpvId src1)
638 {
639 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
640 }
641
642 static SpvId
643 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
644 SpvId src0, SpvId src1, SpvId src2)
645 {
646 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
647 }
648
649 static SpvId
650 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
651 SpvId src)
652 {
653 SpvId args[] = { src };
654 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
655 op, args, ARRAY_SIZE(args));
656 }
657
658 static SpvId
659 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
660 SpvId src0, SpvId src1)
661 {
662 SpvId args[] = { src0, src1 };
663 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
664 op, args, ARRAY_SIZE(args));
665 }
666
667 static SpvId
668 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
669 {
670 assert(bit_size == 32);
671 return spirv_builder_const_float(&ctx->builder, bit_size, value);
672 }
673
674 static SpvId
675 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
676 const float values[])
677 {
678 assert(bit_size == 32);
679
680 if (num_components > 1) {
681 SpvId components[num_components];
682 for (int i = 0; i < num_components; i++)
683 components[i] = emit_float_const(ctx, bit_size, values[i]);
684
685 SpvId type = get_fvec_type(ctx, bit_size, num_components);
686 return spirv_builder_const_composite(&ctx->builder, type, components,
687 num_components);
688 }
689
690 assert(num_components == 1);
691 return emit_float_const(ctx, bit_size, values[0]);
692 }
693
694 static SpvId
695 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
696 const uint32_t values[])
697 {
698 assert(bit_size == 32);
699
700 if (num_components > 1) {
701 SpvId components[num_components];
702 for (int i = 0; i < num_components; i++)
703 components[i] = spirv_builder_const_uint(&ctx->builder, bit_size,
704 values[i]);
705
706 SpvId type = get_uvec_type(ctx, bit_size, num_components);
707 return spirv_builder_const_composite(&ctx->builder, type, components,
708 num_components);
709 }
710
711 assert(num_components == 1);
712 return spirv_builder_const_uint(&ctx->builder, bit_size, values[0]);
713 }
714
715 static inline unsigned
716 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
717 {
718 if (nir_op_infos[instr->op].input_sizes[src] > 0)
719 return nir_op_infos[instr->op].input_sizes[src];
720
721 if (instr->dest.dest.is_ssa)
722 return instr->dest.dest.ssa.num_components;
723 else
724 return instr->dest.dest.reg.reg->num_components;
725 }
726
727 static SpvId
728 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
729 {
730 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
731
732 unsigned num_components = alu_instr_src_components(alu, src);
733 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
734 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
735
736 switch (nir_alu_type_get_base_type(type)) {
737 case nir_type_bool:
738 assert(bit_size == 1);
739 return uvec_to_bvec(ctx, uint_value, num_components);
740
741 case nir_type_int:
742 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
743
744 case nir_type_uint:
745 return uint_value;
746
747 case nir_type_float:
748 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
749
750 default:
751 unreachable("unknown nir_alu_type");
752 }
753 }
754
755 static void
756 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
757 {
758 assert(!alu->dest.saturate);
759 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
760 }
761
762 static SpvId
763 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
764 {
765 unsigned num_components = nir_dest_num_components(*dest);
766 unsigned bit_size = nir_dest_bit_size(*dest);
767
768 switch (nir_alu_type_get_base_type(type)) {
769 case nir_type_bool:
770 return get_bvec_type(ctx, num_components);
771
772 case nir_type_int:
773 return get_ivec_type(ctx, bit_size, num_components);
774
775 case nir_type_uint:
776 return get_uvec_type(ctx, bit_size, num_components);
777
778 case nir_type_float:
779 return get_fvec_type(ctx, bit_size, num_components);
780
781 default:
782 unreachable("unsupported nir_alu_type");
783 }
784 }
785
786 static void
787 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
788 {
789 SpvId src[nir_op_infos[alu->op].num_inputs];
790 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
791 src[i] = get_alu_src(ctx, alu, i);
792
793 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
794 nir_op_infos[alu->op].output_type);
795 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
796 unsigned num_components = nir_dest_num_components(alu->dest.dest);
797
798 SpvId result = 0;
799 switch (alu->op) {
800 case nir_op_mov:
801 assert(nir_op_infos[alu->op].num_inputs == 1);
802 result = src[0];
803 break;
804
805 #define UNOP(nir_op, spirv_op) \
806 case nir_op: \
807 assert(nir_op_infos[alu->op].num_inputs == 1); \
808 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
809 break;
810
811 UNOP(nir_op_ineg, SpvOpSNegate)
812 UNOP(nir_op_fneg, SpvOpFNegate)
813 UNOP(nir_op_fddx, SpvOpDPdx)
814 UNOP(nir_op_fddy, SpvOpDPdy)
815 UNOP(nir_op_f2i32, SpvOpConvertFToS)
816 UNOP(nir_op_f2u32, SpvOpConvertFToU)
817 UNOP(nir_op_i2f32, SpvOpConvertSToF)
818 UNOP(nir_op_u2f32, SpvOpConvertUToF)
819 UNOP(nir_op_inot, SpvOpNot)
820 #undef UNOP
821
822 case nir_op_b2i32:
823 assert(nir_op_infos[alu->op].num_inputs == 1);
824 result = bvec_to_uvec(ctx, src[0], num_components);
825 break;
826
827 #define BUILTIN_UNOP(nir_op, spirv_op) \
828 case nir_op: \
829 assert(nir_op_infos[alu->op].num_inputs == 1); \
830 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
831 break;
832
833 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
834 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
835 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
836 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
837 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
838 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
839 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
840 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
841 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
842 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
843 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
844 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
845 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
846 #undef BUILTIN_UNOP
847
848 case nir_op_frcp: {
849 assert(nir_op_infos[alu->op].num_inputs == 1);
850 float one[4] = { 1, 1, 1, 1 };
851 src[1] = src[0];
852 src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
853 result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
854 }
855 break;
856
857 case nir_op_f2b1: {
858 assert(nir_op_infos[alu->op].num_inputs == 1);
859 float values[NIR_MAX_VEC_COMPONENTS] = { 0 };
860 SpvId zero = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
861 num_components, values);
862 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0], zero);
863 } break;
864
865
866 #define BINOP(nir_op, spirv_op) \
867 case nir_op: \
868 assert(nir_op_infos[alu->op].num_inputs == 2); \
869 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
870 break;
871
872 BINOP(nir_op_iadd, SpvOpIAdd)
873 BINOP(nir_op_isub, SpvOpISub)
874 BINOP(nir_op_imul, SpvOpIMul)
875 BINOP(nir_op_idiv, SpvOpSDiv)
876 BINOP(nir_op_udiv, SpvOpUDiv)
877 BINOP(nir_op_fadd, SpvOpFAdd)
878 BINOP(nir_op_fsub, SpvOpFSub)
879 BINOP(nir_op_fmul, SpvOpFMul)
880 BINOP(nir_op_fdiv, SpvOpFDiv)
881 BINOP(nir_op_fmod, SpvOpFMod)
882 BINOP(nir_op_ilt, SpvOpSLessThan)
883 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
884 BINOP(nir_op_ieq, SpvOpIEqual)
885 BINOP(nir_op_ine, SpvOpINotEqual)
886 BINOP(nir_op_flt, SpvOpFOrdLessThan)
887 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
888 BINOP(nir_op_feq, SpvOpFOrdEqual)
889 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
890 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
891 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
892 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
893 BINOP(nir_op_iand, SpvOpBitwiseAnd)
894 BINOP(nir_op_ior, SpvOpBitwiseOr)
895 #undef BINOP
896
897 #define BUILTIN_BINOP(nir_op, spirv_op) \
898 case nir_op: \
899 assert(nir_op_infos[alu->op].num_inputs == 2); \
900 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
901 break;
902
903 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
904 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
905 #undef BUILTIN_BINOP
906
907 case nir_op_fdot2:
908 case nir_op_fdot3:
909 case nir_op_fdot4:
910 assert(nir_op_infos[alu->op].num_inputs == 2);
911 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
912 break;
913
914 case nir_op_seq:
915 case nir_op_sne:
916 case nir_op_slt:
917 case nir_op_sge: {
918 assert(nir_op_infos[alu->op].num_inputs == 2);
919 int num_components = nir_dest_num_components(alu->dest.dest);
920 SpvId bool_type = get_bvec_type(ctx, num_components);
921
922 SpvId zero = emit_float_const(ctx, 32, 0.0f);
923 SpvId one = emit_float_const(ctx, 32, 1.0f);
924 if (num_components > 1) {
925 SpvId zero_comps[num_components], one_comps[num_components];
926 for (int i = 0; i < num_components; i++) {
927 zero_comps[i] = zero;
928 one_comps[i] = one;
929 }
930
931 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
932 zero_comps, num_components);
933 one = spirv_builder_const_composite(&ctx->builder, dest_type,
934 one_comps, num_components);
935 }
936
937 SpvOp op;
938 switch (alu->op) {
939 case nir_op_seq: op = SpvOpFOrdEqual; break;
940 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
941 case nir_op_slt: op = SpvOpFOrdLessThan; break;
942 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
943 default: unreachable("unexpected op");
944 }
945
946 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
947 result = emit_select(ctx, dest_type, result, one, zero);
948 }
949 break;
950
951 case nir_op_fcsel: {
952 assert(nir_op_infos[alu->op].num_inputs == 3);
953 int num_components = nir_dest_num_components(alu->dest.dest);
954 SpvId bool_type = get_bvec_type(ctx, num_components);
955
956 float zero[4] = { 0, 0, 0, 0 };
957 SpvId cmp = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
958 num_components, zero);
959
960 result = emit_binop(ctx, SpvOpFOrdGreaterThan, bool_type, src[0], cmp);
961 result = emit_select(ctx, dest_type, result, src[1], src[2]);
962 }
963 break;
964
965 case nir_op_bcsel:
966 assert(nir_op_infos[alu->op].num_inputs == 3);
967 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
968 break;
969
970 case nir_op_vec2:
971 case nir_op_vec3:
972 case nir_op_vec4: {
973 int num_inputs = nir_op_infos[alu->op].num_inputs;
974 assert(2 <= num_inputs && num_inputs <= 4);
975 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
976 src, num_inputs);
977 }
978 break;
979
980 default:
981 fprintf(stderr, "emit_alu: not implemented (%s)\n",
982 nir_op_infos[alu->op].name);
983
984 unreachable("unsupported opcode");
985 return;
986 }
987
988 store_alu_result(ctx, alu, result);
989 }
990
991 static void
992 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
993 {
994 uint32_t values[NIR_MAX_VEC_COMPONENTS];
995 for (int i = 0; i < load_const->def.num_components; ++i)
996 values[i] = load_const->value[i].u32;
997
998 SpvId constant = get_uvec_constant(ctx, load_const->def.bit_size,
999 load_const->def.num_components,
1000 values);
1001 store_ssa_def_uint(ctx, &load_const->def, constant);
1002 }
1003
1004 static void
1005 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1006 {
1007 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1008 assert(const_block_index); // no dynamic indexing for now
1009 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1010
1011 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1012 if (const_offset) {
1013 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1014 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1015 SpvStorageClassUniform,
1016 uvec4_type);
1017
1018 unsigned idx = const_offset->u32;
1019 SpvId member = spirv_builder_const_uint(&ctx->builder, 32, 0);
1020 SpvId offset = spirv_builder_const_uint(&ctx->builder, 32, idx);
1021 SpvId offsets[] = { member, offset };
1022 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1023 ctx->ubos[0], offsets,
1024 ARRAY_SIZE(offsets));
1025 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1026
1027 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1028 unsigned num_components = nir_dest_num_components(intr->dest);
1029 if (num_components == 1) {
1030 uint32_t components[] = { 0 };
1031 result = spirv_builder_emit_composite_extract(&ctx->builder,
1032 type,
1033 result, components,
1034 1);
1035 } else if (num_components < 4) {
1036 SpvId constituents[num_components];
1037 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1038 for (uint32_t i = 0; i < num_components; ++i)
1039 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1040 uint_type,
1041 result, &i,
1042 1);
1043
1044 result = spirv_builder_emit_composite_construct(&ctx->builder,
1045 type,
1046 constituents,
1047 num_components);
1048 }
1049
1050 store_dest_uint(ctx, &intr->dest, result);
1051 } else
1052 unreachable("uniform-addressing not yet supported");
1053 }
1054
1055 static void
1056 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1057 {
1058 assert(ctx->block_started);
1059 spirv_builder_emit_kill(&ctx->builder);
1060 /* discard is weird in NIR, so let's just create an unreachable block after
1061 it and hope that the vulkan driver will DCE any instructinos in it. */
1062 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1063 }
1064
1065 static void
1066 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1067 {
1068 /* uint is a bit of a lie here; it's really just a pointer */
1069 SpvId ptr = get_src_uint(ctx, intr->src);
1070
1071 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1072 SpvId result = spirv_builder_emit_load(&ctx->builder,
1073 get_glsl_type(ctx, var->type),
1074 ptr);
1075 unsigned num_components = nir_dest_num_components(intr->dest);
1076 unsigned bit_size = nir_dest_bit_size(intr->dest);
1077 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1078 store_dest_uint(ctx, &intr->dest, result);
1079 }
1080
1081 static void
1082 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1083 {
1084 /* uint is a bit of a lie here; it's really just a pointer */
1085 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1086 SpvId src = get_src_uint(ctx, &intr->src[1]);
1087
1088 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1089 SpvId result = emit_unop(ctx, SpvOpBitcast,
1090 get_glsl_type(ctx, glsl_without_array(var->type)),
1091 src);
1092 spirv_builder_emit_store(&ctx->builder, ptr, result);
1093 }
1094
1095 static void
1096 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1097 {
1098 switch (intr->intrinsic) {
1099 case nir_intrinsic_load_ubo:
1100 emit_load_ubo(ctx, intr);
1101 break;
1102
1103 case nir_intrinsic_discard:
1104 emit_discard(ctx, intr);
1105 break;
1106
1107 case nir_intrinsic_load_deref:
1108 emit_load_deref(ctx, intr);
1109 break;
1110
1111 case nir_intrinsic_store_deref:
1112 emit_store_deref(ctx, intr);
1113 break;
1114
1115 default:
1116 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1117 nir_intrinsic_infos[intr->intrinsic].name);
1118 unreachable("unsupported intrinsic");
1119 }
1120 }
1121
1122 static void
1123 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1124 {
1125 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1126 undef->def.num_components);
1127
1128 store_ssa_def_uint(ctx, &undef->def,
1129 spirv_builder_emit_undef(&ctx->builder, type));
1130 }
1131
1132 static SpvId
1133 get_src_float(struct ntv_context *ctx, nir_src *src)
1134 {
1135 SpvId def = get_src_uint(ctx, src);
1136 unsigned num_components = nir_src_num_components(*src);
1137 unsigned bit_size = nir_src_bit_size(*src);
1138 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1139 }
1140
1141 static void
1142 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1143 {
1144 assert(tex->op == nir_texop_tex ||
1145 tex->op == nir_texop_txb ||
1146 tex->op == nir_texop_txl);
1147 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1148 assert(tex->texture_index == tex->sampler_index);
1149
1150 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1151 unsigned coord_components;
1152 for (unsigned i = 0; i < tex->num_srcs; i++) {
1153 switch (tex->src[i].src_type) {
1154 case nir_tex_src_coord:
1155 coord = get_src_float(ctx, &tex->src[i].src);
1156 coord_components = nir_src_num_components(tex->src[i].src);
1157 break;
1158
1159 case nir_tex_src_projector:
1160 assert(nir_src_num_components(tex->src[i].src) == 1);
1161 proj = get_src_float(ctx, &tex->src[i].src);
1162 assert(proj != 0);
1163 break;
1164
1165 case nir_tex_src_bias:
1166 assert(tex->op == nir_texop_txb);
1167 bias = get_src_float(ctx, &tex->src[i].src);
1168 assert(bias != 0);
1169 break;
1170
1171 case nir_tex_src_lod:
1172 assert(nir_src_num_components(tex->src[i].src) == 1);
1173 lod = get_src_float(ctx, &tex->src[i].src);
1174 assert(lod != 0);
1175 break;
1176
1177 case nir_tex_src_comparator:
1178 assert(nir_src_num_components(tex->src[i].src) == 1);
1179 dref = get_src_float(ctx, &tex->src[i].src);
1180 assert(dref != 0);
1181 break;
1182
1183 default:
1184 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1185 unreachable("unknown texture source");
1186 }
1187 }
1188
1189 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1190 lod = emit_float_const(ctx, 32, 0.0f);
1191 assert(lod != 0);
1192 }
1193
1194 bool is_ms;
1195 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1196 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1197 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1198 dimension, false, tex->is_array, is_ms, 1,
1199 SpvImageFormatUnknown);
1200 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1201 image_type);
1202
1203 assert(tex->texture_index < ctx->num_samplers);
1204 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1205 ctx->samplers[tex->texture_index]);
1206
1207 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1208
1209 if (proj) {
1210 SpvId constituents[coord_components + 1];
1211 if (coord_components == 1)
1212 constituents[0] = coord;
1213 else {
1214 assert(coord_components > 1);
1215 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1216 for (uint32_t i = 0; i < coord_components; ++i)
1217 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1218 float_type,
1219 coord,
1220 &i, 1);
1221 }
1222
1223 constituents[coord_components++] = proj;
1224
1225 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1226 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1227 vec_type,
1228 constituents,
1229 coord_components);
1230 }
1231
1232 SpvId actual_dest_type = dest_type;
1233 if (dref)
1234 actual_dest_type = float_type;
1235
1236 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1237 actual_dest_type, load,
1238 coord,
1239 proj != 0,
1240 lod, bias, dref);
1241 spirv_builder_emit_decoration(&ctx->builder, result,
1242 SpvDecorationRelaxedPrecision);
1243
1244 if (dref) {
1245 SpvId components[4] = { result, result, result, result };
1246 result = spirv_builder_emit_composite_construct(&ctx->builder,
1247 dest_type,
1248 components,
1249 4);
1250 }
1251
1252 store_dest(ctx, &tex->dest, result, tex->dest_type);
1253 }
1254
1255 static void
1256 start_block(struct ntv_context *ctx, SpvId label)
1257 {
1258 /* terminate previous block if needed */
1259 if (ctx->block_started)
1260 spirv_builder_emit_branch(&ctx->builder, label);
1261
1262 /* start new block */
1263 spirv_builder_label(&ctx->builder, label);
1264 ctx->block_started = true;
1265 }
1266
1267 static void
1268 branch(struct ntv_context *ctx, SpvId label)
1269 {
1270 assert(ctx->block_started);
1271 spirv_builder_emit_branch(&ctx->builder, label);
1272 ctx->block_started = false;
1273 }
1274
1275 static void
1276 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1277 SpvId else_id)
1278 {
1279 assert(ctx->block_started);
1280 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1281 then_id, else_id);
1282 ctx->block_started = false;
1283 }
1284
1285 static void
1286 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1287 {
1288 switch (jump->type) {
1289 case nir_jump_break:
1290 assert(ctx->loop_break);
1291 branch(ctx, ctx->loop_break);
1292 break;
1293
1294 case nir_jump_continue:
1295 assert(ctx->loop_cont);
1296 branch(ctx, ctx->loop_cont);
1297 break;
1298
1299 default:
1300 unreachable("Unsupported jump type\n");
1301 }
1302 }
1303
1304 static void
1305 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1306 {
1307 assert(deref->deref_type == nir_deref_type_var);
1308
1309 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1310 assert(he);
1311 SpvId result = (SpvId)(intptr_t)he->data;
1312 /* uint is a bit of a lie here, it's really just an opaque type */
1313 store_dest_uint(ctx, &deref->dest, result);
1314 }
1315
1316 static void
1317 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1318 {
1319 assert(deref->deref_type == nir_deref_type_array);
1320 nir_variable *var = nir_deref_instr_get_variable(deref);
1321
1322 SpvStorageClass storage_class;
1323 switch (var->data.mode) {
1324 case nir_var_shader_in:
1325 storage_class = SpvStorageClassInput;
1326 break;
1327
1328 case nir_var_shader_out:
1329 storage_class = SpvStorageClassOutput;
1330 break;
1331
1332 default:
1333 unreachable("Unsupported nir_variable_mode\n");
1334 }
1335
1336 SpvId index = get_src_uint(ctx, &deref->arr.index);
1337
1338 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1339 storage_class,
1340 get_glsl_type(ctx, deref->type));
1341
1342 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1343 ptr_type,
1344 get_src_uint(ctx, &deref->parent),
1345 &index, 1);
1346 /* uint is a bit of a lie here, it's really just an opaque type */
1347 store_dest_uint(ctx, &deref->dest, result);
1348 }
1349
1350 static void
1351 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1352 {
1353 switch (deref->deref_type) {
1354 case nir_deref_type_var:
1355 emit_deref_var(ctx, deref);
1356 break;
1357
1358 case nir_deref_type_array:
1359 emit_deref_array(ctx, deref);
1360 break;
1361
1362 default:
1363 unreachable("unexpected deref_type");
1364 }
1365 }
1366
1367 static void
1368 emit_block(struct ntv_context *ctx, struct nir_block *block)
1369 {
1370 start_block(ctx, block_label(ctx, block));
1371 nir_foreach_instr(instr, block) {
1372 switch (instr->type) {
1373 case nir_instr_type_alu:
1374 emit_alu(ctx, nir_instr_as_alu(instr));
1375 break;
1376 case nir_instr_type_intrinsic:
1377 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1378 break;
1379 case nir_instr_type_load_const:
1380 emit_load_const(ctx, nir_instr_as_load_const(instr));
1381 break;
1382 case nir_instr_type_ssa_undef:
1383 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1384 break;
1385 case nir_instr_type_tex:
1386 emit_tex(ctx, nir_instr_as_tex(instr));
1387 break;
1388 case nir_instr_type_phi:
1389 unreachable("nir_instr_type_phi not supported");
1390 break;
1391 case nir_instr_type_jump:
1392 emit_jump(ctx, nir_instr_as_jump(instr));
1393 break;
1394 case nir_instr_type_call:
1395 unreachable("nir_instr_type_call not supported");
1396 break;
1397 case nir_instr_type_parallel_copy:
1398 unreachable("nir_instr_type_parallel_copy not supported");
1399 break;
1400 case nir_instr_type_deref:
1401 emit_deref(ctx, nir_instr_as_deref(instr));
1402 break;
1403 }
1404 }
1405 }
1406
1407 static void
1408 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1409
1410 static SpvId
1411 get_src_bool(struct ntv_context *ctx, nir_src *src)
1412 {
1413 SpvId def = get_src_uint(ctx, src);
1414 assert(nir_src_bit_size(*src) == 32);
1415 unsigned num_components = nir_src_num_components(*src);
1416 return uvec_to_bvec(ctx, def, num_components);
1417 }
1418
1419 static void
1420 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1421 {
1422 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1423
1424 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1425 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1426 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1427 SpvId else_id = endif_id;
1428
1429 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1430 if (has_else) {
1431 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1432 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1433 }
1434
1435 /* create a header-block */
1436 start_block(ctx, header_id);
1437 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1438 SpvSelectionControlMaskNone);
1439 branch_conditional(ctx, condition, then_id, else_id);
1440
1441 emit_cf_list(ctx, &if_stmt->then_list);
1442
1443 if (has_else) {
1444 if (ctx->block_started)
1445 branch(ctx, endif_id);
1446
1447 emit_cf_list(ctx, &if_stmt->else_list);
1448 }
1449
1450 start_block(ctx, endif_id);
1451 }
1452
1453 static void
1454 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1455 {
1456 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1457 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1458 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1459 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1460
1461 /* create a header-block */
1462 start_block(ctx, header_id);
1463 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1464 branch(ctx, begin_id);
1465
1466 SpvId save_break = ctx->loop_break;
1467 SpvId save_cont = ctx->loop_cont;
1468 ctx->loop_break = break_id;
1469 ctx->loop_cont = cont_id;
1470
1471 emit_cf_list(ctx, &loop->body);
1472
1473 ctx->loop_break = save_break;
1474 ctx->loop_cont = save_cont;
1475
1476 branch(ctx, cont_id);
1477 start_block(ctx, cont_id);
1478 branch(ctx, header_id);
1479
1480 start_block(ctx, break_id);
1481 }
1482
1483 static void
1484 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1485 {
1486 foreach_list_typed(nir_cf_node, node, node, list) {
1487 switch (node->type) {
1488 case nir_cf_node_block:
1489 emit_block(ctx, nir_cf_node_as_block(node));
1490 break;
1491
1492 case nir_cf_node_if:
1493 emit_if(ctx, nir_cf_node_as_if(node));
1494 break;
1495
1496 case nir_cf_node_loop:
1497 emit_loop(ctx, nir_cf_node_as_loop(node));
1498 break;
1499
1500 case nir_cf_node_function:
1501 unreachable("nir_cf_node_function not supported");
1502 break;
1503 }
1504 }
1505 }
1506
1507 struct spirv_shader *
1508 nir_to_spirv(struct nir_shader *s)
1509 {
1510 struct spirv_shader *ret = NULL;
1511
1512 struct ntv_context ctx = {};
1513
1514 switch (s->info.stage) {
1515 case MESA_SHADER_VERTEX:
1516 case MESA_SHADER_FRAGMENT:
1517 case MESA_SHADER_COMPUTE:
1518 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1519 break;
1520
1521 case MESA_SHADER_TESS_CTRL:
1522 case MESA_SHADER_TESS_EVAL:
1523 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1524 break;
1525
1526 case MESA_SHADER_GEOMETRY:
1527 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1528 break;
1529
1530 default:
1531 unreachable("invalid stage");
1532 }
1533
1534 // TODO: only enable when needed
1535 if (s->info.stage == MESA_SHADER_FRAGMENT)
1536 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1537
1538 ctx.stage = s->info.stage;
1539 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1540 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1541
1542 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1543 SpvMemoryModelGLSL450);
1544
1545 SpvExecutionModel exec_model;
1546 switch (s->info.stage) {
1547 case MESA_SHADER_VERTEX:
1548 exec_model = SpvExecutionModelVertex;
1549 break;
1550 case MESA_SHADER_TESS_CTRL:
1551 exec_model = SpvExecutionModelTessellationControl;
1552 break;
1553 case MESA_SHADER_TESS_EVAL:
1554 exec_model = SpvExecutionModelTessellationEvaluation;
1555 break;
1556 case MESA_SHADER_GEOMETRY:
1557 exec_model = SpvExecutionModelGeometry;
1558 break;
1559 case MESA_SHADER_FRAGMENT:
1560 exec_model = SpvExecutionModelFragment;
1561 break;
1562 case MESA_SHADER_COMPUTE:
1563 exec_model = SpvExecutionModelGLCompute;
1564 break;
1565 default:
1566 unreachable("invalid stage");
1567 }
1568
1569 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1570 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1571 NULL, 0);
1572 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1573 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1574
1575 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1576 _mesa_key_pointer_equal);
1577
1578 nir_foreach_variable(var, &s->inputs)
1579 emit_input(&ctx, var);
1580
1581 nir_foreach_variable(var, &s->outputs)
1582 emit_output(&ctx, var);
1583
1584 nir_foreach_variable(var, &s->uniforms)
1585 emit_uniform(&ctx, var);
1586
1587 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1588 "main", ctx.entry_ifaces,
1589 ctx.num_entry_ifaces);
1590 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1591 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1592 SpvExecutionModeOriginUpperLeft);
1593 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1594 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1595 SpvExecutionModeDepthReplacing);
1596 }
1597
1598
1599 spirv_builder_function(&ctx.builder, entry_point, type_void,
1600 SpvFunctionControlMaskNone,
1601 type_main);
1602
1603 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1604 nir_metadata_require(entry, nir_metadata_block_index);
1605
1606 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1607 if (!ctx.defs)
1608 goto fail;
1609 ctx.num_defs = entry->ssa_alloc;
1610
1611 nir_index_local_regs(entry);
1612 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1613 if (!ctx.regs)
1614 goto fail;
1615 ctx.num_regs = entry->reg_alloc;
1616
1617 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1618 if (!block_ids)
1619 goto fail;
1620
1621 for (int i = 0; i < entry->num_blocks; ++i)
1622 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1623
1624 ctx.block_ids = block_ids;
1625 ctx.num_blocks = entry->num_blocks;
1626
1627 /* emit a block only for the variable declarations */
1628 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1629 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1630 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1631 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1632 SpvStorageClassFunction,
1633 type);
1634 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1635 SpvStorageClassFunction);
1636
1637 ctx.regs[reg->index] = var;
1638 }
1639
1640 emit_cf_list(&ctx, &entry->body);
1641
1642 free(ctx.defs);
1643
1644 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1645 spirv_builder_function_end(&ctx.builder);
1646
1647 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1648
1649 ret = CALLOC_STRUCT(spirv_shader);
1650 if (!ret)
1651 goto fail;
1652
1653 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1654 if (!ret->words)
1655 goto fail;
1656
1657 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1658 assert(ret->num_words == num_words);
1659
1660 return ret;
1661
1662 fail:
1663
1664 if (ret)
1665 spirv_shader_delete(ret);
1666
1667 if (ctx.vars)
1668 _mesa_hash_table_destroy(ctx.vars, NULL);
1669
1670 return NULL;
1671 }
1672
1673 void
1674 spirv_shader_delete(struct spirv_shader *s)
1675 {
1676 FREE(s->words);
1677 FREE(s);
1678 }