zink/spirv: alias generic varyings on non-generic ones
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId samplers[PIPE_MAX_SAMPLERS];
42 size_t num_samplers;
43 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
44 size_t num_entry_ifaces;
45
46 SpvId *defs;
47 size_t num_defs;
48
49 SpvId *regs;
50 size_t num_regs;
51
52 struct hash_table *vars; /* nir_variable -> SpvId */
53
54 const SpvId *block_ids;
55 size_t num_blocks;
56 bool block_started;
57 SpvId loop_break, loop_cont;
58
59 SpvId front_face_var;
60 };
61
62 static SpvId
63 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
64 unsigned num_components, float value);
65
66 static SpvId
67 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
68 unsigned num_components, uint32_t value);
69
70 static SpvId
71 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
72 unsigned num_components, int32_t value);
73
74 static SpvId
75 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
76
77 static SpvId
78 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
79 SpvId src0, SpvId src1);
80
81 static SpvId
82 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
83 SpvId src0, SpvId src1, SpvId src2);
84
85 static SpvId
86 get_bvec_type(struct ntv_context *ctx, int num_components)
87 {
88 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
89 if (num_components > 1)
90 return spirv_builder_type_vector(&ctx->builder, bool_type,
91 num_components);
92
93 assert(num_components == 1);
94 return bool_type;
95 }
96
97 static SpvId
98 block_label(struct ntv_context *ctx, nir_block *block)
99 {
100 assert(block->index < ctx->num_blocks);
101 return ctx->block_ids[block->index];
102 }
103
104 static SpvId
105 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
106 {
107 assert(bit_size == 32);
108 return spirv_builder_const_float(&ctx->builder, bit_size, value);
109 }
110
111 static SpvId
112 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
113 {
114 assert(bit_size == 32);
115 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
116 }
117
118 static SpvId
119 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
120 {
121 assert(bit_size == 32);
122 return spirv_builder_const_int(&ctx->builder, bit_size, value);
123 }
124
125 static SpvId
126 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
127 {
128 assert(bit_size == 32); // only 32-bit floats supported so far
129
130 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
131 if (num_components > 1)
132 return spirv_builder_type_vector(&ctx->builder, float_type,
133 num_components);
134
135 assert(num_components == 1);
136 return float_type;
137 }
138
139 static SpvId
140 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
141 {
142 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
143
144 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
145 if (num_components > 1)
146 return spirv_builder_type_vector(&ctx->builder, int_type,
147 num_components);
148
149 assert(num_components == 1);
150 return int_type;
151 }
152
153 static SpvId
154 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
155 {
156 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
157
158 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
159 if (num_components > 1)
160 return spirv_builder_type_vector(&ctx->builder, uint_type,
161 num_components);
162
163 assert(num_components == 1);
164 return uint_type;
165 }
166
167 static SpvId
168 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
169 {
170 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
171 nir_dest_num_components(*dest));
172 }
173
174 static SpvId
175 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
176 {
177 switch (type) {
178 case GLSL_TYPE_BOOL:
179 return spirv_builder_type_bool(&ctx->builder);
180
181 case GLSL_TYPE_FLOAT:
182 return spirv_builder_type_float(&ctx->builder, 32);
183
184 case GLSL_TYPE_INT:
185 return spirv_builder_type_int(&ctx->builder, 32);
186
187 case GLSL_TYPE_UINT:
188 return spirv_builder_type_uint(&ctx->builder, 32);
189 /* TODO: handle more types */
190
191 default:
192 unreachable("unknown GLSL type");
193 }
194 }
195
196 static SpvId
197 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
198 {
199 assert(type);
200 if (glsl_type_is_scalar(type))
201 return get_glsl_basetype(ctx, glsl_get_base_type(type));
202
203 if (glsl_type_is_vector(type))
204 return spirv_builder_type_vector(&ctx->builder,
205 get_glsl_basetype(ctx, glsl_get_base_type(type)),
206 glsl_get_vector_elements(type));
207
208 if (glsl_type_is_array(type)) {
209 SpvId ret = spirv_builder_type_array(&ctx->builder,
210 get_glsl_type(ctx, glsl_get_array_element(type)),
211 emit_uint_const(ctx, 32, glsl_get_length(type)));
212 uint32_t stride = glsl_get_explicit_stride(type);
213 if (stride)
214 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
215 return ret;
216 }
217
218
219 unreachable("we shouldn't get here, I think...");
220 }
221
222 static void
223 emit_input(struct ntv_context *ctx, struct nir_variable *var)
224 {
225 SpvId var_type = get_glsl_type(ctx, var->type);
226 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
227 SpvStorageClassInput,
228 var_type);
229 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
230 SpvStorageClassInput);
231
232 if (var->name)
233 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
234
235 if (ctx->stage == MESA_SHADER_FRAGMENT) {
236 if (var->data.location >= VARYING_SLOT_VAR0)
237 spirv_builder_emit_location(&ctx->builder, var_id,
238 var->data.location - VARYING_SLOT_VAR0);
239 else if (var->data.location >= VARYING_SLOT_COL0 &&
240 var->data.location <= VARYING_SLOT_TEX7) {
241 spirv_builder_emit_location(&ctx->builder, var_id,
242 var->data.location);
243 } else {
244 switch (var->data.location) {
245 case VARYING_SLOT_POS:
246 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
247 break;
248
249 case VARYING_SLOT_PNTC:
250 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
251 break;
252
253 default:
254 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
255 unreachable("unexpected varying slot");
256 }
257 }
258 } else {
259 spirv_builder_emit_location(&ctx->builder, var_id,
260 var->data.driver_location);
261 }
262
263 if (var->data.location_frac)
264 spirv_builder_emit_component(&ctx->builder, var_id,
265 var->data.location_frac);
266
267 if (var->data.interpolation == INTERP_MODE_FLAT)
268 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
269
270 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
271
272 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
273 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
274 }
275
276 static void
277 emit_output(struct ntv_context *ctx, struct nir_variable *var)
278 {
279 SpvId var_type = get_glsl_type(ctx, var->type);
280 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
281 SpvStorageClassOutput,
282 var_type);
283 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
284 SpvStorageClassOutput);
285 if (var->name)
286 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
287
288
289 if (ctx->stage == MESA_SHADER_VERTEX) {
290 if (var->data.location >= VARYING_SLOT_VAR0)
291 spirv_builder_emit_location(&ctx->builder, var_id,
292 var->data.location - VARYING_SLOT_VAR0);
293 else if (var->data.location >= VARYING_SLOT_COL0 &&
294 var->data.location <= VARYING_SLOT_TEX7) {
295 spirv_builder_emit_location(&ctx->builder, var_id,
296 var->data.location);
297 } else {
298 switch (var->data.location) {
299 case VARYING_SLOT_POS:
300 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
301 break;
302
303 case VARYING_SLOT_PSIZ:
304 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
305 break;
306
307 case VARYING_SLOT_CLIP_DIST0:
308 assert(glsl_type_is_array(var->type));
309 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
310 break;
311
312 default:
313 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
314 unreachable("unexpected varying slot");
315 }
316 }
317 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
318 if (var->data.location >= FRAG_RESULT_DATA0)
319 spirv_builder_emit_location(&ctx->builder, var_id,
320 var->data.location - FRAG_RESULT_DATA0);
321 else {
322 switch (var->data.location) {
323 case FRAG_RESULT_COLOR:
324 spirv_builder_emit_location(&ctx->builder, var_id, 0);
325 break;
326
327 case FRAG_RESULT_DEPTH:
328 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
329 break;
330
331 default:
332 spirv_builder_emit_location(&ctx->builder, var_id,
333 var->data.driver_location);
334 }
335 }
336 }
337
338 if (var->data.location_frac)
339 spirv_builder_emit_component(&ctx->builder, var_id,
340 var->data.location_frac);
341
342 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
343
344 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
345 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
346 }
347
348 static SpvDim
349 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
350 {
351 *is_ms = false;
352 switch (gdim) {
353 case GLSL_SAMPLER_DIM_1D:
354 return SpvDim1D;
355 case GLSL_SAMPLER_DIM_2D:
356 return SpvDim2D;
357 case GLSL_SAMPLER_DIM_RECT:
358 return SpvDimRect;
359 case GLSL_SAMPLER_DIM_CUBE:
360 return SpvDimCube;
361 case GLSL_SAMPLER_DIM_3D:
362 return SpvDim3D;
363 case GLSL_SAMPLER_DIM_MS:
364 *is_ms = true;
365 return SpvDim2D;
366 default:
367 fprintf(stderr, "unknown sampler type %d\n", gdim);
368 break;
369 }
370 return SpvDim2D;
371 }
372
373 static void
374 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
375 {
376 bool is_ms;
377 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
378 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
379 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
380 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
381 SpvImageFormatUnknown);
382
383 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
384 image_type);
385 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
386 SpvStorageClassUniformConstant,
387 sampled_type);
388 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
389 SpvStorageClassUniformConstant);
390
391 if (var->name)
392 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
393
394 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
395 ctx->samplers[ctx->num_samplers++] = var_id;
396
397 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
398 var->data.descriptor_set);
399 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
400 }
401
402 static void
403 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
404 {
405 uint32_t size = glsl_count_attribute_slots(var->type, false);
406 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
407 SpvId array_length = emit_uint_const(ctx, 32, size);
408 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
409 array_length);
410 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
411
412 // wrap UBO-array in a struct
413 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
414 if (var->name) {
415 char struct_name[100];
416 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
417 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
418 }
419
420 spirv_builder_emit_decoration(&ctx->builder, struct_type,
421 SpvDecorationBlock);
422 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
423
424
425 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
426 SpvStorageClassUniform,
427 struct_type);
428
429 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
430 SpvStorageClassUniform);
431 if (var->name)
432 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
433
434 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
435 ctx->ubos[ctx->num_ubos++] = var_id;
436
437 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
438 var->data.descriptor_set);
439 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
440 }
441
442 static void
443 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
444 {
445 if (var->data.mode == nir_var_mem_ubo)
446 emit_ubo(ctx, var);
447 else {
448 assert(var->data.mode == nir_var_uniform);
449 if (glsl_type_is_sampler(var->type))
450 emit_sampler(ctx, var);
451 }
452 }
453
454 static SpvId
455 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
456 {
457 assert(ssa->index < ctx->num_defs);
458 assert(ctx->defs[ssa->index] != 0);
459 return ctx->defs[ssa->index];
460 }
461
462 static SpvId
463 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
464 {
465 assert(reg->index < ctx->num_regs);
466 assert(ctx->regs[reg->index] != 0);
467 return ctx->regs[reg->index];
468 }
469
470 static SpvId
471 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
472 {
473 assert(reg->reg);
474 assert(!reg->indirect);
475 assert(!reg->base_offset);
476
477 SpvId var = get_var_from_reg(ctx, reg->reg);
478 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
479 return spirv_builder_emit_load(&ctx->builder, type, var);
480 }
481
482 static SpvId
483 get_src_uint(struct ntv_context *ctx, nir_src *src)
484 {
485 if (src->is_ssa)
486 return get_src_uint_ssa(ctx, src->ssa);
487 else
488 return get_src_uint_reg(ctx, &src->reg);
489 }
490
491 static SpvId
492 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
493 {
494 assert(!alu->src[src].negate);
495 assert(!alu->src[src].abs);
496
497 SpvId def = get_src_uint(ctx, &alu->src[src].src);
498
499 unsigned used_channels = 0;
500 bool need_swizzle = false;
501 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
502 if (!nir_alu_instr_channel_used(alu, src, i))
503 continue;
504
505 used_channels++;
506
507 if (alu->src[src].swizzle[i] != i)
508 need_swizzle = true;
509 }
510 assert(used_channels != 0);
511
512 unsigned live_channels = nir_src_num_components(alu->src[src].src);
513 if (used_channels != live_channels)
514 need_swizzle = true;
515
516 if (!need_swizzle)
517 return def;
518
519 int bit_size = nir_src_bit_size(alu->src[src].src);
520 assert(bit_size == 1 || bit_size == 32);
521
522 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
523 if (used_channels == 1) {
524 uint32_t indices[] = { alu->src[src].swizzle[0] };
525 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
526 def, indices,
527 ARRAY_SIZE(indices));
528 } else if (live_channels == 1) {
529 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
530 used_channels);
531
532 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
533 for (unsigned i = 0; i < used_channels; ++i)
534 constituents[i] = def;
535
536 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
537 constituents,
538 used_channels);
539 } else {
540 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
541 used_channels);
542
543 uint32_t components[NIR_MAX_VEC_COMPONENTS];
544 size_t num_components = 0;
545 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
546 if (!nir_alu_instr_channel_used(alu, src, i))
547 continue;
548
549 components[num_components++] = alu->src[src].swizzle[i];
550 }
551
552 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
553 def, def, components, num_components);
554 }
555 }
556
557 static void
558 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
559 {
560 assert(result != 0);
561 assert(ssa->index < ctx->num_defs);
562 ctx->defs[ssa->index] = result;
563 }
564
565 static SpvId
566 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
567 SpvId if_true, SpvId if_false)
568 {
569 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
570 }
571
572 static SpvId
573 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
574 {
575 SpvId otype = get_uvec_type(ctx, 32, num_components);
576 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
577 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
578 return emit_select(ctx, otype, value, one, zero);
579 }
580
581 static SpvId
582 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
583 {
584 SpvId type = get_bvec_type(ctx, num_components);
585 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
586 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
587 }
588
589 static SpvId
590 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
591 {
592 return emit_unop(ctx, SpvOpBitcast, type, value);
593 }
594
595 static SpvId
596 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
597 unsigned num_components)
598 {
599 SpvId type = get_uvec_type(ctx, bit_size, num_components);
600 return emit_bitcast(ctx, type, value);
601 }
602
603 static SpvId
604 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
605 unsigned num_components)
606 {
607 SpvId type = get_ivec_type(ctx, bit_size, num_components);
608 return emit_bitcast(ctx, type, value);
609 }
610
611 static SpvId
612 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
613 unsigned num_components)
614 {
615 SpvId type = get_fvec_type(ctx, bit_size, num_components);
616 return emit_bitcast(ctx, type, value);
617 }
618
619 static void
620 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
621 {
622 SpvId var = get_var_from_reg(ctx, reg->reg);
623 assert(var);
624 spirv_builder_emit_store(&ctx->builder, var, result);
625 }
626
627 static void
628 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
629 {
630 if (dest->is_ssa)
631 store_ssa_def_uint(ctx, &dest->ssa, result);
632 else
633 store_reg_def(ctx, &dest->reg, result);
634 }
635
636 static void
637 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
638 {
639 unsigned num_components = nir_dest_num_components(*dest);
640 unsigned bit_size = nir_dest_bit_size(*dest);
641
642 switch (nir_alu_type_get_base_type(type)) {
643 case nir_type_bool:
644 assert(bit_size == 1);
645 result = bvec_to_uvec(ctx, result, num_components);
646 break;
647
648 case nir_type_uint:
649 break; /* nothing to do! */
650
651 case nir_type_int:
652 case nir_type_float:
653 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
654 break;
655
656 default:
657 unreachable("unsupported nir_alu_type");
658 }
659
660 store_dest_uint(ctx, dest, result);
661 }
662
663 static SpvId
664 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
665 {
666 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
667 }
668
669 static SpvId
670 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
671 SpvId src0, SpvId src1)
672 {
673 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
674 }
675
676 static SpvId
677 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
678 SpvId src0, SpvId src1, SpvId src2)
679 {
680 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
681 }
682
683 static SpvId
684 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
685 SpvId src)
686 {
687 SpvId args[] = { src };
688 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
689 op, args, ARRAY_SIZE(args));
690 }
691
692 static SpvId
693 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
694 SpvId src0, SpvId src1)
695 {
696 SpvId args[] = { src0, src1 };
697 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
698 op, args, ARRAY_SIZE(args));
699 }
700
701 static SpvId
702 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
703 unsigned num_components, float value)
704 {
705 assert(bit_size == 32);
706
707 SpvId result = emit_float_const(ctx, bit_size, value);
708 if (num_components == 1)
709 return result;
710
711 assert(num_components > 1);
712 SpvId components[num_components];
713 for (int i = 0; i < num_components; i++)
714 components[i] = result;
715
716 SpvId type = get_fvec_type(ctx, bit_size, num_components);
717 return spirv_builder_const_composite(&ctx->builder, type, components,
718 num_components);
719 }
720
721 static SpvId
722 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
723 unsigned num_components, uint32_t value)
724 {
725 assert(bit_size == 32);
726
727 SpvId result = emit_uint_const(ctx, bit_size, value);
728 if (num_components == 1)
729 return result;
730
731 assert(num_components > 1);
732 SpvId components[num_components];
733 for (int i = 0; i < num_components; i++)
734 components[i] = result;
735
736 SpvId type = get_uvec_type(ctx, bit_size, num_components);
737 return spirv_builder_const_composite(&ctx->builder, type, components,
738 num_components);
739 }
740
741 static SpvId
742 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
743 unsigned num_components, int32_t value)
744 {
745 assert(bit_size == 32);
746
747 SpvId result = emit_int_const(ctx, bit_size, value);
748 if (num_components == 1)
749 return result;
750
751 assert(num_components > 1);
752 SpvId components[num_components];
753 for (int i = 0; i < num_components; i++)
754 components[i] = result;
755
756 SpvId type = get_ivec_type(ctx, bit_size, num_components);
757 return spirv_builder_const_composite(&ctx->builder, type, components,
758 num_components);
759 }
760
761 static inline unsigned
762 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
763 {
764 if (nir_op_infos[instr->op].input_sizes[src] > 0)
765 return nir_op_infos[instr->op].input_sizes[src];
766
767 if (instr->dest.dest.is_ssa)
768 return instr->dest.dest.ssa.num_components;
769 else
770 return instr->dest.dest.reg.reg->num_components;
771 }
772
773 static SpvId
774 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
775 {
776 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
777
778 unsigned num_components = alu_instr_src_components(alu, src);
779 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
780 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
781
782 switch (nir_alu_type_get_base_type(type)) {
783 case nir_type_bool:
784 assert(bit_size == 1);
785 return uvec_to_bvec(ctx, uint_value, num_components);
786
787 case nir_type_int:
788 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
789
790 case nir_type_uint:
791 return uint_value;
792
793 case nir_type_float:
794 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
795
796 default:
797 unreachable("unknown nir_alu_type");
798 }
799 }
800
801 static void
802 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
803 {
804 assert(!alu->dest.saturate);
805 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
806 }
807
808 static SpvId
809 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
810 {
811 unsigned num_components = nir_dest_num_components(*dest);
812 unsigned bit_size = nir_dest_bit_size(*dest);
813
814 switch (nir_alu_type_get_base_type(type)) {
815 case nir_type_bool:
816 return get_bvec_type(ctx, num_components);
817
818 case nir_type_int:
819 return get_ivec_type(ctx, bit_size, num_components);
820
821 case nir_type_uint:
822 return get_uvec_type(ctx, bit_size, num_components);
823
824 case nir_type_float:
825 return get_fvec_type(ctx, bit_size, num_components);
826
827 default:
828 unreachable("unsupported nir_alu_type");
829 }
830 }
831
832 static void
833 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
834 {
835 SpvId src[nir_op_infos[alu->op].num_inputs];
836 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
837 src[i] = get_alu_src(ctx, alu, i);
838
839 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
840 nir_op_infos[alu->op].output_type);
841 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
842 unsigned num_components = nir_dest_num_components(alu->dest.dest);
843
844 SpvId result = 0;
845 switch (alu->op) {
846 case nir_op_mov:
847 assert(nir_op_infos[alu->op].num_inputs == 1);
848 result = src[0];
849 break;
850
851 #define UNOP(nir_op, spirv_op) \
852 case nir_op: \
853 assert(nir_op_infos[alu->op].num_inputs == 1); \
854 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
855 break;
856
857 UNOP(nir_op_ineg, SpvOpSNegate)
858 UNOP(nir_op_fneg, SpvOpFNegate)
859 UNOP(nir_op_fddx, SpvOpDPdx)
860 UNOP(nir_op_fddy, SpvOpDPdy)
861 UNOP(nir_op_f2i32, SpvOpConvertFToS)
862 UNOP(nir_op_f2u32, SpvOpConvertFToU)
863 UNOP(nir_op_i2f32, SpvOpConvertSToF)
864 UNOP(nir_op_u2f32, SpvOpConvertUToF)
865 UNOP(nir_op_inot, SpvOpNot)
866 #undef UNOP
867
868 case nir_op_b2i32:
869 assert(nir_op_infos[alu->op].num_inputs == 1);
870 result = emit_select(ctx, dest_type, src[0],
871 get_ivec_constant(ctx, 32, num_components, 1),
872 get_ivec_constant(ctx, 32, num_components, 0));
873 break;
874
875 case nir_op_b2f32:
876 assert(nir_op_infos[alu->op].num_inputs == 1);
877 result = emit_select(ctx, dest_type, src[0],
878 get_fvec_constant(ctx, 32, num_components, 1),
879 get_fvec_constant(ctx, 32, num_components, 0));
880 break;
881
882 #define BUILTIN_UNOP(nir_op, spirv_op) \
883 case nir_op: \
884 assert(nir_op_infos[alu->op].num_inputs == 1); \
885 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
886 break;
887
888 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
889 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
890 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
891 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
892 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
893 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
894 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
895 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
896 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
897 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
898 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
899 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
900 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
901 #undef BUILTIN_UNOP
902
903 case nir_op_frcp:
904 assert(nir_op_infos[alu->op].num_inputs == 1);
905 result = emit_binop(ctx, SpvOpFDiv, dest_type,
906 get_fvec_constant(ctx, bit_size, num_components, 1),
907 src[0]);
908 break;
909
910 case nir_op_f2b1:
911 assert(nir_op_infos[alu->op].num_inputs == 1);
912 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
913 get_fvec_constant(ctx,
914 nir_src_bit_size(alu->src[0].src),
915 num_components, 0));
916 break;
917
918
919 #define BINOP(nir_op, spirv_op) \
920 case nir_op: \
921 assert(nir_op_infos[alu->op].num_inputs == 2); \
922 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
923 break;
924
925 BINOP(nir_op_iadd, SpvOpIAdd)
926 BINOP(nir_op_isub, SpvOpISub)
927 BINOP(nir_op_imul, SpvOpIMul)
928 BINOP(nir_op_idiv, SpvOpSDiv)
929 BINOP(nir_op_udiv, SpvOpUDiv)
930 BINOP(nir_op_fadd, SpvOpFAdd)
931 BINOP(nir_op_fsub, SpvOpFSub)
932 BINOP(nir_op_fmul, SpvOpFMul)
933 BINOP(nir_op_fdiv, SpvOpFDiv)
934 BINOP(nir_op_fmod, SpvOpFMod)
935 BINOP(nir_op_ilt, SpvOpSLessThan)
936 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
937 BINOP(nir_op_ieq, SpvOpIEqual)
938 BINOP(nir_op_ine, SpvOpINotEqual)
939 BINOP(nir_op_flt, SpvOpFOrdLessThan)
940 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
941 BINOP(nir_op_feq, SpvOpFOrdEqual)
942 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
943 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
944 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
945 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
946 BINOP(nir_op_iand, SpvOpBitwiseAnd)
947 BINOP(nir_op_ior, SpvOpBitwiseOr)
948 #undef BINOP
949
950 #define BUILTIN_BINOP(nir_op, spirv_op) \
951 case nir_op: \
952 assert(nir_op_infos[alu->op].num_inputs == 2); \
953 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
954 break;
955
956 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
957 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
958 #undef BUILTIN_BINOP
959
960 case nir_op_fdot2:
961 case nir_op_fdot3:
962 case nir_op_fdot4:
963 assert(nir_op_infos[alu->op].num_inputs == 2);
964 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
965 break;
966
967 case nir_op_seq:
968 case nir_op_sne:
969 case nir_op_slt:
970 case nir_op_sge: {
971 assert(nir_op_infos[alu->op].num_inputs == 2);
972 int num_components = nir_dest_num_components(alu->dest.dest);
973 SpvId bool_type = get_bvec_type(ctx, num_components);
974
975 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
976 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
977 if (num_components > 1) {
978 SpvId zero_comps[num_components], one_comps[num_components];
979 for (int i = 0; i < num_components; i++) {
980 zero_comps[i] = zero;
981 one_comps[i] = one;
982 }
983
984 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
985 zero_comps, num_components);
986 one = spirv_builder_const_composite(&ctx->builder, dest_type,
987 one_comps, num_components);
988 }
989
990 SpvOp op;
991 switch (alu->op) {
992 case nir_op_seq: op = SpvOpFOrdEqual; break;
993 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
994 case nir_op_slt: op = SpvOpFOrdLessThan; break;
995 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
996 default: unreachable("unexpected op");
997 }
998
999 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1000 result = emit_select(ctx, dest_type, result, one, zero);
1001 }
1002 break;
1003
1004 case nir_op_fcsel:
1005 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1006 get_bvec_type(ctx, num_components),
1007 src[0],
1008 get_fvec_constant(ctx,
1009 nir_src_bit_size(alu->src[0].src),
1010 num_components, 0));
1011 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1012 break;
1013
1014 case nir_op_bcsel:
1015 assert(nir_op_infos[alu->op].num_inputs == 3);
1016 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1017 break;
1018
1019 case nir_op_vec2:
1020 case nir_op_vec3:
1021 case nir_op_vec4: {
1022 int num_inputs = nir_op_infos[alu->op].num_inputs;
1023 assert(2 <= num_inputs && num_inputs <= 4);
1024 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1025 src, num_inputs);
1026 }
1027 break;
1028
1029 default:
1030 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1031 nir_op_infos[alu->op].name);
1032
1033 unreachable("unsupported opcode");
1034 return;
1035 }
1036
1037 store_alu_result(ctx, alu, result);
1038 }
1039
1040 static void
1041 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1042 {
1043 uint32_t values[NIR_MAX_VEC_COMPONENTS];
1044 for (int i = 0; i < load_const->def.num_components; ++i)
1045 values[i] = load_const->value[i].u32;
1046
1047 unsigned bit_size = load_const->def.bit_size;
1048 unsigned num_components = load_const->def.num_components;
1049
1050 SpvId constant;
1051 if (num_components > 1) {
1052 SpvId components[num_components];
1053 for (int i = 0; i < num_components; i++)
1054 components[i] = emit_uint_const(ctx, bit_size, values[i]);
1055
1056 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1057 constant = spirv_builder_const_composite(&ctx->builder, type,
1058 components, num_components);
1059 } else {
1060 assert(num_components == 1);
1061 constant = emit_uint_const(ctx, bit_size, values[0]);
1062 }
1063
1064 store_ssa_def_uint(ctx, &load_const->def, constant);
1065 }
1066
1067 static void
1068 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1069 {
1070 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1071 assert(const_block_index); // no dynamic indexing for now
1072 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1073
1074 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1075 if (const_offset) {
1076 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1077 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1078 SpvStorageClassUniform,
1079 uvec4_type);
1080
1081 unsigned idx = const_offset->u32;
1082 SpvId member = emit_uint_const(ctx, 32, 0);
1083 SpvId offset = emit_uint_const(ctx, 32, idx);
1084 SpvId offsets[] = { member, offset };
1085 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1086 ctx->ubos[0], offsets,
1087 ARRAY_SIZE(offsets));
1088 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1089
1090 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1091 unsigned num_components = nir_dest_num_components(intr->dest);
1092 if (num_components == 1) {
1093 uint32_t components[] = { 0 };
1094 result = spirv_builder_emit_composite_extract(&ctx->builder,
1095 type,
1096 result, components,
1097 1);
1098 } else if (num_components < 4) {
1099 SpvId constituents[num_components];
1100 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1101 for (uint32_t i = 0; i < num_components; ++i)
1102 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1103 uint_type,
1104 result, &i,
1105 1);
1106
1107 result = spirv_builder_emit_composite_construct(&ctx->builder,
1108 type,
1109 constituents,
1110 num_components);
1111 }
1112
1113 store_dest_uint(ctx, &intr->dest, result);
1114 } else
1115 unreachable("uniform-addressing not yet supported");
1116 }
1117
1118 static void
1119 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1120 {
1121 assert(ctx->block_started);
1122 spirv_builder_emit_kill(&ctx->builder);
1123 /* discard is weird in NIR, so let's just create an unreachable block after
1124 it and hope that the vulkan driver will DCE any instructinos in it. */
1125 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1126 }
1127
1128 static void
1129 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1130 {
1131 /* uint is a bit of a lie here; it's really just a pointer */
1132 SpvId ptr = get_src_uint(ctx, intr->src);
1133
1134 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1135 SpvId result = spirv_builder_emit_load(&ctx->builder,
1136 get_glsl_type(ctx, var->type),
1137 ptr);
1138 unsigned num_components = nir_dest_num_components(intr->dest);
1139 unsigned bit_size = nir_dest_bit_size(intr->dest);
1140 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1141 store_dest_uint(ctx, &intr->dest, result);
1142 }
1143
1144 static void
1145 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1146 {
1147 /* uint is a bit of a lie here; it's really just a pointer */
1148 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1149 SpvId src = get_src_uint(ctx, &intr->src[1]);
1150
1151 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1152 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1153 SpvId result = emit_bitcast(ctx, type, src);
1154 spirv_builder_emit_store(&ctx->builder, ptr, result);
1155 }
1156
1157 static void
1158 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1159 {
1160 SpvId var_type = get_glsl_type(ctx, glsl_bool_type());
1161 if (!ctx->front_face_var) {
1162 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1163 SpvStorageClassInput,
1164 var_type);
1165 ctx->front_face_var = spirv_builder_emit_var(&ctx->builder,
1166 pointer_type,
1167 SpvStorageClassInput);
1168 spirv_builder_emit_name(&ctx->builder, ctx->front_face_var,
1169 "gl_FrontFacing");
1170 spirv_builder_emit_builtin(&ctx->builder, ctx->front_face_var,
1171 SpvBuiltInFrontFacing);
1172
1173 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1174 ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->front_face_var;
1175 }
1176
1177 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1178 ctx->front_face_var);
1179 assert(1 == nir_dest_num_components(intr->dest));
1180 result = bvec_to_uvec(ctx, result, 1);
1181 store_dest_uint(ctx, &intr->dest, result);
1182 }
1183
1184 static void
1185 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1186 {
1187 switch (intr->intrinsic) {
1188 case nir_intrinsic_load_ubo:
1189 emit_load_ubo(ctx, intr);
1190 break;
1191
1192 case nir_intrinsic_discard:
1193 emit_discard(ctx, intr);
1194 break;
1195
1196 case nir_intrinsic_load_deref:
1197 emit_load_deref(ctx, intr);
1198 break;
1199
1200 case nir_intrinsic_store_deref:
1201 emit_store_deref(ctx, intr);
1202 break;
1203
1204 case nir_intrinsic_load_front_face:
1205 emit_load_front_face(ctx, intr);
1206 break;
1207
1208 default:
1209 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1210 nir_intrinsic_infos[intr->intrinsic].name);
1211 unreachable("unsupported intrinsic");
1212 }
1213 }
1214
1215 static void
1216 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1217 {
1218 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1219 undef->def.num_components);
1220
1221 store_ssa_def_uint(ctx, &undef->def,
1222 spirv_builder_emit_undef(&ctx->builder, type));
1223 }
1224
1225 static SpvId
1226 get_src_float(struct ntv_context *ctx, nir_src *src)
1227 {
1228 SpvId def = get_src_uint(ctx, src);
1229 unsigned num_components = nir_src_num_components(*src);
1230 unsigned bit_size = nir_src_bit_size(*src);
1231 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1232 }
1233
1234 static void
1235 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1236 {
1237 assert(tex->op == nir_texop_tex ||
1238 tex->op == nir_texop_txb ||
1239 tex->op == nir_texop_txl);
1240 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1241 assert(tex->texture_index == tex->sampler_index);
1242
1243 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1244 unsigned coord_components;
1245 for (unsigned i = 0; i < tex->num_srcs; i++) {
1246 switch (tex->src[i].src_type) {
1247 case nir_tex_src_coord:
1248 coord = get_src_float(ctx, &tex->src[i].src);
1249 coord_components = nir_src_num_components(tex->src[i].src);
1250 break;
1251
1252 case nir_tex_src_projector:
1253 assert(nir_src_num_components(tex->src[i].src) == 1);
1254 proj = get_src_float(ctx, &tex->src[i].src);
1255 assert(proj != 0);
1256 break;
1257
1258 case nir_tex_src_bias:
1259 assert(tex->op == nir_texop_txb);
1260 bias = get_src_float(ctx, &tex->src[i].src);
1261 assert(bias != 0);
1262 break;
1263
1264 case nir_tex_src_lod:
1265 assert(nir_src_num_components(tex->src[i].src) == 1);
1266 lod = get_src_float(ctx, &tex->src[i].src);
1267 assert(lod != 0);
1268 break;
1269
1270 case nir_tex_src_comparator:
1271 assert(nir_src_num_components(tex->src[i].src) == 1);
1272 dref = get_src_float(ctx, &tex->src[i].src);
1273 assert(dref != 0);
1274 break;
1275
1276 default:
1277 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1278 unreachable("unknown texture source");
1279 }
1280 }
1281
1282 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1283 lod = emit_float_const(ctx, 32, 0.0f);
1284 assert(lod != 0);
1285 }
1286
1287 bool is_ms;
1288 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1289 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1290 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1291 dimension, false, tex->is_array, is_ms, 1,
1292 SpvImageFormatUnknown);
1293 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1294 image_type);
1295
1296 assert(tex->texture_index < ctx->num_samplers);
1297 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1298 ctx->samplers[tex->texture_index]);
1299
1300 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1301
1302 if (proj) {
1303 SpvId constituents[coord_components + 1];
1304 if (coord_components == 1)
1305 constituents[0] = coord;
1306 else {
1307 assert(coord_components > 1);
1308 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1309 for (uint32_t i = 0; i < coord_components; ++i)
1310 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1311 float_type,
1312 coord,
1313 &i, 1);
1314 }
1315
1316 constituents[coord_components++] = proj;
1317
1318 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1319 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1320 vec_type,
1321 constituents,
1322 coord_components);
1323 }
1324
1325 SpvId actual_dest_type = dest_type;
1326 if (dref)
1327 actual_dest_type = float_type;
1328
1329 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1330 actual_dest_type, load,
1331 coord,
1332 proj != 0,
1333 lod, bias, dref);
1334 spirv_builder_emit_decoration(&ctx->builder, result,
1335 SpvDecorationRelaxedPrecision);
1336
1337 if (dref) {
1338 SpvId components[4] = { result, result, result, result };
1339 result = spirv_builder_emit_composite_construct(&ctx->builder,
1340 dest_type,
1341 components,
1342 4);
1343 }
1344
1345 store_dest(ctx, &tex->dest, result, tex->dest_type);
1346 }
1347
1348 static void
1349 start_block(struct ntv_context *ctx, SpvId label)
1350 {
1351 /* terminate previous block if needed */
1352 if (ctx->block_started)
1353 spirv_builder_emit_branch(&ctx->builder, label);
1354
1355 /* start new block */
1356 spirv_builder_label(&ctx->builder, label);
1357 ctx->block_started = true;
1358 }
1359
1360 static void
1361 branch(struct ntv_context *ctx, SpvId label)
1362 {
1363 assert(ctx->block_started);
1364 spirv_builder_emit_branch(&ctx->builder, label);
1365 ctx->block_started = false;
1366 }
1367
1368 static void
1369 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1370 SpvId else_id)
1371 {
1372 assert(ctx->block_started);
1373 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1374 then_id, else_id);
1375 ctx->block_started = false;
1376 }
1377
1378 static void
1379 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1380 {
1381 switch (jump->type) {
1382 case nir_jump_break:
1383 assert(ctx->loop_break);
1384 branch(ctx, ctx->loop_break);
1385 break;
1386
1387 case nir_jump_continue:
1388 assert(ctx->loop_cont);
1389 branch(ctx, ctx->loop_cont);
1390 break;
1391
1392 default:
1393 unreachable("Unsupported jump type\n");
1394 }
1395 }
1396
1397 static void
1398 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1399 {
1400 assert(deref->deref_type == nir_deref_type_var);
1401
1402 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1403 assert(he);
1404 SpvId result = (SpvId)(intptr_t)he->data;
1405 /* uint is a bit of a lie here, it's really just an opaque type */
1406 store_dest_uint(ctx, &deref->dest, result);
1407 }
1408
1409 static void
1410 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1411 {
1412 assert(deref->deref_type == nir_deref_type_array);
1413 nir_variable *var = nir_deref_instr_get_variable(deref);
1414
1415 SpvStorageClass storage_class;
1416 switch (var->data.mode) {
1417 case nir_var_shader_in:
1418 storage_class = SpvStorageClassInput;
1419 break;
1420
1421 case nir_var_shader_out:
1422 storage_class = SpvStorageClassOutput;
1423 break;
1424
1425 default:
1426 unreachable("Unsupported nir_variable_mode\n");
1427 }
1428
1429 SpvId index = get_src_uint(ctx, &deref->arr.index);
1430
1431 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1432 storage_class,
1433 get_glsl_type(ctx, deref->type));
1434
1435 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1436 ptr_type,
1437 get_src_uint(ctx, &deref->parent),
1438 &index, 1);
1439 /* uint is a bit of a lie here, it's really just an opaque type */
1440 store_dest_uint(ctx, &deref->dest, result);
1441 }
1442
1443 static void
1444 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1445 {
1446 switch (deref->deref_type) {
1447 case nir_deref_type_var:
1448 emit_deref_var(ctx, deref);
1449 break;
1450
1451 case nir_deref_type_array:
1452 emit_deref_array(ctx, deref);
1453 break;
1454
1455 default:
1456 unreachable("unexpected deref_type");
1457 }
1458 }
1459
1460 static void
1461 emit_block(struct ntv_context *ctx, struct nir_block *block)
1462 {
1463 start_block(ctx, block_label(ctx, block));
1464 nir_foreach_instr(instr, block) {
1465 switch (instr->type) {
1466 case nir_instr_type_alu:
1467 emit_alu(ctx, nir_instr_as_alu(instr));
1468 break;
1469 case nir_instr_type_intrinsic:
1470 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1471 break;
1472 case nir_instr_type_load_const:
1473 emit_load_const(ctx, nir_instr_as_load_const(instr));
1474 break;
1475 case nir_instr_type_ssa_undef:
1476 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1477 break;
1478 case nir_instr_type_tex:
1479 emit_tex(ctx, nir_instr_as_tex(instr));
1480 break;
1481 case nir_instr_type_phi:
1482 unreachable("nir_instr_type_phi not supported");
1483 break;
1484 case nir_instr_type_jump:
1485 emit_jump(ctx, nir_instr_as_jump(instr));
1486 break;
1487 case nir_instr_type_call:
1488 unreachable("nir_instr_type_call not supported");
1489 break;
1490 case nir_instr_type_parallel_copy:
1491 unreachable("nir_instr_type_parallel_copy not supported");
1492 break;
1493 case nir_instr_type_deref:
1494 emit_deref(ctx, nir_instr_as_deref(instr));
1495 break;
1496 }
1497 }
1498 }
1499
1500 static void
1501 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1502
1503 static SpvId
1504 get_src_bool(struct ntv_context *ctx, nir_src *src)
1505 {
1506 SpvId def = get_src_uint(ctx, src);
1507 assert(nir_src_bit_size(*src) == 1);
1508 unsigned num_components = nir_src_num_components(*src);
1509 return uvec_to_bvec(ctx, def, num_components);
1510 }
1511
1512 static void
1513 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1514 {
1515 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1516
1517 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1518 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1519 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1520 SpvId else_id = endif_id;
1521
1522 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1523 if (has_else) {
1524 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1525 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1526 }
1527
1528 /* create a header-block */
1529 start_block(ctx, header_id);
1530 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1531 SpvSelectionControlMaskNone);
1532 branch_conditional(ctx, condition, then_id, else_id);
1533
1534 emit_cf_list(ctx, &if_stmt->then_list);
1535
1536 if (has_else) {
1537 if (ctx->block_started)
1538 branch(ctx, endif_id);
1539
1540 emit_cf_list(ctx, &if_stmt->else_list);
1541 }
1542
1543 start_block(ctx, endif_id);
1544 }
1545
1546 static void
1547 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1548 {
1549 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1550 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1551 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1552 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1553
1554 /* create a header-block */
1555 start_block(ctx, header_id);
1556 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1557 branch(ctx, begin_id);
1558
1559 SpvId save_break = ctx->loop_break;
1560 SpvId save_cont = ctx->loop_cont;
1561 ctx->loop_break = break_id;
1562 ctx->loop_cont = cont_id;
1563
1564 emit_cf_list(ctx, &loop->body);
1565
1566 ctx->loop_break = save_break;
1567 ctx->loop_cont = save_cont;
1568
1569 branch(ctx, cont_id);
1570 start_block(ctx, cont_id);
1571 branch(ctx, header_id);
1572
1573 start_block(ctx, break_id);
1574 }
1575
1576 static void
1577 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1578 {
1579 foreach_list_typed(nir_cf_node, node, node, list) {
1580 switch (node->type) {
1581 case nir_cf_node_block:
1582 emit_block(ctx, nir_cf_node_as_block(node));
1583 break;
1584
1585 case nir_cf_node_if:
1586 emit_if(ctx, nir_cf_node_as_if(node));
1587 break;
1588
1589 case nir_cf_node_loop:
1590 emit_loop(ctx, nir_cf_node_as_loop(node));
1591 break;
1592
1593 case nir_cf_node_function:
1594 unreachable("nir_cf_node_function not supported");
1595 break;
1596 }
1597 }
1598 }
1599
1600 struct spirv_shader *
1601 nir_to_spirv(struct nir_shader *s)
1602 {
1603 struct spirv_shader *ret = NULL;
1604
1605 struct ntv_context ctx = {};
1606
1607 switch (s->info.stage) {
1608 case MESA_SHADER_VERTEX:
1609 case MESA_SHADER_FRAGMENT:
1610 case MESA_SHADER_COMPUTE:
1611 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1612 break;
1613
1614 case MESA_SHADER_TESS_CTRL:
1615 case MESA_SHADER_TESS_EVAL:
1616 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1617 break;
1618
1619 case MESA_SHADER_GEOMETRY:
1620 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1621 break;
1622
1623 default:
1624 unreachable("invalid stage");
1625 }
1626
1627 // TODO: only enable when needed
1628 if (s->info.stage == MESA_SHADER_FRAGMENT)
1629 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1630
1631 ctx.stage = s->info.stage;
1632 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1633 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1634
1635 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1636 SpvMemoryModelGLSL450);
1637
1638 SpvExecutionModel exec_model;
1639 switch (s->info.stage) {
1640 case MESA_SHADER_VERTEX:
1641 exec_model = SpvExecutionModelVertex;
1642 break;
1643 case MESA_SHADER_TESS_CTRL:
1644 exec_model = SpvExecutionModelTessellationControl;
1645 break;
1646 case MESA_SHADER_TESS_EVAL:
1647 exec_model = SpvExecutionModelTessellationEvaluation;
1648 break;
1649 case MESA_SHADER_GEOMETRY:
1650 exec_model = SpvExecutionModelGeometry;
1651 break;
1652 case MESA_SHADER_FRAGMENT:
1653 exec_model = SpvExecutionModelFragment;
1654 break;
1655 case MESA_SHADER_COMPUTE:
1656 exec_model = SpvExecutionModelGLCompute;
1657 break;
1658 default:
1659 unreachable("invalid stage");
1660 }
1661
1662 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1663 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1664 NULL, 0);
1665 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1666 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1667
1668 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1669 _mesa_key_pointer_equal);
1670
1671 nir_foreach_variable(var, &s->inputs)
1672 emit_input(&ctx, var);
1673
1674 nir_foreach_variable(var, &s->outputs)
1675 emit_output(&ctx, var);
1676
1677 nir_foreach_variable(var, &s->uniforms)
1678 emit_uniform(&ctx, var);
1679
1680 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1681 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1682 SpvExecutionModeOriginUpperLeft);
1683 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1684 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1685 SpvExecutionModeDepthReplacing);
1686 }
1687
1688
1689 spirv_builder_function(&ctx.builder, entry_point, type_void,
1690 SpvFunctionControlMaskNone,
1691 type_main);
1692
1693 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1694 nir_metadata_require(entry, nir_metadata_block_index);
1695
1696 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1697 if (!ctx.defs)
1698 goto fail;
1699 ctx.num_defs = entry->ssa_alloc;
1700
1701 nir_index_local_regs(entry);
1702 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1703 if (!ctx.regs)
1704 goto fail;
1705 ctx.num_regs = entry->reg_alloc;
1706
1707 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1708 if (!block_ids)
1709 goto fail;
1710
1711 for (int i = 0; i < entry->num_blocks; ++i)
1712 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1713
1714 ctx.block_ids = block_ids;
1715 ctx.num_blocks = entry->num_blocks;
1716
1717 /* emit a block only for the variable declarations */
1718 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1719 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1720 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1721 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1722 SpvStorageClassFunction,
1723 type);
1724 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1725 SpvStorageClassFunction);
1726
1727 ctx.regs[reg->index] = var;
1728 }
1729
1730 emit_cf_list(&ctx, &entry->body);
1731
1732 free(ctx.defs);
1733
1734 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1735 spirv_builder_function_end(&ctx.builder);
1736
1737 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1738 "main", ctx.entry_ifaces,
1739 ctx.num_entry_ifaces);
1740
1741 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1742
1743 ret = CALLOC_STRUCT(spirv_shader);
1744 if (!ret)
1745 goto fail;
1746
1747 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1748 if (!ret->words)
1749 goto fail;
1750
1751 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1752 assert(ret->num_words == num_words);
1753
1754 return ret;
1755
1756 fail:
1757
1758 if (ret)
1759 spirv_shader_delete(ret);
1760
1761 if (ctx.vars)
1762 _mesa_hash_table_destroy(ctx.vars, NULL);
1763
1764 return NULL;
1765 }
1766
1767 void
1768 spirv_shader_delete(struct spirv_shader *s)
1769 {
1770 FREE(s->words);
1771 FREE(s);
1772 }