zink/spirv: var -> regs
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30
31 struct ntv_context {
32 struct spirv_builder builder;
33
34 SpvId GLSL_std_450;
35
36 gl_shader_stage stage;
37 SpvId inputs[PIPE_MAX_SHADER_INPUTS][4];
38 SpvId input_types[PIPE_MAX_SHADER_INPUTS][4];
39 SpvId outputs[PIPE_MAX_SHADER_OUTPUTS][4];
40 SpvId output_types[PIPE_MAX_SHADER_OUTPUTS][4];
41 int var_location;
42
43 SpvId ubos[128];
44 size_t num_ubos;
45 SpvId samplers[PIPE_MAX_SAMPLERS];
46 size_t num_samplers;
47 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
48 size_t num_entry_ifaces;
49
50 SpvId *defs;
51 size_t num_defs;
52
53 SpvId *regs;
54 size_t num_regs;
55
56 const SpvId *block_ids;
57 size_t num_blocks;
58 bool block_started;
59 SpvId loop_break, loop_cont;
60 };
61
62 static SpvId
63 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
64 const float values[]);
65
66 static SpvId
67 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
68 const uint32_t values[]);
69
70 static SpvId
71 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
72
73 static SpvId
74 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
75 SpvId src0, SpvId src1);
76
77 static SpvId
78 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
79 SpvId src0, SpvId src1, SpvId src2);
80
81 static SpvId
82 get_bvec_type(struct ntv_context *ctx, int num_components)
83 {
84 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
85 if (num_components > 1)
86 return spirv_builder_type_vector(&ctx->builder, bool_type,
87 num_components);
88
89 assert(num_components == 1);
90 return bool_type;
91 }
92
93 static SpvId
94 block_label(struct ntv_context *ctx, nir_block *block)
95 {
96 assert(block->index < ctx->num_blocks);
97 return ctx->block_ids[block->index];
98 }
99
100 static SpvId
101 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
102 {
103 assert(bit_size == 32); // only 32-bit floats supported so far
104
105 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
106 if (num_components > 1)
107 return spirv_builder_type_vector(&ctx->builder, float_type,
108 num_components);
109
110 assert(num_components == 1);
111 return float_type;
112 }
113
114 static SpvId
115 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
116 {
117 assert(bit_size == 32); // only 32-bit ints supported so far
118
119 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
120 if (num_components > 1)
121 return spirv_builder_type_vector(&ctx->builder, int_type,
122 num_components);
123
124 assert(num_components == 1);
125 return int_type;
126 }
127
128 static SpvId
129 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
130 {
131 assert(bit_size == 32); // only 32-bit uints supported so far
132
133 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
134 if (num_components > 1)
135 return spirv_builder_type_vector(&ctx->builder, uint_type,
136 num_components);
137
138 assert(num_components == 1);
139 return uint_type;
140 }
141
142 static SpvId
143 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
144 {
145 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
146 nir_dest_num_components(*dest));
147 }
148
149 static SpvId
150 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
151 {
152 switch (type) {
153 case GLSL_TYPE_FLOAT:
154 return spirv_builder_type_float(&ctx->builder, 32);
155
156 case GLSL_TYPE_INT:
157 return spirv_builder_type_int(&ctx->builder, 32);
158
159 case GLSL_TYPE_UINT:
160 return spirv_builder_type_uint(&ctx->builder, 32);
161 /* TODO: handle more types */
162
163 default:
164 unreachable("unknown GLSL type");
165 }
166 }
167
168 static SpvId
169 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
170 {
171 assert(type);
172 if (glsl_type_is_scalar(type))
173 return get_glsl_basetype(ctx, glsl_get_base_type(type));
174
175 if (glsl_type_is_vector(type))
176 return spirv_builder_type_vector(&ctx->builder,
177 get_glsl_basetype(ctx, glsl_get_base_type(type)),
178 glsl_get_vector_elements(type));
179
180 unreachable("we shouldn't get here, I think...");
181 }
182
183 static void
184 emit_input(struct ntv_context *ctx, struct nir_variable *var)
185 {
186 SpvId vec_type = get_glsl_type(ctx, var->type);
187 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
188 SpvStorageClassInput,
189 vec_type);
190 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
191 SpvStorageClassInput);
192
193 if (var->name)
194 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
195
196 if (ctx->stage == MESA_SHADER_FRAGMENT) {
197 if (var->data.location >= VARYING_SLOT_VAR0 ||
198 (var->data.location >= VARYING_SLOT_COL0 &&
199 var->data.location <= VARYING_SLOT_TEX7)) {
200 spirv_builder_emit_location(&ctx->builder, var_id,
201 ctx->var_location++);
202 } else {
203 switch (var->data.location) {
204 case VARYING_SLOT_POS:
205 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
206 break;
207
208 case VARYING_SLOT_PNTC:
209 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
210 break;
211
212 default:
213 unreachable("unknown varying slot");
214 }
215 }
216 } else {
217 spirv_builder_emit_location(&ctx->builder, var_id,
218 var->data.driver_location);
219 }
220
221 if (var->data.location_frac)
222 spirv_builder_emit_component(&ctx->builder, var_id,
223 var->data.location_frac);
224
225 if (var->data.interpolation == INTERP_MODE_FLAT)
226 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
227
228 assert(var->data.driver_location < PIPE_MAX_SHADER_INPUTS);
229 assert(var->data.location_frac < 4);
230 assert(ctx->inputs[var->data.driver_location][var->data.location_frac] == 0);
231 ctx->inputs[var->data.driver_location][var->data.location_frac] = var_id;
232 ctx->input_types[var->data.driver_location][var->data.location_frac] = vec_type;
233
234 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
235 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
236 }
237
238 static void
239 emit_output(struct ntv_context *ctx, struct nir_variable *var)
240 {
241 SpvId vec_type = get_glsl_type(ctx, var->type);
242 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
243 SpvStorageClassOutput,
244 vec_type);
245 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
246 SpvStorageClassOutput);
247 if (var->name)
248 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
249
250
251 if (ctx->stage == MESA_SHADER_VERTEX) {
252 if (var->data.location >= VARYING_SLOT_VAR0 ||
253 (var->data.location >= VARYING_SLOT_COL0 &&
254 var->data.location <= VARYING_SLOT_TEX7)) {
255 spirv_builder_emit_location(&ctx->builder, var_id,
256 ctx->var_location++);
257 } else {
258 switch (var->data.location) {
259 case VARYING_SLOT_POS:
260 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
261 break;
262
263 case VARYING_SLOT_PSIZ:
264 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
265 break;
266
267 default:
268 unreachable("unknown varying slot");
269 }
270 }
271 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
272 switch (var->data.location) {
273 case FRAG_RESULT_DEPTH:
274 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
275 break;
276
277 default:
278 spirv_builder_emit_location(&ctx->builder, var_id,
279 var->data.driver_location);
280 }
281 }
282
283 if (var->data.location_frac)
284 spirv_builder_emit_component(&ctx->builder, var_id,
285 var->data.location_frac);
286
287 assert(var->data.driver_location < PIPE_MAX_SHADER_INPUTS);
288 assert(var->data.location_frac < 4);
289 assert(ctx->outputs[var->data.driver_location][var->data.location_frac] == 0);
290 ctx->outputs[var->data.driver_location][var->data.location_frac] = var_id;
291 ctx->output_types[var->data.driver_location][var->data.location_frac] = vec_type;
292
293 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
294 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
295 }
296
297 static SpvDim
298 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
299 {
300 *is_ms = false;
301 switch (gdim) {
302 case GLSL_SAMPLER_DIM_1D:
303 return SpvDim1D;
304 case GLSL_SAMPLER_DIM_2D:
305 return SpvDim2D;
306 case GLSL_SAMPLER_DIM_RECT:
307 return SpvDimRect;
308 case GLSL_SAMPLER_DIM_CUBE:
309 return SpvDimCube;
310 case GLSL_SAMPLER_DIM_3D:
311 return SpvDim3D;
312 case GLSL_SAMPLER_DIM_MS:
313 *is_ms = true;
314 return SpvDim2D;
315 default:
316 fprintf(stderr, "unknown sampler type %d\n", gdim);
317 break;
318 }
319 return SpvDim2D;
320 }
321
322 static void
323 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
324 {
325 bool is_ms;
326 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
327 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
328 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
329 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
330 SpvImageFormatUnknown);
331
332 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
333 image_type);
334 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
335 SpvStorageClassUniformConstant,
336 sampled_type);
337 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
338 SpvStorageClassUniformConstant);
339
340 if (var->name)
341 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
342
343 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
344 ctx->samplers[ctx->num_samplers++] = var_id;
345
346 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
347 var->data.descriptor_set);
348 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
349 }
350
351 static void
352 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
353 {
354 uint32_t size = glsl_count_attribute_slots(var->type, false);
355 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
356 SpvId array_length = spirv_builder_const_uint(&ctx->builder, 32, size);
357 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
358 array_length);
359 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
360
361 // wrap UBO-array in a struct
362 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
363 if (var->name) {
364 char struct_name[100];
365 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
366 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
367 }
368
369 spirv_builder_emit_decoration(&ctx->builder, struct_type,
370 SpvDecorationBlock);
371 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
372
373
374 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
375 SpvStorageClassUniform,
376 struct_type);
377
378 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
379 SpvStorageClassUniform);
380 if (var->name)
381 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
382
383 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
384 ctx->ubos[ctx->num_ubos++] = var_id;
385
386 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
387 var->data.descriptor_set);
388 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
389 }
390
391 static void
392 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
393 {
394 if (glsl_type_is_sampler(var->type))
395 emit_sampler(ctx, var);
396 else if (var->interface_type)
397 emit_ubo(ctx, var);
398 }
399
400 static SpvId
401 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
402 {
403 assert(ssa->index < ctx->num_defs);
404 assert(ctx->defs[ssa->index] != 0);
405 return ctx->defs[ssa->index];
406 }
407
408 static SpvId
409 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
410 {
411 assert(reg->index < ctx->num_regs);
412 assert(ctx->regs[reg->index] != 0);
413 return ctx->regs[reg->index];
414 }
415
416 static SpvId
417 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
418 {
419 assert(reg->reg);
420 assert(!reg->indirect);
421 assert(!reg->base_offset);
422
423 SpvId var = get_var_from_reg(ctx, reg->reg);
424 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
425 return spirv_builder_emit_load(&ctx->builder, type, var);
426 }
427
428 static SpvId
429 get_src_uint(struct ntv_context *ctx, nir_src *src)
430 {
431 if (src->is_ssa)
432 return get_src_uint_ssa(ctx, src->ssa);
433 else
434 return get_src_uint_reg(ctx, &src->reg);
435 }
436
437 static SpvId
438 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
439 {
440 assert(!alu->src[src].negate);
441 assert(!alu->src[src].abs);
442
443 SpvId def = get_src_uint(ctx, &alu->src[src].src);
444
445 unsigned used_channels = 0;
446 bool need_swizzle = false;
447 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
448 if (!nir_alu_instr_channel_used(alu, src, i))
449 continue;
450
451 used_channels++;
452
453 if (alu->src[src].swizzle[i] != i)
454 need_swizzle = true;
455 }
456 assert(used_channels != 0);
457
458 unsigned live_channels = nir_src_num_components(alu->src[src].src);
459 if (used_channels != live_channels)
460 need_swizzle = true;
461
462 if (!need_swizzle)
463 return def;
464
465 int bit_size = nir_src_bit_size(alu->src[src].src);
466
467 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
468 if (used_channels == 1) {
469 uint32_t indices[] = { alu->src[src].swizzle[0] };
470 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
471 def, indices,
472 ARRAY_SIZE(indices));
473 } else if (live_channels == 1) {
474 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
475 used_channels);
476
477 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
478 for (unsigned i = 0; i < used_channels; ++i)
479 constituents[i] = def;
480
481 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
482 constituents,
483 used_channels);
484 } else {
485 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
486 used_channels);
487
488 uint32_t components[NIR_MAX_VEC_COMPONENTS];
489 size_t num_components = 0;
490 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
491 if (!nir_alu_instr_channel_used(alu, src, i))
492 continue;
493
494 components[num_components++] = alu->src[src].swizzle[i];
495 }
496
497 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
498 def, def, components, num_components);
499 }
500 }
501
502 static void
503 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
504 {
505 assert(result != 0);
506 assert(ssa->index < ctx->num_defs);
507 ctx->defs[ssa->index] = result;
508 }
509
510 static SpvId
511 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
512 {
513 SpvId otype = get_uvec_type(ctx, 32, num_components);
514 uint32_t zeros[4] = { 0, 0, 0, 0 };
515 uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
516 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
517 SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
518 return emit_triop(ctx, SpvOpSelect, otype, value, one, zero);
519 }
520
521 static SpvId
522 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
523 {
524 SpvId type = get_bvec_type(ctx, num_components);
525
526 uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
527 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
528
529 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
530 }
531
532 static SpvId
533 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
534 unsigned num_components)
535 {
536 SpvId type = get_uvec_type(ctx, bit_size, num_components);
537 return emit_unop(ctx, SpvOpBitcast, type, value);
538 }
539
540 static SpvId
541 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
542 unsigned num_components)
543 {
544 SpvId type = get_ivec_type(ctx, bit_size, num_components);
545 return emit_unop(ctx, SpvOpBitcast, type, value);
546 }
547
548 static SpvId
549 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
550 unsigned num_components)
551 {
552 SpvId type = get_fvec_type(ctx, bit_size, num_components);
553 return emit_unop(ctx, SpvOpBitcast, type, value);
554 }
555
556 static void
557 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
558 {
559 SpvId var = get_var_from_reg(ctx, reg->reg);
560 assert(var);
561 spirv_builder_emit_store(&ctx->builder, var, result);
562 }
563
564 static void
565 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
566 {
567 if (dest->is_ssa)
568 store_ssa_def_uint(ctx, &dest->ssa, result);
569 else
570 store_reg_def(ctx, &dest->reg, result);
571 }
572
573 static void
574 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
575 {
576 unsigned num_components = nir_dest_num_components(*dest);
577 unsigned bit_size = nir_dest_bit_size(*dest);
578
579 switch (nir_alu_type_get_base_type(type)) {
580 case nir_type_bool:
581 assert(bit_size == 1);
582 result = bvec_to_uvec(ctx, result, num_components);
583 break;
584
585 case nir_type_uint:
586 break; /* nothing to do! */
587
588 case nir_type_int:
589 case nir_type_float:
590 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
591 break;
592
593 default:
594 unreachable("unsupported nir_alu_type");
595 }
596
597 store_dest_uint(ctx, dest, result);
598 }
599
600 static SpvId
601 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
602 {
603 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
604 }
605
606 static SpvId
607 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
608 SpvId src0, SpvId src1)
609 {
610 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
611 }
612
613 static SpvId
614 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
615 SpvId src0, SpvId src1, SpvId src2)
616 {
617 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
618 }
619
620 static SpvId
621 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
622 SpvId src)
623 {
624 SpvId args[] = { src };
625 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
626 op, args, ARRAY_SIZE(args));
627 }
628
629 static SpvId
630 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
631 SpvId src0, SpvId src1)
632 {
633 SpvId args[] = { src0, src1 };
634 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
635 op, args, ARRAY_SIZE(args));
636 }
637
638 static SpvId
639 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
640 const float values[])
641 {
642 assert(bit_size == 32);
643
644 if (num_components > 1) {
645 SpvId components[num_components];
646 for (int i = 0; i < num_components; i++)
647 components[i] = spirv_builder_const_float(&ctx->builder, bit_size,
648 values[i]);
649
650 SpvId type = get_fvec_type(ctx, bit_size, num_components);
651 return spirv_builder_const_composite(&ctx->builder, type, components,
652 num_components);
653 }
654
655 assert(num_components == 1);
656 return spirv_builder_const_float(&ctx->builder, bit_size, values[0]);
657 }
658
659 static SpvId
660 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
661 const uint32_t values[])
662 {
663 assert(bit_size == 32);
664
665 if (num_components > 1) {
666 SpvId components[num_components];
667 for (int i = 0; i < num_components; i++)
668 components[i] = spirv_builder_const_uint(&ctx->builder, bit_size,
669 values[i]);
670
671 SpvId type = get_uvec_type(ctx, bit_size, num_components);
672 return spirv_builder_const_composite(&ctx->builder, type, components,
673 num_components);
674 }
675
676 assert(num_components == 1);
677 return spirv_builder_const_uint(&ctx->builder, bit_size, values[0]);
678 }
679
680 static inline unsigned
681 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
682 {
683 if (nir_op_infos[instr->op].input_sizes[src] > 0)
684 return nir_op_infos[instr->op].input_sizes[src];
685
686 if (instr->dest.dest.is_ssa)
687 return instr->dest.dest.ssa.num_components;
688 else
689 return instr->dest.dest.reg.reg->num_components;
690 }
691
692 static SpvId
693 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
694 {
695 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
696
697 unsigned num_components = alu_instr_src_components(alu, src);
698 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
699 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
700
701 switch (nir_alu_type_get_base_type(type)) {
702 case nir_type_bool:
703 assert(bit_size == 1);
704 return uvec_to_bvec(ctx, uint_value, num_components);
705
706 case nir_type_int:
707 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
708
709 case nir_type_uint:
710 return uint_value;
711
712 case nir_type_float:
713 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
714
715 default:
716 unreachable("unknown nir_alu_type");
717 }
718 }
719
720 static void
721 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
722 {
723 assert(!alu->dest.saturate);
724 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
725 }
726
727 static SpvId
728 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
729 {
730 unsigned num_components = nir_dest_num_components(*dest);
731 unsigned bit_size = nir_dest_bit_size(*dest);
732
733 switch (nir_alu_type_get_base_type(type)) {
734 case nir_type_bool:
735 return get_bvec_type(ctx, num_components);
736
737 case nir_type_int:
738 return get_ivec_type(ctx, bit_size, num_components);
739
740 case nir_type_uint:
741 return get_uvec_type(ctx, bit_size, num_components);
742
743 case nir_type_float:
744 return get_fvec_type(ctx, bit_size, num_components);
745
746 default:
747 unreachable("unsupported nir_alu_type");
748 }
749 }
750
751 static void
752 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
753 {
754 SpvId src[nir_op_infos[alu->op].num_inputs];
755 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
756 src[i] = get_alu_src(ctx, alu, i);
757
758 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
759 nir_op_infos[alu->op].output_type);
760 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
761 unsigned num_components = nir_dest_num_components(alu->dest.dest);
762
763 SpvId result = 0;
764 switch (alu->op) {
765 case nir_op_mov:
766 assert(nir_op_infos[alu->op].num_inputs == 1);
767 result = src[0];
768 break;
769
770 #define UNOP(nir_op, spirv_op) \
771 case nir_op: \
772 assert(nir_op_infos[alu->op].num_inputs == 1); \
773 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
774 break;
775
776 #define BUILTIN_UNOP(nir_op, spirv_op) \
777 case nir_op: \
778 assert(nir_op_infos[alu->op].num_inputs == 1); \
779 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
780 break;
781
782 UNOP(nir_op_fneg, SpvOpFNegate)
783 UNOP(nir_op_fddx, SpvOpDPdx)
784 UNOP(nir_op_fddy, SpvOpDPdy)
785
786 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
787 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
788 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
789 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
790 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
791 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
792 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
793 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
794 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
795 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
796 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
797 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
798 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
799
800 case nir_op_frcp: {
801 assert(nir_op_infos[alu->op].num_inputs == 1);
802 float one[4] = { 1, 1, 1, 1 };
803 src[1] = src[0];
804 src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
805 result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
806 }
807 break;
808
809 #undef UNOP
810 #undef BUILTIN_UNOP
811
812 #define BINOP(nir_op, spirv_op) \
813 case nir_op: \
814 assert(nir_op_infos[alu->op].num_inputs == 2); \
815 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
816 break;
817
818 #define BUILTIN_BINOP(nir_op, spirv_op) \
819 case nir_op: \
820 assert(nir_op_infos[alu->op].num_inputs == 2); \
821 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
822 break;
823
824 BINOP(nir_op_iadd, SpvOpIAdd)
825 BINOP(nir_op_isub, SpvOpISub)
826 BINOP(nir_op_imul, SpvOpIMul)
827 BINOP(nir_op_fadd, SpvOpFAdd)
828 BINOP(nir_op_fsub, SpvOpFSub)
829 BINOP(nir_op_fmul, SpvOpFMul)
830 BINOP(nir_op_fmod, SpvOpFMod)
831 BINOP(nir_op_flt, SpvOpFUnordLessThan)
832 BINOP(nir_op_fge, SpvOpFUnordGreaterThanEqual)
833
834 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
835 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
836
837 #undef BINOP
838 #undef BUILTIN_BINOP
839
840 case nir_op_fdot2:
841 case nir_op_fdot3:
842 case nir_op_fdot4:
843 assert(nir_op_infos[alu->op].num_inputs == 2);
844 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
845 break;
846
847 case nir_op_seq:
848 case nir_op_sne:
849 case nir_op_slt:
850 case nir_op_sge: {
851 assert(nir_op_infos[alu->op].num_inputs == 2);
852 int num_components = nir_dest_num_components(alu->dest.dest);
853 SpvId bool_type = get_bvec_type(ctx, num_components);
854
855 SpvId zero = spirv_builder_const_float(&ctx->builder, 32, 0.0f);
856 SpvId one = spirv_builder_const_float(&ctx->builder, 32, 1.0f);
857 if (num_components > 1) {
858 SpvId zero_comps[num_components], one_comps[num_components];
859 for (int i = 0; i < num_components; i++) {
860 zero_comps[i] = zero;
861 one_comps[i] = one;
862 }
863
864 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
865 zero_comps, num_components);
866 one = spirv_builder_const_composite(&ctx->builder, dest_type,
867 one_comps, num_components);
868 }
869
870 SpvOp op;
871 switch (alu->op) {
872 case nir_op_seq: op = SpvOpFOrdEqual; break;
873 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
874 case nir_op_slt: op = SpvOpFOrdLessThan; break;
875 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
876 default: unreachable("unexpected op");
877 }
878
879 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
880 result = emit_triop(ctx, SpvOpSelect, dest_type, result, one, zero);
881 }
882 break;
883
884 case nir_op_fcsel: {
885 assert(nir_op_infos[alu->op].num_inputs == 3);
886 int num_components = nir_dest_num_components(alu->dest.dest);
887 SpvId bool_type = get_bvec_type(ctx, num_components);
888
889 float zero[4] = { 0, 0, 0, 0 };
890 SpvId cmp = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
891 num_components, zero);
892
893 result = emit_binop(ctx, SpvOpFOrdGreaterThan, bool_type, src[0], cmp);
894 result = emit_triop(ctx, SpvOpSelect, dest_type, result, src[1], src[2]);
895 }
896 break;
897
898 case nir_op_vec2:
899 case nir_op_vec3:
900 case nir_op_vec4: {
901 int num_inputs = nir_op_infos[alu->op].num_inputs;
902 assert(2 <= num_inputs && num_inputs <= 4);
903 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
904 src, num_inputs);
905 }
906 break;
907
908 default:
909 fprintf(stderr, "emit_alu: not implemented (%s)\n",
910 nir_op_infos[alu->op].name);
911
912 unreachable("unsupported opcode");
913 return;
914 }
915
916 store_alu_result(ctx, alu, result);
917 }
918
919 static void
920 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
921 {
922 uint32_t values[NIR_MAX_VEC_COMPONENTS];
923 for (int i = 0; i < load_const->def.num_components; ++i)
924 values[i] = load_const->value[i].u32;
925
926 SpvId constant = get_uvec_constant(ctx, load_const->def.bit_size,
927 load_const->def.num_components,
928 values);
929 store_ssa_def_uint(ctx, &load_const->def, constant);
930 }
931
932 static void
933 emit_load_input(struct ntv_context *ctx, nir_intrinsic_instr *intr)
934 {
935 nir_const_value *const_offset = nir_src_as_const_value(intr->src[0]);
936 if (const_offset) {
937 int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
938 assert(driver_location < PIPE_MAX_SHADER_INPUTS);
939 int location_frac = nir_intrinsic_component(intr);
940 assert(location_frac < 4);
941
942 SpvId ptr = ctx->inputs[driver_location][location_frac];
943 SpvId type = ctx->input_types[driver_location][location_frac];
944 assert(ptr && type);
945
946 SpvId result = spirv_builder_emit_load(&ctx->builder, type, ptr);
947
948 unsigned num_components = nir_dest_num_components(intr->dest);
949 unsigned bit_size = nir_dest_bit_size(intr->dest);
950 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
951
952 store_dest_uint(ctx, &intr->dest, result);
953 } else
954 unreachable("input-addressing not yet supported");
955 }
956
957 static void
958 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
959 {
960 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
961 assert(const_block_index); // no dynamic indexing for now
962 assert(const_block_index->u32 == 0); // we only support the default UBO for now
963
964 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
965 if (const_offset) {
966 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
967 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
968 SpvStorageClassUniform,
969 uvec4_type);
970
971 unsigned idx = const_offset->u32;
972 SpvId member = spirv_builder_const_uint(&ctx->builder, 32, 0);
973 SpvId offset = spirv_builder_const_uint(&ctx->builder, 32, idx);
974 SpvId offsets[] = { member, offset };
975 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
976 ctx->ubos[0], offsets,
977 ARRAY_SIZE(offsets));
978 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
979
980 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
981 unsigned num_components = nir_dest_num_components(intr->dest);
982 if (num_components == 1) {
983 uint32_t components[] = { 0 };
984 result = spirv_builder_emit_composite_extract(&ctx->builder,
985 type,
986 result, components,
987 1);
988 } else if (num_components < 4) {
989 SpvId constituents[num_components];
990 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
991 for (uint32_t i = 0; i < num_components; ++i)
992 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
993 uint_type,
994 result, &i,
995 1);
996
997 result = spirv_builder_emit_composite_construct(&ctx->builder,
998 type,
999 constituents,
1000 num_components);
1001 }
1002
1003 store_dest_uint(ctx, &intr->dest, result);
1004 } else
1005 unreachable("uniform-addressing not yet supported");
1006 }
1007
1008 static void
1009 emit_store_output(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1010 {
1011 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1012 if (const_offset) {
1013 int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
1014 assert(driver_location < PIPE_MAX_SHADER_OUTPUTS);
1015 int location_frac = nir_intrinsic_component(intr);
1016 assert(location_frac < 4);
1017
1018 SpvId ptr = ctx->outputs[driver_location][location_frac];
1019 assert(ptr > 0);
1020
1021 SpvId src = get_src_uint(ctx, &intr->src[0]);
1022 SpvId spirv_type = ctx->output_types[driver_location][location_frac];
1023 SpvId result = emit_unop(ctx, SpvOpBitcast, spirv_type, src);
1024 spirv_builder_emit_store(&ctx->builder, ptr, result);
1025 } else
1026 unreachable("output-addressing not yet supported");
1027 }
1028
1029 static void
1030 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1031 {
1032 assert(ctx->block_started);
1033 spirv_builder_emit_kill(&ctx->builder);
1034 /* discard is weird in NIR, so let's just create an unreachable block after
1035 it and hope that the vulkan driver will DCE any instructinos in it. */
1036 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1037 }
1038
1039 static void
1040 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1041 {
1042 switch (intr->intrinsic) {
1043 case nir_intrinsic_load_input:
1044 emit_load_input(ctx, intr);
1045 break;
1046
1047 case nir_intrinsic_load_ubo:
1048 emit_load_ubo(ctx, intr);
1049 break;
1050
1051 case nir_intrinsic_store_output:
1052 emit_store_output(ctx, intr);
1053 break;
1054
1055 case nir_intrinsic_discard:
1056 emit_discard(ctx, intr);
1057 break;
1058
1059 default:
1060 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1061 nir_intrinsic_infos[intr->intrinsic].name);
1062 unreachable("unsupported intrinsic");
1063 }
1064 }
1065
1066 static void
1067 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1068 {
1069 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1070 undef->def.num_components);
1071
1072 store_ssa_def_uint(ctx, &undef->def,
1073 spirv_builder_emit_undef(&ctx->builder, type));
1074 }
1075
1076 static SpvId
1077 get_src_float(struct ntv_context *ctx, nir_src *src)
1078 {
1079 SpvId def = get_src_uint(ctx, src);
1080 unsigned num_components = nir_src_num_components(*src);
1081 unsigned bit_size = nir_src_bit_size(*src);
1082 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1083 }
1084
1085 static void
1086 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1087 {
1088 assert(tex->op == nir_texop_tex);
1089 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1090 assert(tex->texture_index == tex->sampler_index);
1091
1092 bool has_proj = false, has_lod = false;
1093 SpvId coord = 0, proj, lod;
1094 unsigned coord_components;
1095 for (unsigned i = 0; i < tex->num_srcs; i++) {
1096 switch (tex->src[i].src_type) {
1097 case nir_tex_src_coord:
1098 coord = get_src_float(ctx, &tex->src[i].src);
1099 coord_components = nir_src_num_components(tex->src[i].src);
1100 break;
1101
1102 case nir_tex_src_projector:
1103 has_proj = true;
1104 proj = get_src_float(ctx, &tex->src[i].src);
1105 assert(nir_src_num_components(tex->src[i].src) == 1);
1106 break;
1107
1108 case nir_tex_src_lod:
1109 has_lod = true;
1110 lod = get_src_float(ctx, &tex->src[i].src);
1111 assert(nir_src_num_components(tex->src[i].src) == 1);
1112 break;
1113
1114 default:
1115 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1116 unreachable("unknown texture source");
1117 }
1118 }
1119
1120 if (!has_lod && ctx->stage != MESA_SHADER_FRAGMENT) {
1121 has_lod = true;
1122 lod = spirv_builder_const_float(&ctx->builder, 32, 0);
1123 }
1124
1125 bool is_ms;
1126 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1127 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1128 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1129 dimension, false, tex->is_array, is_ms, 1,
1130 SpvImageFormatUnknown);
1131 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1132 image_type);
1133
1134 assert(tex->texture_index < ctx->num_samplers);
1135 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1136 ctx->samplers[tex->texture_index]);
1137
1138 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1139
1140 SpvId result;
1141 if (has_proj) {
1142 SpvId constituents[coord_components + 1];
1143 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1144 for (uint32_t i = 0; i < coord_components; ++i)
1145 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1146 float_type,
1147 coord,
1148 &i, 1);
1149
1150 constituents[coord_components++] = proj;
1151
1152 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1153 SpvId merged = spirv_builder_emit_composite_construct(&ctx->builder,
1154 vec_type,
1155 constituents,
1156 coord_components);
1157
1158 if (has_lod)
1159 result = spirv_builder_emit_image_sample_proj_explicit_lod(&ctx->builder,
1160 dest_type,
1161 load,
1162 merged,
1163 lod);
1164 else
1165 result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
1166 dest_type,
1167 load,
1168 merged);
1169 } else {
1170 if (has_lod)
1171 result = spirv_builder_emit_image_sample_explicit_lod(&ctx->builder,
1172 dest_type,
1173 load,
1174 coord, lod);
1175 else
1176 result = spirv_builder_emit_image_sample_implicit_lod(&ctx->builder,
1177 dest_type,
1178 load,
1179 coord);
1180 }
1181 spirv_builder_emit_decoration(&ctx->builder, result,
1182 SpvDecorationRelaxedPrecision);
1183
1184 store_dest(ctx, &tex->dest, result, tex->dest_type);
1185 }
1186
1187 static void
1188 start_block(struct ntv_context *ctx, SpvId label)
1189 {
1190 /* terminate previous block if needed */
1191 if (ctx->block_started)
1192 spirv_builder_emit_branch(&ctx->builder, label);
1193
1194 /* start new block */
1195 spirv_builder_label(&ctx->builder, label);
1196 ctx->block_started = true;
1197 }
1198
1199 static void
1200 branch(struct ntv_context *ctx, SpvId label)
1201 {
1202 assert(ctx->block_started);
1203 spirv_builder_emit_branch(&ctx->builder, label);
1204 ctx->block_started = false;
1205 }
1206
1207 static void
1208 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1209 SpvId else_id)
1210 {
1211 assert(ctx->block_started);
1212 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1213 then_id, else_id);
1214 ctx->block_started = false;
1215 }
1216
1217 static void
1218 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1219 {
1220 switch (jump->type) {
1221 case nir_jump_break:
1222 assert(ctx->loop_break);
1223 branch(ctx, ctx->loop_break);
1224 break;
1225
1226 case nir_jump_continue:
1227 assert(ctx->loop_cont);
1228 branch(ctx, ctx->loop_cont);
1229 break;
1230
1231 default:
1232 unreachable("Unsupported jump type\n");
1233 }
1234 }
1235
1236 static void
1237 emit_block(struct ntv_context *ctx, struct nir_block *block)
1238 {
1239 start_block(ctx, block_label(ctx, block));
1240 nir_foreach_instr(instr, block) {
1241 switch (instr->type) {
1242 case nir_instr_type_alu:
1243 emit_alu(ctx, nir_instr_as_alu(instr));
1244 break;
1245 case nir_instr_type_intrinsic:
1246 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1247 break;
1248 case nir_instr_type_load_const:
1249 emit_load_const(ctx, nir_instr_as_load_const(instr));
1250 break;
1251 case nir_instr_type_ssa_undef:
1252 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1253 break;
1254 case nir_instr_type_tex:
1255 emit_tex(ctx, nir_instr_as_tex(instr));
1256 break;
1257 case nir_instr_type_phi:
1258 unreachable("nir_instr_type_phi not supported");
1259 break;
1260 case nir_instr_type_jump:
1261 emit_jump(ctx, nir_instr_as_jump(instr));
1262 break;
1263 case nir_instr_type_call:
1264 unreachable("nir_instr_type_call not supported");
1265 break;
1266 case nir_instr_type_parallel_copy:
1267 unreachable("nir_instr_type_parallel_copy not supported");
1268 break;
1269 case nir_instr_type_deref:
1270 unreachable("nir_instr_type_deref not supported");
1271 break;
1272 }
1273 }
1274 }
1275
1276 static void
1277 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1278
1279 static SpvId
1280 get_src_bool(struct ntv_context *ctx, nir_src *src)
1281 {
1282 SpvId def = get_src_uint(ctx, src);
1283 assert(nir_src_bit_size(*src) == 32);
1284 unsigned num_components = nir_src_num_components(*src);
1285 return uvec_to_bvec(ctx, def, num_components);
1286 }
1287
1288 static void
1289 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1290 {
1291 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1292
1293 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1294 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1295 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1296 SpvId else_id = endif_id;
1297
1298 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1299 if (has_else) {
1300 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1301 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1302 }
1303
1304 /* create a header-block */
1305 start_block(ctx, header_id);
1306 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1307 SpvSelectionControlMaskNone);
1308 branch_conditional(ctx, condition, then_id, else_id);
1309
1310 emit_cf_list(ctx, &if_stmt->then_list);
1311
1312 if (has_else) {
1313 if (ctx->block_started)
1314 branch(ctx, endif_id);
1315
1316 emit_cf_list(ctx, &if_stmt->else_list);
1317 }
1318
1319 start_block(ctx, endif_id);
1320 }
1321
1322 static void
1323 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1324 {
1325 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1326 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1327 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1328 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1329
1330 /* create a header-block */
1331 start_block(ctx, header_id);
1332 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1333 branch(ctx, begin_id);
1334
1335 SpvId save_break = ctx->loop_break;
1336 SpvId save_cont = ctx->loop_cont;
1337 ctx->loop_break = break_id;
1338 ctx->loop_cont = cont_id;
1339
1340 emit_cf_list(ctx, &loop->body);
1341
1342 ctx->loop_break = save_break;
1343 ctx->loop_cont = save_cont;
1344
1345 branch(ctx, cont_id);
1346 start_block(ctx, cont_id);
1347 branch(ctx, header_id);
1348
1349 start_block(ctx, break_id);
1350 }
1351
1352 static void
1353 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1354 {
1355 foreach_list_typed(nir_cf_node, node, node, list) {
1356 switch (node->type) {
1357 case nir_cf_node_block:
1358 emit_block(ctx, nir_cf_node_as_block(node));
1359 break;
1360
1361 case nir_cf_node_if:
1362 emit_if(ctx, nir_cf_node_as_if(node));
1363 break;
1364
1365 case nir_cf_node_loop:
1366 emit_loop(ctx, nir_cf_node_as_loop(node));
1367 break;
1368
1369 case nir_cf_node_function:
1370 unreachable("nir_cf_node_function not supported");
1371 break;
1372 }
1373 }
1374 }
1375
1376 struct spirv_shader *
1377 nir_to_spirv(struct nir_shader *s)
1378 {
1379 struct spirv_shader *ret = NULL;
1380
1381 struct ntv_context ctx = {};
1382
1383 switch (s->info.stage) {
1384 case MESA_SHADER_VERTEX:
1385 case MESA_SHADER_FRAGMENT:
1386 case MESA_SHADER_COMPUTE:
1387 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1388 break;
1389
1390 case MESA_SHADER_TESS_CTRL:
1391 case MESA_SHADER_TESS_EVAL:
1392 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1393 break;
1394
1395 case MESA_SHADER_GEOMETRY:
1396 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1397 break;
1398
1399 default:
1400 unreachable("invalid stage");
1401 }
1402
1403 ctx.stage = s->info.stage;
1404 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1405 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1406
1407 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1408 SpvMemoryModelGLSL450);
1409
1410 SpvExecutionModel exec_model;
1411 switch (s->info.stage) {
1412 case MESA_SHADER_VERTEX:
1413 exec_model = SpvExecutionModelVertex;
1414 break;
1415 case MESA_SHADER_TESS_CTRL:
1416 exec_model = SpvExecutionModelTessellationControl;
1417 break;
1418 case MESA_SHADER_TESS_EVAL:
1419 exec_model = SpvExecutionModelTessellationEvaluation;
1420 break;
1421 case MESA_SHADER_GEOMETRY:
1422 exec_model = SpvExecutionModelGeometry;
1423 break;
1424 case MESA_SHADER_FRAGMENT:
1425 exec_model = SpvExecutionModelFragment;
1426 break;
1427 case MESA_SHADER_COMPUTE:
1428 exec_model = SpvExecutionModelGLCompute;
1429 break;
1430 default:
1431 unreachable("invalid stage");
1432 }
1433
1434 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1435 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1436 NULL, 0);
1437 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1438 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1439
1440 nir_foreach_variable(var, &s->inputs)
1441 emit_input(&ctx, var);
1442
1443 nir_foreach_variable(var, &s->outputs)
1444 emit_output(&ctx, var);
1445
1446 nir_foreach_variable(var, &s->uniforms)
1447 emit_uniform(&ctx, var);
1448
1449 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1450 "main", ctx.entry_ifaces,
1451 ctx.num_entry_ifaces);
1452 if (s->info.stage == MESA_SHADER_FRAGMENT)
1453 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1454 SpvExecutionModeOriginUpperLeft);
1455
1456
1457 spirv_builder_function(&ctx.builder, entry_point, type_void,
1458 SpvFunctionControlMaskNone,
1459 type_main);
1460
1461 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1462 nir_metadata_require(entry, nir_metadata_block_index);
1463
1464 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1465 if (!ctx.defs)
1466 goto fail;
1467 ctx.num_defs = entry->ssa_alloc;
1468
1469 nir_index_local_regs(entry);
1470 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1471 if (!ctx.regs)
1472 goto fail;
1473 ctx.num_regs = entry->reg_alloc;
1474
1475 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1476 if (!block_ids)
1477 goto fail;
1478
1479 for (int i = 0; i < entry->num_blocks; ++i)
1480 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1481
1482 ctx.block_ids = block_ids;
1483 ctx.num_blocks = entry->num_blocks;
1484
1485 /* emit a block only for the variable declarations */
1486 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1487 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1488 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1489 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1490 SpvStorageClassFunction,
1491 type);
1492 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1493 SpvStorageClassFunction);
1494
1495 ctx.regs[reg->index] = var;
1496 }
1497
1498 emit_cf_list(&ctx, &entry->body);
1499
1500 free(ctx.defs);
1501
1502 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1503 spirv_builder_function_end(&ctx.builder);
1504
1505 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1506
1507 ret = CALLOC_STRUCT(spirv_shader);
1508 if (!ret)
1509 goto fail;
1510
1511 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1512 if (!ret->words)
1513 goto fail;
1514
1515 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1516 assert(ret->num_words == num_words);
1517
1518 return ret;
1519
1520 fail:
1521
1522 if (ret)
1523 spirv_shader_delete(ret);
1524
1525 return NULL;
1526 }
1527
1528 void
1529 spirv_shader_delete(struct spirv_shader *s)
1530 {
1531 FREE(s->words);
1532 FREE(s);
1533 }