zink: more comparison-ops
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38 int var_location;
39
40 SpvId ubos[128];
41 size_t num_ubos;
42 SpvId samplers[PIPE_MAX_SAMPLERS];
43 size_t num_samplers;
44 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
45 size_t num_entry_ifaces;
46
47 SpvId *defs;
48 size_t num_defs;
49
50 SpvId *regs;
51 size_t num_regs;
52
53 struct hash_table *vars; /* nir_variable -> SpvId */
54
55 const SpvId *block_ids;
56 size_t num_blocks;
57 bool block_started;
58 SpvId loop_break, loop_cont;
59 };
60
61 static SpvId
62 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
63 const float values[]);
64
65 static SpvId
66 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
67 const uint32_t values[]);
68
69 static SpvId
70 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
71
72 static SpvId
73 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
74 SpvId src0, SpvId src1);
75
76 static SpvId
77 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
78 SpvId src0, SpvId src1, SpvId src2);
79
80 static SpvId
81 get_bvec_type(struct ntv_context *ctx, int num_components)
82 {
83 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
84 if (num_components > 1)
85 return spirv_builder_type_vector(&ctx->builder, bool_type,
86 num_components);
87
88 assert(num_components == 1);
89 return bool_type;
90 }
91
92 static SpvId
93 block_label(struct ntv_context *ctx, nir_block *block)
94 {
95 assert(block->index < ctx->num_blocks);
96 return ctx->block_ids[block->index];
97 }
98
99 static SpvId
100 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
101 {
102 assert(bit_size == 32); // only 32-bit floats supported so far
103
104 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
105 if (num_components > 1)
106 return spirv_builder_type_vector(&ctx->builder, float_type,
107 num_components);
108
109 assert(num_components == 1);
110 return float_type;
111 }
112
113 static SpvId
114 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
115 {
116 assert(bit_size == 32); // only 32-bit ints supported so far
117
118 SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
119 if (num_components > 1)
120 return spirv_builder_type_vector(&ctx->builder, int_type,
121 num_components);
122
123 assert(num_components == 1);
124 return int_type;
125 }
126
127 static SpvId
128 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
129 {
130 assert(bit_size == 32); // only 32-bit uints supported so far
131
132 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
133 if (num_components > 1)
134 return spirv_builder_type_vector(&ctx->builder, uint_type,
135 num_components);
136
137 assert(num_components == 1);
138 return uint_type;
139 }
140
141 static SpvId
142 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
143 {
144 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
145 nir_dest_num_components(*dest));
146 }
147
148 static SpvId
149 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
150 {
151 switch (type) {
152 case GLSL_TYPE_FLOAT:
153 return spirv_builder_type_float(&ctx->builder, 32);
154
155 case GLSL_TYPE_INT:
156 return spirv_builder_type_int(&ctx->builder, 32);
157
158 case GLSL_TYPE_UINT:
159 return spirv_builder_type_uint(&ctx->builder, 32);
160 /* TODO: handle more types */
161
162 default:
163 unreachable("unknown GLSL type");
164 }
165 }
166
167 static SpvId
168 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
169 {
170 assert(type);
171 if (glsl_type_is_scalar(type))
172 return get_glsl_basetype(ctx, glsl_get_base_type(type));
173
174 if (glsl_type_is_vector(type))
175 return spirv_builder_type_vector(&ctx->builder,
176 get_glsl_basetype(ctx, glsl_get_base_type(type)),
177 glsl_get_vector_elements(type));
178
179 if (glsl_type_is_array(type)) {
180 SpvId ret = spirv_builder_type_array(&ctx->builder,
181 get_glsl_type(ctx, glsl_get_array_element(type)),
182 spirv_builder_const_uint(&ctx->builder, 32, glsl_get_length(type)));
183 uint32_t stride = glsl_get_explicit_stride(type);
184 if (stride)
185 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
186 return ret;
187 }
188
189
190 unreachable("we shouldn't get here, I think...");
191 }
192
193 static void
194 emit_input(struct ntv_context *ctx, struct nir_variable *var)
195 {
196 SpvId var_type = get_glsl_type(ctx, var->type);
197 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
198 SpvStorageClassInput,
199 var_type);
200 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
201 SpvStorageClassInput);
202
203 if (var->name)
204 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
205
206 if (ctx->stage == MESA_SHADER_FRAGMENT) {
207 if (var->data.location >= VARYING_SLOT_VAR0 ||
208 (var->data.location >= VARYING_SLOT_COL0 &&
209 var->data.location <= VARYING_SLOT_TEX7)) {
210 spirv_builder_emit_location(&ctx->builder, var_id,
211 ctx->var_location++);
212 } else {
213 switch (var->data.location) {
214 case VARYING_SLOT_POS:
215 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
216 break;
217
218 case VARYING_SLOT_PNTC:
219 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
220 break;
221
222 default:
223 unreachable("unknown varying slot");
224 }
225 }
226 } else {
227 spirv_builder_emit_location(&ctx->builder, var_id,
228 var->data.driver_location);
229 }
230
231 if (var->data.location_frac)
232 spirv_builder_emit_component(&ctx->builder, var_id,
233 var->data.location_frac);
234
235 if (var->data.interpolation == INTERP_MODE_FLAT)
236 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
237
238 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
239
240 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
241 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
242 }
243
244 static void
245 emit_output(struct ntv_context *ctx, struct nir_variable *var)
246 {
247 SpvId var_type = get_glsl_type(ctx, var->type);
248 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
249 SpvStorageClassOutput,
250 var_type);
251 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
252 SpvStorageClassOutput);
253 if (var->name)
254 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
255
256
257 if (ctx->stage == MESA_SHADER_VERTEX) {
258 if (var->data.location >= VARYING_SLOT_VAR0 ||
259 (var->data.location >= VARYING_SLOT_COL0 &&
260 var->data.location <= VARYING_SLOT_TEX7)) {
261 spirv_builder_emit_location(&ctx->builder, var_id,
262 ctx->var_location++);
263 } else {
264 switch (var->data.location) {
265 case VARYING_SLOT_POS:
266 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
267 break;
268
269 case VARYING_SLOT_PSIZ:
270 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
271 break;
272
273 case VARYING_SLOT_CLIP_DIST0:
274 assert(glsl_type_is_array(var->type));
275 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
276 break;
277
278 default:
279 unreachable("unknown varying slot");
280 }
281 }
282 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
283 switch (var->data.location) {
284 case FRAG_RESULT_DEPTH:
285 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
286 break;
287
288 default:
289 spirv_builder_emit_location(&ctx->builder, var_id,
290 var->data.driver_location);
291 }
292 }
293
294 if (var->data.location_frac)
295 spirv_builder_emit_component(&ctx->builder, var_id,
296 var->data.location_frac);
297
298 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
299
300 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
301 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
302 }
303
304 static SpvDim
305 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
306 {
307 *is_ms = false;
308 switch (gdim) {
309 case GLSL_SAMPLER_DIM_1D:
310 return SpvDim1D;
311 case GLSL_SAMPLER_DIM_2D:
312 return SpvDim2D;
313 case GLSL_SAMPLER_DIM_RECT:
314 return SpvDimRect;
315 case GLSL_SAMPLER_DIM_CUBE:
316 return SpvDimCube;
317 case GLSL_SAMPLER_DIM_3D:
318 return SpvDim3D;
319 case GLSL_SAMPLER_DIM_MS:
320 *is_ms = true;
321 return SpvDim2D;
322 default:
323 fprintf(stderr, "unknown sampler type %d\n", gdim);
324 break;
325 }
326 return SpvDim2D;
327 }
328
329 static void
330 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
331 {
332 bool is_ms;
333 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
334 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
335 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
336 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
337 SpvImageFormatUnknown);
338
339 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
340 image_type);
341 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
342 SpvStorageClassUniformConstant,
343 sampled_type);
344 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
345 SpvStorageClassUniformConstant);
346
347 if (var->name)
348 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
349
350 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
351 ctx->samplers[ctx->num_samplers++] = var_id;
352
353 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
354 var->data.descriptor_set);
355 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
356 }
357
358 static void
359 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
360 {
361 uint32_t size = glsl_count_attribute_slots(var->type, false);
362 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
363 SpvId array_length = spirv_builder_const_uint(&ctx->builder, 32, size);
364 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
365 array_length);
366 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
367
368 // wrap UBO-array in a struct
369 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
370 if (var->name) {
371 char struct_name[100];
372 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
373 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
374 }
375
376 spirv_builder_emit_decoration(&ctx->builder, struct_type,
377 SpvDecorationBlock);
378 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
379
380
381 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
382 SpvStorageClassUniform,
383 struct_type);
384
385 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
386 SpvStorageClassUniform);
387 if (var->name)
388 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
389
390 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
391 ctx->ubos[ctx->num_ubos++] = var_id;
392
393 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
394 var->data.descriptor_set);
395 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
396 }
397
398 static void
399 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
400 {
401 if (var->data.mode == nir_var_mem_ubo)
402 emit_ubo(ctx, var);
403 else {
404 assert(var->data.mode == nir_var_uniform);
405 if (glsl_type_is_sampler(var->type))
406 emit_sampler(ctx, var);
407 }
408 }
409
410 static SpvId
411 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
412 {
413 assert(ssa->index < ctx->num_defs);
414 assert(ctx->defs[ssa->index] != 0);
415 return ctx->defs[ssa->index];
416 }
417
418 static SpvId
419 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
420 {
421 assert(reg->index < ctx->num_regs);
422 assert(ctx->regs[reg->index] != 0);
423 return ctx->regs[reg->index];
424 }
425
426 static SpvId
427 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
428 {
429 assert(reg->reg);
430 assert(!reg->indirect);
431 assert(!reg->base_offset);
432
433 SpvId var = get_var_from_reg(ctx, reg->reg);
434 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
435 return spirv_builder_emit_load(&ctx->builder, type, var);
436 }
437
438 static SpvId
439 get_src_uint(struct ntv_context *ctx, nir_src *src)
440 {
441 if (src->is_ssa)
442 return get_src_uint_ssa(ctx, src->ssa);
443 else
444 return get_src_uint_reg(ctx, &src->reg);
445 }
446
447 static SpvId
448 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
449 {
450 assert(!alu->src[src].negate);
451 assert(!alu->src[src].abs);
452
453 SpvId def = get_src_uint(ctx, &alu->src[src].src);
454
455 unsigned used_channels = 0;
456 bool need_swizzle = false;
457 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
458 if (!nir_alu_instr_channel_used(alu, src, i))
459 continue;
460
461 used_channels++;
462
463 if (alu->src[src].swizzle[i] != i)
464 need_swizzle = true;
465 }
466 assert(used_channels != 0);
467
468 unsigned live_channels = nir_src_num_components(alu->src[src].src);
469 if (used_channels != live_channels)
470 need_swizzle = true;
471
472 if (!need_swizzle)
473 return def;
474
475 int bit_size = nir_src_bit_size(alu->src[src].src);
476
477 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
478 if (used_channels == 1) {
479 uint32_t indices[] = { alu->src[src].swizzle[0] };
480 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
481 def, indices,
482 ARRAY_SIZE(indices));
483 } else if (live_channels == 1) {
484 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
485 used_channels);
486
487 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
488 for (unsigned i = 0; i < used_channels; ++i)
489 constituents[i] = def;
490
491 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
492 constituents,
493 used_channels);
494 } else {
495 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
496 used_channels);
497
498 uint32_t components[NIR_MAX_VEC_COMPONENTS];
499 size_t num_components = 0;
500 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
501 if (!nir_alu_instr_channel_used(alu, src, i))
502 continue;
503
504 components[num_components++] = alu->src[src].swizzle[i];
505 }
506
507 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
508 def, def, components, num_components);
509 }
510 }
511
512 static void
513 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
514 {
515 assert(result != 0);
516 assert(ssa->index < ctx->num_defs);
517 ctx->defs[ssa->index] = result;
518 }
519
520 static SpvId
521 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
522 {
523 SpvId otype = get_uvec_type(ctx, 32, num_components);
524 uint32_t zeros[4] = { 0, 0, 0, 0 };
525 uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
526 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
527 SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
528 return emit_triop(ctx, SpvOpSelect, otype, value, one, zero);
529 }
530
531 static SpvId
532 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
533 {
534 SpvId type = get_bvec_type(ctx, num_components);
535
536 uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
537 SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
538
539 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
540 }
541
542 static SpvId
543 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
544 unsigned num_components)
545 {
546 SpvId type = get_uvec_type(ctx, bit_size, num_components);
547 return emit_unop(ctx, SpvOpBitcast, type, value);
548 }
549
550 static SpvId
551 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
552 unsigned num_components)
553 {
554 SpvId type = get_ivec_type(ctx, bit_size, num_components);
555 return emit_unop(ctx, SpvOpBitcast, type, value);
556 }
557
558 static SpvId
559 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
560 unsigned num_components)
561 {
562 SpvId type = get_fvec_type(ctx, bit_size, num_components);
563 return emit_unop(ctx, SpvOpBitcast, type, value);
564 }
565
566 static void
567 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
568 {
569 SpvId var = get_var_from_reg(ctx, reg->reg);
570 assert(var);
571 spirv_builder_emit_store(&ctx->builder, var, result);
572 }
573
574 static void
575 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
576 {
577 if (dest->is_ssa)
578 store_ssa_def_uint(ctx, &dest->ssa, result);
579 else
580 store_reg_def(ctx, &dest->reg, result);
581 }
582
583 static void
584 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
585 {
586 unsigned num_components = nir_dest_num_components(*dest);
587 unsigned bit_size = nir_dest_bit_size(*dest);
588
589 switch (nir_alu_type_get_base_type(type)) {
590 case nir_type_bool:
591 assert(bit_size == 1);
592 result = bvec_to_uvec(ctx, result, num_components);
593 break;
594
595 case nir_type_uint:
596 break; /* nothing to do! */
597
598 case nir_type_int:
599 case nir_type_float:
600 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
601 break;
602
603 default:
604 unreachable("unsupported nir_alu_type");
605 }
606
607 store_dest_uint(ctx, dest, result);
608 }
609
610 static SpvId
611 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
612 {
613 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
614 }
615
616 static SpvId
617 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
618 SpvId src0, SpvId src1)
619 {
620 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
621 }
622
623 static SpvId
624 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
625 SpvId src0, SpvId src1, SpvId src2)
626 {
627 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
628 }
629
630 static SpvId
631 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
632 SpvId src)
633 {
634 SpvId args[] = { src };
635 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
636 op, args, ARRAY_SIZE(args));
637 }
638
639 static SpvId
640 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
641 SpvId src0, SpvId src1)
642 {
643 SpvId args[] = { src0, src1 };
644 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
645 op, args, ARRAY_SIZE(args));
646 }
647
648 static SpvId
649 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
650 const float values[])
651 {
652 assert(bit_size == 32);
653
654 if (num_components > 1) {
655 SpvId components[num_components];
656 for (int i = 0; i < num_components; i++)
657 components[i] = spirv_builder_const_float(&ctx->builder, bit_size,
658 values[i]);
659
660 SpvId type = get_fvec_type(ctx, bit_size, num_components);
661 return spirv_builder_const_composite(&ctx->builder, type, components,
662 num_components);
663 }
664
665 assert(num_components == 1);
666 return spirv_builder_const_float(&ctx->builder, bit_size, values[0]);
667 }
668
669 static SpvId
670 get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
671 const uint32_t values[])
672 {
673 assert(bit_size == 32);
674
675 if (num_components > 1) {
676 SpvId components[num_components];
677 for (int i = 0; i < num_components; i++)
678 components[i] = spirv_builder_const_uint(&ctx->builder, bit_size,
679 values[i]);
680
681 SpvId type = get_uvec_type(ctx, bit_size, num_components);
682 return spirv_builder_const_composite(&ctx->builder, type, components,
683 num_components);
684 }
685
686 assert(num_components == 1);
687 return spirv_builder_const_uint(&ctx->builder, bit_size, values[0]);
688 }
689
690 static inline unsigned
691 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
692 {
693 if (nir_op_infos[instr->op].input_sizes[src] > 0)
694 return nir_op_infos[instr->op].input_sizes[src];
695
696 if (instr->dest.dest.is_ssa)
697 return instr->dest.dest.ssa.num_components;
698 else
699 return instr->dest.dest.reg.reg->num_components;
700 }
701
702 static SpvId
703 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
704 {
705 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
706
707 unsigned num_components = alu_instr_src_components(alu, src);
708 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
709 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
710
711 switch (nir_alu_type_get_base_type(type)) {
712 case nir_type_bool:
713 assert(bit_size == 1);
714 return uvec_to_bvec(ctx, uint_value, num_components);
715
716 case nir_type_int:
717 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
718
719 case nir_type_uint:
720 return uint_value;
721
722 case nir_type_float:
723 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
724
725 default:
726 unreachable("unknown nir_alu_type");
727 }
728 }
729
730 static void
731 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
732 {
733 assert(!alu->dest.saturate);
734 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
735 }
736
737 static SpvId
738 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
739 {
740 unsigned num_components = nir_dest_num_components(*dest);
741 unsigned bit_size = nir_dest_bit_size(*dest);
742
743 switch (nir_alu_type_get_base_type(type)) {
744 case nir_type_bool:
745 return get_bvec_type(ctx, num_components);
746
747 case nir_type_int:
748 return get_ivec_type(ctx, bit_size, num_components);
749
750 case nir_type_uint:
751 return get_uvec_type(ctx, bit_size, num_components);
752
753 case nir_type_float:
754 return get_fvec_type(ctx, bit_size, num_components);
755
756 default:
757 unreachable("unsupported nir_alu_type");
758 }
759 }
760
761 static void
762 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
763 {
764 SpvId src[nir_op_infos[alu->op].num_inputs];
765 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
766 src[i] = get_alu_src(ctx, alu, i);
767
768 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
769 nir_op_infos[alu->op].output_type);
770 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
771 unsigned num_components = nir_dest_num_components(alu->dest.dest);
772
773 SpvId result = 0;
774 switch (alu->op) {
775 case nir_op_mov:
776 assert(nir_op_infos[alu->op].num_inputs == 1);
777 result = src[0];
778 break;
779
780 #define UNOP(nir_op, spirv_op) \
781 case nir_op: \
782 assert(nir_op_infos[alu->op].num_inputs == 1); \
783 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
784 break;
785
786 UNOP(nir_op_ineg, SpvOpSNegate)
787 UNOP(nir_op_fneg, SpvOpFNegate)
788 UNOP(nir_op_fddx, SpvOpDPdx)
789 UNOP(nir_op_fddy, SpvOpDPdy)
790 UNOP(nir_op_f2i32, SpvOpConvertFToS)
791 UNOP(nir_op_f2u32, SpvOpConvertFToU)
792 #undef UNOP
793
794 #define BUILTIN_UNOP(nir_op, spirv_op) \
795 case nir_op: \
796 assert(nir_op_infos[alu->op].num_inputs == 1); \
797 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
798 break;
799
800 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
801 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
802 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
803 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
804 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
805 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
806 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
807 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
808 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
809 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
810 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
811 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
812 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
813 #undef BUILTIN_UNOP
814
815 case nir_op_frcp: {
816 assert(nir_op_infos[alu->op].num_inputs == 1);
817 float one[4] = { 1, 1, 1, 1 };
818 src[1] = src[0];
819 src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
820 result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
821 }
822 break;
823
824 #define BINOP(nir_op, spirv_op) \
825 case nir_op: \
826 assert(nir_op_infos[alu->op].num_inputs == 2); \
827 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
828 break;
829
830 BINOP(nir_op_iadd, SpvOpIAdd)
831 BINOP(nir_op_isub, SpvOpISub)
832 BINOP(nir_op_imul, SpvOpIMul)
833 BINOP(nir_op_idiv, SpvOpSDiv)
834 BINOP(nir_op_udiv, SpvOpUDiv)
835 BINOP(nir_op_fadd, SpvOpFAdd)
836 BINOP(nir_op_fsub, SpvOpFSub)
837 BINOP(nir_op_fmul, SpvOpFMul)
838 BINOP(nir_op_fdiv, SpvOpFDiv)
839 BINOP(nir_op_fmod, SpvOpFMod)
840 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
841 BINOP(nir_op_ieq, SpvOpIEqual)
842 BINOP(nir_op_ine, SpvOpINotEqual)
843 BINOP(nir_op_flt, SpvOpFUnordLessThan)
844 BINOP(nir_op_fge, SpvOpFUnordGreaterThanEqual)
845 BINOP(nir_op_feq, SpvOpFOrdEqual)
846 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
847 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
848 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
849 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
850 #undef BINOP
851
852 #define BUILTIN_BINOP(nir_op, spirv_op) \
853 case nir_op: \
854 assert(nir_op_infos[alu->op].num_inputs == 2); \
855 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
856 break;
857
858 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
859 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
860 #undef BUILTIN_BINOP
861
862 case nir_op_fdot2:
863 case nir_op_fdot3:
864 case nir_op_fdot4:
865 assert(nir_op_infos[alu->op].num_inputs == 2);
866 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
867 break;
868
869 case nir_op_seq:
870 case nir_op_sne:
871 case nir_op_slt:
872 case nir_op_sge: {
873 assert(nir_op_infos[alu->op].num_inputs == 2);
874 int num_components = nir_dest_num_components(alu->dest.dest);
875 SpvId bool_type = get_bvec_type(ctx, num_components);
876
877 SpvId zero = spirv_builder_const_float(&ctx->builder, 32, 0.0f);
878 SpvId one = spirv_builder_const_float(&ctx->builder, 32, 1.0f);
879 if (num_components > 1) {
880 SpvId zero_comps[num_components], one_comps[num_components];
881 for (int i = 0; i < num_components; i++) {
882 zero_comps[i] = zero;
883 one_comps[i] = one;
884 }
885
886 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
887 zero_comps, num_components);
888 one = spirv_builder_const_composite(&ctx->builder, dest_type,
889 one_comps, num_components);
890 }
891
892 SpvOp op;
893 switch (alu->op) {
894 case nir_op_seq: op = SpvOpFOrdEqual; break;
895 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
896 case nir_op_slt: op = SpvOpFOrdLessThan; break;
897 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
898 default: unreachable("unexpected op");
899 }
900
901 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
902 result = emit_triop(ctx, SpvOpSelect, dest_type, result, one, zero);
903 }
904 break;
905
906 case nir_op_fcsel: {
907 assert(nir_op_infos[alu->op].num_inputs == 3);
908 int num_components = nir_dest_num_components(alu->dest.dest);
909 SpvId bool_type = get_bvec_type(ctx, num_components);
910
911 float zero[4] = { 0, 0, 0, 0 };
912 SpvId cmp = get_fvec_constant(ctx, nir_src_bit_size(alu->src[0].src),
913 num_components, zero);
914
915 result = emit_binop(ctx, SpvOpFOrdGreaterThan, bool_type, src[0], cmp);
916 result = emit_triop(ctx, SpvOpSelect, dest_type, result, src[1], src[2]);
917 }
918 break;
919
920 case nir_op_vec2:
921 case nir_op_vec3:
922 case nir_op_vec4: {
923 int num_inputs = nir_op_infos[alu->op].num_inputs;
924 assert(2 <= num_inputs && num_inputs <= 4);
925 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
926 src, num_inputs);
927 }
928 break;
929
930 default:
931 fprintf(stderr, "emit_alu: not implemented (%s)\n",
932 nir_op_infos[alu->op].name);
933
934 unreachable("unsupported opcode");
935 return;
936 }
937
938 store_alu_result(ctx, alu, result);
939 }
940
941 static void
942 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
943 {
944 uint32_t values[NIR_MAX_VEC_COMPONENTS];
945 for (int i = 0; i < load_const->def.num_components; ++i)
946 values[i] = load_const->value[i].u32;
947
948 SpvId constant = get_uvec_constant(ctx, load_const->def.bit_size,
949 load_const->def.num_components,
950 values);
951 store_ssa_def_uint(ctx, &load_const->def, constant);
952 }
953
954 static void
955 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
956 {
957 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
958 assert(const_block_index); // no dynamic indexing for now
959 assert(const_block_index->u32 == 0); // we only support the default UBO for now
960
961 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
962 if (const_offset) {
963 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
964 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
965 SpvStorageClassUniform,
966 uvec4_type);
967
968 unsigned idx = const_offset->u32;
969 SpvId member = spirv_builder_const_uint(&ctx->builder, 32, 0);
970 SpvId offset = spirv_builder_const_uint(&ctx->builder, 32, idx);
971 SpvId offsets[] = { member, offset };
972 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
973 ctx->ubos[0], offsets,
974 ARRAY_SIZE(offsets));
975 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
976
977 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
978 unsigned num_components = nir_dest_num_components(intr->dest);
979 if (num_components == 1) {
980 uint32_t components[] = { 0 };
981 result = spirv_builder_emit_composite_extract(&ctx->builder,
982 type,
983 result, components,
984 1);
985 } else if (num_components < 4) {
986 SpvId constituents[num_components];
987 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
988 for (uint32_t i = 0; i < num_components; ++i)
989 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
990 uint_type,
991 result, &i,
992 1);
993
994 result = spirv_builder_emit_composite_construct(&ctx->builder,
995 type,
996 constituents,
997 num_components);
998 }
999
1000 store_dest_uint(ctx, &intr->dest, result);
1001 } else
1002 unreachable("uniform-addressing not yet supported");
1003 }
1004
1005 static void
1006 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1007 {
1008 assert(ctx->block_started);
1009 spirv_builder_emit_kill(&ctx->builder);
1010 /* discard is weird in NIR, so let's just create an unreachable block after
1011 it and hope that the vulkan driver will DCE any instructinos in it. */
1012 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1013 }
1014
1015 static void
1016 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1017 {
1018 /* uint is a bit of a lie here; it's really just a pointer */
1019 SpvId ptr = get_src_uint(ctx, intr->src);
1020
1021 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1022 SpvId result = spirv_builder_emit_load(&ctx->builder,
1023 get_glsl_type(ctx, var->type),
1024 ptr);
1025 unsigned num_components = nir_dest_num_components(intr->dest);
1026 unsigned bit_size = nir_dest_bit_size(intr->dest);
1027 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1028 store_dest_uint(ctx, &intr->dest, result);
1029 }
1030
1031 static void
1032 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1033 {
1034 /* uint is a bit of a lie here; it's really just a pointer */
1035 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1036 SpvId src = get_src_uint(ctx, &intr->src[1]);
1037
1038 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1039 SpvId result = emit_unop(ctx, SpvOpBitcast,
1040 get_glsl_type(ctx, glsl_without_array(var->type)),
1041 src);
1042 spirv_builder_emit_store(&ctx->builder, ptr, result);
1043 }
1044
1045 static void
1046 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1047 {
1048 switch (intr->intrinsic) {
1049 case nir_intrinsic_load_ubo:
1050 emit_load_ubo(ctx, intr);
1051 break;
1052
1053 case nir_intrinsic_discard:
1054 emit_discard(ctx, intr);
1055 break;
1056
1057 case nir_intrinsic_load_deref:
1058 emit_load_deref(ctx, intr);
1059 break;
1060
1061 case nir_intrinsic_store_deref:
1062 emit_store_deref(ctx, intr);
1063 break;
1064
1065 default:
1066 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1067 nir_intrinsic_infos[intr->intrinsic].name);
1068 unreachable("unsupported intrinsic");
1069 }
1070 }
1071
1072 static void
1073 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1074 {
1075 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1076 undef->def.num_components);
1077
1078 store_ssa_def_uint(ctx, &undef->def,
1079 spirv_builder_emit_undef(&ctx->builder, type));
1080 }
1081
1082 static SpvId
1083 get_src_float(struct ntv_context *ctx, nir_src *src)
1084 {
1085 SpvId def = get_src_uint(ctx, src);
1086 unsigned num_components = nir_src_num_components(*src);
1087 unsigned bit_size = nir_src_bit_size(*src);
1088 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1089 }
1090
1091 static void
1092 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1093 {
1094 assert(tex->op == nir_texop_tex);
1095 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1096 assert(tex->texture_index == tex->sampler_index);
1097
1098 bool has_proj = false, has_lod = false;
1099 SpvId coord = 0, proj, lod;
1100 unsigned coord_components;
1101 for (unsigned i = 0; i < tex->num_srcs; i++) {
1102 switch (tex->src[i].src_type) {
1103 case nir_tex_src_coord:
1104 coord = get_src_float(ctx, &tex->src[i].src);
1105 coord_components = nir_src_num_components(tex->src[i].src);
1106 break;
1107
1108 case nir_tex_src_projector:
1109 has_proj = true;
1110 proj = get_src_float(ctx, &tex->src[i].src);
1111 assert(nir_src_num_components(tex->src[i].src) == 1);
1112 break;
1113
1114 case nir_tex_src_lod:
1115 has_lod = true;
1116 lod = get_src_float(ctx, &tex->src[i].src);
1117 assert(nir_src_num_components(tex->src[i].src) == 1);
1118 break;
1119
1120 default:
1121 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1122 unreachable("unknown texture source");
1123 }
1124 }
1125
1126 if (!has_lod && ctx->stage != MESA_SHADER_FRAGMENT) {
1127 has_lod = true;
1128 lod = spirv_builder_const_float(&ctx->builder, 32, 0);
1129 }
1130
1131 bool is_ms;
1132 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1133 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1134 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1135 dimension, false, tex->is_array, is_ms, 1,
1136 SpvImageFormatUnknown);
1137 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1138 image_type);
1139
1140 assert(tex->texture_index < ctx->num_samplers);
1141 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1142 ctx->samplers[tex->texture_index]);
1143
1144 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1145
1146 SpvId result;
1147 if (has_proj) {
1148 SpvId constituents[coord_components + 1];
1149 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1150 for (uint32_t i = 0; i < coord_components; ++i)
1151 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1152 float_type,
1153 coord,
1154 &i, 1);
1155
1156 constituents[coord_components++] = proj;
1157
1158 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1159 SpvId merged = spirv_builder_emit_composite_construct(&ctx->builder,
1160 vec_type,
1161 constituents,
1162 coord_components);
1163
1164 if (has_lod)
1165 result = spirv_builder_emit_image_sample_proj_explicit_lod(&ctx->builder,
1166 dest_type,
1167 load,
1168 merged,
1169 lod);
1170 else
1171 result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
1172 dest_type,
1173 load,
1174 merged);
1175 } else {
1176 if (has_lod)
1177 result = spirv_builder_emit_image_sample_explicit_lod(&ctx->builder,
1178 dest_type,
1179 load,
1180 coord, lod);
1181 else
1182 result = spirv_builder_emit_image_sample_implicit_lod(&ctx->builder,
1183 dest_type,
1184 load,
1185 coord);
1186 }
1187 spirv_builder_emit_decoration(&ctx->builder, result,
1188 SpvDecorationRelaxedPrecision);
1189
1190 store_dest(ctx, &tex->dest, result, tex->dest_type);
1191 }
1192
1193 static void
1194 start_block(struct ntv_context *ctx, SpvId label)
1195 {
1196 /* terminate previous block if needed */
1197 if (ctx->block_started)
1198 spirv_builder_emit_branch(&ctx->builder, label);
1199
1200 /* start new block */
1201 spirv_builder_label(&ctx->builder, label);
1202 ctx->block_started = true;
1203 }
1204
1205 static void
1206 branch(struct ntv_context *ctx, SpvId label)
1207 {
1208 assert(ctx->block_started);
1209 spirv_builder_emit_branch(&ctx->builder, label);
1210 ctx->block_started = false;
1211 }
1212
1213 static void
1214 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1215 SpvId else_id)
1216 {
1217 assert(ctx->block_started);
1218 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1219 then_id, else_id);
1220 ctx->block_started = false;
1221 }
1222
1223 static void
1224 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1225 {
1226 switch (jump->type) {
1227 case nir_jump_break:
1228 assert(ctx->loop_break);
1229 branch(ctx, ctx->loop_break);
1230 break;
1231
1232 case nir_jump_continue:
1233 assert(ctx->loop_cont);
1234 branch(ctx, ctx->loop_cont);
1235 break;
1236
1237 default:
1238 unreachable("Unsupported jump type\n");
1239 }
1240 }
1241
1242 static void
1243 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1244 {
1245 assert(deref->deref_type == nir_deref_type_var);
1246
1247 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1248 assert(he);
1249 SpvId result = (SpvId)(intptr_t)he->data;
1250 /* uint is a bit of a lie here, it's really just an opaque type */
1251 store_dest_uint(ctx, &deref->dest, result);
1252 }
1253
1254 static void
1255 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1256 {
1257 assert(deref->deref_type == nir_deref_type_array);
1258 nir_variable *var = nir_deref_instr_get_variable(deref);
1259
1260 SpvStorageClass storage_class;
1261 switch (var->data.mode) {
1262 case nir_var_shader_in:
1263 storage_class = SpvStorageClassInput;
1264 break;
1265
1266 case nir_var_shader_out:
1267 storage_class = SpvStorageClassOutput;
1268 break;
1269
1270 default:
1271 unreachable("Unsupported nir_variable_mode\n");
1272 }
1273
1274 SpvId index = get_src_uint(ctx, &deref->arr.index);
1275
1276 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1277 storage_class,
1278 get_glsl_type(ctx, deref->type));
1279
1280 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1281 ptr_type,
1282 get_src_uint(ctx, &deref->parent),
1283 &index, 1);
1284 /* uint is a bit of a lie here, it's really just an opaque type */
1285 store_dest_uint(ctx, &deref->dest, result);
1286 }
1287
1288 static void
1289 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1290 {
1291 switch (deref->deref_type) {
1292 case nir_deref_type_var:
1293 emit_deref_var(ctx, deref);
1294 break;
1295
1296 case nir_deref_type_array:
1297 emit_deref_array(ctx, deref);
1298 break;
1299
1300 default:
1301 unreachable("unexpected deref_type");
1302 }
1303 }
1304
1305 static void
1306 emit_block(struct ntv_context *ctx, struct nir_block *block)
1307 {
1308 start_block(ctx, block_label(ctx, block));
1309 nir_foreach_instr(instr, block) {
1310 switch (instr->type) {
1311 case nir_instr_type_alu:
1312 emit_alu(ctx, nir_instr_as_alu(instr));
1313 break;
1314 case nir_instr_type_intrinsic:
1315 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1316 break;
1317 case nir_instr_type_load_const:
1318 emit_load_const(ctx, nir_instr_as_load_const(instr));
1319 break;
1320 case nir_instr_type_ssa_undef:
1321 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1322 break;
1323 case nir_instr_type_tex:
1324 emit_tex(ctx, nir_instr_as_tex(instr));
1325 break;
1326 case nir_instr_type_phi:
1327 unreachable("nir_instr_type_phi not supported");
1328 break;
1329 case nir_instr_type_jump:
1330 emit_jump(ctx, nir_instr_as_jump(instr));
1331 break;
1332 case nir_instr_type_call:
1333 unreachable("nir_instr_type_call not supported");
1334 break;
1335 case nir_instr_type_parallel_copy:
1336 unreachable("nir_instr_type_parallel_copy not supported");
1337 break;
1338 case nir_instr_type_deref:
1339 emit_deref(ctx, nir_instr_as_deref(instr));
1340 break;
1341 }
1342 }
1343 }
1344
1345 static void
1346 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1347
1348 static SpvId
1349 get_src_bool(struct ntv_context *ctx, nir_src *src)
1350 {
1351 SpvId def = get_src_uint(ctx, src);
1352 assert(nir_src_bit_size(*src) == 32);
1353 unsigned num_components = nir_src_num_components(*src);
1354 return uvec_to_bvec(ctx, def, num_components);
1355 }
1356
1357 static void
1358 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1359 {
1360 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1361
1362 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1363 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1364 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1365 SpvId else_id = endif_id;
1366
1367 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1368 if (has_else) {
1369 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1370 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1371 }
1372
1373 /* create a header-block */
1374 start_block(ctx, header_id);
1375 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1376 SpvSelectionControlMaskNone);
1377 branch_conditional(ctx, condition, then_id, else_id);
1378
1379 emit_cf_list(ctx, &if_stmt->then_list);
1380
1381 if (has_else) {
1382 if (ctx->block_started)
1383 branch(ctx, endif_id);
1384
1385 emit_cf_list(ctx, &if_stmt->else_list);
1386 }
1387
1388 start_block(ctx, endif_id);
1389 }
1390
1391 static void
1392 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1393 {
1394 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1395 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1396 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1397 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1398
1399 /* create a header-block */
1400 start_block(ctx, header_id);
1401 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1402 branch(ctx, begin_id);
1403
1404 SpvId save_break = ctx->loop_break;
1405 SpvId save_cont = ctx->loop_cont;
1406 ctx->loop_break = break_id;
1407 ctx->loop_cont = cont_id;
1408
1409 emit_cf_list(ctx, &loop->body);
1410
1411 ctx->loop_break = save_break;
1412 ctx->loop_cont = save_cont;
1413
1414 branch(ctx, cont_id);
1415 start_block(ctx, cont_id);
1416 branch(ctx, header_id);
1417
1418 start_block(ctx, break_id);
1419 }
1420
1421 static void
1422 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1423 {
1424 foreach_list_typed(nir_cf_node, node, node, list) {
1425 switch (node->type) {
1426 case nir_cf_node_block:
1427 emit_block(ctx, nir_cf_node_as_block(node));
1428 break;
1429
1430 case nir_cf_node_if:
1431 emit_if(ctx, nir_cf_node_as_if(node));
1432 break;
1433
1434 case nir_cf_node_loop:
1435 emit_loop(ctx, nir_cf_node_as_loop(node));
1436 break;
1437
1438 case nir_cf_node_function:
1439 unreachable("nir_cf_node_function not supported");
1440 break;
1441 }
1442 }
1443 }
1444
1445 struct spirv_shader *
1446 nir_to_spirv(struct nir_shader *s)
1447 {
1448 struct spirv_shader *ret = NULL;
1449
1450 struct ntv_context ctx = {};
1451
1452 switch (s->info.stage) {
1453 case MESA_SHADER_VERTEX:
1454 case MESA_SHADER_FRAGMENT:
1455 case MESA_SHADER_COMPUTE:
1456 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1457 break;
1458
1459 case MESA_SHADER_TESS_CTRL:
1460 case MESA_SHADER_TESS_EVAL:
1461 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1462 break;
1463
1464 case MESA_SHADER_GEOMETRY:
1465 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1466 break;
1467
1468 default:
1469 unreachable("invalid stage");
1470 }
1471
1472 ctx.stage = s->info.stage;
1473 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1474 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1475
1476 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1477 SpvMemoryModelGLSL450);
1478
1479 SpvExecutionModel exec_model;
1480 switch (s->info.stage) {
1481 case MESA_SHADER_VERTEX:
1482 exec_model = SpvExecutionModelVertex;
1483 break;
1484 case MESA_SHADER_TESS_CTRL:
1485 exec_model = SpvExecutionModelTessellationControl;
1486 break;
1487 case MESA_SHADER_TESS_EVAL:
1488 exec_model = SpvExecutionModelTessellationEvaluation;
1489 break;
1490 case MESA_SHADER_GEOMETRY:
1491 exec_model = SpvExecutionModelGeometry;
1492 break;
1493 case MESA_SHADER_FRAGMENT:
1494 exec_model = SpvExecutionModelFragment;
1495 break;
1496 case MESA_SHADER_COMPUTE:
1497 exec_model = SpvExecutionModelGLCompute;
1498 break;
1499 default:
1500 unreachable("invalid stage");
1501 }
1502
1503 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1504 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1505 NULL, 0);
1506 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1507 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1508
1509 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1510 _mesa_key_pointer_equal);
1511
1512 nir_foreach_variable(var, &s->inputs)
1513 emit_input(&ctx, var);
1514
1515 nir_foreach_variable(var, &s->outputs)
1516 emit_output(&ctx, var);
1517
1518 nir_foreach_variable(var, &s->uniforms)
1519 emit_uniform(&ctx, var);
1520
1521 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1522 "main", ctx.entry_ifaces,
1523 ctx.num_entry_ifaces);
1524 if (s->info.stage == MESA_SHADER_FRAGMENT)
1525 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1526 SpvExecutionModeOriginUpperLeft);
1527
1528
1529 spirv_builder_function(&ctx.builder, entry_point, type_void,
1530 SpvFunctionControlMaskNone,
1531 type_main);
1532
1533 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1534 nir_metadata_require(entry, nir_metadata_block_index);
1535
1536 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1537 if (!ctx.defs)
1538 goto fail;
1539 ctx.num_defs = entry->ssa_alloc;
1540
1541 nir_index_local_regs(entry);
1542 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1543 if (!ctx.regs)
1544 goto fail;
1545 ctx.num_regs = entry->reg_alloc;
1546
1547 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1548 if (!block_ids)
1549 goto fail;
1550
1551 for (int i = 0; i < entry->num_blocks; ++i)
1552 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1553
1554 ctx.block_ids = block_ids;
1555 ctx.num_blocks = entry->num_blocks;
1556
1557 /* emit a block only for the variable declarations */
1558 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1559 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1560 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1561 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1562 SpvStorageClassFunction,
1563 type);
1564 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1565 SpvStorageClassFunction);
1566
1567 ctx.regs[reg->index] = var;
1568 }
1569
1570 emit_cf_list(&ctx, &entry->body);
1571
1572 free(ctx.defs);
1573
1574 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1575 spirv_builder_function_end(&ctx.builder);
1576
1577 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1578
1579 ret = CALLOC_STRUCT(spirv_shader);
1580 if (!ret)
1581 goto fail;
1582
1583 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1584 if (!ret->words)
1585 goto fail;
1586
1587 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1588 assert(ret->num_words == num_words);
1589
1590 return ret;
1591
1592 fail:
1593
1594 if (ret)
1595 spirv_shader_delete(ret);
1596
1597 if (ctx.vars)
1598 _mesa_hash_table_destroy(ctx.vars, NULL);
1599
1600 return NULL;
1601 }
1602
1603 void
1604 spirv_shader_delete(struct spirv_shader *s)
1605 {
1606 FREE(s->words);
1607 FREE(s);
1608 }