zink: implement nir_texop_txs
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31
32 struct ntv_context {
33 struct spirv_builder builder;
34
35 SpvId GLSL_std_450;
36
37 gl_shader_stage stage;
38
39 SpvId ubos[128];
40 size_t num_ubos;
41 SpvId samplers[PIPE_MAX_SAMPLERS];
42 size_t num_samplers;
43 SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
44 size_t num_entry_ifaces;
45
46 SpvId *defs;
47 size_t num_defs;
48
49 SpvId *regs;
50 size_t num_regs;
51
52 struct hash_table *vars; /* nir_variable -> SpvId */
53
54 const SpvId *block_ids;
55 size_t num_blocks;
56 bool block_started;
57 SpvId loop_break, loop_cont;
58
59 SpvId front_face_var, vertex_id_var;
60 };
61
62 static SpvId
63 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
64 unsigned num_components, float value);
65
66 static SpvId
67 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
68 unsigned num_components, uint32_t value);
69
70 static SpvId
71 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
72 unsigned num_components, int32_t value);
73
74 static SpvId
75 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
76
77 static SpvId
78 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
79 SpvId src0, SpvId src1);
80
81 static SpvId
82 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
83 SpvId src0, SpvId src1, SpvId src2);
84
85 static SpvId
86 get_bvec_type(struct ntv_context *ctx, int num_components)
87 {
88 SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
89 if (num_components > 1)
90 return spirv_builder_type_vector(&ctx->builder, bool_type,
91 num_components);
92
93 assert(num_components == 1);
94 return bool_type;
95 }
96
97 static SpvId
98 block_label(struct ntv_context *ctx, nir_block *block)
99 {
100 assert(block->index < ctx->num_blocks);
101 return ctx->block_ids[block->index];
102 }
103
104 static SpvId
105 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
106 {
107 assert(bit_size == 32);
108 return spirv_builder_const_float(&ctx->builder, bit_size, value);
109 }
110
111 static SpvId
112 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
113 {
114 assert(bit_size == 32);
115 return spirv_builder_const_uint(&ctx->builder, bit_size, value);
116 }
117
118 static SpvId
119 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
120 {
121 assert(bit_size == 32);
122 return spirv_builder_const_int(&ctx->builder, bit_size, value);
123 }
124
125 static SpvId
126 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
127 {
128 assert(bit_size == 32); // only 32-bit floats supported so far
129
130 SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
131 if (num_components > 1)
132 return spirv_builder_type_vector(&ctx->builder, float_type,
133 num_components);
134
135 assert(num_components == 1);
136 return float_type;
137 }
138
139 static SpvId
140 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
141 {
142 assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
143
144 SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
145 if (num_components > 1)
146 return spirv_builder_type_vector(&ctx->builder, int_type,
147 num_components);
148
149 assert(num_components == 1);
150 return int_type;
151 }
152
153 static SpvId
154 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
155 {
156 assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
157
158 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
159 if (num_components > 1)
160 return spirv_builder_type_vector(&ctx->builder, uint_type,
161 num_components);
162
163 assert(num_components == 1);
164 return uint_type;
165 }
166
167 static SpvId
168 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
169 {
170 return get_uvec_type(ctx, nir_dest_bit_size(*dest),
171 nir_dest_num_components(*dest));
172 }
173
174 static SpvId
175 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
176 {
177 switch (type) {
178 case GLSL_TYPE_BOOL:
179 return spirv_builder_type_bool(&ctx->builder);
180
181 case GLSL_TYPE_FLOAT:
182 return spirv_builder_type_float(&ctx->builder, 32);
183
184 case GLSL_TYPE_INT:
185 return spirv_builder_type_int(&ctx->builder, 32);
186
187 case GLSL_TYPE_UINT:
188 return spirv_builder_type_uint(&ctx->builder, 32);
189 /* TODO: handle more types */
190
191 default:
192 unreachable("unknown GLSL type");
193 }
194 }
195
196 static SpvId
197 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
198 {
199 assert(type);
200 if (glsl_type_is_scalar(type))
201 return get_glsl_basetype(ctx, glsl_get_base_type(type));
202
203 if (glsl_type_is_vector(type))
204 return spirv_builder_type_vector(&ctx->builder,
205 get_glsl_basetype(ctx, glsl_get_base_type(type)),
206 glsl_get_vector_elements(type));
207
208 if (glsl_type_is_array(type)) {
209 SpvId ret = spirv_builder_type_array(&ctx->builder,
210 get_glsl_type(ctx, glsl_get_array_element(type)),
211 emit_uint_const(ctx, 32, glsl_get_length(type)));
212 uint32_t stride = glsl_get_explicit_stride(type);
213 if (stride)
214 spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
215 return ret;
216 }
217
218
219 unreachable("we shouldn't get here, I think...");
220 }
221
222 static void
223 emit_input(struct ntv_context *ctx, struct nir_variable *var)
224 {
225 SpvId var_type = get_glsl_type(ctx, var->type);
226 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
227 SpvStorageClassInput,
228 var_type);
229 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
230 SpvStorageClassInput);
231
232 if (var->name)
233 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
234
235 if (ctx->stage == MESA_SHADER_FRAGMENT) {
236 if (var->data.location >= VARYING_SLOT_VAR0)
237 spirv_builder_emit_location(&ctx->builder, var_id,
238 var->data.location -
239 VARYING_SLOT_VAR0 +
240 VARYING_SLOT_TEX0);
241 else if ((var->data.location >= VARYING_SLOT_COL0 &&
242 var->data.location <= VARYING_SLOT_TEX7) ||
243 var->data.location == VARYING_SLOT_BFC0 ||
244 var->data.location == VARYING_SLOT_BFC1) {
245 spirv_builder_emit_location(&ctx->builder, var_id,
246 var->data.location);
247 } else {
248 switch (var->data.location) {
249 case VARYING_SLOT_POS:
250 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
251 break;
252
253 case VARYING_SLOT_PNTC:
254 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
255 break;
256
257 default:
258 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
259 unreachable("unexpected varying slot");
260 }
261 }
262 } else {
263 spirv_builder_emit_location(&ctx->builder, var_id,
264 var->data.driver_location);
265 }
266
267 if (var->data.location_frac)
268 spirv_builder_emit_component(&ctx->builder, var_id,
269 var->data.location_frac);
270
271 if (var->data.interpolation == INTERP_MODE_FLAT)
272 spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
273
274 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
275
276 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
277 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
278 }
279
280 static void
281 emit_output(struct ntv_context *ctx, struct nir_variable *var)
282 {
283 SpvId var_type = get_glsl_type(ctx, var->type);
284 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
285 SpvStorageClassOutput,
286 var_type);
287 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
288 SpvStorageClassOutput);
289 if (var->name)
290 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
291
292
293 if (ctx->stage == MESA_SHADER_VERTEX) {
294 if (var->data.location >= VARYING_SLOT_VAR0)
295 spirv_builder_emit_location(&ctx->builder, var_id,
296 var->data.location -
297 VARYING_SLOT_VAR0 +
298 VARYING_SLOT_TEX0);
299 else if ((var->data.location >= VARYING_SLOT_COL0 &&
300 var->data.location <= VARYING_SLOT_TEX7) ||
301 var->data.location == VARYING_SLOT_BFC0 ||
302 var->data.location == VARYING_SLOT_BFC1) {
303 spirv_builder_emit_location(&ctx->builder, var_id,
304 var->data.location);
305 } else {
306 switch (var->data.location) {
307 case VARYING_SLOT_POS:
308 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
309 break;
310
311 case VARYING_SLOT_PSIZ:
312 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
313 break;
314
315 case VARYING_SLOT_CLIP_DIST0:
316 assert(glsl_type_is_array(var->type));
317 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
318 break;
319
320 default:
321 debug_printf("unknown varying slot: %s\n", gl_varying_slot_name(var->data.location));
322 unreachable("unexpected varying slot");
323 }
324 }
325 } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
326 if (var->data.location >= FRAG_RESULT_DATA0)
327 spirv_builder_emit_location(&ctx->builder, var_id,
328 var->data.location - FRAG_RESULT_DATA0);
329 else {
330 switch (var->data.location) {
331 case FRAG_RESULT_COLOR:
332 spirv_builder_emit_location(&ctx->builder, var_id, 0);
333 break;
334
335 case FRAG_RESULT_DEPTH:
336 spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
337 break;
338
339 default:
340 spirv_builder_emit_location(&ctx->builder, var_id,
341 var->data.driver_location);
342 }
343 }
344 }
345
346 if (var->data.location_frac)
347 spirv_builder_emit_component(&ctx->builder, var_id,
348 var->data.location_frac);
349
350 _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
351
352 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
353 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
354 }
355
356 static SpvDim
357 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
358 {
359 *is_ms = false;
360 switch (gdim) {
361 case GLSL_SAMPLER_DIM_1D:
362 return SpvDim1D;
363 case GLSL_SAMPLER_DIM_2D:
364 return SpvDim2D;
365 case GLSL_SAMPLER_DIM_3D:
366 return SpvDim3D;
367 case GLSL_SAMPLER_DIM_CUBE:
368 return SpvDimCube;
369 case GLSL_SAMPLER_DIM_RECT:
370 return SpvDimRect;
371 case GLSL_SAMPLER_DIM_BUF:
372 return SpvDimBuffer;
373 case GLSL_SAMPLER_DIM_EXTERNAL:
374 return SpvDim2D; /* seems dodgy... */
375 case GLSL_SAMPLER_DIM_MS:
376 *is_ms = true;
377 return SpvDim2D;
378 default:
379 fprintf(stderr, "unknown sampler type %d\n", gdim);
380 break;
381 }
382 return SpvDim2D;
383 }
384
385 static void
386 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
387 {
388 bool is_ms;
389 SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
390 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
391 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
392 dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
393 SpvImageFormatUnknown);
394
395 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
396 image_type);
397 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
398 SpvStorageClassUniformConstant,
399 sampled_type);
400 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
401 SpvStorageClassUniformConstant);
402
403 if (var->name)
404 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
405
406 assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
407 ctx->samplers[ctx->num_samplers++] = var_id;
408
409 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
410 var->data.descriptor_set);
411 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
412 }
413
414 static void
415 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
416 {
417 uint32_t size = glsl_count_attribute_slots(var->type, false);
418 SpvId vec4_type = get_uvec_type(ctx, 32, 4);
419 SpvId array_length = emit_uint_const(ctx, 32, size);
420 SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
421 array_length);
422 spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
423
424 // wrap UBO-array in a struct
425 SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
426 if (var->name) {
427 char struct_name[100];
428 snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
429 spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
430 }
431
432 spirv_builder_emit_decoration(&ctx->builder, struct_type,
433 SpvDecorationBlock);
434 spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
435
436
437 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
438 SpvStorageClassUniform,
439 struct_type);
440
441 SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
442 SpvStorageClassUniform);
443 if (var->name)
444 spirv_builder_emit_name(&ctx->builder, var_id, var->name);
445
446 assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
447 ctx->ubos[ctx->num_ubos++] = var_id;
448
449 spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
450 var->data.descriptor_set);
451 spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
452 }
453
454 static void
455 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
456 {
457 if (var->data.mode == nir_var_mem_ubo)
458 emit_ubo(ctx, var);
459 else {
460 assert(var->data.mode == nir_var_uniform);
461 if (glsl_type_is_sampler(var->type))
462 emit_sampler(ctx, var);
463 }
464 }
465
466 static SpvId
467 get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
468 {
469 assert(ssa->index < ctx->num_defs);
470 assert(ctx->defs[ssa->index] != 0);
471 return ctx->defs[ssa->index];
472 }
473
474 static SpvId
475 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
476 {
477 assert(reg->index < ctx->num_regs);
478 assert(ctx->regs[reg->index] != 0);
479 return ctx->regs[reg->index];
480 }
481
482 static SpvId
483 get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
484 {
485 assert(reg->reg);
486 assert(!reg->indirect);
487 assert(!reg->base_offset);
488
489 SpvId var = get_var_from_reg(ctx, reg->reg);
490 SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
491 return spirv_builder_emit_load(&ctx->builder, type, var);
492 }
493
494 static SpvId
495 get_src_uint(struct ntv_context *ctx, nir_src *src)
496 {
497 if (src->is_ssa)
498 return get_src_uint_ssa(ctx, src->ssa);
499 else
500 return get_src_uint_reg(ctx, &src->reg);
501 }
502
503 static SpvId
504 get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
505 {
506 assert(!alu->src[src].negate);
507 assert(!alu->src[src].abs);
508
509 SpvId def = get_src_uint(ctx, &alu->src[src].src);
510
511 unsigned used_channels = 0;
512 bool need_swizzle = false;
513 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
514 if (!nir_alu_instr_channel_used(alu, src, i))
515 continue;
516
517 used_channels++;
518
519 if (alu->src[src].swizzle[i] != i)
520 need_swizzle = true;
521 }
522 assert(used_channels != 0);
523
524 unsigned live_channels = nir_src_num_components(alu->src[src].src);
525 if (used_channels != live_channels)
526 need_swizzle = true;
527
528 if (!need_swizzle)
529 return def;
530
531 int bit_size = nir_src_bit_size(alu->src[src].src);
532 assert(bit_size == 1 || bit_size == 32);
533
534 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
535 if (used_channels == 1) {
536 uint32_t indices[] = { alu->src[src].swizzle[0] };
537 return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
538 def, indices,
539 ARRAY_SIZE(indices));
540 } else if (live_channels == 1) {
541 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
542 used_channels);
543
544 SpvId constituents[NIR_MAX_VEC_COMPONENTS];
545 for (unsigned i = 0; i < used_channels; ++i)
546 constituents[i] = def;
547
548 return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
549 constituents,
550 used_channels);
551 } else {
552 SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
553 used_channels);
554
555 uint32_t components[NIR_MAX_VEC_COMPONENTS];
556 size_t num_components = 0;
557 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
558 if (!nir_alu_instr_channel_used(alu, src, i))
559 continue;
560
561 components[num_components++] = alu->src[src].swizzle[i];
562 }
563
564 return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
565 def, def, components, num_components);
566 }
567 }
568
569 static void
570 store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
571 {
572 assert(result != 0);
573 assert(ssa->index < ctx->num_defs);
574 ctx->defs[ssa->index] = result;
575 }
576
577 static SpvId
578 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
579 SpvId if_true, SpvId if_false)
580 {
581 return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
582 }
583
584 static SpvId
585 bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
586 {
587 SpvId otype = get_uvec_type(ctx, 32, num_components);
588 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
589 SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
590 return emit_select(ctx, otype, value, one, zero);
591 }
592
593 static SpvId
594 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
595 {
596 SpvId type = get_bvec_type(ctx, num_components);
597 SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
598 return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
599 }
600
601 static SpvId
602 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
603 {
604 return emit_unop(ctx, SpvOpBitcast, type, value);
605 }
606
607 static SpvId
608 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
609 unsigned num_components)
610 {
611 SpvId type = get_uvec_type(ctx, bit_size, num_components);
612 return emit_bitcast(ctx, type, value);
613 }
614
615 static SpvId
616 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
617 unsigned num_components)
618 {
619 SpvId type = get_ivec_type(ctx, bit_size, num_components);
620 return emit_bitcast(ctx, type, value);
621 }
622
623 static SpvId
624 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
625 unsigned num_components)
626 {
627 SpvId type = get_fvec_type(ctx, bit_size, num_components);
628 return emit_bitcast(ctx, type, value);
629 }
630
631 static void
632 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
633 {
634 SpvId var = get_var_from_reg(ctx, reg->reg);
635 assert(var);
636 spirv_builder_emit_store(&ctx->builder, var, result);
637 }
638
639 static void
640 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
641 {
642 if (dest->is_ssa)
643 store_ssa_def_uint(ctx, &dest->ssa, result);
644 else
645 store_reg_def(ctx, &dest->reg, result);
646 }
647
648 static void
649 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
650 {
651 unsigned num_components = nir_dest_num_components(*dest);
652 unsigned bit_size = nir_dest_bit_size(*dest);
653
654 switch (nir_alu_type_get_base_type(type)) {
655 case nir_type_bool:
656 assert(bit_size == 1);
657 result = bvec_to_uvec(ctx, result, num_components);
658 break;
659
660 case nir_type_uint:
661 break; /* nothing to do! */
662
663 case nir_type_int:
664 case nir_type_float:
665 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
666 break;
667
668 default:
669 unreachable("unsupported nir_alu_type");
670 }
671
672 store_dest_uint(ctx, dest, result);
673 }
674
675 static SpvId
676 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
677 {
678 return spirv_builder_emit_unop(&ctx->builder, op, type, src);
679 }
680
681 static SpvId
682 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
683 SpvId src0, SpvId src1)
684 {
685 return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
686 }
687
688 static SpvId
689 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
690 SpvId src0, SpvId src1, SpvId src2)
691 {
692 return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
693 }
694
695 static SpvId
696 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
697 SpvId src)
698 {
699 SpvId args[] = { src };
700 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
701 op, args, ARRAY_SIZE(args));
702 }
703
704 static SpvId
705 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
706 SpvId src0, SpvId src1)
707 {
708 SpvId args[] = { src0, src1 };
709 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
710 op, args, ARRAY_SIZE(args));
711 }
712
713 static SpvId
714 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
715 SpvId src0, SpvId src1, SpvId src2)
716 {
717 SpvId args[] = { src0, src1, src2 };
718 return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
719 op, args, ARRAY_SIZE(args));
720 }
721
722 static SpvId
723 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
724 unsigned num_components, float value)
725 {
726 assert(bit_size == 32);
727
728 SpvId result = emit_float_const(ctx, bit_size, value);
729 if (num_components == 1)
730 return result;
731
732 assert(num_components > 1);
733 SpvId components[num_components];
734 for (int i = 0; i < num_components; i++)
735 components[i] = result;
736
737 SpvId type = get_fvec_type(ctx, bit_size, num_components);
738 return spirv_builder_const_composite(&ctx->builder, type, components,
739 num_components);
740 }
741
742 static SpvId
743 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
744 unsigned num_components, uint32_t value)
745 {
746 assert(bit_size == 32);
747
748 SpvId result = emit_uint_const(ctx, bit_size, value);
749 if (num_components == 1)
750 return result;
751
752 assert(num_components > 1);
753 SpvId components[num_components];
754 for (int i = 0; i < num_components; i++)
755 components[i] = result;
756
757 SpvId type = get_uvec_type(ctx, bit_size, num_components);
758 return spirv_builder_const_composite(&ctx->builder, type, components,
759 num_components);
760 }
761
762 static SpvId
763 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
764 unsigned num_components, int32_t value)
765 {
766 assert(bit_size == 32);
767
768 SpvId result = emit_int_const(ctx, bit_size, value);
769 if (num_components == 1)
770 return result;
771
772 assert(num_components > 1);
773 SpvId components[num_components];
774 for (int i = 0; i < num_components; i++)
775 components[i] = result;
776
777 SpvId type = get_ivec_type(ctx, bit_size, num_components);
778 return spirv_builder_const_composite(&ctx->builder, type, components,
779 num_components);
780 }
781
782 static inline unsigned
783 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
784 {
785 if (nir_op_infos[instr->op].input_sizes[src] > 0)
786 return nir_op_infos[instr->op].input_sizes[src];
787
788 if (instr->dest.dest.is_ssa)
789 return instr->dest.dest.ssa.num_components;
790 else
791 return instr->dest.dest.reg.reg->num_components;
792 }
793
794 static SpvId
795 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
796 {
797 SpvId uint_value = get_alu_src_uint(ctx, alu, src);
798
799 unsigned num_components = alu_instr_src_components(alu, src);
800 unsigned bit_size = nir_src_bit_size(alu->src[src].src);
801 nir_alu_type type = nir_op_infos[alu->op].input_types[src];
802
803 switch (nir_alu_type_get_base_type(type)) {
804 case nir_type_bool:
805 assert(bit_size == 1);
806 return uvec_to_bvec(ctx, uint_value, num_components);
807
808 case nir_type_int:
809 return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
810
811 case nir_type_uint:
812 return uint_value;
813
814 case nir_type_float:
815 return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
816
817 default:
818 unreachable("unknown nir_alu_type");
819 }
820 }
821
822 static void
823 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
824 {
825 assert(!alu->dest.saturate);
826 return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
827 }
828
829 static SpvId
830 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
831 {
832 unsigned num_components = nir_dest_num_components(*dest);
833 unsigned bit_size = nir_dest_bit_size(*dest);
834
835 switch (nir_alu_type_get_base_type(type)) {
836 case nir_type_bool:
837 return get_bvec_type(ctx, num_components);
838
839 case nir_type_int:
840 return get_ivec_type(ctx, bit_size, num_components);
841
842 case nir_type_uint:
843 return get_uvec_type(ctx, bit_size, num_components);
844
845 case nir_type_float:
846 return get_fvec_type(ctx, bit_size, num_components);
847
848 default:
849 unreachable("unsupported nir_alu_type");
850 }
851 }
852
853 static void
854 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
855 {
856 SpvId src[nir_op_infos[alu->op].num_inputs];
857 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
858 src[i] = get_alu_src(ctx, alu, i);
859
860 SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
861 nir_op_infos[alu->op].output_type);
862 unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
863 unsigned num_components = nir_dest_num_components(alu->dest.dest);
864
865 SpvId result = 0;
866 switch (alu->op) {
867 case nir_op_mov:
868 assert(nir_op_infos[alu->op].num_inputs == 1);
869 result = src[0];
870 break;
871
872 #define UNOP(nir_op, spirv_op) \
873 case nir_op: \
874 assert(nir_op_infos[alu->op].num_inputs == 1); \
875 result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
876 break;
877
878 UNOP(nir_op_ineg, SpvOpSNegate)
879 UNOP(nir_op_fneg, SpvOpFNegate)
880 UNOP(nir_op_fddx, SpvOpDPdx)
881 UNOP(nir_op_fddy, SpvOpDPdy)
882 UNOP(nir_op_f2i32, SpvOpConvertFToS)
883 UNOP(nir_op_f2u32, SpvOpConvertFToU)
884 UNOP(nir_op_i2f32, SpvOpConvertSToF)
885 UNOP(nir_op_u2f32, SpvOpConvertUToF)
886 UNOP(nir_op_inot, SpvOpNot)
887 #undef UNOP
888
889 case nir_op_b2i32:
890 assert(nir_op_infos[alu->op].num_inputs == 1);
891 result = emit_select(ctx, dest_type, src[0],
892 get_ivec_constant(ctx, 32, num_components, 1),
893 get_ivec_constant(ctx, 32, num_components, 0));
894 break;
895
896 case nir_op_b2f32:
897 assert(nir_op_infos[alu->op].num_inputs == 1);
898 result = emit_select(ctx, dest_type, src[0],
899 get_fvec_constant(ctx, 32, num_components, 1),
900 get_fvec_constant(ctx, 32, num_components, 0));
901 break;
902
903 #define BUILTIN_UNOP(nir_op, spirv_op) \
904 case nir_op: \
905 assert(nir_op_infos[alu->op].num_inputs == 1); \
906 result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
907 break;
908
909 BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
910 BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
911 BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
912 BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
913 BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
914 BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
915 BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
916 BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
917 BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
918 BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
919 BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
920 BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
921 BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
922 BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
923 #undef BUILTIN_UNOP
924
925 case nir_op_frcp:
926 assert(nir_op_infos[alu->op].num_inputs == 1);
927 result = emit_binop(ctx, SpvOpFDiv, dest_type,
928 get_fvec_constant(ctx, bit_size, num_components, 1),
929 src[0]);
930 break;
931
932 case nir_op_f2b1:
933 assert(nir_op_infos[alu->op].num_inputs == 1);
934 result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
935 get_fvec_constant(ctx,
936 nir_src_bit_size(alu->src[0].src),
937 num_components, 0));
938 break;
939
940
941 #define BINOP(nir_op, spirv_op) \
942 case nir_op: \
943 assert(nir_op_infos[alu->op].num_inputs == 2); \
944 result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
945 break;
946
947 BINOP(nir_op_iadd, SpvOpIAdd)
948 BINOP(nir_op_isub, SpvOpISub)
949 BINOP(nir_op_imul, SpvOpIMul)
950 BINOP(nir_op_idiv, SpvOpSDiv)
951 BINOP(nir_op_udiv, SpvOpUDiv)
952 BINOP(nir_op_umod, SpvOpUMod)
953 BINOP(nir_op_fadd, SpvOpFAdd)
954 BINOP(nir_op_fsub, SpvOpFSub)
955 BINOP(nir_op_fmul, SpvOpFMul)
956 BINOP(nir_op_fdiv, SpvOpFDiv)
957 BINOP(nir_op_fmod, SpvOpFMod)
958 BINOP(nir_op_ilt, SpvOpSLessThan)
959 BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
960 BINOP(nir_op_ieq, SpvOpIEqual)
961 BINOP(nir_op_ine, SpvOpINotEqual)
962 BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
963 BINOP(nir_op_flt, SpvOpFOrdLessThan)
964 BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
965 BINOP(nir_op_feq, SpvOpFOrdEqual)
966 BINOP(nir_op_fne, SpvOpFOrdNotEqual)
967 BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
968 BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
969 BINOP(nir_op_ushr, SpvOpShiftRightLogical)
970 BINOP(nir_op_iand, SpvOpBitwiseAnd)
971 BINOP(nir_op_ior, SpvOpBitwiseOr)
972 #undef BINOP
973
974 #define BUILTIN_BINOP(nir_op, spirv_op) \
975 case nir_op: \
976 assert(nir_op_infos[alu->op].num_inputs == 2); \
977 result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
978 break;
979
980 BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
981 BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
982 #undef BUILTIN_BINOP
983
984 case nir_op_fdot2:
985 case nir_op_fdot3:
986 case nir_op_fdot4:
987 assert(nir_op_infos[alu->op].num_inputs == 2);
988 result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
989 break;
990
991 case nir_op_seq:
992 case nir_op_sne:
993 case nir_op_slt:
994 case nir_op_sge: {
995 assert(nir_op_infos[alu->op].num_inputs == 2);
996 int num_components = nir_dest_num_components(alu->dest.dest);
997 SpvId bool_type = get_bvec_type(ctx, num_components);
998
999 SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1000 SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1001 if (num_components > 1) {
1002 SpvId zero_comps[num_components], one_comps[num_components];
1003 for (int i = 0; i < num_components; i++) {
1004 zero_comps[i] = zero;
1005 one_comps[i] = one;
1006 }
1007
1008 zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1009 zero_comps, num_components);
1010 one = spirv_builder_const_composite(&ctx->builder, dest_type,
1011 one_comps, num_components);
1012 }
1013
1014 SpvOp op;
1015 switch (alu->op) {
1016 case nir_op_seq: op = SpvOpFOrdEqual; break;
1017 case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1018 case nir_op_slt: op = SpvOpFOrdLessThan; break;
1019 case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1020 default: unreachable("unexpected op");
1021 }
1022
1023 result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1024 result = emit_select(ctx, dest_type, result, one, zero);
1025 }
1026 break;
1027
1028 case nir_op_flrp:
1029 assert(nir_op_infos[alu->op].num_inputs == 3);
1030 result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1031 src[0], src[1], src[2]);
1032 break;
1033
1034 case nir_op_fcsel:
1035 result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1036 get_bvec_type(ctx, num_components),
1037 src[0],
1038 get_fvec_constant(ctx,
1039 nir_src_bit_size(alu->src[0].src),
1040 num_components, 0));
1041 result = emit_select(ctx, dest_type, result, src[1], src[2]);
1042 break;
1043
1044 case nir_op_bcsel:
1045 assert(nir_op_infos[alu->op].num_inputs == 3);
1046 result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1047 break;
1048
1049 case nir_op_bany_fnequal2:
1050 case nir_op_bany_fnequal3:
1051 case nir_op_bany_fnequal4:
1052 assert(nir_op_infos[alu->op].num_inputs == 2);
1053 assert(alu_instr_src_components(alu, 0) ==
1054 alu_instr_src_components(alu, 1));
1055 result = emit_binop(ctx, SpvOpFOrdNotEqual,
1056 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1057 src[0], src[1]);
1058 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1059 break;
1060
1061 case nir_op_ball_fequal2:
1062 case nir_op_ball_fequal3:
1063 case nir_op_ball_fequal4:
1064 assert(nir_op_infos[alu->op].num_inputs == 2);
1065 assert(alu_instr_src_components(alu, 0) ==
1066 alu_instr_src_components(alu, 1));
1067 result = emit_binop(ctx, SpvOpFOrdEqual,
1068 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1069 src[0], src[1]);
1070 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1071 break;
1072
1073 case nir_op_bany_inequal2:
1074 case nir_op_bany_inequal3:
1075 case nir_op_bany_inequal4:
1076 assert(nir_op_infos[alu->op].num_inputs == 2);
1077 assert(alu_instr_src_components(alu, 0) ==
1078 alu_instr_src_components(alu, 1));
1079 result = emit_binop(ctx, SpvOpINotEqual,
1080 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1081 src[0], src[1]);
1082 result = emit_unop(ctx, SpvOpAny, dest_type, result);
1083 break;
1084
1085 case nir_op_ball_iequal2:
1086 case nir_op_ball_iequal3:
1087 case nir_op_ball_iequal4:
1088 assert(nir_op_infos[alu->op].num_inputs == 2);
1089 assert(alu_instr_src_components(alu, 0) ==
1090 alu_instr_src_components(alu, 1));
1091 result = emit_binop(ctx, SpvOpIEqual,
1092 get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1093 src[0], src[1]);
1094 result = emit_unop(ctx, SpvOpAll, dest_type, result);
1095 break;
1096
1097 case nir_op_vec2:
1098 case nir_op_vec3:
1099 case nir_op_vec4: {
1100 int num_inputs = nir_op_infos[alu->op].num_inputs;
1101 assert(2 <= num_inputs && num_inputs <= 4);
1102 result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1103 src, num_inputs);
1104 }
1105 break;
1106
1107 default:
1108 fprintf(stderr, "emit_alu: not implemented (%s)\n",
1109 nir_op_infos[alu->op].name);
1110
1111 unreachable("unsupported opcode");
1112 return;
1113 }
1114
1115 store_alu_result(ctx, alu, result);
1116 }
1117
1118 static void
1119 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1120 {
1121 unsigned bit_size = load_const->def.bit_size;
1122 unsigned num_components = load_const->def.num_components;
1123
1124 SpvId constant;
1125 if (num_components > 1) {
1126 SpvId components[num_components];
1127 SpvId type;
1128 if (bit_size == 1) {
1129 for (int i = 0; i < num_components; i++)
1130 components[i] = spirv_builder_const_bool(&ctx->builder,
1131 load_const->value[i].b);
1132
1133 type = get_bvec_type(ctx, num_components);
1134 } else {
1135 for (int i = 0; i < num_components; i++)
1136 components[i] = emit_uint_const(ctx, bit_size,
1137 load_const->value[i].u32);
1138
1139 type = get_uvec_type(ctx, bit_size, num_components);
1140 }
1141 constant = spirv_builder_const_composite(&ctx->builder, type,
1142 components, num_components);
1143 } else {
1144 assert(num_components == 1);
1145 if (bit_size == 1)
1146 constant = spirv_builder_const_bool(&ctx->builder,
1147 load_const->value[0].b);
1148 else
1149 constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1150 }
1151
1152 if (bit_size == 1)
1153 constant = bvec_to_uvec(ctx, constant, num_components);
1154
1155 store_ssa_def_uint(ctx, &load_const->def, constant);
1156 }
1157
1158 static void
1159 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1160 {
1161 nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1162 assert(const_block_index); // no dynamic indexing for now
1163 assert(const_block_index->u32 == 0); // we only support the default UBO for now
1164
1165 nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1166 if (const_offset) {
1167 SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1168 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1169 SpvStorageClassUniform,
1170 uvec4_type);
1171
1172 unsigned idx = const_offset->u32;
1173 SpvId member = emit_uint_const(ctx, 32, 0);
1174 SpvId offset = emit_uint_const(ctx, 32, idx);
1175 SpvId offsets[] = { member, offset };
1176 SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1177 ctx->ubos[0], offsets,
1178 ARRAY_SIZE(offsets));
1179 SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1180
1181 SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1182 unsigned num_components = nir_dest_num_components(intr->dest);
1183 if (num_components == 1) {
1184 uint32_t components[] = { 0 };
1185 result = spirv_builder_emit_composite_extract(&ctx->builder,
1186 type,
1187 result, components,
1188 1);
1189 } else if (num_components < 4) {
1190 SpvId constituents[num_components];
1191 SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1192 for (uint32_t i = 0; i < num_components; ++i)
1193 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1194 uint_type,
1195 result, &i,
1196 1);
1197
1198 result = spirv_builder_emit_composite_construct(&ctx->builder,
1199 type,
1200 constituents,
1201 num_components);
1202 }
1203
1204 store_dest_uint(ctx, &intr->dest, result);
1205 } else
1206 unreachable("uniform-addressing not yet supported");
1207 }
1208
1209 static void
1210 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1211 {
1212 assert(ctx->block_started);
1213 spirv_builder_emit_kill(&ctx->builder);
1214 /* discard is weird in NIR, so let's just create an unreachable block after
1215 it and hope that the vulkan driver will DCE any instructinos in it. */
1216 spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1217 }
1218
1219 static void
1220 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1221 {
1222 /* uint is a bit of a lie here; it's really just a pointer */
1223 SpvId ptr = get_src_uint(ctx, intr->src);
1224
1225 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1226 SpvId result = spirv_builder_emit_load(&ctx->builder,
1227 get_glsl_type(ctx, var->type),
1228 ptr);
1229 unsigned num_components = nir_dest_num_components(intr->dest);
1230 unsigned bit_size = nir_dest_bit_size(intr->dest);
1231 result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1232 store_dest_uint(ctx, &intr->dest, result);
1233 }
1234
1235 static void
1236 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1237 {
1238 /* uint is a bit of a lie here; it's really just a pointer */
1239 SpvId ptr = get_src_uint(ctx, &intr->src[0]);
1240 SpvId src = get_src_uint(ctx, &intr->src[1]);
1241
1242 nir_variable *var = nir_intrinsic_get_var(intr, 0);
1243 SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
1244 SpvId result = emit_bitcast(ctx, type, src);
1245 spirv_builder_emit_store(&ctx->builder, ptr, result);
1246 }
1247
1248 static SpvId
1249 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1250 SpvStorageClass storage_class,
1251 const char *name, SpvBuiltIn builtin)
1252 {
1253 SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1254 storage_class,
1255 var_type);
1256 SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1257 storage_class);
1258 spirv_builder_emit_name(&ctx->builder, var, name);
1259 spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1260
1261 assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1262 ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1263 return var;
1264 }
1265
1266 static void
1267 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1268 {
1269 SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1270 if (!ctx->front_face_var)
1271 ctx->front_face_var = create_builtin_var(ctx, var_type,
1272 SpvStorageClassInput,
1273 "gl_FrontFacing",
1274 SpvBuiltInFrontFacing);
1275
1276 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1277 ctx->front_face_var);
1278 assert(1 == nir_dest_num_components(intr->dest));
1279 result = bvec_to_uvec(ctx, result, 1);
1280 store_dest_uint(ctx, &intr->dest, result);
1281 }
1282
1283 static void
1284 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1285 {
1286 SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1287 if (!ctx->vertex_id_var)
1288 ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1289 SpvStorageClassInput,
1290 "gl_VertexID",
1291 SpvBuiltInVertexIndex);
1292
1293 SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1294 ctx->vertex_id_var);
1295 assert(1 == nir_dest_num_components(intr->dest));
1296 store_dest_uint(ctx, &intr->dest, result);
1297 }
1298
1299 static void
1300 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1301 {
1302 switch (intr->intrinsic) {
1303 case nir_intrinsic_load_ubo:
1304 emit_load_ubo(ctx, intr);
1305 break;
1306
1307 case nir_intrinsic_discard:
1308 emit_discard(ctx, intr);
1309 break;
1310
1311 case nir_intrinsic_load_deref:
1312 emit_load_deref(ctx, intr);
1313 break;
1314
1315 case nir_intrinsic_store_deref:
1316 emit_store_deref(ctx, intr);
1317 break;
1318
1319 case nir_intrinsic_load_front_face:
1320 emit_load_front_face(ctx, intr);
1321 break;
1322
1323 case nir_intrinsic_load_vertex_id:
1324 emit_load_vertex_id(ctx, intr);
1325 break;
1326
1327 default:
1328 fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1329 nir_intrinsic_infos[intr->intrinsic].name);
1330 unreachable("unsupported intrinsic");
1331 }
1332 }
1333
1334 static void
1335 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1336 {
1337 SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1338 undef->def.num_components);
1339
1340 store_ssa_def_uint(ctx, &undef->def,
1341 spirv_builder_emit_undef(&ctx->builder, type));
1342 }
1343
1344 static SpvId
1345 get_src_float(struct ntv_context *ctx, nir_src *src)
1346 {
1347 SpvId def = get_src_uint(ctx, src);
1348 unsigned num_components = nir_src_num_components(*src);
1349 unsigned bit_size = nir_src_bit_size(*src);
1350 return bitcast_to_fvec(ctx, def, bit_size, num_components);
1351 }
1352
1353 static SpvId
1354 get_src_int(struct ntv_context *ctx, nir_src *src)
1355 {
1356 SpvId def = get_src_uint(ctx, src);
1357 unsigned num_components = nir_src_num_components(*src);
1358 unsigned bit_size = nir_src_bit_size(*src);
1359 return bitcast_to_ivec(ctx, def, bit_size, num_components);
1360 }
1361
1362 static void
1363 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1364 {
1365 assert(tex->op == nir_texop_tex ||
1366 tex->op == nir_texop_txb ||
1367 tex->op == nir_texop_txl ||
1368 tex->op == nir_texop_txd ||
1369 tex->op == nir_texop_txf ||
1370 tex->op == nir_texop_txs);
1371 assert(tex->op == nir_texop_txs ||
1372 nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
1373 assert(tex->texture_index == tex->sampler_index);
1374
1375 SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0;
1376 unsigned coord_components = 0;
1377 for (unsigned i = 0; i < tex->num_srcs; i++) {
1378 switch (tex->src[i].src_type) {
1379 case nir_tex_src_coord:
1380 if (tex->op == nir_texop_txf)
1381 coord = get_src_int(ctx, &tex->src[i].src);
1382 else
1383 coord = get_src_float(ctx, &tex->src[i].src);
1384 coord_components = nir_src_num_components(tex->src[i].src);
1385 break;
1386
1387 case nir_tex_src_projector:
1388 assert(nir_src_num_components(tex->src[i].src) == 1);
1389 proj = get_src_float(ctx, &tex->src[i].src);
1390 assert(proj != 0);
1391 break;
1392
1393 case nir_tex_src_bias:
1394 assert(tex->op == nir_texop_txb);
1395 bias = get_src_float(ctx, &tex->src[i].src);
1396 assert(bias != 0);
1397 break;
1398
1399 case nir_tex_src_lod:
1400 assert(nir_src_num_components(tex->src[i].src) == 1);
1401 if (tex->op == nir_texop_txf ||
1402 tex->op == nir_texop_txs)
1403 lod = get_src_int(ctx, &tex->src[i].src);
1404 else
1405 lod = get_src_float(ctx, &tex->src[i].src);
1406 assert(lod != 0);
1407 break;
1408
1409 case nir_tex_src_comparator:
1410 assert(nir_src_num_components(tex->src[i].src) == 1);
1411 dref = get_src_float(ctx, &tex->src[i].src);
1412 assert(dref != 0);
1413 break;
1414
1415 case nir_tex_src_ddx:
1416 dx = get_src_float(ctx, &tex->src[i].src);
1417 assert(dx != 0);
1418 break;
1419
1420 case nir_tex_src_ddy:
1421 dy = get_src_float(ctx, &tex->src[i].src);
1422 assert(dy != 0);
1423 break;
1424
1425 default:
1426 fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1427 unreachable("unknown texture source");
1428 }
1429 }
1430
1431 if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1432 lod = emit_float_const(ctx, 32, 0.0f);
1433 assert(lod != 0);
1434 }
1435
1436 bool is_ms;
1437 SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
1438 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1439 SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
1440 dimension, false, tex->is_array, is_ms, 1,
1441 SpvImageFormatUnknown);
1442 SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1443 image_type);
1444
1445 assert(tex->texture_index < ctx->num_samplers);
1446 SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1447 ctx->samplers[tex->texture_index]);
1448
1449 SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1450
1451 if (tex->op == nir_texop_txs) {
1452 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1453 SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1454 dest_type, image,
1455 lod);
1456 store_dest(ctx, &tex->dest, result, tex->dest_type);
1457 return;
1458 }
1459
1460 if (proj) {
1461 SpvId constituents[coord_components + 1];
1462 if (coord_components == 1)
1463 constituents[0] = coord;
1464 else {
1465 assert(coord_components > 1);
1466 SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1467 for (uint32_t i = 0; i < coord_components; ++i)
1468 constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1469 float_type,
1470 coord,
1471 &i, 1);
1472 }
1473
1474 constituents[coord_components++] = proj;
1475
1476 SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1477 coord = spirv_builder_emit_composite_construct(&ctx->builder,
1478 vec_type,
1479 constituents,
1480 coord_components);
1481 }
1482
1483 SpvId actual_dest_type = dest_type;
1484 if (dref)
1485 actual_dest_type = float_type;
1486
1487 SpvId result;
1488 if (tex->op == nir_texop_txf) {
1489 SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1490 result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1491 image, coord, lod);
1492 } else {
1493 result = spirv_builder_emit_image_sample(&ctx->builder,
1494 actual_dest_type, load,
1495 coord,
1496 proj != 0,
1497 lod, bias, dref, dx, dy);
1498 }
1499
1500 spirv_builder_emit_decoration(&ctx->builder, result,
1501 SpvDecorationRelaxedPrecision);
1502
1503 if (dref) {
1504 SpvId components[4] = { result, result, result, result };
1505 result = spirv_builder_emit_composite_construct(&ctx->builder,
1506 dest_type,
1507 components,
1508 4);
1509 }
1510
1511 store_dest(ctx, &tex->dest, result, tex->dest_type);
1512 }
1513
1514 static void
1515 start_block(struct ntv_context *ctx, SpvId label)
1516 {
1517 /* terminate previous block if needed */
1518 if (ctx->block_started)
1519 spirv_builder_emit_branch(&ctx->builder, label);
1520
1521 /* start new block */
1522 spirv_builder_label(&ctx->builder, label);
1523 ctx->block_started = true;
1524 }
1525
1526 static void
1527 branch(struct ntv_context *ctx, SpvId label)
1528 {
1529 assert(ctx->block_started);
1530 spirv_builder_emit_branch(&ctx->builder, label);
1531 ctx->block_started = false;
1532 }
1533
1534 static void
1535 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
1536 SpvId else_id)
1537 {
1538 assert(ctx->block_started);
1539 spirv_builder_emit_branch_conditional(&ctx->builder, condition,
1540 then_id, else_id);
1541 ctx->block_started = false;
1542 }
1543
1544 static void
1545 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
1546 {
1547 switch (jump->type) {
1548 case nir_jump_break:
1549 assert(ctx->loop_break);
1550 branch(ctx, ctx->loop_break);
1551 break;
1552
1553 case nir_jump_continue:
1554 assert(ctx->loop_cont);
1555 branch(ctx, ctx->loop_cont);
1556 break;
1557
1558 default:
1559 unreachable("Unsupported jump type\n");
1560 }
1561 }
1562
1563 static void
1564 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
1565 {
1566 assert(deref->deref_type == nir_deref_type_var);
1567
1568 struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
1569 assert(he);
1570 SpvId result = (SpvId)(intptr_t)he->data;
1571 /* uint is a bit of a lie here, it's really just an opaque type */
1572 store_dest_uint(ctx, &deref->dest, result);
1573 }
1574
1575 static void
1576 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
1577 {
1578 assert(deref->deref_type == nir_deref_type_array);
1579 nir_variable *var = nir_deref_instr_get_variable(deref);
1580
1581 SpvStorageClass storage_class;
1582 switch (var->data.mode) {
1583 case nir_var_shader_in:
1584 storage_class = SpvStorageClassInput;
1585 break;
1586
1587 case nir_var_shader_out:
1588 storage_class = SpvStorageClassOutput;
1589 break;
1590
1591 default:
1592 unreachable("Unsupported nir_variable_mode\n");
1593 }
1594
1595 SpvId index = get_src_uint(ctx, &deref->arr.index);
1596
1597 SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
1598 storage_class,
1599 get_glsl_type(ctx, deref->type));
1600
1601 SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
1602 ptr_type,
1603 get_src_uint(ctx, &deref->parent),
1604 &index, 1);
1605 /* uint is a bit of a lie here, it's really just an opaque type */
1606 store_dest_uint(ctx, &deref->dest, result);
1607 }
1608
1609 static void
1610 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
1611 {
1612 switch (deref->deref_type) {
1613 case nir_deref_type_var:
1614 emit_deref_var(ctx, deref);
1615 break;
1616
1617 case nir_deref_type_array:
1618 emit_deref_array(ctx, deref);
1619 break;
1620
1621 default:
1622 unreachable("unexpected deref_type");
1623 }
1624 }
1625
1626 static void
1627 emit_block(struct ntv_context *ctx, struct nir_block *block)
1628 {
1629 start_block(ctx, block_label(ctx, block));
1630 nir_foreach_instr(instr, block) {
1631 switch (instr->type) {
1632 case nir_instr_type_alu:
1633 emit_alu(ctx, nir_instr_as_alu(instr));
1634 break;
1635 case nir_instr_type_intrinsic:
1636 emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
1637 break;
1638 case nir_instr_type_load_const:
1639 emit_load_const(ctx, nir_instr_as_load_const(instr));
1640 break;
1641 case nir_instr_type_ssa_undef:
1642 emit_undef(ctx, nir_instr_as_ssa_undef(instr));
1643 break;
1644 case nir_instr_type_tex:
1645 emit_tex(ctx, nir_instr_as_tex(instr));
1646 break;
1647 case nir_instr_type_phi:
1648 unreachable("nir_instr_type_phi not supported");
1649 break;
1650 case nir_instr_type_jump:
1651 emit_jump(ctx, nir_instr_as_jump(instr));
1652 break;
1653 case nir_instr_type_call:
1654 unreachable("nir_instr_type_call not supported");
1655 break;
1656 case nir_instr_type_parallel_copy:
1657 unreachable("nir_instr_type_parallel_copy not supported");
1658 break;
1659 case nir_instr_type_deref:
1660 emit_deref(ctx, nir_instr_as_deref(instr));
1661 break;
1662 }
1663 }
1664 }
1665
1666 static void
1667 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
1668
1669 static SpvId
1670 get_src_bool(struct ntv_context *ctx, nir_src *src)
1671 {
1672 SpvId def = get_src_uint(ctx, src);
1673 assert(nir_src_bit_size(*src) == 1);
1674 unsigned num_components = nir_src_num_components(*src);
1675 return uvec_to_bvec(ctx, def, num_components);
1676 }
1677
1678 static void
1679 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
1680 {
1681 SpvId condition = get_src_bool(ctx, &if_stmt->condition);
1682
1683 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1684 SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
1685 SpvId endif_id = spirv_builder_new_id(&ctx->builder);
1686 SpvId else_id = endif_id;
1687
1688 bool has_else = !exec_list_is_empty(&if_stmt->else_list);
1689 if (has_else) {
1690 assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
1691 else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
1692 }
1693
1694 /* create a header-block */
1695 start_block(ctx, header_id);
1696 spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
1697 SpvSelectionControlMaskNone);
1698 branch_conditional(ctx, condition, then_id, else_id);
1699
1700 emit_cf_list(ctx, &if_stmt->then_list);
1701
1702 if (has_else) {
1703 if (ctx->block_started)
1704 branch(ctx, endif_id);
1705
1706 emit_cf_list(ctx, &if_stmt->else_list);
1707 }
1708
1709 start_block(ctx, endif_id);
1710 }
1711
1712 static void
1713 emit_loop(struct ntv_context *ctx, nir_loop *loop)
1714 {
1715 SpvId header_id = spirv_builder_new_id(&ctx->builder);
1716 SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
1717 SpvId break_id = spirv_builder_new_id(&ctx->builder);
1718 SpvId cont_id = spirv_builder_new_id(&ctx->builder);
1719
1720 /* create a header-block */
1721 start_block(ctx, header_id);
1722 spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
1723 branch(ctx, begin_id);
1724
1725 SpvId save_break = ctx->loop_break;
1726 SpvId save_cont = ctx->loop_cont;
1727 ctx->loop_break = break_id;
1728 ctx->loop_cont = cont_id;
1729
1730 emit_cf_list(ctx, &loop->body);
1731
1732 ctx->loop_break = save_break;
1733 ctx->loop_cont = save_cont;
1734
1735 branch(ctx, cont_id);
1736 start_block(ctx, cont_id);
1737 branch(ctx, header_id);
1738
1739 start_block(ctx, break_id);
1740 }
1741
1742 static void
1743 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
1744 {
1745 foreach_list_typed(nir_cf_node, node, node, list) {
1746 switch (node->type) {
1747 case nir_cf_node_block:
1748 emit_block(ctx, nir_cf_node_as_block(node));
1749 break;
1750
1751 case nir_cf_node_if:
1752 emit_if(ctx, nir_cf_node_as_if(node));
1753 break;
1754
1755 case nir_cf_node_loop:
1756 emit_loop(ctx, nir_cf_node_as_loop(node));
1757 break;
1758
1759 case nir_cf_node_function:
1760 unreachable("nir_cf_node_function not supported");
1761 break;
1762 }
1763 }
1764 }
1765
1766 struct spirv_shader *
1767 nir_to_spirv(struct nir_shader *s)
1768 {
1769 struct spirv_shader *ret = NULL;
1770
1771 struct ntv_context ctx = {};
1772
1773 switch (s->info.stage) {
1774 case MESA_SHADER_VERTEX:
1775 case MESA_SHADER_FRAGMENT:
1776 case MESA_SHADER_COMPUTE:
1777 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
1778 break;
1779
1780 case MESA_SHADER_TESS_CTRL:
1781 case MESA_SHADER_TESS_EVAL:
1782 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
1783 break;
1784
1785 case MESA_SHADER_GEOMETRY:
1786 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
1787 break;
1788
1789 default:
1790 unreachable("invalid stage");
1791 }
1792
1793 // TODO: only enable when needed
1794 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1795 spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
1796 spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
1797 }
1798
1799 ctx.stage = s->info.stage;
1800 ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
1801 spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
1802
1803 spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
1804 SpvMemoryModelGLSL450);
1805
1806 SpvExecutionModel exec_model;
1807 switch (s->info.stage) {
1808 case MESA_SHADER_VERTEX:
1809 exec_model = SpvExecutionModelVertex;
1810 break;
1811 case MESA_SHADER_TESS_CTRL:
1812 exec_model = SpvExecutionModelTessellationControl;
1813 break;
1814 case MESA_SHADER_TESS_EVAL:
1815 exec_model = SpvExecutionModelTessellationEvaluation;
1816 break;
1817 case MESA_SHADER_GEOMETRY:
1818 exec_model = SpvExecutionModelGeometry;
1819 break;
1820 case MESA_SHADER_FRAGMENT:
1821 exec_model = SpvExecutionModelFragment;
1822 break;
1823 case MESA_SHADER_COMPUTE:
1824 exec_model = SpvExecutionModelGLCompute;
1825 break;
1826 default:
1827 unreachable("invalid stage");
1828 }
1829
1830 SpvId type_void = spirv_builder_type_void(&ctx.builder);
1831 SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
1832 NULL, 0);
1833 SpvId entry_point = spirv_builder_new_id(&ctx.builder);
1834 spirv_builder_emit_name(&ctx.builder, entry_point, "main");
1835
1836 ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
1837 _mesa_key_pointer_equal);
1838
1839 nir_foreach_variable(var, &s->inputs)
1840 emit_input(&ctx, var);
1841
1842 nir_foreach_variable(var, &s->outputs)
1843 emit_output(&ctx, var);
1844
1845 nir_foreach_variable(var, &s->uniforms)
1846 emit_uniform(&ctx, var);
1847
1848 if (s->info.stage == MESA_SHADER_FRAGMENT) {
1849 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1850 SpvExecutionModeOriginUpperLeft);
1851 if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
1852 spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
1853 SpvExecutionModeDepthReplacing);
1854 }
1855
1856
1857 spirv_builder_function(&ctx.builder, entry_point, type_void,
1858 SpvFunctionControlMaskNone,
1859 type_main);
1860
1861 nir_function_impl *entry = nir_shader_get_entrypoint(s);
1862 nir_metadata_require(entry, nir_metadata_block_index);
1863
1864 ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
1865 if (!ctx.defs)
1866 goto fail;
1867 ctx.num_defs = entry->ssa_alloc;
1868
1869 nir_index_local_regs(entry);
1870 ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
1871 if (!ctx.regs)
1872 goto fail;
1873 ctx.num_regs = entry->reg_alloc;
1874
1875 SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
1876 if (!block_ids)
1877 goto fail;
1878
1879 for (int i = 0; i < entry->num_blocks; ++i)
1880 block_ids[i] = spirv_builder_new_id(&ctx.builder);
1881
1882 ctx.block_ids = block_ids;
1883 ctx.num_blocks = entry->num_blocks;
1884
1885 /* emit a block only for the variable declarations */
1886 start_block(&ctx, spirv_builder_new_id(&ctx.builder));
1887 foreach_list_typed(nir_register, reg, node, &entry->registers) {
1888 SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
1889 SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
1890 SpvStorageClassFunction,
1891 type);
1892 SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
1893 SpvStorageClassFunction);
1894
1895 ctx.regs[reg->index] = var;
1896 }
1897
1898 emit_cf_list(&ctx, &entry->body);
1899
1900 free(ctx.defs);
1901
1902 spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
1903 spirv_builder_function_end(&ctx.builder);
1904
1905 spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
1906 "main", ctx.entry_ifaces,
1907 ctx.num_entry_ifaces);
1908
1909 size_t num_words = spirv_builder_get_num_words(&ctx.builder);
1910
1911 ret = CALLOC_STRUCT(spirv_shader);
1912 if (!ret)
1913 goto fail;
1914
1915 ret->words = MALLOC(sizeof(uint32_t) * num_words);
1916 if (!ret->words)
1917 goto fail;
1918
1919 ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
1920 assert(ret->num_words == num_words);
1921
1922 return ret;
1923
1924 fail:
1925
1926 if (ret)
1927 spirv_shader_delete(ret);
1928
1929 if (ctx.vars)
1930 _mesa_hash_table_destroy(ctx.vars, NULL);
1931
1932 return NULL;
1933 }
1934
1935 void
1936 spirv_shader_delete(struct spirv_shader *s)
1937 {
1938 FREE(s->words);
1939 FREE(s);
1940 }