zink: do not lower io
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59 };
60
61 static SpvId
62 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
63 const float values[]);
64
65 static SpvId
66 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
67 const uint32_t values[]);
68
69 static SpvId
70 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
71
72 static SpvId
73 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
74 SpvId src0, SpvId src1);
75
76 static SpvId
77 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
78 SpvId src0, SpvId src1, SpvId src2);
79
80 static SpvId
81 get_bvec_type(struct ntv_context *ctx, int num_components)
82 {
83 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
84 if (num_components > 1)
85 return spirv_builder_type_vector(&ctx->builder, bool_type,
86 num_components);
87
88 assert(num_components == 1);
89 return bool_type;
90 }
91
92 static SpvId
93 block_label(struct ntv_context *ctx, nir_block *block)
94 {
95 assert(block->index < ctx->num_blocks);
96 return ctx->block_ids[block->index];
97 }
98
99 static SpvId
100 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
101 {
102 assert(bit_size == 32); // only 32-bit floats supported so far
103
104 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
105 if (num_components > 1)
106 return spirv_builder_type_vector(&ctx->builder, float_type,
107 num_components);
108
109 assert(num_components == 1);
110 return float_type;
111 }
112
113 static SpvId
114 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
115 {
116 assert(bit_size == 32); // only 32-bit ints supported so far
117
118 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
119 if (num_components > 1)
120 return spirv_builder_type_vector(&ctx->builder, int_type,
121 num_components);
122
123 assert(num_components == 1);
124 return int_type;
125 }
126
127 static SpvId
128 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
129 {
130 assert(bit_size == 32); // only 32-bit uints supported so far
131
132 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
133 if (num_components > 1)
134 return spirv_builder_type_vector(&ctx->builder, uint_type,
135 num_components);
136
137 assert(num_components == 1);
138 return uint_type;
139 }
140
141 static SpvId
142 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
143 {
144 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
145 nir_dest_num_components(*dest));
146 }
147
148 static SpvId
149 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
150 {
151 switch (type) {
152 case GLSL_TYPE_FLOAT:
153 return spirv_builder_type_float(&ctx->builder, 32);
154
155 case GLSL_TYPE_INT:
156 return spirv_builder_type_int(&ctx->builder, 32);
157
158 case GLSL_TYPE_UINT:
159 return spirv_builder_type_uint(&ctx->builder, 32);
160 /* TODO: handle more types */
161
162 default:
163 unreachable("unknown GLSL type");
164 }
165 }
166
167 static SpvId
168 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
169 {
170 assert(type);
171 if (glsl_type_is_scalar(type))
172 return get_glsl_basetype(ctx, glsl_get_base_type(type));
173
174 if (glsl_type_is_vector(type))
175 return spirv_builder_type_vector(&ctx->builder,
176 get_glsl_basetype(ctx, glsl_get_base_type(type)),
177 glsl_get_vector_elements(type));
178
179 unreachable("we shouldn't get here, I think...");
180 }
181
182 static void
183 emit_input(struct ntv_context *ctx, struct nir_variable *var)
184 {
185 SpvId var_type = get_glsl_type(ctx, var->type);
186 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
187 SpvStorageClassInput,
188 var_type);
189 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
190 SpvStorageClassInput);
191
192 if (var->name)
193 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
194
195 if (ctx->stage == MESA_SHADER_FRAGMENT) {
196 if (var->data.location >= VARYING_SLOT_VAR0 ||
197 (var->data.location >= VARYING_SLOT_COL0 &&
198 var->data.location <= VARYING_SLOT_TEX7)) {
199 spirv_builder_emit_location(&ctx->builder, var_id,
200 ctx->var_location++);
201 } else {
202 switch (var->data.location) {
203 case VARYING_SLOT_POS:
204 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
205 break;
206
207 case VARYING_SLOT_PNTC:
208 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
209 break;
210
211 default:
212 unreachable("unknown varying slot");
213 }
214 }
215 } else {
216 spirv_builder_emit_location(&ctx->builder, var_id,
217 var->data.driver_location);
218 }
219
220 if (var->data.location_frac)
221 spirv_builder_emit_component(&ctx->builder, var_id,
222 var->data.location_frac);
223
224 if (var->data.interpolation == INTERP_MODE_FLAT)
225 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
226
227 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
228
229 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
230 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
231 }
232
233 static void
234 emit_output(struct ntv_context *ctx, struct nir_variable *var)
235 {
236 SpvId var_type = get_glsl_type(ctx, var->type);
237 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
238 SpvStorageClassOutput,
239 var_type);
240 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
241 SpvStorageClassOutput);
242 if (var->name)
243 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
244
245
246 if (ctx->stage == MESA_SHADER_VERTEX) {
247 if (var->data.location >= VARYING_SLOT_VAR0 ||
248 (var->data.location >= VARYING_SLOT_COL0 &&
249 var->data.location <= VARYING_SLOT_TEX7)) {
250 spirv_builder_emit_location(&ctx->builder, var_id,
251 ctx->var_location++);
252 } else {
253 switch (var->data.location) {
254 case VARYING_SLOT_POS:
255 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
256 break;
257
258 case VARYING_SLOT_PSIZ:
259 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
260 break;
261
262 default:
263 unreachable("unknown varying slot");
264 }
265 }
266 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
267 switch (var->data.location) {
268 case FRAG_RESULT_DEPTH:
269 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
270 break;
271
272 default:
273 spirv_builder_emit_location(&ctx->builder, var_id,
274 var->data.driver_location);
275 }
276 }
277
278 if (var->data.location_frac)
279 spirv_builder_emit_component(&ctx->builder, var_id,
280 var->data.location_frac);
281
282 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
283
284 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
285 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
286 }
287
288 static SpvDim
289 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
290 {
291 *is_ms = false;
292 switch (gdim) {
293 case GLSL_SAMPLER_DIM_1D:
294 return SpvDim1D;
295 case GLSL_SAMPLER_DIM_2D:
296 return SpvDim2D;
297 case GLSL_SAMPLER_DIM_RECT:
298 return SpvDimRect;
299 case GLSL_SAMPLER_DIM_CUBE:
300 return SpvDimCube;
301 case GLSL_SAMPLER_DIM_3D:
302 return SpvDim3D;
303 case GLSL_SAMPLER_DIM_MS:
304 *is_ms = true;
305 return SpvDim2D;
306 default:
307 fprintf(stderr, "unknown sampler type %d\n", gdim);
308 break;
309 }
310 return SpvDim2D;
311 }
312
313 static void
314 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
315 {
316 bool is_ms;
317 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
318 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
319 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
320 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
321 SpvImageFormatUnknown);
322
323 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
324 image_type);
325 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
326 SpvStorageClassUniformConstant,
327 sampled_type);
328 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
329 SpvStorageClassUniformConstant);
330
331 if (var->name)
332 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
333
334 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
335 ctx->samplers[ctx->num_samplers++] = var_id;
336
337 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
338 var->data.descriptor_set);
339 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
340 }
341
342 static void
343 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
344 {
345 uint32_t size = glsl_count_attribute_slots(var->type, false);
346 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
347 SpvId array_length = spirv_builder_const_uint(&ctx->builder, 32, size);
348 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
349 array_length);
350 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
351
352 // wrap UBO-array in a struct
353 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
354 if (var->name) {
355 char struct_name[100];
356 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
357 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
358 }
359
360 spirv_builder_emit_decoration(&ctx->builder, struct_type,
361 SpvDecorationBlock);
362 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
363
364
365 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
366 SpvStorageClassUniform,
367 struct_type);
368
369 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
370 SpvStorageClassUniform);
371 if (var->name)
372 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
373
374 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
375 ctx->ubos[ctx->num_ubos++] = var_id;
376
377 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
378 var->data.descriptor_set);
379 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
380 }
381
382 static void
383 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
384 {
385 if (glsl_type_is_sampler(var->type))
386 emit_sampler(ctx, var);
387 else if (var->interface_type)
388 emit_ubo(ctx, var);
389 }
390
391 static SpvId
392 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
393 {
394 assert(ssa->index < ctx->num_defs);
395 assert(ctx->defs[ssa->index] != 0);
396 return ctx->defs[ssa->index];
397 }
398
399 static SpvId
400 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
401 {
402 assert(reg->index < ctx->num_regs);
403 assert(ctx->regs[reg->index] != 0);
404 return ctx->regs[reg->index];
405 }
406
407 static SpvId
408 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
409 {
410 assert(reg->reg);
411 assert(!reg->indirect);
412 assert(!reg->base_offset);
413
414 SpvId var = get_var_from_reg(ctx, reg->reg);
415 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
416 return spirv_builder_emit_load(&ctx->builder, type, var);
417 }
418
419 static SpvId
420 get_src_uint(struct ntv_context *ctx, nir_src *src)
421 {
422 if (src->is_ssa)
423 return get_src_uint_ssa(ctx, src->ssa);
424 else
425 return get_src_uint_reg(ctx, &src->reg);
426 }
427
428 static SpvId
429 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
430 {
431 assert(!alu->src[src].negate);
432 assert(!alu->src[src].abs);
433
434 SpvId def = get_src_uint(ctx, &alu->src[src].src);
435
436 unsigned used_channels = 0;
437 bool need_swizzle = false;
438 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
439 if (!nir_alu_instr_channel_used(alu, src, i))
440 continue;
441
442 used_channels++;
443
444 if (alu->src[src].swizzle[i] != i)
445 need_swizzle = true;
446 }
447 assert(used_channels != 0);
448
449 unsigned live_channels = nir_src_num_components(alu->src[src].src);
450 if (used_channels != live_channels)
451 need_swizzle = true;
452
453 if (!need_swizzle)
454 return def;
455
456 int bit_size = nir_src_bit_size(alu->src[src].src);
457
458 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
459 if (used_channels == 1) {
460 uint32_t indices[] = { alu->src[src].swizzle[0] };
461 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
462 def, indices,
463 ARRAY_SIZE(indices));
464 } else if (live_channels == 1) {
465 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
466 used_channels);
467
468 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
469 for (unsigned i = 0; i < used_channels; ++i)
470 constituents[i] = def;
471
472 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
473 constituents,
474 used_channels);
475 } else {
476 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
477 used_channels);
478
479 uint32_t components[NIR_MAX_VEC_COMPONENTS];
480 size_t num_components = 0;
481 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
482 if (!nir_alu_instr_channel_used(alu, src, i))
483 continue;
484
485 components[num_components++] = alu->src[src].swizzle[i];
486 }
487
488 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
489 def, def, components, num_components);
490 }
491 }
492
493 static void
494 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
495 {
496 assert(result != 0);
497 assert(ssa->index < ctx->num_defs);
498 ctx->defs[ssa->index] = result;
499 }
500
501 static SpvId
502 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
503 {
504 SpvId otype = get_uvec_type(ctx, 32, num_components);
505 uint32_t zeros[4] = { 0, 0, 0, 0 };
506 uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
507 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
508 SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
509 return emit_triop(ctx, SpvOpSelect, otype, value, one, zero);
510 }
511
512 static SpvId
513 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
514 {
515 SpvId type = get_bvec_type(ctx, num_components);
516
517 uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
518 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
519
520 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
521 }
522
523 static SpvId
524 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
525 unsigned num_components)
526 {
527 SpvId type = get_uvec_type(ctx, bit_size, num_components);
528 return emit_unop(ctx, SpvOpBitcast, type, value);
529 }
530
531 static SpvId
532 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
533 unsigned num_components)
534 {
535 SpvId type = get_ivec_type(ctx, bit_size, num_components);
536 return emit_unop(ctx, SpvOpBitcast, type, value);
537 }
538
539 static SpvId
540 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
541 unsigned num_components)
542 {
543 SpvId type = get_fvec_type(ctx, bit_size, num_components);
544 return emit_unop(ctx, SpvOpBitcast, type, value);
545 }
546
547 static void
548 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
549 {
550 SpvId var = get_var_from_reg(ctx, reg->reg);
551 assert(var);
552 spirv_builder_emit_store(&ctx->builder, var, result);
553 }
554
555 static void
556 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
557 {
558 if (dest->is_ssa)
559 store_ssa_def_uint(ctx, &dest->ssa, result);
560 else
561 store_reg_def(ctx, &dest->reg, result);
562 }
563
564 static void
565 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
566 {
567 unsigned num_components = nir_dest_num_components(*dest);
568 unsigned bit_size = nir_dest_bit_size(*dest);
569
570 switch (nir_alu_type_get_base_type(type)) {
571 case nir_type_bool:
572 assert(bit_size == 1);
573 result = bvec_to_uvec(ctx, result, num_components);
574 break;
575
576 case nir_type_uint:
577 break; /* nothing to do! */
578
579 case nir_type_int:
580 case nir_type_float:
581 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
582 break;
583
584 default:
585 unreachable("unsupported nir_alu_type");
586 }
587
588 store_dest_uint(ctx, dest, result);
589 }
590
591 static SpvId
592 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
593 {
594 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
595 }
596
597 static SpvId
598 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
599 SpvId src0, SpvId src1)
600 {
601 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
602 }
603
604 static SpvId
605 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
606 SpvId src0, SpvId src1, SpvId src2)
607 {
608 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
609 }
610
611 static SpvId
612 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
613 SpvId src)
614 {
615 SpvId args[] = { src };
616 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
617 op, args, ARRAY_SIZE(args));
618 }
619
620 static SpvId
621 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
622 SpvId src0, SpvId src1)
623 {
624 SpvId args[] = { src0, src1 };
625 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
626 op, args, ARRAY_SIZE(args));
627 }
628
629 static SpvId
630 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
631 const float values[])
632 {
633 assert(bit_size == 32);
634
635 if (num_components > 1) {
636 SpvId components[num_components];
637 for (int i = 0; i < num_components; i++)
638 components[i] = spirv_builder_const_float(&ctx->builder, bit_size,
639 values[i]);
640
641 SpvId type = get_fvec_type(ctx, bit_size, num_components);
642 return spirv_builder_const_composite(&ctx->builder, type, components,
643 num_components);
644 }
645
646 assert(num_components == 1);
647 return spirv_builder_const_float(&ctx->builder, bit_size, values[0]);
648 }
649
650 static SpvId
651 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
652 const uint32_t values[])
653 {
654 assert(bit_size == 32);
655
656 if (num_components > 1) {
657 SpvId components[num_components];
658 for (int i = 0; i < num_components; i++)
659 components[i] = spirv_builder_const_uint(&ctx->builder, bit_size,
660 values[i]);
661
662 SpvId type = get_uvec_type(ctx, bit_size, num_components);
663 return spirv_builder_const_composite(&ctx->builder, type, components,
664 num_components);
665 }
666
667 assert(num_components == 1);
668 return spirv_builder_const_uint(&ctx->builder, bit_size, values[0]);
669 }
670
671 static inline unsigned
672 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
673 {
674 if (nir_op_infos[instr->op].input_sizes[src] > 0)
675 return nir_op_infos[instr->op].input_sizes[src];
676
677 if (instr->dest.dest.is_ssa)
678 return instr->dest.dest.ssa.num_components;
679 else
680 return instr->dest.dest.reg.reg->num_components;
681 }
682
683 static SpvId
684 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
685 {
686 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
687
688 unsigned num_components = alu_instr_src_components(alu, src);
689 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
690 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
691
692 switch (nir_alu_type_get_base_type(type)) {
693 case nir_type_bool:
694 assert(bit_size == 1);
695 return uvec_to_bvec(ctx, uint_value, num_components);
696
697 case nir_type_int:
698 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
699
700 case nir_type_uint:
701 return uint_value;
702
703 case nir_type_float:
704 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
705
706 default:
707 unreachable("unknown nir_alu_type");
708 }
709 }
710
711 static void
712 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
713 {
714 assert(!alu->dest.saturate);
715 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
716 }
717
718 static SpvId
719 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
720 {
721 unsigned num_components = nir_dest_num_components(*dest);
722 unsigned bit_size = nir_dest_bit_size(*dest);
723
724 switch (nir_alu_type_get_base_type(type)) {
725 case nir_type_bool:
726 return get_bvec_type(ctx, num_components);
727
728 case nir_type_int:
729 return get_ivec_type(ctx, bit_size, num_components);
730
731 case nir_type_uint:
732 return get_uvec_type(ctx, bit_size, num_components);
733
734 case nir_type_float:
735 return get_fvec_type(ctx, bit_size, num_components);
736
737 default:
738 unreachable("unsupported nir_alu_type");
739 }
740 }
741
742 static void
743 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
744 {
745 SpvId src[nir_op_infos[alu->op].num_inputs];
746 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
747 src[i] = get_alu_src(ctx, alu, i);
748
749 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
750 nir_op_infos[alu->op].output_type);
751 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
752 unsigned num_components = nir_dest_num_components(alu->dest.dest);
753
754 SpvId result = 0;
755 switch (alu->op) {
756 case nir_op_mov:
757 assert(nir_op_infos[alu->op].num_inputs == 1);
758 result = src[0];
759 break;
760
761 #define UNOP(nir_op, spirv_op) \
762 case nir_op: \
763 assert(nir_op_infos[alu->op].num_inputs == 1); \
764 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
765 break;
766
767 #define BUILTIN_UNOP(nir_op, spirv_op) \
768 case nir_op: \
769 assert(nir_op_infos[alu->op].num_inputs == 1); \
770 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
771 break;
772
773 UNOP(nir_op_fneg, SpvOpFNegate)
774 UNOP(nir_op_fddx, SpvOpDPdx)
775 UNOP(nir_op_fddy, SpvOpDPdy)
776
777 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
778 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
779 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
780 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
781 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
782 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
783 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
784 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
785 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
786 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
787 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
788 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
789 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
790
791 case nir_op_frcp: {
792 assert(nir_op_infos[alu->op].num_inputs == 1);
793 float one[4] = { 1, 1, 1, 1 };
794 src[1] = src[0];
795 src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
796 result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
797 }
798 break;
799
800 #undef UNOP
801 #undef BUILTIN_UNOP
802
803 #define BINOP(nir_op, spirv_op) \
804 case nir_op: \
805 assert(nir_op_infos[alu->op].num_inputs == 2); \
806 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
807 break;
808
809 #define BUILTIN_BINOP(nir_op, spirv_op) \
810 case nir_op: \
811 assert(nir_op_infos[alu->op].num_inputs == 2); \
812 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
813 break;
814
815 BINOP(nir_op_iadd, SpvOpIAdd)
816 BINOP(nir_op_isub, SpvOpISub)
817 BINOP(nir_op_imul, SpvOpIMul)
818 BINOP(nir_op_fadd, SpvOpFAdd)
819 BINOP(nir_op_fsub, SpvOpFSub)
820 BINOP(nir_op_fmul, SpvOpFMul)
821 BINOP(nir_op_fmod, SpvOpFMod)
822 BINOP(nir_op_flt, SpvOpFUnordLessThan)
823 BINOP(nir_op_fge, SpvOpFUnordGreaterThanEqual)
824
825 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
826 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
827
828 #undef BINOP
829 #undef BUILTIN_BINOP
830
831 case nir_op_fdot2:
832 case nir_op_fdot3:
833 case nir_op_fdot4:
834 assert(nir_op_infos[alu->op].num_inputs == 2);
835 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
836 break;
837
838 case nir_op_seq:
839 case nir_op_sne:
840 case nir_op_slt:
841 case nir_op_sge: {
842 assert(nir_op_infos[alu->op].num_inputs == 2);
843 int num_components = nir_dest_num_components(alu->dest.dest);
844 SpvId bool_type = get_bvec_type(ctx, num_components);
845
846 SpvId zero = spirv_builder_const_float(&ctx->builder, 32, 0.0f);
847 SpvId one = spirv_builder_const_float(&ctx->builder, 32, 1.0f);
848 if (num_components > 1) {
849 SpvId zero_comps[num_components], one_comps[num_components];
850 for (int i = 0; i < num_components; i++) {
851 zero_comps[i] = zero;
852 one_comps[i] = one;
853 }
854
855 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
856 zero_comps, num_components);
857 one = spirv_builder_const_composite(&ctx->builder, dest_type,
858 one_comps, num_components);
859 }
860
861 SpvOp op;
862 switch (alu->op) {
863 case nir_op_seq: op = SpvOpFOrdEqual; break;
864 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
865 case nir_op_slt: op = SpvOpFOrdLessThan; break;
866 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
867 default: unreachable("unexpected op");
868 }
869
870 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
871 result = emit_triop(ctx, SpvOpSelect, dest_type, result, one, zero);
872 }
873 break;
874
875 case nir_op_fcsel: {
876 assert(nir_op_infos[alu->op].num_inputs == 3);
877 int num_components = nir_dest_num_components(alu->dest.dest);
878 SpvId bool_type = get_bvec_type(ctx, num_components);
879
880 float zero[4] = { 0, 0, 0, 0 };
881 SpvId cmp = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
882 num_components, zero);
883
884 result = emit_binop(ctx, SpvOpFOrdGreaterThan, bool_type, src[0], cmp);
885 result = emit_triop(ctx, SpvOpSelect, dest_type, result, src[1], src[2]);
886 }
887 break;
888
889 case nir_op_vec2:
890 case nir_op_vec3:
891 case nir_op_vec4: {
892 int num_inputs = nir_op_infos[alu->op].num_inputs;
893 assert(2 <= num_inputs && num_inputs <= 4);
894 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
895 src, num_inputs);
896 }
897 break;
898
899 default:
900 fprintf(stderr, "emit_alu: not implemented (%s)\n",
901 nir_op_infos[alu->op].name);
902
903 unreachable("unsupported opcode");
904 return;
905 }
906
907 store_alu_result(ctx, alu, result);
908 }
909
910 static void
911 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
912 {
913 uint32_t values[NIR_MAX_VEC_COMPONENTS];
914 for (int i = 0; i < load_const->def.num_components; ++i)
915 values[i] = load_const->value[i].u32;
916
917 SpvId constant = get_uvec_constant(ctx, load_const->def.bit_size,
918 load_const->def.num_components,
919 values);
920 store_ssa_def_uint(ctx, &load_const->def, constant);
921 }
922
923 static void
924 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
925 {
926 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
927 assert(const_block_index); // no dynamic indexing for now
928 assert(const_block_index->u32 == 0); // we only support the default UBO for now
929
930 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
931 if (const_offset) {
932 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
933 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
934 SpvStorageClassUniform,
935 uvec4_type);
936
937 unsigned idx = const_offset->u32;
938 SpvId member = spirv_builder_const_uint(&ctx->builder, 32, 0);
939 SpvId offset = spirv_builder_const_uint(&ctx->builder, 32, idx);
940 SpvId offsets[] = { member, offset };
941 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
942 ctx->ubos[0], offsets,
943 ARRAY_SIZE(offsets));
944 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
945
946 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
947 unsigned num_components = nir_dest_num_components(intr->dest);
948 if (num_components == 1) {
949 uint32_t components[] = { 0 };
950 result = spirv_builder_emit_composite_extract(&ctx->builder,
951 type,
952 result, components,
953 1);
954 } else if (num_components < 4) {
955 SpvId constituents[num_components];
956 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
957 for (uint32_t i = 0; i < num_components; ++i)
958 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
959 uint_type,
960 result, &i,
961 1);
962
963 result = spirv_builder_emit_composite_construct(&ctx->builder,
964 type,
965 constituents,
966 num_components);
967 }
968
969 store_dest_uint(ctx, &intr->dest, result);
970 } else
971 unreachable("uniform-addressing not yet supported");
972 }
973
974 static void
975 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
976 {
977 assert(ctx->block_started);
978 spirv_builder_emit_kill(&ctx->builder);
979 /* discard is weird in NIR, so let's just create an unreachable block after
980 it and hope that the vulkan driver will DCE any instructinos in it. */
981 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
982 }
983
984 static void
985 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
986 {
987 nir_variable *var = nir_intrinsic_get_var(intr, 0);
988 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, var);
989 assert(he);
990 SpvId ptr = (SpvId)(intptr_t)he->data;
991
992 // SpvId ptr = get_src_uint(ctx, intr->src); /* uint is a bit of a lie here; it's really just a pointer */
993 SpvId result = spirv_builder_emit_load(&ctx->builder,
994 get_glsl_type(ctx, var->type),
995 ptr);
996 unsigned num_components = nir_dest_num_components(intr->dest);
997 unsigned bit_size = nir_dest_bit_size(intr->dest);
998 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
999 store_dest_uint(ctx, &intr->dest, result);
1000 }
1001
1002 static void
1003 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1004 {
1005 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1006 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, var);
1007 assert(he);
1008 SpvId ptr = (SpvId)(intptr_t)he->data;
1009 // SpvId ptr = get_src_uint(ctx, &intr->src[0]); /* uint is a bit of a lie here; it's really just a pointer */
1010
1011 SpvId src = get_src_uint(ctx, &intr->src[1]);
1012 SpvId result = emit_unop(ctx, SpvOpBitcast, get_glsl_type(ctx, var->type),
1013 src);
1014 spirv_builder_emit_store(&ctx->builder, ptr, result);
1015 }
1016
1017 static void
1018 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1019 {
1020 switch (intr->intrinsic) {
1021 case nir_intrinsic_load_ubo:
1022 emit_load_ubo(ctx, intr);
1023 break;
1024
1025 case nir_intrinsic_discard:
1026 emit_discard(ctx, intr);
1027 break;
1028
1029 case nir_intrinsic_load_deref:
1030 emit_load_deref(ctx, intr);
1031 break;
1032
1033 case nir_intrinsic_store_deref:
1034 emit_store_deref(ctx, intr);
1035 break;
1036
1037 default:
1038 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1039 nir_intrinsic_infos[intr->intrinsic].name);
1040 unreachable("unsupported intrinsic");
1041 }
1042 }
1043
1044 static void
1045 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1046 {
1047 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1048 undef->def.num_components);
1049
1050 store_ssa_def_uint(ctx, &undef->def,
1051 spirv_builder_emit_undef(&ctx->builder, type));
1052 }
1053
1054 static SpvId
1055 get_src_float(struct ntv_context *ctx, nir_src *src)
1056 {
1057 SpvId def = get_src_uint(ctx, src);
1058 unsigned num_components = nir_src_num_components(*src);
1059 unsigned bit_size = nir_src_bit_size(*src);
1060 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1061 }
1062
1063 static void
1064 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1065 {
1066 assert(tex->op == nir_texop_tex);
1067 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1068 assert(tex->texture_index == tex->sampler_index);
1069
1070 bool has_proj = false, has_lod = false;
1071 SpvId coord = 0, proj, lod;
1072 unsigned coord_components;
1073 for (unsigned i = 0; i < tex->num_srcs; i++) {
1074 switch (tex->src[i].src_type) {
1075 case nir_tex_src_coord:
1076 coord = get_src_float(ctx, &tex->src[i].src);
1077 coord_components = nir_src_num_components(tex->src[i].src);
1078 break;
1079
1080 case nir_tex_src_projector:
1081 has_proj = true;
1082 proj = get_src_float(ctx, &tex->src[i].src);
1083 assert(nir_src_num_components(tex->src[i].src) == 1);
1084 break;
1085
1086 case nir_tex_src_lod:
1087 has_lod = true;
1088 lod = get_src_float(ctx, &tex->src[i].src);
1089 assert(nir_src_num_components(tex->src[i].src) == 1);
1090 break;
1091
1092 default:
1093 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1094 unreachable("unknown texture source");
1095 }
1096 }
1097
1098 if (!has_lod && ctx->stage != MESA_SHADER_FRAGMENT) {
1099 has_lod = true;
1100 lod = spirv_builder_const_float(&ctx->builder, 32, 0);
1101 }
1102
1103 bool is_ms;
1104 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1105 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1106 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1107 dimension, false, tex->is_array, is_ms, 1,
1108 SpvImageFormatUnknown);
1109 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1110 image_type);
1111
1112 assert(tex->texture_index < ctx->num_samplers);
1113 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1114 ctx->samplers[tex->texture_index]);
1115
1116 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1117
1118 SpvId result;
1119 if (has_proj) {
1120 SpvId constituents[coord_components + 1];
1121 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1122 for (uint32_t i = 0; i < coord_components; ++i)
1123 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1124 float_type,
1125 coord,
1126 &i, 1);
1127
1128 constituents[coord_components++] = proj;
1129
1130 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1131 SpvId merged = spirv_builder_emit_composite_construct(&ctx->builder,
1132 vec_type,
1133 constituents,
1134 coord_components);
1135
1136 if (has_lod)
1137 result = spirv_builder_emit_image_sample_proj_explicit_lod(&ctx->builder,
1138 dest_type,
1139 load,
1140 merged,
1141 lod);
1142 else
1143 result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
1144 dest_type,
1145 load,
1146 merged);
1147 } else {
1148 if (has_lod)
1149 result = spirv_builder_emit_image_sample_explicit_lod(&ctx->builder,
1150 dest_type,
1151 load,
1152 coord, lod);
1153 else
1154 result = spirv_builder_emit_image_sample_implicit_lod(&ctx->builder,
1155 dest_type,
1156 load,
1157 coord);
1158 }
1159 spirv_builder_emit_decoration(&ctx->builder, result,
1160 SpvDecorationRelaxedPrecision);
1161
1162 store_dest(ctx, &tex->dest, result, tex->dest_type);
1163 }
1164
1165 static void
1166 start_block(struct ntv_context *ctx, SpvId label)
1167 {
1168 /* terminate previous block if needed */
1169 if (ctx->block_started)
1170 spirv_builder_emit_branch(&ctx->builder, label);
1171
1172 /* start new block */
1173 spirv_builder_label(&ctx->builder, label);
1174 ctx->block_started = true;
1175 }
1176
1177 static void
1178 branch(struct ntv_context *ctx, SpvId label)
1179 {
1180 assert(ctx->block_started);
1181 spirv_builder_emit_branch(&ctx->builder, label);
1182 ctx->block_started = false;
1183 }
1184
1185 static void
1186 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1187 SpvId else_id)
1188 {
1189 assert(ctx->block_started);
1190 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1191 then_id, else_id);
1192 ctx->block_started = false;
1193 }
1194
1195 static void
1196 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1197 {
1198 switch (jump->type) {
1199 case nir_jump_break:
1200 assert(ctx->loop_break);
1201 branch(ctx, ctx->loop_break);
1202 break;
1203
1204 case nir_jump_continue:
1205 assert(ctx->loop_cont);
1206 branch(ctx, ctx->loop_cont);
1207 break;
1208
1209 default:
1210 unreachable("Unsupported jump type\n");
1211 }
1212 }
1213
1214 static void
1215 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1216 {
1217 assert(deref->deref_type == nir_deref_type_var);
1218
1219 SpvStorageClass storage_class;
1220 switch (deref->var->data.mode) {
1221 case nir_var_shader_in:
1222 storage_class = SpvStorageClassInput;
1223 break;
1224
1225 case nir_var_shader_out:
1226 storage_class = SpvStorageClassOutput;
1227 break;
1228
1229 default:
1230 unreachable("Unsupported nir_variable_mode\n");
1231 }
1232
1233 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1234 assert(he);
1235
1236 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1237 storage_class,
1238 get_glsl_type(ctx, deref->type));
1239
1240 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1241 ptr_type,
1242 (SpvId)(intptr_t)he->data,
1243 NULL, 0);
1244 /* uint is a bit of a lie here, it's really just an opaque type */
1245 store_dest_uint(ctx, &deref->dest, result);
1246 }
1247
1248 static void
1249 emit_block(struct ntv_context *ctx, struct nir_block *block)
1250 {
1251 start_block(ctx, block_label(ctx, block));
1252 nir_foreach_instr(instr, block) {
1253 switch (instr->type) {
1254 case nir_instr_type_alu:
1255 emit_alu(ctx, nir_instr_as_alu(instr));
1256 break;
1257 case nir_instr_type_intrinsic:
1258 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1259 break;
1260 case nir_instr_type_load_const:
1261 emit_load_const(ctx, nir_instr_as_load_const(instr));
1262 break;
1263 case nir_instr_type_ssa_undef:
1264 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1265 break;
1266 case nir_instr_type_tex:
1267 emit_tex(ctx, nir_instr_as_tex(instr));
1268 break;
1269 case nir_instr_type_phi:
1270 unreachable("nir_instr_type_phi not supported");
1271 break;
1272 case nir_instr_type_jump:
1273 emit_jump(ctx, nir_instr_as_jump(instr));
1274 break;
1275 case nir_instr_type_call:
1276 unreachable("nir_instr_type_call not supported");
1277 break;
1278 case nir_instr_type_parallel_copy:
1279 unreachable("nir_instr_type_parallel_copy not supported");
1280 break;
1281 case nir_instr_type_deref:
1282 /* these are handled in emit_{load,store}_deref */
1283 /* emit_deref(ctx, nir_instr_as_deref(instr)); */
1284 break;
1285 }
1286 }
1287 }
1288
1289 static void
1290 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1291
1292 static SpvId
1293 get_src_bool(struct ntv_context *ctx, nir_src *src)
1294 {
1295 SpvId def = get_src_uint(ctx, src);
1296 assert(nir_src_bit_size(*src) == 32);
1297 unsigned num_components = nir_src_num_components(*src);
1298 return uvec_to_bvec(ctx, def, num_components);
1299 }
1300
1301 static void
1302 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1303 {
1304 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1305
1306 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1307 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1308 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1309 SpvId else_id = endif_id;
1310
1311 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1312 if (has_else) {
1313 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1314 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1315 }
1316
1317 /* create a header-block */
1318 start_block(ctx, header_id);
1319 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1320 SpvSelectionControlMaskNone);
1321 branch_conditional(ctx, condition, then_id, else_id);
1322
1323 emit_cf_list(ctx, &if_stmt->then_list);
1324
1325 if (has_else) {
1326 if (ctx->block_started)
1327 branch(ctx, endif_id);
1328
1329 emit_cf_list(ctx, &if_stmt->else_list);
1330 }
1331
1332 start_block(ctx, endif_id);
1333 }
1334
1335 static void
1336 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1337 {
1338 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1339 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1340 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1341 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1342
1343 /* create a header-block */
1344 start_block(ctx, header_id);
1345 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1346 branch(ctx, begin_id);
1347
1348 SpvId save_break = ctx->loop_break;
1349 SpvId save_cont = ctx->loop_cont;
1350 ctx->loop_break = break_id;
1351 ctx->loop_cont = cont_id;
1352
1353 emit_cf_list(ctx, &loop->body);
1354
1355 ctx->loop_break = save_break;
1356 ctx->loop_cont = save_cont;
1357
1358 branch(ctx, cont_id);
1359 start_block(ctx, cont_id);
1360 branch(ctx, header_id);
1361
1362 start_block(ctx, break_id);
1363 }
1364
1365 static void
1366 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1367 {
1368 foreach_list_typed(nir_cf_node, node, node, list) {
1369 switch (node->type) {
1370 case nir_cf_node_block:
1371 emit_block(ctx, nir_cf_node_as_block(node));
1372 break;
1373
1374 case nir_cf_node_if:
1375 emit_if(ctx, nir_cf_node_as_if(node));
1376 break;
1377
1378 case nir_cf_node_loop:
1379 emit_loop(ctx, nir_cf_node_as_loop(node));
1380 break;
1381
1382 case nir_cf_node_function:
1383 unreachable("nir_cf_node_function not supported");
1384 break;
1385 }
1386 }
1387 }
1388
1389 struct spirv_shader *
1390 nir_to_spirv(struct nir_shader *s)
1391 {
1392 struct spirv_shader *ret = NULL;
1393
1394 struct ntv_context ctx = {};
1395
1396 switch (s->info.stage) {
1397 case MESA_SHADER_VERTEX:
1398 case MESA_SHADER_FRAGMENT:
1399 case MESA_SHADER_COMPUTE:
1400 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1401 break;
1402
1403 case MESA_SHADER_TESS_CTRL:
1404 case MESA_SHADER_TESS_EVAL:
1405 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1406 break;
1407
1408 case MESA_SHADER_GEOMETRY:
1409 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1410 break;
1411
1412 default:
1413 unreachable("invalid stage");
1414 }
1415
1416 ctx.stage = s->info.stage;
1417 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1418 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1419
1420 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1421 SpvMemoryModelGLSL450);
1422
1423 SpvExecutionModel exec_model;
1424 switch (s->info.stage) {
1425 case MESA_SHADER_VERTEX:
1426 exec_model = SpvExecutionModelVertex;
1427 break;
1428 case MESA_SHADER_TESS_CTRL:
1429 exec_model = SpvExecutionModelTessellationControl;
1430 break;
1431 case MESA_SHADER_TESS_EVAL:
1432 exec_model = SpvExecutionModelTessellationEvaluation;
1433 break;
1434 case MESA_SHADER_GEOMETRY:
1435 exec_model = SpvExecutionModelGeometry;
1436 break;
1437 case MESA_SHADER_FRAGMENT:
1438 exec_model = SpvExecutionModelFragment;
1439 break;
1440 case MESA_SHADER_COMPUTE:
1441 exec_model = SpvExecutionModelGLCompute;
1442 break;
1443 default:
1444 unreachable("invalid stage");
1445 }
1446
1447 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1448 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1449 NULL, 0);
1450 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1451 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1452
1453 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1454 _mesa_key_pointer_equal);
1455
1456 nir_foreach_variable(var, &s->inputs)
1457 emit_input(&ctx, var);
1458
1459 nir_foreach_variable(var, &s->outputs)
1460 emit_output(&ctx, var);
1461
1462 nir_foreach_variable(var, &s->uniforms)
1463 emit_uniform(&ctx, var);
1464
1465 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1466 "main", ctx.entry_ifaces,
1467 ctx.num_entry_ifaces);
1468 if (s->info.stage == MESA_SHADER_FRAGMENT)
1469 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1470 SpvExecutionModeOriginUpperLeft);
1471
1472
1473 spirv_builder_function(&ctx.builder, entry_point, type_void,
1474 SpvFunctionControlMaskNone,
1475 type_main);
1476
1477 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1478 nir_metadata_require(entry, nir_metadata_block_index);
1479
1480 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1481 if (!ctx.defs)
1482 goto fail;
1483 ctx.num_defs = entry->ssa_alloc;
1484
1485 nir_index_local_regs(entry);
1486 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1487 if (!ctx.regs)
1488 goto fail;
1489 ctx.num_regs = entry->reg_alloc;
1490
1491 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1492 if (!block_ids)
1493 goto fail;
1494
1495 for (int i = 0; i < entry->num_blocks; ++i)
1496 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1497
1498 ctx.block_ids = block_ids;
1499 ctx.num_blocks = entry->num_blocks;
1500
1501 /* emit a block only for the variable declarations */
1502 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1503 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1504 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1505 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1506 SpvStorageClassFunction,
1507 type);
1508 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1509 SpvStorageClassFunction);
1510
1511 ctx.regs[reg->index] = var;
1512 }
1513
1514 emit_cf_list(&ctx, &entry->body);
1515
1516 free(ctx.defs);
1517
1518 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1519 spirv_builder_function_end(&ctx.builder);
1520
1521 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1522
1523 ret = CALLOC_STRUCT(spirv_shader);
1524 if (!ret)
1525 goto fail;
1526
1527 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1528 if (!ret->words)
1529 goto fail;
1530
1531 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1532 assert(ret->num_words == num_words);
1533
1534 return ret;
1535
1536 fail:
1537
1538 if (ret)
1539 spirv_shader_delete(ret);
1540
1541 if (ctx.vars)
1542 _mesa_hash_table_destroy(ctx.vars, NULL);
1543
1544 return NULL;
1545 }
1546
1547 void
1548 spirv_shader_delete(struct spirv_shader *s)
1549 {
1550 FREE(s->words);
1551 FREE(s);
1552 }