zink: lower two-sided coloring
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId samplers[PIPE_MAX_SAMPLERS];
42 size_t num_samplers;
43 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
44 size_t num_entry_ifaces;
45
46 SpvId *defs;
47 size_t num_defs;
48
49 SpvId *regs;
50 size_t num_regs;
51
52 struct hash_table *vars; /* nir_variable -> SpvId */
53
54 const SpvId *block_ids;
55 size_t num_blocks;
56 bool block_started;
57 SpvId loop_break, loop_cont;
58
59 SpvId front_face_var;
60 };
61
62 static SpvId
63 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
64 unsigned num_components, float value);
65
66 static SpvId
67 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
68 unsigned num_components, uint32_t value);
69
70 static SpvId
71 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
72 unsigned num_components, int32_t value);
73
74 static SpvId
75 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
76
77 static SpvId
78 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
79 SpvId src0, SpvId src1);
80
81 static SpvId
82 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
83 SpvId src0, SpvId src1, SpvId src2);
84
85 static SpvId
86 get_bvec_type(struct ntv_context *ctx, int num_components)
87 {
88 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
89 if (num_components > 1)
90 return spirv_builder_type_vector(&ctx->builder, bool_type,
91 num_components);
92
93 assert(num_components == 1);
94 return bool_type;
95 }
96
97 static SpvId
98 block_label(struct ntv_context *ctx, nir_block *block)
99 {
100 assert(block->index < ctx->num_blocks);
101 return ctx->block_ids[block->index];
102 }
103
104 static SpvId
105 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
106 {
107 assert(bit_size == 32);
108 return spirv_builder_const_float(&ctx->builder, bit_size, value);
109 }
110
111 static SpvId
112 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
113 {
114 assert(bit_size == 32);
115 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
116 }
117
118 static SpvId
119 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
120 {
121 assert(bit_size == 32);
122 return spirv_builder_const_int(&ctx->builder, bit_size, value);
123 }
124
125 static SpvId
126 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
127 {
128 assert(bit_size == 32); // only 32-bit floats supported so far
129
130 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
131 if (num_components > 1)
132 return spirv_builder_type_vector(&ctx->builder, float_type,
133 num_components);
134
135 assert(num_components == 1);
136 return float_type;
137 }
138
139 static SpvId
140 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
141 {
142 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
143
144 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
145 if (num_components > 1)
146 return spirv_builder_type_vector(&ctx->builder, int_type,
147 num_components);
148
149 assert(num_components == 1);
150 return int_type;
151 }
152
153 static SpvId
154 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
155 {
156 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
157
158 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
159 if (num_components > 1)
160 return spirv_builder_type_vector(&ctx->builder, uint_type,
161 num_components);
162
163 assert(num_components == 1);
164 return uint_type;
165 }
166
167 static SpvId
168 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
169 {
170 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
171 nir_dest_num_components(*dest));
172 }
173
174 static SpvId
175 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
176 {
177 switch (type) {
178 case GLSL_TYPE_BOOL:
179 return spirv_builder_type_bool(&ctx->builder);
180
181 case GLSL_TYPE_FLOAT:
182 return spirv_builder_type_float(&ctx->builder, 32);
183
184 case GLSL_TYPE_INT:
185 return spirv_builder_type_int(&ctx->builder, 32);
186
187 case GLSL_TYPE_UINT:
188 return spirv_builder_type_uint(&ctx->builder, 32);
189 /* TODO: handle more types */
190
191 default:
192 unreachable("unknown GLSL type");
193 }
194 }
195
196 static SpvId
197 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
198 {
199 assert(type);
200 if (glsl_type_is_scalar(type))
201 return get_glsl_basetype(ctx, glsl_get_base_type(type));
202
203 if (glsl_type_is_vector(type))
204 return spirv_builder_type_vector(&ctx->builder,
205 get_glsl_basetype(ctx, glsl_get_base_type(type)),
206 glsl_get_vector_elements(type));
207
208 if (glsl_type_is_array(type)) {
209 SpvId ret = spirv_builder_type_array(&ctx->builder,
210 get_glsl_type(ctx, glsl_get_array_element(type)),
211 emit_uint_const(ctx, 32, glsl_get_length(type)));
212 uint32_t stride = glsl_get_explicit_stride(type);
213 if (stride)
214 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
215 return ret;
216 }
217
218
219 unreachable("we shouldn't get here, I think...");
220 }
221
222 static void
223 emit_input(struct ntv_context *ctx, struct nir_variable *var)
224 {
225 SpvId var_type = get_glsl_type(ctx, var->type);
226 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
227 SpvStorageClassInput,
228 var_type);
229 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
230 SpvStorageClassInput);
231
232 if (var->name)
233 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
234
235 if (ctx->stage == MESA_SHADER_FRAGMENT) {
236 if (var->data.location >= VARYING_SLOT_VAR0)
237 spirv_builder_emit_location(&ctx->builder, var_id,
238 var->data.location - VARYING_SLOT_VAR0);
239 else if ((var->data.location >= VARYING_SLOT_COL0 &&
240 var->data.location <= VARYING_SLOT_TEX7) ||
241 var->data.location == VARYING_SLOT_BFC0 ||
242 var->data.location == VARYING_SLOT_BFC1) {
243 spirv_builder_emit_location(&ctx->builder, var_id,
244 var->data.location);
245 } else {
246 switch (var->data.location) {
247 case VARYING_SLOT_POS:
248 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
249 break;
250
251 case VARYING_SLOT_PNTC:
252 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
253 break;
254
255 default:
256 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
257 unreachable("unexpected varying slot");
258 }
259 }
260 } else {
261 spirv_builder_emit_location(&ctx->builder, var_id,
262 var->data.driver_location);
263 }
264
265 if (var->data.location_frac)
266 spirv_builder_emit_component(&ctx->builder, var_id,
267 var->data.location_frac);
268
269 if (var->data.interpolation == INTERP_MODE_FLAT)
270 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
271
272 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
273
274 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
275 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
276 }
277
278 static void
279 emit_output(struct ntv_context *ctx, struct nir_variable *var)
280 {
281 SpvId var_type = get_glsl_type(ctx, var->type);
282 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
283 SpvStorageClassOutput,
284 var_type);
285 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
286 SpvStorageClassOutput);
287 if (var->name)
288 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
289
290
291 if (ctx->stage == MESA_SHADER_VERTEX) {
292 if (var->data.location >= VARYING_SLOT_VAR0)
293 spirv_builder_emit_location(&ctx->builder, var_id,
294 var->data.location - VARYING_SLOT_VAR0);
295 else if ((var->data.location >= VARYING_SLOT_COL0 &&
296 var->data.location <= VARYING_SLOT_TEX7) ||
297 var->data.location == VARYING_SLOT_BFC0 ||
298 var->data.location == VARYING_SLOT_BFC1) {
299 spirv_builder_emit_location(&ctx->builder, var_id,
300 var->data.location);
301 } else {
302 switch (var->data.location) {
303 case VARYING_SLOT_POS:
304 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
305 break;
306
307 case VARYING_SLOT_PSIZ:
308 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
309 break;
310
311 case VARYING_SLOT_CLIP_DIST0:
312 assert(glsl_type_is_array(var->type));
313 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
314 break;
315
316 default:
317 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
318 unreachable("unexpected varying slot");
319 }
320 }
321 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
322 if (var->data.location >= FRAG_RESULT_DATA0)
323 spirv_builder_emit_location(&ctx->builder, var_id,
324 var->data.location - FRAG_RESULT_DATA0);
325 else {
326 switch (var->data.location) {
327 case FRAG_RESULT_COLOR:
328 spirv_builder_emit_location(&ctx->builder, var_id, 0);
329 break;
330
331 case FRAG_RESULT_DEPTH:
332 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
333 break;
334
335 default:
336 spirv_builder_emit_location(&ctx->builder, var_id,
337 var->data.driver_location);
338 }
339 }
340 }
341
342 if (var->data.location_frac)
343 spirv_builder_emit_component(&ctx->builder, var_id,
344 var->data.location_frac);
345
346 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
347
348 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
349 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
350 }
351
352 static SpvDim
353 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
354 {
355 *is_ms = false;
356 switch (gdim) {
357 case GLSL_SAMPLER_DIM_1D:
358 return SpvDim1D;
359 case GLSL_SAMPLER_DIM_2D:
360 return SpvDim2D;
361 case GLSL_SAMPLER_DIM_RECT:
362 return SpvDimRect;
363 case GLSL_SAMPLER_DIM_CUBE:
364 return SpvDimCube;
365 case GLSL_SAMPLER_DIM_3D:
366 return SpvDim3D;
367 case GLSL_SAMPLER_DIM_MS:
368 *is_ms = true;
369 return SpvDim2D;
370 default:
371 fprintf(stderr, "unknown sampler type %d\n", gdim);
372 break;
373 }
374 return SpvDim2D;
375 }
376
377 static void
378 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
379 {
380 bool is_ms;
381 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
382 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
383 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
384 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
385 SpvImageFormatUnknown);
386
387 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
388 image_type);
389 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
390 SpvStorageClassUniformConstant,
391 sampled_type);
392 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
393 SpvStorageClassUniformConstant);
394
395 if (var->name)
396 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
397
398 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
399 ctx->samplers[ctx->num_samplers++] = var_id;
400
401 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
402 var->data.descriptor_set);
403 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
404 }
405
406 static void
407 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
408 {
409 uint32_t size = glsl_count_attribute_slots(var->type, false);
410 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
411 SpvId array_length = emit_uint_const(ctx, 32, size);
412 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
413 array_length);
414 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
415
416 // wrap UBO-array in a struct
417 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
418 if (var->name) {
419 char struct_name[100];
420 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
421 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
422 }
423
424 spirv_builder_emit_decoration(&ctx->builder, struct_type,
425 SpvDecorationBlock);
426 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
427
428
429 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
430 SpvStorageClassUniform,
431 struct_type);
432
433 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
434 SpvStorageClassUniform);
435 if (var->name)
436 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
437
438 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
439 ctx->ubos[ctx->num_ubos++] = var_id;
440
441 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
442 var->data.descriptor_set);
443 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
444 }
445
446 static void
447 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
448 {
449 if (var->data.mode == nir_var_mem_ubo)
450 emit_ubo(ctx, var);
451 else {
452 assert(var->data.mode == nir_var_uniform);
453 if (glsl_type_is_sampler(var->type))
454 emit_sampler(ctx, var);
455 }
456 }
457
458 static SpvId
459 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
460 {
461 assert(ssa->index < ctx->num_defs);
462 assert(ctx->defs[ssa->index] != 0);
463 return ctx->defs[ssa->index];
464 }
465
466 static SpvId
467 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
468 {
469 assert(reg->index < ctx->num_regs);
470 assert(ctx->regs[reg->index] != 0);
471 return ctx->regs[reg->index];
472 }
473
474 static SpvId
475 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
476 {
477 assert(reg->reg);
478 assert(!reg->indirect);
479 assert(!reg->base_offset);
480
481 SpvId var = get_var_from_reg(ctx, reg->reg);
482 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
483 return spirv_builder_emit_load(&ctx->builder, type, var);
484 }
485
486 static SpvId
487 get_src_uint(struct ntv_context *ctx, nir_src *src)
488 {
489 if (src->is_ssa)
490 return get_src_uint_ssa(ctx, src->ssa);
491 else
492 return get_src_uint_reg(ctx, &src->reg);
493 }
494
495 static SpvId
496 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
497 {
498 assert(!alu->src[src].negate);
499 assert(!alu->src[src].abs);
500
501 SpvId def = get_src_uint(ctx, &alu->src[src].src);
502
503 unsigned used_channels = 0;
504 bool need_swizzle = false;
505 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
506 if (!nir_alu_instr_channel_used(alu, src, i))
507 continue;
508
509 used_channels++;
510
511 if (alu->src[src].swizzle[i] != i)
512 need_swizzle = true;
513 }
514 assert(used_channels != 0);
515
516 unsigned live_channels = nir_src_num_components(alu->src[src].src);
517 if (used_channels != live_channels)
518 need_swizzle = true;
519
520 if (!need_swizzle)
521 return def;
522
523 int bit_size = nir_src_bit_size(alu->src[src].src);
524 assert(bit_size == 1 || bit_size == 32);
525
526 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
527 if (used_channels == 1) {
528 uint32_t indices[] = { alu->src[src].swizzle[0] };
529 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
530 def, indices,
531 ARRAY_SIZE(indices));
532 } else if (live_channels == 1) {
533 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
534 used_channels);
535
536 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
537 for (unsigned i = 0; i < used_channels; ++i)
538 constituents[i] = def;
539
540 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
541 constituents,
542 used_channels);
543 } else {
544 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
545 used_channels);
546
547 uint32_t components[NIR_MAX_VEC_COMPONENTS];
548 size_t num_components = 0;
549 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
550 if (!nir_alu_instr_channel_used(alu, src, i))
551 continue;
552
553 components[num_components++] = alu->src[src].swizzle[i];
554 }
555
556 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
557 def, def, components, num_components);
558 }
559 }
560
561 static void
562 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
563 {
564 assert(result != 0);
565 assert(ssa->index < ctx->num_defs);
566 ctx->defs[ssa->index] = result;
567 }
568
569 static SpvId
570 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
571 SpvId if_true, SpvId if_false)
572 {
573 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
574 }
575
576 static SpvId
577 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
578 {
579 SpvId otype = get_uvec_type(ctx, 32, num_components);
580 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
581 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
582 return emit_select(ctx, otype, value, one, zero);
583 }
584
585 static SpvId
586 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
587 {
588 SpvId type = get_bvec_type(ctx, num_components);
589 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
590 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
591 }
592
593 static SpvId
594 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
595 {
596 return emit_unop(ctx, SpvOpBitcast, type, value);
597 }
598
599 static SpvId
600 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
601 unsigned num_components)
602 {
603 SpvId type = get_uvec_type(ctx, bit_size, num_components);
604 return emit_bitcast(ctx, type, value);
605 }
606
607 static SpvId
608 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
609 unsigned num_components)
610 {
611 SpvId type = get_ivec_type(ctx, bit_size, num_components);
612 return emit_bitcast(ctx, type, value);
613 }
614
615 static SpvId
616 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
617 unsigned num_components)
618 {
619 SpvId type = get_fvec_type(ctx, bit_size, num_components);
620 return emit_bitcast(ctx, type, value);
621 }
622
623 static void
624 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
625 {
626 SpvId var = get_var_from_reg(ctx, reg->reg);
627 assert(var);
628 spirv_builder_emit_store(&ctx->builder, var, result);
629 }
630
631 static void
632 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
633 {
634 if (dest->is_ssa)
635 store_ssa_def_uint(ctx, &dest->ssa, result);
636 else
637 store_reg_def(ctx, &dest->reg, result);
638 }
639
640 static void
641 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
642 {
643 unsigned num_components = nir_dest_num_components(*dest);
644 unsigned bit_size = nir_dest_bit_size(*dest);
645
646 switch (nir_alu_type_get_base_type(type)) {
647 case nir_type_bool:
648 assert(bit_size == 1);
649 result = bvec_to_uvec(ctx, result, num_components);
650 break;
651
652 case nir_type_uint:
653 break; /* nothing to do! */
654
655 case nir_type_int:
656 case nir_type_float:
657 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
658 break;
659
660 default:
661 unreachable("unsupported nir_alu_type");
662 }
663
664 store_dest_uint(ctx, dest, result);
665 }
666
667 static SpvId
668 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
669 {
670 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
671 }
672
673 static SpvId
674 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
675 SpvId src0, SpvId src1)
676 {
677 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
678 }
679
680 static SpvId
681 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
682 SpvId src0, SpvId src1, SpvId src2)
683 {
684 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
685 }
686
687 static SpvId
688 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
689 SpvId src)
690 {
691 SpvId args[] = { src };
692 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
693 op, args, ARRAY_SIZE(args));
694 }
695
696 static SpvId
697 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
698 SpvId src0, SpvId src1)
699 {
700 SpvId args[] = { src0, src1 };
701 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
702 op, args, ARRAY_SIZE(args));
703 }
704
705 static SpvId
706 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
707 unsigned num_components, float value)
708 {
709 assert(bit_size == 32);
710
711 SpvId result = emit_float_const(ctx, bit_size, value);
712 if (num_components == 1)
713 return result;
714
715 assert(num_components > 1);
716 SpvId components[num_components];
717 for (int i = 0; i < num_components; i++)
718 components[i] = result;
719
720 SpvId type = get_fvec_type(ctx, bit_size, num_components);
721 return spirv_builder_const_composite(&ctx->builder, type, components,
722 num_components);
723 }
724
725 static SpvId
726 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
727 unsigned num_components, uint32_t value)
728 {
729 assert(bit_size == 32);
730
731 SpvId result = emit_uint_const(ctx, bit_size, value);
732 if (num_components == 1)
733 return result;
734
735 assert(num_components > 1);
736 SpvId components[num_components];
737 for (int i = 0; i < num_components; i++)
738 components[i] = result;
739
740 SpvId type = get_uvec_type(ctx, bit_size, num_components);
741 return spirv_builder_const_composite(&ctx->builder, type, components,
742 num_components);
743 }
744
745 static SpvId
746 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
747 unsigned num_components, int32_t value)
748 {
749 assert(bit_size == 32);
750
751 SpvId result = emit_int_const(ctx, bit_size, value);
752 if (num_components == 1)
753 return result;
754
755 assert(num_components > 1);
756 SpvId components[num_components];
757 for (int i = 0; i < num_components; i++)
758 components[i] = result;
759
760 SpvId type = get_ivec_type(ctx, bit_size, num_components);
761 return spirv_builder_const_composite(&ctx->builder, type, components,
762 num_components);
763 }
764
765 static inline unsigned
766 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
767 {
768 if (nir_op_infos[instr->op].input_sizes[src] > 0)
769 return nir_op_infos[instr->op].input_sizes[src];
770
771 if (instr->dest.dest.is_ssa)
772 return instr->dest.dest.ssa.num_components;
773 else
774 return instr->dest.dest.reg.reg->num_components;
775 }
776
777 static SpvId
778 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
779 {
780 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
781
782 unsigned num_components = alu_instr_src_components(alu, src);
783 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
784 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
785
786 switch (nir_alu_type_get_base_type(type)) {
787 case nir_type_bool:
788 assert(bit_size == 1);
789 return uvec_to_bvec(ctx, uint_value, num_components);
790
791 case nir_type_int:
792 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
793
794 case nir_type_uint:
795 return uint_value;
796
797 case nir_type_float:
798 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
799
800 default:
801 unreachable("unknown nir_alu_type");
802 }
803 }
804
805 static void
806 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
807 {
808 assert(!alu->dest.saturate);
809 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
810 }
811
812 static SpvId
813 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
814 {
815 unsigned num_components = nir_dest_num_components(*dest);
816 unsigned bit_size = nir_dest_bit_size(*dest);
817
818 switch (nir_alu_type_get_base_type(type)) {
819 case nir_type_bool:
820 return get_bvec_type(ctx, num_components);
821
822 case nir_type_int:
823 return get_ivec_type(ctx, bit_size, num_components);
824
825 case nir_type_uint:
826 return get_uvec_type(ctx, bit_size, num_components);
827
828 case nir_type_float:
829 return get_fvec_type(ctx, bit_size, num_components);
830
831 default:
832 unreachable("unsupported nir_alu_type");
833 }
834 }
835
836 static void
837 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
838 {
839 SpvId src[nir_op_infos[alu->op].num_inputs];
840 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
841 src[i] = get_alu_src(ctx, alu, i);
842
843 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
844 nir_op_infos[alu->op].output_type);
845 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
846 unsigned num_components = nir_dest_num_components(alu->dest.dest);
847
848 SpvId result = 0;
849 switch (alu->op) {
850 case nir_op_mov:
851 assert(nir_op_infos[alu->op].num_inputs == 1);
852 result = src[0];
853 break;
854
855 #define UNOP(nir_op, spirv_op) \
856 case nir_op: \
857 assert(nir_op_infos[alu->op].num_inputs == 1); \
858 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
859 break;
860
861 UNOP(nir_op_ineg, SpvOpSNegate)
862 UNOP(nir_op_fneg, SpvOpFNegate)
863 UNOP(nir_op_fddx, SpvOpDPdx)
864 UNOP(nir_op_fddy, SpvOpDPdy)
865 UNOP(nir_op_f2i32, SpvOpConvertFToS)
866 UNOP(nir_op_f2u32, SpvOpConvertFToU)
867 UNOP(nir_op_i2f32, SpvOpConvertSToF)
868 UNOP(nir_op_u2f32, SpvOpConvertUToF)
869 UNOP(nir_op_inot, SpvOpNot)
870 #undef UNOP
871
872 case nir_op_b2i32:
873 assert(nir_op_infos[alu->op].num_inputs == 1);
874 result = emit_select(ctx, dest_type, src[0],
875 get_ivec_constant(ctx, 32, num_components, 1),
876 get_ivec_constant(ctx, 32, num_components, 0));
877 break;
878
879 case nir_op_b2f32:
880 assert(nir_op_infos[alu->op].num_inputs == 1);
881 result = emit_select(ctx, dest_type, src[0],
882 get_fvec_constant(ctx, 32, num_components, 1),
883 get_fvec_constant(ctx, 32, num_components, 0));
884 break;
885
886 #define BUILTIN_UNOP(nir_op, spirv_op) \
887 case nir_op: \
888 assert(nir_op_infos[alu->op].num_inputs == 1); \
889 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
890 break;
891
892 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
893 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
894 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
895 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
896 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
897 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
898 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
899 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
900 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
901 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
902 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
903 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
904 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
905 #undef BUILTIN_UNOP
906
907 case nir_op_frcp:
908 assert(nir_op_infos[alu->op].num_inputs == 1);
909 result = emit_binop(ctx, SpvOpFDiv, dest_type,
910 get_fvec_constant(ctx, bit_size, num_components, 1),
911 src[0]);
912 break;
913
914 case nir_op_f2b1:
915 assert(nir_op_infos[alu->op].num_inputs == 1);
916 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
917 get_fvec_constant(ctx,
918 nir_src_bit_size(alu->src[0].src),
919 num_components, 0));
920 break;
921
922
923 #define BINOP(nir_op, spirv_op) \
924 case nir_op: \
925 assert(nir_op_infos[alu->op].num_inputs == 2); \
926 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
927 break;
928
929 BINOP(nir_op_iadd, SpvOpIAdd)
930 BINOP(nir_op_isub, SpvOpISub)
931 BINOP(nir_op_imul, SpvOpIMul)
932 BINOP(nir_op_idiv, SpvOpSDiv)
933 BINOP(nir_op_udiv, SpvOpUDiv)
934 BINOP(nir_op_fadd, SpvOpFAdd)
935 BINOP(nir_op_fsub, SpvOpFSub)
936 BINOP(nir_op_fmul, SpvOpFMul)
937 BINOP(nir_op_fdiv, SpvOpFDiv)
938 BINOP(nir_op_fmod, SpvOpFMod)
939 BINOP(nir_op_ilt, SpvOpSLessThan)
940 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
941 BINOP(nir_op_ieq, SpvOpIEqual)
942 BINOP(nir_op_ine, SpvOpINotEqual)
943 BINOP(nir_op_flt, SpvOpFOrdLessThan)
944 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
945 BINOP(nir_op_feq, SpvOpFOrdEqual)
946 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
947 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
948 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
949 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
950 BINOP(nir_op_iand, SpvOpBitwiseAnd)
951 BINOP(nir_op_ior, SpvOpBitwiseOr)
952 #undef BINOP
953
954 #define BUILTIN_BINOP(nir_op, spirv_op) \
955 case nir_op: \
956 assert(nir_op_infos[alu->op].num_inputs == 2); \
957 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
958 break;
959
960 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
961 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
962 #undef BUILTIN_BINOP
963
964 case nir_op_fdot2:
965 case nir_op_fdot3:
966 case nir_op_fdot4:
967 assert(nir_op_infos[alu->op].num_inputs == 2);
968 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
969 break;
970
971 case nir_op_seq:
972 case nir_op_sne:
973 case nir_op_slt:
974 case nir_op_sge: {
975 assert(nir_op_infos[alu->op].num_inputs == 2);
976 int num_components = nir_dest_num_components(alu->dest.dest);
977 SpvId bool_type = get_bvec_type(ctx, num_components);
978
979 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
980 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
981 if (num_components > 1) {
982 SpvId zero_comps[num_components], one_comps[num_components];
983 for (int i = 0; i < num_components; i++) {
984 zero_comps[i] = zero;
985 one_comps[i] = one;
986 }
987
988 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
989 zero_comps, num_components);
990 one = spirv_builder_const_composite(&ctx->builder, dest_type,
991 one_comps, num_components);
992 }
993
994 SpvOp op;
995 switch (alu->op) {
996 case nir_op_seq: op = SpvOpFOrdEqual; break;
997 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
998 case nir_op_slt: op = SpvOpFOrdLessThan; break;
999 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1000 default: unreachable("unexpected op");
1001 }
1002
1003 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1004 result = emit_select(ctx, dest_type, result, one, zero);
1005 }
1006 break;
1007
1008 case nir_op_fcsel:
1009 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1010 get_bvec_type(ctx, num_components),
1011 src[0],
1012 get_fvec_constant(ctx,
1013 nir_src_bit_size(alu->src[0].src),
1014 num_components, 0));
1015 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1016 break;
1017
1018 case nir_op_bcsel:
1019 assert(nir_op_infos[alu->op].num_inputs == 3);
1020 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1021 break;
1022
1023 case nir_op_vec2:
1024 case nir_op_vec3:
1025 case nir_op_vec4: {
1026 int num_inputs = nir_op_infos[alu->op].num_inputs;
1027 assert(2 <= num_inputs && num_inputs <= 4);
1028 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1029 src, num_inputs);
1030 }
1031 break;
1032
1033 default:
1034 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1035 nir_op_infos[alu->op].name);
1036
1037 unreachable("unsupported opcode");
1038 return;
1039 }
1040
1041 store_alu_result(ctx, alu, result);
1042 }
1043
1044 static void
1045 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1046 {
1047 uint32_t values[NIR_MAX_VEC_COMPONENTS];
1048 for (int i = 0; i < load_const->def.num_components; ++i)
1049 values[i] = load_const->value[i].u32;
1050
1051 unsigned bit_size = load_const->def.bit_size;
1052 unsigned num_components = load_const->def.num_components;
1053
1054 SpvId constant;
1055 if (num_components > 1) {
1056 SpvId components[num_components];
1057 for (int i = 0; i < num_components; i++)
1058 components[i] = emit_uint_const(ctx, bit_size, values[i]);
1059
1060 SpvId type = get_uvec_type(ctx, bit_size, num_components);
1061 constant = spirv_builder_const_composite(&ctx->builder, type,
1062 components, num_components);
1063 } else {
1064 assert(num_components == 1);
1065 constant = emit_uint_const(ctx, bit_size, values[0]);
1066 }
1067
1068 store_ssa_def_uint(ctx, &load_const->def, constant);
1069 }
1070
1071 static void
1072 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1073 {
1074 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1075 assert(const_block_index); // no dynamic indexing for now
1076 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1077
1078 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1079 if (const_offset) {
1080 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1081 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1082 SpvStorageClassUniform,
1083 uvec4_type);
1084
1085 unsigned idx = const_offset->u32;
1086 SpvId member = emit_uint_const(ctx, 32, 0);
1087 SpvId offset = emit_uint_const(ctx, 32, idx);
1088 SpvId offsets[] = { member, offset };
1089 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1090 ctx->ubos[0], offsets,
1091 ARRAY_SIZE(offsets));
1092 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1093
1094 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1095 unsigned num_components = nir_dest_num_components(intr->dest);
1096 if (num_components == 1) {
1097 uint32_t components[] = { 0 };
1098 result = spirv_builder_emit_composite_extract(&ctx->builder,
1099 type,
1100 result, components,
1101 1);
1102 } else if (num_components < 4) {
1103 SpvId constituents[num_components];
1104 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1105 for (uint32_t i = 0; i < num_components; ++i)
1106 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1107 uint_type,
1108 result, &i,
1109 1);
1110
1111 result = spirv_builder_emit_composite_construct(&ctx->builder,
1112 type,
1113 constituents,
1114 num_components);
1115 }
1116
1117 store_dest_uint(ctx, &intr->dest, result);
1118 } else
1119 unreachable("uniform-addressing not yet supported");
1120 }
1121
1122 static void
1123 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1124 {
1125 assert(ctx->block_started);
1126 spirv_builder_emit_kill(&ctx->builder);
1127 /* discard is weird in NIR, so let's just create an unreachable block after
1128 it and hope that the vulkan driver will DCE any instructinos in it. */
1129 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1130 }
1131
1132 static void
1133 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1134 {
1135 /* uint is a bit of a lie here; it's really just a pointer */
1136 SpvId ptr = get_src_uint(ctx, intr->src);
1137
1138 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1139 SpvId result = spirv_builder_emit_load(&ctx->builder,
1140 get_glsl_type(ctx, var->type),
1141 ptr);
1142 unsigned num_components = nir_dest_num_components(intr->dest);
1143 unsigned bit_size = nir_dest_bit_size(intr->dest);
1144 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1145 store_dest_uint(ctx, &intr->dest, result);
1146 }
1147
1148 static void
1149 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1150 {
1151 /* uint is a bit of a lie here; it's really just a pointer */
1152 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1153 SpvId src = get_src_uint(ctx, &intr->src[1]);
1154
1155 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1156 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1157 SpvId result = emit_bitcast(ctx, type, src);
1158 spirv_builder_emit_store(&ctx->builder, ptr, result);
1159 }
1160
1161 static void
1162 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1163 {
1164 SpvId var_type = get_glsl_type(ctx, glsl_bool_type());
1165 if (!ctx->front_face_var) {
1166 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1167 SpvStorageClassInput,
1168 var_type);
1169 ctx->front_face_var = spirv_builder_emit_var(&ctx->builder,
1170 pointer_type,
1171 SpvStorageClassInput);
1172 spirv_builder_emit_name(&ctx->builder, ctx->front_face_var,
1173 "gl_FrontFacing");
1174 spirv_builder_emit_builtin(&ctx->builder, ctx->front_face_var,
1175 SpvBuiltInFrontFacing);
1176
1177 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1178 ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->front_face_var;
1179 }
1180
1181 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1182 ctx->front_face_var);
1183 assert(1 == nir_dest_num_components(intr->dest));
1184 result = bvec_to_uvec(ctx, result, 1);
1185 store_dest_uint(ctx, &intr->dest, result);
1186 }
1187
1188 static void
1189 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1190 {
1191 switch (intr->intrinsic) {
1192 case nir_intrinsic_load_ubo:
1193 emit_load_ubo(ctx, intr);
1194 break;
1195
1196 case nir_intrinsic_discard:
1197 emit_discard(ctx, intr);
1198 break;
1199
1200 case nir_intrinsic_load_deref:
1201 emit_load_deref(ctx, intr);
1202 break;
1203
1204 case nir_intrinsic_store_deref:
1205 emit_store_deref(ctx, intr);
1206 break;
1207
1208 case nir_intrinsic_load_front_face:
1209 emit_load_front_face(ctx, intr);
1210 break;
1211
1212 default:
1213 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1214 nir_intrinsic_infos[intr->intrinsic].name);
1215 unreachable("unsupported intrinsic");
1216 }
1217 }
1218
1219 static void
1220 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1221 {
1222 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1223 undef->def.num_components);
1224
1225 store_ssa_def_uint(ctx, &undef->def,
1226 spirv_builder_emit_undef(&ctx->builder, type));
1227 }
1228
1229 static SpvId
1230 get_src_float(struct ntv_context *ctx, nir_src *src)
1231 {
1232 SpvId def = get_src_uint(ctx, src);
1233 unsigned num_components = nir_src_num_components(*src);
1234 unsigned bit_size = nir_src_bit_size(*src);
1235 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1236 }
1237
1238 static void
1239 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1240 {
1241 assert(tex->op == nir_texop_tex ||
1242 tex->op == nir_texop_txb ||
1243 tex->op == nir_texop_txl);
1244 assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1245 assert(tex->texture_index == tex->sampler_index);
1246
1247 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0;
1248 unsigned coord_components;
1249 for (unsigned i = 0; i < tex->num_srcs; i++) {
1250 switch (tex->src[i].src_type) {
1251 case nir_tex_src_coord:
1252 coord = get_src_float(ctx, &tex->src[i].src);
1253 coord_components = nir_src_num_components(tex->src[i].src);
1254 break;
1255
1256 case nir_tex_src_projector:
1257 assert(nir_src_num_components(tex->src[i].src) == 1);
1258 proj = get_src_float(ctx, &tex->src[i].src);
1259 assert(proj != 0);
1260 break;
1261
1262 case nir_tex_src_bias:
1263 assert(tex->op == nir_texop_txb);
1264 bias = get_src_float(ctx, &tex->src[i].src);
1265 assert(bias != 0);
1266 break;
1267
1268 case nir_tex_src_lod:
1269 assert(nir_src_num_components(tex->src[i].src) == 1);
1270 lod = get_src_float(ctx, &tex->src[i].src);
1271 assert(lod != 0);
1272 break;
1273
1274 case nir_tex_src_comparator:
1275 assert(nir_src_num_components(tex->src[i].src) == 1);
1276 dref = get_src_float(ctx, &tex->src[i].src);
1277 assert(dref != 0);
1278 break;
1279
1280 default:
1281 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1282 unreachable("unknown texture source");
1283 }
1284 }
1285
1286 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1287 lod = emit_float_const(ctx, 32, 0.0f);
1288 assert(lod != 0);
1289 }
1290
1291 bool is_ms;
1292 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1293 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1294 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1295 dimension, false, tex->is_array, is_ms, 1,
1296 SpvImageFormatUnknown);
1297 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1298 image_type);
1299
1300 assert(tex->texture_index < ctx->num_samplers);
1301 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1302 ctx->samplers[tex->texture_index]);
1303
1304 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1305
1306 if (proj) {
1307 SpvId constituents[coord_components + 1];
1308 if (coord_components == 1)
1309 constituents[0] = coord;
1310 else {
1311 assert(coord_components > 1);
1312 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1313 for (uint32_t i = 0; i < coord_components; ++i)
1314 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1315 float_type,
1316 coord,
1317 &i, 1);
1318 }
1319
1320 constituents[coord_components++] = proj;
1321
1322 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1323 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1324 vec_type,
1325 constituents,
1326 coord_components);
1327 }
1328
1329 SpvId actual_dest_type = dest_type;
1330 if (dref)
1331 actual_dest_type = float_type;
1332
1333 SpvId result = spirv_builder_emit_image_sample(&ctx->builder,
1334 actual_dest_type, load,
1335 coord,
1336 proj != 0,
1337 lod, bias, dref);
1338 spirv_builder_emit_decoration(&ctx->builder, result,
1339 SpvDecorationRelaxedPrecision);
1340
1341 if (dref) {
1342 SpvId components[4] = { result, result, result, result };
1343 result = spirv_builder_emit_composite_construct(&ctx->builder,
1344 dest_type,
1345 components,
1346 4);
1347 }
1348
1349 store_dest(ctx, &tex->dest, result, tex->dest_type);
1350 }
1351
1352 static void
1353 start_block(struct ntv_context *ctx, SpvId label)
1354 {
1355 /* terminate previous block if needed */
1356 if (ctx->block_started)
1357 spirv_builder_emit_branch(&ctx->builder, label);
1358
1359 /* start new block */
1360 spirv_builder_label(&ctx->builder, label);
1361 ctx->block_started = true;
1362 }
1363
1364 static void
1365 branch(struct ntv_context *ctx, SpvId label)
1366 {
1367 assert(ctx->block_started);
1368 spirv_builder_emit_branch(&ctx->builder, label);
1369 ctx->block_started = false;
1370 }
1371
1372 static void
1373 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1374 SpvId else_id)
1375 {
1376 assert(ctx->block_started);
1377 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1378 then_id, else_id);
1379 ctx->block_started = false;
1380 }
1381
1382 static void
1383 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1384 {
1385 switch (jump->type) {
1386 case nir_jump_break:
1387 assert(ctx->loop_break);
1388 branch(ctx, ctx->loop_break);
1389 break;
1390
1391 case nir_jump_continue:
1392 assert(ctx->loop_cont);
1393 branch(ctx, ctx->loop_cont);
1394 break;
1395
1396 default:
1397 unreachable("Unsupported jump type\n");
1398 }
1399 }
1400
1401 static void
1402 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1403 {
1404 assert(deref->deref_type == nir_deref_type_var);
1405
1406 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1407 assert(he);
1408 SpvId result = (SpvId)(intptr_t)he->data;
1409 /* uint is a bit of a lie here, it's really just an opaque type */
1410 store_dest_uint(ctx, &deref->dest, result);
1411 }
1412
1413 static void
1414 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1415 {
1416 assert(deref->deref_type == nir_deref_type_array);
1417 nir_variable *var = nir_deref_instr_get_variable(deref);
1418
1419 SpvStorageClass storage_class;
1420 switch (var->data.mode) {
1421 case nir_var_shader_in:
1422 storage_class = SpvStorageClassInput;
1423 break;
1424
1425 case nir_var_shader_out:
1426 storage_class = SpvStorageClassOutput;
1427 break;
1428
1429 default:
1430 unreachable("Unsupported nir_variable_mode\n");
1431 }
1432
1433 SpvId index = get_src_uint(ctx, &deref->arr.index);
1434
1435 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1436 storage_class,
1437 get_glsl_type(ctx, deref->type));
1438
1439 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1440 ptr_type,
1441 get_src_uint(ctx, &deref->parent),
1442 &index, 1);
1443 /* uint is a bit of a lie here, it's really just an opaque type */
1444 store_dest_uint(ctx, &deref->dest, result);
1445 }
1446
1447 static void
1448 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1449 {
1450 switch (deref->deref_type) {
1451 case nir_deref_type_var:
1452 emit_deref_var(ctx, deref);
1453 break;
1454
1455 case nir_deref_type_array:
1456 emit_deref_array(ctx, deref);
1457 break;
1458
1459 default:
1460 unreachable("unexpected deref_type");
1461 }
1462 }
1463
1464 static void
1465 emit_block(struct ntv_context *ctx, struct nir_block *block)
1466 {
1467 start_block(ctx, block_label(ctx, block));
1468 nir_foreach_instr(instr, block) {
1469 switch (instr->type) {
1470 case nir_instr_type_alu:
1471 emit_alu(ctx, nir_instr_as_alu(instr));
1472 break;
1473 case nir_instr_type_intrinsic:
1474 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1475 break;
1476 case nir_instr_type_load_const:
1477 emit_load_const(ctx, nir_instr_as_load_const(instr));
1478 break;
1479 case nir_instr_type_ssa_undef:
1480 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1481 break;
1482 case nir_instr_type_tex:
1483 emit_tex(ctx, nir_instr_as_tex(instr));
1484 break;
1485 case nir_instr_type_phi:
1486 unreachable("nir_instr_type_phi not supported");
1487 break;
1488 case nir_instr_type_jump:
1489 emit_jump(ctx, nir_instr_as_jump(instr));
1490 break;
1491 case nir_instr_type_call:
1492 unreachable("nir_instr_type_call not supported");
1493 break;
1494 case nir_instr_type_parallel_copy:
1495 unreachable("nir_instr_type_parallel_copy not supported");
1496 break;
1497 case nir_instr_type_deref:
1498 emit_deref(ctx, nir_instr_as_deref(instr));
1499 break;
1500 }
1501 }
1502 }
1503
1504 static void
1505 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1506
1507 static SpvId
1508 get_src_bool(struct ntv_context *ctx, nir_src *src)
1509 {
1510 SpvId def = get_src_uint(ctx, src);
1511 assert(nir_src_bit_size(*src) == 1);
1512 unsigned num_components = nir_src_num_components(*src);
1513 return uvec_to_bvec(ctx, def, num_components);
1514 }
1515
1516 static void
1517 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1518 {
1519 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1520
1521 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1522 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1523 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1524 SpvId else_id = endif_id;
1525
1526 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1527 if (has_else) {
1528 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1529 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1530 }
1531
1532 /* create a header-block */
1533 start_block(ctx, header_id);
1534 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1535 SpvSelectionControlMaskNone);
1536 branch_conditional(ctx, condition, then_id, else_id);
1537
1538 emit_cf_list(ctx, &if_stmt->then_list);
1539
1540 if (has_else) {
1541 if (ctx->block_started)
1542 branch(ctx, endif_id);
1543
1544 emit_cf_list(ctx, &if_stmt->else_list);
1545 }
1546
1547 start_block(ctx, endif_id);
1548 }
1549
1550 static void
1551 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1552 {
1553 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1554 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1555 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1556 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1557
1558 /* create a header-block */
1559 start_block(ctx, header_id);
1560 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1561 branch(ctx, begin_id);
1562
1563 SpvId save_break = ctx->loop_break;
1564 SpvId save_cont = ctx->loop_cont;
1565 ctx->loop_break = break_id;
1566 ctx->loop_cont = cont_id;
1567
1568 emit_cf_list(ctx, &loop->body);
1569
1570 ctx->loop_break = save_break;
1571 ctx->loop_cont = save_cont;
1572
1573 branch(ctx, cont_id);
1574 start_block(ctx, cont_id);
1575 branch(ctx, header_id);
1576
1577 start_block(ctx, break_id);
1578 }
1579
1580 static void
1581 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1582 {
1583 foreach_list_typed(nir_cf_node, node, node, list) {
1584 switch (node->type) {
1585 case nir_cf_node_block:
1586 emit_block(ctx, nir_cf_node_as_block(node));
1587 break;
1588
1589 case nir_cf_node_if:
1590 emit_if(ctx, nir_cf_node_as_if(node));
1591 break;
1592
1593 case nir_cf_node_loop:
1594 emit_loop(ctx, nir_cf_node_as_loop(node));
1595 break;
1596
1597 case nir_cf_node_function:
1598 unreachable("nir_cf_node_function not supported");
1599 break;
1600 }
1601 }
1602 }
1603
1604 struct spirv_shader *
1605 nir_to_spirv(struct nir_shader *s)
1606 {
1607 struct spirv_shader *ret = NULL;
1608
1609 struct ntv_context ctx = {};
1610
1611 switch (s->info.stage) {
1612 case MESA_SHADER_VERTEX:
1613 case MESA_SHADER_FRAGMENT:
1614 case MESA_SHADER_COMPUTE:
1615 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1616 break;
1617
1618 case MESA_SHADER_TESS_CTRL:
1619 case MESA_SHADER_TESS_EVAL:
1620 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1621 break;
1622
1623 case MESA_SHADER_GEOMETRY:
1624 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1625 break;
1626
1627 default:
1628 unreachable("invalid stage");
1629 }
1630
1631 // TODO: only enable when needed
1632 if (s->info.stage == MESA_SHADER_FRAGMENT)
1633 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1634
1635 ctx.stage = s->info.stage;
1636 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1637 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1638
1639 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1640 SpvMemoryModelGLSL450);
1641
1642 SpvExecutionModel exec_model;
1643 switch (s->info.stage) {
1644 case MESA_SHADER_VERTEX:
1645 exec_model = SpvExecutionModelVertex;
1646 break;
1647 case MESA_SHADER_TESS_CTRL:
1648 exec_model = SpvExecutionModelTessellationControl;
1649 break;
1650 case MESA_SHADER_TESS_EVAL:
1651 exec_model = SpvExecutionModelTessellationEvaluation;
1652 break;
1653 case MESA_SHADER_GEOMETRY:
1654 exec_model = SpvExecutionModelGeometry;
1655 break;
1656 case MESA_SHADER_FRAGMENT:
1657 exec_model = SpvExecutionModelFragment;
1658 break;
1659 case MESA_SHADER_COMPUTE:
1660 exec_model = SpvExecutionModelGLCompute;
1661 break;
1662 default:
1663 unreachable("invalid stage");
1664 }
1665
1666 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1667 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1668 NULL, 0);
1669 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1670 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1671
1672 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1673 _mesa_key_pointer_equal);
1674
1675 nir_foreach_variable(var, &s->inputs)
1676 emit_input(&ctx, var);
1677
1678 nir_foreach_variable(var, &s->outputs)
1679 emit_output(&ctx, var);
1680
1681 nir_foreach_variable(var, &s->uniforms)
1682 emit_uniform(&ctx, var);
1683
1684 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1685 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1686 SpvExecutionModeOriginUpperLeft);
1687 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1688 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1689 SpvExecutionModeDepthReplacing);
1690 }
1691
1692
1693 spirv_builder_function(&ctx.builder, entry_point, type_void,
1694 SpvFunctionControlMaskNone,
1695 type_main);
1696
1697 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1698 nir_metadata_require(entry, nir_metadata_block_index);
1699
1700 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1701 if (!ctx.defs)
1702 goto fail;
1703 ctx.num_defs = entry->ssa_alloc;
1704
1705 nir_index_local_regs(entry);
1706 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1707 if (!ctx.regs)
1708 goto fail;
1709 ctx.num_regs = entry->reg_alloc;
1710
1711 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1712 if (!block_ids)
1713 goto fail;
1714
1715 for (int i = 0; i < entry->num_blocks; ++i)
1716 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1717
1718 ctx.block_ids = block_ids;
1719 ctx.num_blocks = entry->num_blocks;
1720
1721 /* emit a block only for the variable declarations */
1722 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1723 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1724 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1725 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1726 SpvStorageClassFunction,
1727 type);
1728 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1729 SpvStorageClassFunction);
1730
1731 ctx.regs[reg->index] = var;
1732 }
1733
1734 emit_cf_list(&ctx, &entry->body);
1735
1736 free(ctx.defs);
1737
1738 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1739 spirv_builder_function_end(&ctx.builder);
1740
1741 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1742 "main", ctx.entry_ifaces,
1743 ctx.num_entry_ifaces);
1744
1745 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1746
1747 ret = CALLOC_STRUCT(spirv_shader);
1748 if (!ret)
1749 goto fail;
1750
1751 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1752 if (!ret->words)
1753 goto fail;
1754
1755 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1756 assert(ret->num_words == num_words);
1757
1758 return ret;
1759
1760 fail:
1761
1762 if (ret)
1763 spirv_shader_delete(ret);
1764
1765 if (ctx.vars)
1766 _mesa_hash_table_destroy(ctx.vars, NULL);
1767
1768 return NULL;
1769 }
1770
1771 void
1772 spirv_shader_delete(struct spirv_shader *s)
1773 {
1774 FREE(s->words);
1775 FREE(s);
1776 }