zink/spirv: implement if-statements
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30
31 struct ntv_context {
32 struct spirv_builder builder;
33
34 SpvId GLSL_std_450;
35
36 gl_shader_stage stage;
37 SpvId inputs[PIPE_MAX_SHADER_INPUTS][4];
38 SpvId input_types[PIPE_MAX_SHADER_INPUTS][4];
39 SpvId outputs[PIPE_MAX_SHADER_OUTPUTS][4];
40 SpvId output_types[PIPE_MAX_SHADER_OUTPUTS][4];
41
42 SpvId ubos[128];
43 size_t num_ubos;
44 SpvId samplers[PIPE_MAX_SAMPLERS];
45 size_t num_samplers;
46 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
47 size_t num_entry_ifaces;
48
49 SpvId *defs;
50 size_t num_defs;
51
52 struct hash_table *vars;
53
54 const SpvId *block_ids;
55 size_t num_blocks;
56 bool block_started;
57 };
58
59 static SpvId
60 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
61 const float values[]);
62
63 static SpvId
64 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
65 const uint32_t values[]);
66
67 static SpvId
68 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
69
70 static SpvId
71 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
72 SpvId src0, SpvId src1);
73
74 static SpvId
75 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
76 SpvId src0, SpvId src1, SpvId src2);
77
78 static SpvId
79 get_bvec_type(struct ntv_context *ctx, int num_components)
80 {
81 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
82 if (num_components > 1)
83 return spirv_builder_type_vector(&ctx->builder, bool_type,
84 num_components);
85
86 assert(num_components == 1);
87 return bool_type;
88 }
89
90 static SpvId
91 block_label(struct ntv_context *ctx, nir_block *block)
92 {
93 assert(block->index < ctx->num_blocks);
94 return ctx->block_ids[block->index];
95 }
96
97 static SpvId
98 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
99 {
100 assert(bit_size == 32); // only 32-bit floats supported so far
101
102 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
103 if (num_components > 1)
104 return spirv_builder_type_vector(&ctx->builder, float_type,
105 num_components);
106
107 assert(num_components == 1);
108 return float_type;
109 }
110
111 static SpvId
112 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
113 {
114 assert(bit_size == 32); // only 32-bit ints supported so far
115
116 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
117 if (num_components > 1)
118 return spirv_builder_type_vector(&ctx->builder, int_type,
119 num_components);
120
121 assert(num_components == 1);
122 return int_type;
123 }
124
125 static SpvId
126 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
127 {
128 assert(bit_size == 32); // only 32-bit uints supported so far
129
130 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
131 if (num_components > 1)
132 return spirv_builder_type_vector(&ctx->builder, uint_type,
133 num_components);
134
135 assert(num_components == 1);
136 return uint_type;
137 }
138
139 static SpvId
140 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
141 {
142 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
143 nir_dest_num_components(*dest));
144 }
145
146 static SpvId
147 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
148 {
149 switch (type) {
150 case GLSL_TYPE_FLOAT:
151 return spirv_builder_type_float(&ctx->builder, 32);
152
153 case GLSL_TYPE_INT:
154 return spirv_builder_type_int(&ctx->builder, 32);
155
156 case GLSL_TYPE_UINT:
157 return spirv_builder_type_uint(&ctx->builder, 32);
158 /* TODO: handle more types */
159
160 default:
161 unreachable("unknown GLSL type");
162 }
163 }
164
165 static SpvId
166 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
167 {
168 assert(type);
169 if (glsl_type_is_scalar(type))
170 return get_glsl_basetype(ctx, glsl_get_base_type(type));
171
172 if (glsl_type_is_vector(type))
173 return spirv_builder_type_vector(&ctx->builder,
174 get_glsl_basetype(ctx, glsl_get_base_type(type)),
175 glsl_get_vector_elements(type));
176
177 unreachable("we shouldn't get here, I think...");
178 }
179
180 static void
181 emit_input(struct ntv_context *ctx, struct nir_variable *var)
182 {
183 SpvId vec_type = get_glsl_type(ctx, var->type);
184 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
185 SpvStorageClassInput,
186 vec_type);
187 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
188 SpvStorageClassInput);
189
190 if (var->name)
191 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
192
193 if (ctx->stage == MESA_SHADER_FRAGMENT) {
194 switch (var->data.location) {
195 case VARYING_SLOT_POS:
196 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
197 break;
198
199 case VARYING_SLOT_PNTC:
200 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
201 break;
202
203 default:
204 spirv_builder_emit_location(&ctx->builder, var_id,
205 var->data.driver_location);
206 break;
207 }
208 } else {
209 spirv_builder_emit_location(&ctx->builder, var_id,
210 var->data.driver_location);
211 }
212
213 if (var->data.location_frac)
214 spirv_builder_emit_component(&ctx->builder, var_id,
215 var->data.location_frac);
216
217 if (var->data.interpolation == INTERP_MODE_FLAT)
218 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
219
220 assert(var->data.driver_location < PIPE_MAX_SHADER_INPUTS);
221 assert(var->data.location_frac < 4);
222 assert(ctx->inputs[var->data.driver_location][var->data.location_frac] == 0);
223 ctx->inputs[var->data.driver_location][var->data.location_frac] = var_id;
224 ctx->input_types[var->data.driver_location][var->data.location_frac] = vec_type;
225
226 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
227 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
228 }
229
230 static void
231 emit_output(struct ntv_context *ctx, struct nir_variable *var)
232 {
233 SpvId vec_type = get_glsl_type(ctx, var->type);
234 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
235 SpvStorageClassOutput,
236 vec_type);
237 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
238 SpvStorageClassOutput);
239 if (var->name)
240 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
241
242
243 if (ctx->stage == MESA_SHADER_VERTEX) {
244 switch (var->data.location) {
245 case VARYING_SLOT_POS:
246 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
247 break;
248
249 case VARYING_SLOT_PSIZ:
250 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
251 break;
252
253 default:
254 spirv_builder_emit_location(&ctx->builder, var_id,
255 var->data.driver_location - 1);
256 }
257 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
258 switch (var->data.location) {
259 case FRAG_RESULT_DEPTH:
260 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
261 break;
262
263 default:
264 spirv_builder_emit_location(&ctx->builder, var_id,
265 var->data.driver_location);
266 }
267 }
268
269 if (var->data.location_frac)
270 spirv_builder_emit_component(&ctx->builder, var_id,
271 var->data.location_frac);
272
273 assert(var->data.driver_location < PIPE_MAX_SHADER_INPUTS);
274 assert(var->data.location_frac < 4);
275 assert(ctx->outputs[var->data.driver_location][var->data.location_frac] == 0);
276 ctx->outputs[var->data.driver_location][var->data.location_frac] = var_id;
277 ctx->output_types[var->data.driver_location][var->data.location_frac] = vec_type;
278
279 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
280 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
281 }
282
283 static SpvDim
284 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
285 {
286 *is_ms = false;
287 switch (gdim) {
288 case GLSL_SAMPLER_DIM_1D:
289 return SpvDim1D;
290 case GLSL_SAMPLER_DIM_2D:
291 return SpvDim2D;
292 case GLSL_SAMPLER_DIM_RECT:
293 return SpvDimRect;
294 case GLSL_SAMPLER_DIM_CUBE:
295 return SpvDimCube;
296 case GLSL_SAMPLER_DIM_3D:
297 return SpvDim3D;
298 case GLSL_SAMPLER_DIM_MS:
299 *is_ms = true;
300 return SpvDim2D;
301 default:
302 fprintf(stderr, "unknown sampler type %d\n", gdim);
303 break;
304 }
305 return SpvDim2D;
306 }
307
308 static void
309 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
310 {
311 bool is_ms;
312 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
313 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
314 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
315 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
316 SpvImageFormatUnknown);
317
318 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
319 image_type);
320 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
321 SpvStorageClassUniformConstant,
322 sampled_type);
323 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
324 SpvStorageClassUniformConstant);
325
326 if (var->name)
327 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
328
329 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
330 ctx->samplers[ctx->num_samplers++] = var_id;
331
332 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
333 var->data.descriptor_set);
334 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
335 }
336
337 static void
338 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
339 {
340 uint32_t size = glsl_count_attribute_slots(var->type, false);
341 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
342 SpvId array_length = spirv_builder_const_uint(&ctx->builder, 32, size);
343 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
344 array_length);
345 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
346
347 // wrap UBO-array in a struct
348 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
349 if (var->name) {
350 char struct_name[100];
351 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
352 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
353 }
354
355 spirv_builder_emit_decoration(&ctx->builder, struct_type,
356 SpvDecorationBlock);
357 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
358
359
360 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
361 SpvStorageClassUniform,
362 struct_type);
363
364 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
365 SpvStorageClassUniform);
366 if (var->name)
367 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
368
369 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
370 ctx->ubos[ctx->num_ubos++] = var_id;
371
372 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
373 var->data.descriptor_set);
374 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
375 }
376
377 static void
378 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
379 {
380 if (glsl_type_is_sampler(var->type))
381 emit_sampler(ctx, var);
382 else if (var->interface_type)
383 emit_ubo(ctx, var);
384 }
385
386 static SpvId
387 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
388 {
389 assert(ssa->index < ctx->num_defs);
390 assert(ctx->defs[ssa->index] != 0);
391 return ctx->defs[ssa->index];
392 }
393
394 static SpvId
395 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
396 {
397 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, reg);
398 if (!he) {
399 SpvId type = get_uvec_type(ctx, reg->bit_size, reg->num_components);
400 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
401 SpvStorageClassFunction,
402 type);
403
404 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
405 SpvStorageClassFunction);
406
407 he = _mesa_hash_table_insert(ctx->vars, reg, (void *)(intptr_t)var);
408 }
409 return (SpvId)(intptr_t)he->data;
410 }
411
412 static SpvId
413 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
414 {
415 assert(reg->reg);
416 assert(!reg->indirect);
417 assert(!reg->base_offset);
418
419 SpvId var = get_var_from_reg(ctx, reg->reg);
420 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
421 return spirv_builder_emit_load(&ctx->builder, type, var);
422 }
423
424 static SpvId
425 get_src_uint(struct ntv_context *ctx, nir_src *src)
426 {
427 if (src->is_ssa)
428 return get_src_uint_ssa(ctx, src->ssa);
429 else
430 return get_src_uint_reg(ctx, &src->reg);
431 }
432
433 static SpvId
434 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
435 {
436 assert(!alu->src[src].negate);
437 assert(!alu->src[src].abs);
438
439 SpvId def = get_src_uint(ctx, &alu->src[src].src);
440
441 unsigned used_channels = 0;
442 bool need_swizzle = false;
443 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
444 if (!nir_alu_instr_channel_used(alu, src, i))
445 continue;
446
447 used_channels++;
448
449 if (alu->src[src].swizzle[i] != i)
450 need_swizzle = true;
451 }
452 assert(used_channels != 0);
453
454 unsigned live_channels = nir_src_num_components(alu->src[src].src);
455 if (used_channels != live_channels)
456 need_swizzle = true;
457
458 if (!need_swizzle)
459 return def;
460
461 int bit_size = nir_src_bit_size(alu->src[src].src);
462
463 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
464 if (used_channels == 1) {
465 uint32_t indices[] = { alu->src[src].swizzle[0] };
466 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
467 def, indices,
468 ARRAY_SIZE(indices));
469 } else if (live_channels == 1) {
470 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
471 used_channels);
472
473 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
474 for (unsigned i = 0; i < used_channels; ++i)
475 constituents[i] = def;
476
477 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
478 constituents,
479 used_channels);
480 } else {
481 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
482 used_channels);
483
484 uint32_t components[NIR_MAX_VEC_COMPONENTS];
485 size_t num_components = 0;
486 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
487 if (!nir_alu_instr_channel_used(alu, src, i))
488 continue;
489
490 components[num_components++] = alu->src[src].swizzle[i];
491 }
492
493 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
494 def, def, components, num_components);
495 }
496 }
497
498 static void
499 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
500 {
501 assert(result != 0);
502 assert(ssa->index < ctx->num_defs);
503 ctx->defs[ssa->index] = result;
504 }
505
506 static SpvId
507 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
508 {
509 SpvId otype = get_uvec_type(ctx, 32, num_components);
510 uint32_t zeros[4] = { 0, 0, 0, 0 };
511 uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
512 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
513 SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
514 return emit_triop(ctx, SpvOpSelect, otype, value, one, zero);
515 }
516
517 static SpvId
518 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
519 {
520 SpvId type = get_bvec_type(ctx, num_components);
521
522 uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
523 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
524
525 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
526 }
527
528 static SpvId
529 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
530 unsigned num_components)
531 {
532 SpvId type = get_uvec_type(ctx, bit_size, num_components);
533 return emit_unop(ctx, SpvOpBitcast, type, value);
534 }
535
536 static SpvId
537 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
538 unsigned num_components)
539 {
540 SpvId type = get_ivec_type(ctx, bit_size, num_components);
541 return emit_unop(ctx, SpvOpBitcast, type, value);
542 }
543
544 static SpvId
545 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
546 unsigned num_components)
547 {
548 SpvId type = get_fvec_type(ctx, bit_size, num_components);
549 return emit_unop(ctx, SpvOpBitcast, type, value);
550 }
551
552 static void
553 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
554 {
555 SpvId var = get_var_from_reg(ctx, reg->reg);
556 assert(var);
557 spirv_builder_emit_store(&ctx->builder, var, result);
558 }
559
560 static void
561 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
562 {
563 if (dest->is_ssa)
564 store_ssa_def_uint(ctx, &dest->ssa, result);
565 else
566 store_reg_def(ctx, &dest->reg, result);
567 }
568
569 static void
570 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
571 {
572 unsigned num_components = nir_dest_num_components(*dest);
573 unsigned bit_size = nir_dest_bit_size(*dest);
574
575 switch (nir_alu_type_get_base_type(type)) {
576 case nir_type_bool:
577 assert(bit_size == 1);
578 result = bvec_to_uvec(ctx, result, num_components);
579 break;
580
581 case nir_type_uint:
582 break; /* nothing to do! */
583
584 case nir_type_int:
585 case nir_type_float:
586 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
587 break;
588
589 default:
590 unreachable("unsupported nir_alu_type");
591 }
592
593 store_dest_uint(ctx, dest, result);
594 }
595
596 static SpvId
597 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
598 {
599 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
600 }
601
602 static SpvId
603 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
604 SpvId src0, SpvId src1)
605 {
606 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
607 }
608
609 static SpvId
610 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
611 SpvId src0, SpvId src1, SpvId src2)
612 {
613 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
614 }
615
616 static SpvId
617 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
618 SpvId src)
619 {
620 SpvId args[] = { src };
621 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
622 op, args, ARRAY_SIZE(args));
623 }
624
625 static SpvId
626 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
627 SpvId src0, SpvId src1)
628 {
629 SpvId args[] = { src0, src1 };
630 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
631 op, args, ARRAY_SIZE(args));
632 }
633
634 static SpvId
635 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
636 const float values[])
637 {
638 assert(bit_size == 32);
639
640 if (num_components > 1) {
641 SpvId components[num_components];
642 for (int i = 0; i < num_components; i++)
643 components[i] = spirv_builder_const_float(&ctx->builder, bit_size,
644 values[i]);
645
646 SpvId type = get_fvec_type(ctx, bit_size, num_components);
647 return spirv_builder_const_composite(&ctx->builder, type, components,
648 num_components);
649 }
650
651 assert(num_components == 1);
652 return spirv_builder_const_float(&ctx->builder, bit_size, values[0]);
653 }
654
655 static SpvId
656 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
657 const uint32_t values[])
658 {
659 assert(bit_size == 32);
660
661 if (num_components > 1) {
662 SpvId components[num_components];
663 for (int i = 0; i < num_components; i++)
664 components[i] = spirv_builder_const_uint(&ctx->builder, bit_size,
665 values[i]);
666
667 SpvId type = get_uvec_type(ctx, bit_size, num_components);
668 return spirv_builder_const_composite(&ctx->builder, type, components,
669 num_components);
670 }
671
672 assert(num_components == 1);
673 return spirv_builder_const_uint(&ctx->builder, bit_size, values[0]);
674 }
675
676 static inline unsigned
677 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
678 {
679 if (nir_op_infos[instr->op].input_sizes[src] > 0)
680 return nir_op_infos[instr->op].input_sizes[src];
681
682 if (instr->dest.dest.is_ssa)
683 return instr->dest.dest.ssa.num_components;
684 else
685 return instr->dest.dest.reg.reg->num_components;
686 }
687
688 static SpvId
689 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
690 {
691 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
692
693 unsigned num_components = alu_instr_src_components(alu, src);
694 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
695 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
696
697 switch (nir_alu_type_get_base_type(type)) {
698 case nir_type_bool:
699 assert(bit_size == 1);
700 return uvec_to_bvec(ctx, uint_value, num_components);
701
702 case nir_type_int:
703 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
704
705 case nir_type_uint:
706 return uint_value;
707
708 case nir_type_float:
709 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
710
711 default:
712 unreachable("unknown nir_alu_type");
713 }
714 }
715
716 static void
717 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
718 {
719 assert(!alu->dest.saturate);
720 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
721 }
722
723 static SpvId
724 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
725 {
726 unsigned num_components = nir_dest_num_components(*dest);
727 unsigned bit_size = nir_dest_bit_size(*dest);
728
729 switch (nir_alu_type_get_base_type(type)) {
730 case nir_type_bool:
731 return get_bvec_type(ctx, num_components);
732
733 case nir_type_int:
734 return get_ivec_type(ctx, bit_size, num_components);
735
736 case nir_type_uint:
737 return get_uvec_type(ctx, bit_size, num_components);
738
739 case nir_type_float:
740 return get_fvec_type(ctx, bit_size, num_components);
741
742 default:
743 unreachable("unsupported nir_alu_type");
744 }
745 }
746
747 static void
748 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
749 {
750 SpvId src[nir_op_infos[alu->op].num_inputs];
751 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
752 src[i] = get_alu_src(ctx, alu, i);
753
754 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
755 nir_op_infos[alu->op].output_type);
756 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
757 unsigned num_components = nir_dest_num_components(alu->dest.dest);
758
759 SpvId result = 0;
760 switch (alu->op) {
761 case nir_op_mov:
762 assert(nir_op_infos[alu->op].num_inputs == 1);
763 result = src[0];
764 break;
765
766 #define UNOP(nir_op, spirv_op) \
767 case nir_op: \
768 assert(nir_op_infos[alu->op].num_inputs == 1); \
769 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
770 break;
771
772 #define BUILTIN_UNOP(nir_op, spirv_op) \
773 case nir_op: \
774 assert(nir_op_infos[alu->op].num_inputs == 1); \
775 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
776 break;
777
778 UNOP(nir_op_fneg, SpvOpFNegate)
779 UNOP(nir_op_fddx, SpvOpDPdx)
780 UNOP(nir_op_fddy, SpvOpDPdy)
781
782 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
783 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
784 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
785 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
786 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
787 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
788 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
789 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
790 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
791 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
792 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
793 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
794 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
795
796 case nir_op_frcp: {
797 assert(nir_op_infos[alu->op].num_inputs == 1);
798 float one[4] = { 1, 1, 1, 1 };
799 src[1] = src[0];
800 src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
801 result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
802 }
803 break;
804
805 #undef UNOP
806 #undef BUILTIN_UNOP
807
808 #define BINOP(nir_op, spirv_op) \
809 case nir_op: \
810 assert(nir_op_infos[alu->op].num_inputs == 2); \
811 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
812 break;
813
814 #define BUILTIN_BINOP(nir_op, spirv_op) \
815 case nir_op: \
816 assert(nir_op_infos[alu->op].num_inputs == 2); \
817 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
818 break;
819
820 BINOP(nir_op_iadd, SpvOpIAdd)
821 BINOP(nir_op_isub, SpvOpISub)
822 BINOP(nir_op_imul, SpvOpIMul)
823 BINOP(nir_op_fadd, SpvOpFAdd)
824 BINOP(nir_op_fsub, SpvOpFSub)
825 BINOP(nir_op_fmul, SpvOpFMul)
826 BINOP(nir_op_flt, SpvOpFUnordLessThan)
827 BINOP(nir_op_fge, SpvOpFUnordGreaterThanEqual)
828
829 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
830 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
831
832 #undef BINOP
833 #undef BUILTIN_BINOP
834
835 case nir_op_fdot2:
836 case nir_op_fdot3:
837 case nir_op_fdot4:
838 assert(nir_op_infos[alu->op].num_inputs == 2);
839 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
840 break;
841
842 case nir_op_seq:
843 case nir_op_sne:
844 case nir_op_slt:
845 case nir_op_sge: {
846 assert(nir_op_infos[alu->op].num_inputs == 2);
847 int num_components = nir_dest_num_components(alu->dest.dest);
848 SpvId bool_type = get_bvec_type(ctx, num_components);
849
850 SpvId zero = spirv_builder_const_float(&ctx->builder, 32, 0.0f);
851 SpvId one = spirv_builder_const_float(&ctx->builder, 32, 1.0f);
852 if (num_components > 1) {
853 SpvId zero_comps[num_components], one_comps[num_components];
854 for (int i = 0; i < num_components; i++) {
855 zero_comps[i] = zero;
856 one_comps[i] = one;
857 }
858
859 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
860 zero_comps, num_components);
861 one = spirv_builder_const_composite(&ctx->builder, dest_type,
862 one_comps, num_components);
863 }
864
865 SpvOp op;
866 switch (alu->op) {
867 case nir_op_seq: op = SpvOpFOrdEqual; break;
868 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
869 case nir_op_slt: op = SpvOpFOrdLessThan; break;
870 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
871 default: unreachable("unexpected op");
872 }
873
874 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
875 result = emit_triop(ctx, SpvOpSelect, dest_type, result, one, zero);
876 }
877 break;
878
879 case nir_op_fcsel: {
880 assert(nir_op_infos[alu->op].num_inputs == 3);
881 int num_components = nir_dest_num_components(alu->dest.dest);
882 SpvId bool_type = get_bvec_type(ctx, num_components);
883
884 float zero[4] = { 0, 0, 0, 0 };
885 SpvId cmp = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
886 num_components, zero);
887
888 result = emit_binop(ctx, SpvOpFOrdGreaterThan, bool_type, src[0], cmp);
889 result = emit_triop(ctx, SpvOpSelect, dest_type, result, src[1], src[2]);
890 }
891 break;
892
893 case nir_op_vec2:
894 case nir_op_vec3:
895 case nir_op_vec4: {
896 int num_inputs = nir_op_infos[alu->op].num_inputs;
897 assert(2 <= num_inputs && num_inputs <= 4);
898 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
899 src, num_inputs);
900 }
901 break;
902
903 default:
904 fprintf(stderr, "emit_alu: not implemented (%s)\n",
905 nir_op_infos[alu->op].name);
906
907 unreachable("unsupported opcode");
908 return;
909 }
910
911 store_alu_result(ctx, alu, result);
912 }
913
914 static void
915 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
916 {
917 uint32_t values[NIR_MAX_VEC_COMPONENTS];
918 for (int i = 0; i < load_const->def.num_components; ++i)
919 values[i] = load_const->value[i].u32;
920
921 SpvId constant = get_uvec_constant(ctx, load_const->def.bit_size,
922 load_const->def.num_components,
923 values);
924 store_ssa_def_uint(ctx, &load_const->def, constant);
925 }
926
927 static void
928 emit_load_input(struct ntv_context *ctx, nir_intrinsic_instr *intr)
929 {
930 nir_const_value *const_offset = nir_src_as_const_value(intr->src[0]);
931 if (const_offset) {
932 int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
933 assert(driver_location < PIPE_MAX_SHADER_INPUTS);
934 int location_frac = nir_intrinsic_component(intr);
935 assert(location_frac < 4);
936
937 SpvId ptr = ctx->inputs[driver_location][location_frac];
938 SpvId type = ctx->input_types[driver_location][location_frac];
939 assert(ptr && type);
940
941 SpvId result = spirv_builder_emit_load(&ctx->builder, type, ptr);
942
943 unsigned num_components = nir_dest_num_components(intr->dest);
944 unsigned bit_size = nir_dest_bit_size(intr->dest);
945 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
946
947 store_dest_uint(ctx, &intr->dest, result);
948 } else
949 unreachable("input-addressing not yet supported");
950 }
951
952 static void
953 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
954 {
955 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
956 assert(const_block_index); // no dynamic indexing for now
957 assert(const_block_index->u32 == 0); // we only support the default UBO for now
958
959 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
960 if (const_offset) {
961 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
962 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
963 SpvStorageClassUniform,
964 uvec4_type);
965
966 unsigned idx = const_offset->u32;
967 SpvId member = spirv_builder_const_uint(&ctx->builder, 32, 0);
968 SpvId offset = spirv_builder_const_uint(&ctx->builder, 32, idx);
969 SpvId offsets[] = { member, offset };
970 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
971 ctx->ubos[0], offsets,
972 ARRAY_SIZE(offsets));
973 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
974
975 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
976 unsigned num_components = nir_dest_num_components(intr->dest);
977 if (num_components == 1) {
978 uint32_t components[] = { 0 };
979 result = spirv_builder_emit_composite_extract(&ctx->builder,
980 type,
981 result, components,
982 1);
983 } else if (num_components < 4) {
984 SpvId constituents[num_components];
985 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
986 for (uint32_t i = 0; i < num_components; ++i)
987 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
988 uint_type,
989 result, &i,
990 1);
991
992 result = spirv_builder_emit_composite_construct(&ctx->builder,
993 type,
994 constituents,
995 num_components);
996 }
997
998 store_dest_uint(ctx, &intr->dest, result);
999 } else
1000 unreachable("uniform-addressing not yet supported");
1001 }
1002
1003 static void
1004 emit_store_output(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1005 {
1006 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1007 if (const_offset) {
1008 int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
1009 assert(driver_location < PIPE_MAX_SHADER_OUTPUTS);
1010 int location_frac = nir_intrinsic_component(intr);
1011 assert(location_frac < 4);
1012
1013 SpvId ptr = ctx->outputs[driver_location][location_frac];
1014 assert(ptr > 0);
1015
1016 SpvId src = get_src_uint(ctx, &intr->src[0]);
1017 SpvId spirv_type = ctx->output_types[driver_location][location_frac];
1018 SpvId result = emit_unop(ctx, SpvOpBitcast, spirv_type, src);
1019 spirv_builder_emit_store(&ctx->builder, ptr, result);
1020 } else
1021 unreachable("output-addressing not yet supported");
1022 }
1023
1024 static void
1025 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1026 {
1027 switch (intr->intrinsic) {
1028 case nir_intrinsic_load_input:
1029 emit_load_input(ctx, intr);
1030 break;
1031
1032 case nir_intrinsic_load_ubo:
1033 emit_load_ubo(ctx, intr);
1034 break;
1035
1036 case nir_intrinsic_store_output:
1037 emit_store_output(ctx, intr);
1038 break;
1039
1040 default:
1041 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1042 nir_intrinsic_infos[intr->intrinsic].name);
1043 unreachable("unsupported intrinsic");
1044 }
1045 }
1046
1047 static void
1048 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1049 {
1050 SpvId type = get_fvec_type(ctx, undef->def.bit_size,
1051 undef->def.num_components);
1052
1053 store_ssa_def_uint(ctx, &undef->def,
1054 spirv_builder_emit_undef(&ctx->builder, type));
1055 }
1056
1057 static SpvId
1058 get_src_float(struct ntv_context *ctx, nir_src *src)
1059 {
1060 SpvId def = get_src_uint(ctx, src);
1061 unsigned num_components = nir_src_num_components(*src);
1062 unsigned bit_size = nir_src_bit_size(*src);
1063 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1064 }
1065
1066 static void
1067 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1068 {
1069 assert(tex->op == nir_texop_tex);
1070 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1071 assert(tex->texture_index == tex->sampler_index);
1072
1073 bool has_proj = false;
1074 SpvId coord = 0, proj;
1075 unsigned coord_components;
1076 for (unsigned i = 0; i < tex->num_srcs; i++) {
1077 switch (tex->src[i].src_type) {
1078 case nir_tex_src_coord:
1079 coord = get_src_float(ctx, &tex->src[i].src);
1080 coord_components = nir_src_num_components(tex->src[i].src);
1081 break;
1082
1083 case nir_tex_src_projector:
1084 has_proj = true;
1085 proj = get_src_float(ctx, &tex->src[i].src);
1086 assert(nir_src_num_components(tex->src[i].src) == 1);
1087 break;
1088
1089 default:
1090 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1091 unreachable("unknown texture source");
1092 }
1093 }
1094
1095 bool is_ms;
1096 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1097 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1098 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1099 dimension, false, tex->is_array, is_ms, 1,
1100 SpvImageFormatUnknown);
1101 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1102 image_type);
1103
1104 assert(tex->texture_index < ctx->num_samplers);
1105 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1106 ctx->samplers[tex->texture_index]);
1107
1108 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1109
1110 SpvId result;
1111 if (has_proj) {
1112 SpvId constituents[coord_components + 1];
1113 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1114 for (uint32_t i = 0; i < coord_components; ++i)
1115 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1116 float_type,
1117 coord,
1118 &i, 1);
1119
1120 constituents[coord_components++] = proj;
1121
1122 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1123 SpvId merged = spirv_builder_emit_composite_construct(&ctx->builder,
1124 vec_type,
1125 constituents,
1126 coord_components);
1127
1128 result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
1129 dest_type,
1130 load,
1131 merged);
1132 } else
1133 result = spirv_builder_emit_image_sample_implicit_lod(&ctx->builder,
1134 dest_type, load,
1135 coord);
1136 spirv_builder_emit_decoration(&ctx->builder, result,
1137 SpvDecorationRelaxedPrecision);
1138
1139 store_dest(ctx, &tex->dest, result, tex->dest_type);
1140 }
1141
1142 static void
1143 start_block(struct ntv_context *ctx, SpvId label)
1144 {
1145 /* terminate previous block if needed */
1146 if (ctx->block_started)
1147 spirv_builder_emit_branch(&ctx->builder, label);
1148
1149 /* start new block */
1150 spirv_builder_label(&ctx->builder, label);
1151 ctx->block_started = true;
1152 }
1153
1154 static void
1155 branch(struct ntv_context *ctx, SpvId label)
1156 {
1157 assert(ctx->block_started);
1158 spirv_builder_emit_branch(&ctx->builder, label);
1159 ctx->block_started = false;
1160 }
1161
1162 static void
1163 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1164 SpvId else_id)
1165 {
1166 assert(ctx->block_started);
1167 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1168 then_id, else_id);
1169 ctx->block_started = false;
1170 }
1171
1172 static void
1173 emit_block(struct ntv_context *ctx, struct nir_block *block)
1174 {
1175 start_block(ctx, block_label(ctx, block));
1176 nir_foreach_instr(instr, block) {
1177 switch (instr->type) {
1178 case nir_instr_type_alu:
1179 emit_alu(ctx, nir_instr_as_alu(instr));
1180 break;
1181 case nir_instr_type_intrinsic:
1182 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1183 break;
1184 case nir_instr_type_load_const:
1185 emit_load_const(ctx, nir_instr_as_load_const(instr));
1186 break;
1187 case nir_instr_type_ssa_undef:
1188 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1189 break;
1190 case nir_instr_type_tex:
1191 emit_tex(ctx, nir_instr_as_tex(instr));
1192 break;
1193 case nir_instr_type_phi:
1194 unreachable("nir_instr_type_phi not supported");
1195 break;
1196 case nir_instr_type_jump:
1197 unreachable("nir_instr_type_jump not supported");
1198 break;
1199 case nir_instr_type_call:
1200 unreachable("nir_instr_type_call not supported");
1201 break;
1202 case nir_instr_type_parallel_copy:
1203 unreachable("nir_instr_type_parallel_copy not supported");
1204 break;
1205 case nir_instr_type_deref:
1206 unreachable("nir_instr_type_deref not supported");
1207 break;
1208 }
1209 }
1210 }
1211
1212 static void
1213 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1214
1215 static SpvId
1216 get_src_bool(struct ntv_context *ctx, nir_src *src)
1217 {
1218 SpvId def = get_src_uint(ctx, src);
1219 assert(nir_src_bit_size(*src) == 32);
1220 unsigned num_components = nir_src_num_components(*src);
1221 return uvec_to_bvec(ctx, def, num_components);
1222 }
1223
1224 static void
1225 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1226 {
1227 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1228
1229 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1230 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1231 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1232 SpvId else_id = endif_id;
1233
1234 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1235 if (has_else) {
1236 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1237 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1238 }
1239
1240 /* create a header-block */
1241 start_block(ctx, header_id);
1242 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1243 SpvSelectionControlMaskNone);
1244 branch_conditional(ctx, condition, then_id, else_id);
1245
1246 emit_cf_list(ctx, &if_stmt->then_list);
1247
1248 if (has_else) {
1249 branch(ctx, endif_id);
1250 emit_cf_list(ctx, &if_stmt->else_list);
1251 }
1252
1253 start_block(ctx, endif_id);
1254 }
1255
1256 static void
1257 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1258 {
1259 foreach_list_typed(nir_cf_node, node, node, list) {
1260 switch (node->type) {
1261 case nir_cf_node_block:
1262 emit_block(ctx, nir_cf_node_as_block(node));
1263 break;
1264
1265 case nir_cf_node_if:
1266 emit_if(ctx, nir_cf_node_as_if(node));
1267 break;
1268
1269 case nir_cf_node_loop:
1270 unreachable("nir_cf_node_loop not supported");
1271 break;
1272
1273 case nir_cf_node_function:
1274 unreachable("nir_cf_node_function not supported");
1275 break;
1276 }
1277 }
1278 }
1279
1280 struct spirv_shader *
1281 nir_to_spirv(struct nir_shader *s)
1282 {
1283 struct spirv_shader *ret = NULL;
1284
1285 struct ntv_context ctx = {};
1286
1287 switch (s->info.stage) {
1288 case MESA_SHADER_VERTEX:
1289 case MESA_SHADER_FRAGMENT:
1290 case MESA_SHADER_COMPUTE:
1291 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1292 break;
1293
1294 case MESA_SHADER_TESS_CTRL:
1295 case MESA_SHADER_TESS_EVAL:
1296 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1297 break;
1298
1299 case MESA_SHADER_GEOMETRY:
1300 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1301 break;
1302
1303 default:
1304 unreachable("invalid stage");
1305 }
1306
1307 ctx.stage = s->info.stage;
1308 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1309 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1310
1311 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1312 SpvMemoryModelGLSL450);
1313
1314 SpvExecutionModel exec_model;
1315 switch (s->info.stage) {
1316 case MESA_SHADER_VERTEX:
1317 exec_model = SpvExecutionModelVertex;
1318 break;
1319 case MESA_SHADER_TESS_CTRL:
1320 exec_model = SpvExecutionModelTessellationControl;
1321 break;
1322 case MESA_SHADER_TESS_EVAL:
1323 exec_model = SpvExecutionModelTessellationEvaluation;
1324 break;
1325 case MESA_SHADER_GEOMETRY:
1326 exec_model = SpvExecutionModelGeometry;
1327 break;
1328 case MESA_SHADER_FRAGMENT:
1329 exec_model = SpvExecutionModelFragment;
1330 break;
1331 case MESA_SHADER_COMPUTE:
1332 exec_model = SpvExecutionModelGLCompute;
1333 break;
1334 default:
1335 unreachable("invalid stage");
1336 }
1337
1338 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1339 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1340 NULL, 0);
1341 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1342 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1343
1344 nir_foreach_variable(var, &s->inputs)
1345 emit_input(&ctx, var);
1346
1347 nir_foreach_variable(var, &s->outputs)
1348 emit_output(&ctx, var);
1349
1350 nir_foreach_variable(var, &s->uniforms)
1351 emit_uniform(&ctx, var);
1352
1353 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1354 "main", ctx.entry_ifaces,
1355 ctx.num_entry_ifaces);
1356 if (s->info.stage == MESA_SHADER_FRAGMENT)
1357 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1358 SpvExecutionModeOriginUpperLeft);
1359
1360
1361 spirv_builder_function(&ctx.builder, entry_point, type_void,
1362 SpvFunctionControlMaskNone,
1363 type_main);
1364
1365 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1366 nir_metadata_require(entry, nir_metadata_block_index);
1367
1368 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1369 if (!ctx.defs)
1370 goto fail;
1371 ctx.num_defs = entry->ssa_alloc;
1372
1373 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1374 _mesa_key_pointer_equal);
1375 if (!ctx.vars)
1376 goto fail;
1377
1378 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1379 if (!block_ids)
1380 goto fail;
1381
1382 for (int i = 0; i < entry->num_blocks; ++i)
1383 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1384
1385 ctx.block_ids = block_ids;
1386 ctx.num_blocks = entry->num_blocks;
1387
1388 emit_cf_list(&ctx, &entry->body);
1389
1390 free(ctx.defs);
1391
1392 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1393 spirv_builder_function_end(&ctx.builder);
1394
1395 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1396
1397 ret = CALLOC_STRUCT(spirv_shader);
1398 if (!ret)
1399 goto fail;
1400
1401 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1402 if (!ret->words)
1403 goto fail;
1404
1405 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1406 assert(ret->num_words == num_words);
1407
1408 return ret;
1409
1410 fail:
1411
1412 if (ret)
1413 spirv_shader_delete(ret);
1414
1415 return NULL;
1416 }
1417
1418 void
1419 spirv_shader_delete(struct spirv_shader *s)
1420 {
1421 FREE(s->words);
1422 FREE(s);
1423 }