nir/vtn: fix emitting code after loops
[mesa.git] / src / glsl / nir / spirv_to_nir.c
1 /*
2 * Copyright © 2015 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include "spirv_to_nir_private.h"
29 #include "nir_vla.h"
30
31 static struct vtn_ssa_value *
32 vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
33 const struct glsl_type *type)
34 {
35 struct hash_entry *entry = _mesa_hash_table_search(b->const_table, constant);
36
37 if (entry)
38 return entry->data;
39
40 struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
41 val->type = type;
42
43 switch (glsl_get_base_type(type)) {
44 case GLSL_TYPE_INT:
45 case GLSL_TYPE_UINT:
46 case GLSL_TYPE_BOOL:
47 case GLSL_TYPE_FLOAT:
48 case GLSL_TYPE_DOUBLE:
49 if (glsl_type_is_vector_or_scalar(type)) {
50 unsigned num_components = glsl_get_vector_elements(val->type);
51 nir_load_const_instr *load =
52 nir_load_const_instr_create(b->shader, num_components);
53
54 for (unsigned i = 0; i < num_components; i++)
55 load->value.u[i] = constant->value.u[i];
56
57 nir_instr_insert_before_cf_list(&b->impl->body, &load->instr);
58 val->def = &load->def;
59 } else {
60 assert(glsl_type_is_matrix(type));
61 unsigned rows = glsl_get_vector_elements(val->type);
62 unsigned columns = glsl_get_matrix_columns(val->type);
63 val->elems = ralloc_array(b, struct vtn_ssa_value *, columns);
64
65 for (unsigned i = 0; i < columns; i++) {
66 struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value);
67 col_val->type = glsl_get_column_type(val->type);
68 nir_load_const_instr *load =
69 nir_load_const_instr_create(b->shader, rows);
70
71 for (unsigned j = 0; j < rows; j++)
72 load->value.u[j] = constant->value.u[rows * i + j];
73
74 nir_instr_insert_before_cf_list(&b->impl->body, &load->instr);
75 col_val->def = &load->def;
76
77 val->elems[i] = col_val;
78 }
79 }
80 break;
81
82 case GLSL_TYPE_ARRAY: {
83 unsigned elems = glsl_get_length(val->type);
84 val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
85 const struct glsl_type *elem_type = glsl_get_array_element(val->type);
86 for (unsigned i = 0; i < elems; i++)
87 val->elems[i] = vtn_const_ssa_value(b, constant->elements[i],
88 elem_type);
89 break;
90 }
91
92 case GLSL_TYPE_STRUCT: {
93 unsigned elems = glsl_get_length(val->type);
94 val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
95 for (unsigned i = 0; i < elems; i++) {
96 const struct glsl_type *elem_type =
97 glsl_get_struct_field(val->type, i);
98 val->elems[i] = vtn_const_ssa_value(b, constant->elements[i],
99 elem_type);
100 }
101 break;
102 }
103
104 default:
105 unreachable("bad constant type");
106 }
107
108 return val;
109 }
110
111 struct vtn_ssa_value *
112 vtn_ssa_value(struct vtn_builder *b, uint32_t value_id)
113 {
114 struct vtn_value *val = vtn_untyped_value(b, value_id);
115 switch (val->value_type) {
116 case vtn_value_type_constant:
117 return vtn_const_ssa_value(b, val->constant, val->type);
118
119 case vtn_value_type_ssa:
120 return val->ssa;
121 default:
122 unreachable("Invalid type for an SSA value");
123 }
124 }
125
126 static char *
127 vtn_string_literal(struct vtn_builder *b, const uint32_t *words,
128 unsigned word_count)
129 {
130 return ralloc_strndup(b, (char *)words, word_count * sizeof(*words));
131 }
132
133 static const uint32_t *
134 vtn_foreach_instruction(struct vtn_builder *b, const uint32_t *start,
135 const uint32_t *end, vtn_instruction_handler handler)
136 {
137 const uint32_t *w = start;
138 while (w < end) {
139 SpvOp opcode = w[0] & SpvOpCodeMask;
140 unsigned count = w[0] >> SpvWordCountShift;
141 assert(count >= 1 && w + count <= end);
142
143 if (!handler(b, opcode, w, count))
144 return w;
145
146 w += count;
147 }
148 assert(w == end);
149 return w;
150 }
151
152 static void
153 vtn_handle_extension(struct vtn_builder *b, SpvOp opcode,
154 const uint32_t *w, unsigned count)
155 {
156 switch (opcode) {
157 case SpvOpExtInstImport: {
158 struct vtn_value *val = vtn_push_value(b, w[1], vtn_value_type_extension);
159 if (strcmp((const char *)&w[2], "GLSL.std.450") == 0) {
160 val->ext_handler = vtn_handle_glsl450_instruction;
161 } else {
162 assert(!"Unsupported extension");
163 }
164 break;
165 }
166
167 case SpvOpExtInst: {
168 struct vtn_value *val = vtn_value(b, w[3], vtn_value_type_extension);
169 bool handled = val->ext_handler(b, w[4], w, count);
170 (void)handled;
171 assert(handled);
172 break;
173 }
174
175 default:
176 unreachable("Unhandled opcode");
177 }
178 }
179
180 static void
181 _foreach_decoration_helper(struct vtn_builder *b,
182 struct vtn_value *base_value,
183 struct vtn_value *value,
184 vtn_decoration_foreach_cb cb, void *data)
185 {
186 for (struct vtn_decoration *dec = value->decoration; dec; dec = dec->next) {
187 if (dec->group) {
188 assert(dec->group->value_type == vtn_value_type_decoration_group);
189 _foreach_decoration_helper(b, base_value, dec->group, cb, data);
190 } else {
191 cb(b, base_value, dec, data);
192 }
193 }
194 }
195
196 /** Iterates (recursively if needed) over all of the decorations on a value
197 *
198 * This function iterates over all of the decorations applied to a given
199 * value. If it encounters a decoration group, it recurses into the group
200 * and iterates over all of those decorations as well.
201 */
202 void
203 vtn_foreach_decoration(struct vtn_builder *b, struct vtn_value *value,
204 vtn_decoration_foreach_cb cb, void *data)
205 {
206 _foreach_decoration_helper(b, value, value, cb, data);
207 }
208
209 static void
210 vtn_handle_decoration(struct vtn_builder *b, SpvOp opcode,
211 const uint32_t *w, unsigned count)
212 {
213 switch (opcode) {
214 case SpvOpDecorationGroup:
215 vtn_push_value(b, w[1], vtn_value_type_undef);
216 break;
217
218 case SpvOpDecorate: {
219 struct vtn_value *val = &b->values[w[1]];
220
221 struct vtn_decoration *dec = rzalloc(b, struct vtn_decoration);
222 dec->decoration = w[2];
223 dec->literals = &w[3];
224
225 /* Link into the list */
226 dec->next = val->decoration;
227 val->decoration = dec;
228 break;
229 }
230
231 case SpvOpGroupDecorate: {
232 struct vtn_value *group = &b->values[w[1]];
233 assert(group->value_type == vtn_value_type_decoration_group);
234
235 for (unsigned i = 2; i < count; i++) {
236 struct vtn_value *val = &b->values[w[i]];
237 struct vtn_decoration *dec = rzalloc(b, struct vtn_decoration);
238 dec->group = group;
239
240 /* Link into the list */
241 dec->next = val->decoration;
242 val->decoration = dec;
243 }
244 break;
245 }
246
247 case SpvOpGroupMemberDecorate:
248 assert(!"Bad instruction. Khronos Bug #13513");
249 break;
250
251 default:
252 unreachable("Unhandled opcode");
253 }
254 }
255
256 static const struct glsl_type *
257 vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
258 const uint32_t *args, unsigned count)
259 {
260 switch (opcode) {
261 case SpvOpTypeVoid:
262 return glsl_void_type();
263 case SpvOpTypeBool:
264 return glsl_bool_type();
265 case SpvOpTypeInt:
266 return glsl_int_type();
267 case SpvOpTypeFloat:
268 return glsl_float_type();
269
270 case SpvOpTypeVector: {
271 const struct glsl_type *base =
272 vtn_value(b, args[0], vtn_value_type_type)->type;
273 unsigned elems = args[1];
274
275 assert(glsl_type_is_scalar(base));
276 return glsl_vector_type(glsl_get_base_type(base), elems);
277 }
278
279 case SpvOpTypeMatrix: {
280 const struct glsl_type *base =
281 vtn_value(b, args[0], vtn_value_type_type)->type;
282 unsigned columns = args[1];
283
284 assert(glsl_type_is_vector(base));
285 return glsl_matrix_type(glsl_get_base_type(base),
286 glsl_get_vector_elements(base),
287 columns);
288 }
289
290 case SpvOpTypeArray:
291 return glsl_array_type(b->values[args[0]].type, args[1]);
292
293 case SpvOpTypeStruct: {
294 NIR_VLA(struct glsl_struct_field, fields, count);
295 for (unsigned i = 0; i < count; i++) {
296 /* TODO: Handle decorators */
297 fields[i].type = vtn_value(b, args[i], vtn_value_type_type)->type;
298 fields[i].name = ralloc_asprintf(b, "field%d", i);
299 fields[i].location = -1;
300 fields[i].interpolation = 0;
301 fields[i].centroid = 0;
302 fields[i].sample = 0;
303 fields[i].matrix_layout = 2;
304 fields[i].stream = -1;
305 }
306 return glsl_struct_type(fields, count, "struct");
307 }
308
309 case SpvOpTypeFunction: {
310 const struct glsl_type *return_type = b->values[args[0]].type;
311 NIR_VLA(struct glsl_function_param, params, count - 1);
312 for (unsigned i = 1; i < count; i++) {
313 params[i - 1].type = vtn_value(b, args[i], vtn_value_type_type)->type;
314
315 /* FIXME: */
316 params[i - 1].in = true;
317 params[i - 1].out = true;
318 }
319 return glsl_function_type(return_type, params, count - 1);
320 }
321
322 case SpvOpTypePointer:
323 /* FIXME: For now, we'll just do the really lame thing and return
324 * the same type. The validator should ensure that the proper number
325 * of dereferences happen
326 */
327 return vtn_value(b, args[1], vtn_value_type_type)->type;
328
329 case SpvOpTypeSampler: {
330 const struct glsl_type *sampled_type =
331 vtn_value(b, args[0], vtn_value_type_type)->type;
332
333 assert(glsl_type_is_vector_or_scalar(sampled_type));
334
335 enum glsl_sampler_dim dim;
336 switch ((SpvDim)args[1]) {
337 case SpvDim1D: dim = GLSL_SAMPLER_DIM_1D; break;
338 case SpvDim2D: dim = GLSL_SAMPLER_DIM_2D; break;
339 case SpvDim3D: dim = GLSL_SAMPLER_DIM_3D; break;
340 case SpvDimCube: dim = GLSL_SAMPLER_DIM_CUBE; break;
341 case SpvDimRect: dim = GLSL_SAMPLER_DIM_RECT; break;
342 case SpvDimBuffer: dim = GLSL_SAMPLER_DIM_BUF; break;
343 default:
344 unreachable("Invalid SPIR-V Sampler dimension");
345 }
346
347 /* TODO: Handle the various texture image/filter options */
348 (void)args[2];
349
350 bool is_array = args[3];
351 bool is_shadow = args[4];
352
353 assert(args[5] == 0 && "FIXME: Handl multi-sampled textures");
354
355 return glsl_sampler_type(dim, is_shadow, is_array,
356 glsl_get_base_type(sampled_type));
357 }
358
359 case SpvOpTypeRuntimeArray:
360 case SpvOpTypeOpaque:
361 case SpvOpTypeEvent:
362 case SpvOpTypeDeviceEvent:
363 case SpvOpTypeReserveId:
364 case SpvOpTypeQueue:
365 case SpvOpTypePipe:
366 default:
367 unreachable("Unhandled opcode");
368 }
369 }
370
371 static void
372 vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
373 const uint32_t *w, unsigned count)
374 {
375 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_constant);
376 val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
377 val->constant = ralloc(b, nir_constant);
378 switch (opcode) {
379 case SpvOpConstantTrue:
380 assert(val->type == glsl_bool_type());
381 val->constant->value.u[0] = NIR_TRUE;
382 break;
383 case SpvOpConstantFalse:
384 assert(val->type == glsl_bool_type());
385 val->constant->value.u[0] = NIR_FALSE;
386 break;
387 case SpvOpConstant:
388 assert(glsl_type_is_scalar(val->type));
389 val->constant->value.u[0] = w[3];
390 break;
391 case SpvOpConstantComposite: {
392 unsigned elem_count = count - 3;
393 nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
394 for (unsigned i = 0; i < elem_count; i++)
395 elems[i] = vtn_value(b, w[i + 3], vtn_value_type_constant)->constant;
396
397 switch (glsl_get_base_type(val->type)) {
398 case GLSL_TYPE_UINT:
399 case GLSL_TYPE_INT:
400 case GLSL_TYPE_FLOAT:
401 case GLSL_TYPE_BOOL:
402 if (glsl_type_is_matrix(val->type)) {
403 unsigned rows = glsl_get_vector_elements(val->type);
404 assert(glsl_get_matrix_columns(val->type) == elem_count);
405 for (unsigned i = 0; i < elem_count; i++)
406 for (unsigned j = 0; j < rows; j++)
407 val->constant->value.u[rows * i + j] = elems[i]->value.u[j];
408 } else {
409 assert(glsl_type_is_vector(val->type));
410 assert(glsl_get_vector_elements(val->type) == elem_count);
411 for (unsigned i = 0; i < elem_count; i++)
412 val->constant->value.u[i] = elems[i]->value.u[0];
413 }
414 ralloc_free(elems);
415 break;
416
417 case GLSL_TYPE_STRUCT:
418 case GLSL_TYPE_ARRAY:
419 ralloc_steal(val->constant, elems);
420 val->constant->elements = elems;
421 break;
422
423 default:
424 unreachable("Unsupported type for constants");
425 }
426 break;
427 }
428
429 default:
430 unreachable("Unhandled opcode");
431 }
432 }
433
434 static void
435 var_decoration_cb(struct vtn_builder *b, struct vtn_value *val,
436 const struct vtn_decoration *dec, void *void_var)
437 {
438 assert(val->value_type == vtn_value_type_deref);
439 assert(val->deref->deref.child == NULL);
440 assert(val->deref->var == void_var);
441
442 nir_variable *var = void_var;
443 switch (dec->decoration) {
444 case SpvDecorationPrecisionLow:
445 case SpvDecorationPrecisionMedium:
446 case SpvDecorationPrecisionHigh:
447 break; /* FIXME: Do nothing with these for now. */
448 case SpvDecorationSmooth:
449 var->data.interpolation = INTERP_QUALIFIER_SMOOTH;
450 break;
451 case SpvDecorationNoperspective:
452 var->data.interpolation = INTERP_QUALIFIER_NOPERSPECTIVE;
453 break;
454 case SpvDecorationFlat:
455 var->data.interpolation = INTERP_QUALIFIER_FLAT;
456 break;
457 case SpvDecorationCentroid:
458 var->data.centroid = true;
459 break;
460 case SpvDecorationSample:
461 var->data.sample = true;
462 break;
463 case SpvDecorationInvariant:
464 var->data.invariant = true;
465 break;
466 case SpvDecorationConstant:
467 assert(var->constant_initializer != NULL);
468 var->data.read_only = true;
469 break;
470 case SpvDecorationNonwritable:
471 var->data.read_only = true;
472 break;
473 case SpvDecorationLocation:
474 var->data.explicit_location = true;
475 var->data.location = dec->literals[0];
476 break;
477 case SpvDecorationComponent:
478 var->data.location_frac = dec->literals[0];
479 break;
480 case SpvDecorationIndex:
481 var->data.explicit_index = true;
482 var->data.index = dec->literals[0];
483 break;
484 case SpvDecorationBinding:
485 var->data.explicit_binding = true;
486 var->data.binding = dec->literals[0];
487 break;
488 case SpvDecorationBlock:
489 case SpvDecorationBufferBlock:
490 case SpvDecorationRowMajor:
491 case SpvDecorationColMajor:
492 case SpvDecorationGLSLShared:
493 case SpvDecorationGLSLStd140:
494 case SpvDecorationGLSLStd430:
495 case SpvDecorationGLSLPacked:
496 case SpvDecorationPatch:
497 case SpvDecorationRestrict:
498 case SpvDecorationAliased:
499 case SpvDecorationVolatile:
500 case SpvDecorationCoherent:
501 case SpvDecorationNonreadable:
502 case SpvDecorationUniform:
503 /* This is really nice but we have no use for it right now. */
504 case SpvDecorationNoStaticUse:
505 case SpvDecorationCPacked:
506 case SpvDecorationSaturatedConversion:
507 case SpvDecorationStream:
508 case SpvDecorationDescriptorSet:
509 case SpvDecorationOffset:
510 case SpvDecorationAlignment:
511 case SpvDecorationXfbBuffer:
512 case SpvDecorationStride:
513 case SpvDecorationBuiltIn:
514 case SpvDecorationFuncParamAttr:
515 case SpvDecorationFPRoundingMode:
516 case SpvDecorationFPFastMathMode:
517 case SpvDecorationLinkageAttributes:
518 case SpvDecorationSpecId:
519 break;
520 default:
521 unreachable("Unhandled variable decoration");
522 }
523 }
524
525 static struct vtn_ssa_value *
526 _vtn_variable_load(struct vtn_builder *b,
527 nir_deref_var *src_deref, nir_deref *src_deref_tail)
528 {
529 struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
530 val->type = src_deref_tail->type;
531
532 /* The deref tail may contain a deref to select a component of a vector (in
533 * other words, it might not be an actual tail) so we have to save it away
534 * here since we overwrite it later.
535 */
536 nir_deref *old_child = src_deref_tail->child;
537
538 if (glsl_type_is_vector_or_scalar(val->type)) {
539 nir_intrinsic_instr *load =
540 nir_intrinsic_instr_create(b->shader, nir_intrinsic_load_var);
541 load->variables[0] =
542 nir_deref_as_var(nir_copy_deref(load, &src_deref->deref));
543 load->num_components = glsl_get_vector_elements(val->type);
544 nir_ssa_dest_init(&load->instr, &load->dest, load->num_components, NULL);
545
546 nir_builder_instr_insert(&b->nb, &load->instr);
547
548 if (src_deref->var->data.mode == nir_var_uniform &&
549 glsl_get_base_type(val->type) == GLSL_TYPE_BOOL) {
550 /* Uniform boolean loads need to be fixed up since they're defined
551 * to be zero/nonzero rather than NIR_FALSE/NIR_TRUE.
552 */
553 val->def = nir_ine(&b->nb, &load->dest.ssa, nir_imm_int(&b->nb, 0));
554 } else {
555 val->def = &load->dest.ssa;
556 }
557 } else if (glsl_get_base_type(val->type) == GLSL_TYPE_ARRAY ||
558 glsl_type_is_matrix(val->type)) {
559 unsigned elems = glsl_get_length(val->type);
560 val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
561
562 nir_deref_array *deref = nir_deref_array_create(b);
563 deref->deref_array_type = nir_deref_array_type_direct;
564 deref->deref.type = glsl_get_array_element(val->type);
565 src_deref_tail->child = &deref->deref;
566 for (unsigned i = 0; i < elems; i++) {
567 deref->base_offset = i;
568 val->elems[i] = _vtn_variable_load(b, src_deref, &deref->deref);
569 }
570 } else {
571 assert(glsl_get_base_type(val->type) == GLSL_TYPE_STRUCT);
572 unsigned elems = glsl_get_length(val->type);
573 val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
574
575 nir_deref_struct *deref = nir_deref_struct_create(b, 0);
576 src_deref_tail->child = &deref->deref;
577 for (unsigned i = 0; i < elems; i++) {
578 deref->index = i;
579 deref->deref.type = glsl_get_struct_field(val->type, i);
580 val->elems[i] = _vtn_variable_load(b, src_deref, &deref->deref);
581 }
582 }
583
584 src_deref_tail->child = old_child;
585
586 return val;
587 }
588
589 static void
590 _vtn_variable_store(struct vtn_builder *b, nir_deref_var *dest_deref,
591 nir_deref *dest_deref_tail, struct vtn_ssa_value *src)
592 {
593 nir_deref *old_child = dest_deref_tail->child;
594
595 if (glsl_type_is_vector_or_scalar(src->type)) {
596 nir_intrinsic_instr *store =
597 nir_intrinsic_instr_create(b->shader, nir_intrinsic_store_var);
598 store->variables[0] =
599 nir_deref_as_var(nir_copy_deref(store, &dest_deref->deref));
600 store->src[0] = nir_src_for_ssa(src->def);
601
602 nir_builder_instr_insert(&b->nb, &store->instr);
603 } else if (glsl_get_base_type(src->type) == GLSL_TYPE_ARRAY ||
604 glsl_type_is_matrix(src->type)) {
605 unsigned elems = glsl_get_length(src->type);
606
607 nir_deref_array *deref = nir_deref_array_create(b);
608 deref->deref_array_type = nir_deref_array_type_direct;
609 deref->deref.type = glsl_get_array_element(src->type);
610 dest_deref_tail->child = &deref->deref;
611 for (unsigned i = 0; i < elems; i++) {
612 deref->base_offset = i;
613 _vtn_variable_store(b, dest_deref, &deref->deref, src->elems[i]);
614 }
615 } else {
616 assert(glsl_get_base_type(src->type) == GLSL_TYPE_STRUCT);
617 unsigned elems = glsl_get_length(src->type);
618
619 nir_deref_struct *deref = nir_deref_struct_create(b, 0);
620 dest_deref_tail->child = &deref->deref;
621 for (unsigned i = 0; i < elems; i++) {
622 deref->index = i;
623 deref->deref.type = glsl_get_struct_field(src->type, i);
624 _vtn_variable_store(b, dest_deref, &deref->deref, src->elems[i]);
625 }
626 }
627
628 dest_deref_tail->child = old_child;
629 }
630
631 /*
632 * Gets the NIR-level deref tail, which may have as a child an array deref
633 * selecting which component due to OpAccessChain supporting per-component
634 * indexing in SPIR-V.
635 */
636
637 static nir_deref *
638 get_deref_tail(nir_deref_var *deref)
639 {
640 nir_deref *cur = &deref->deref;
641 while (!glsl_type_is_vector_or_scalar(cur->type) && cur->child)
642 cur = cur->child;
643
644 return cur;
645 }
646
647 static nir_ssa_def *vtn_vector_extract(struct vtn_builder *b,
648 nir_ssa_def *src, unsigned index);
649
650 static nir_ssa_def *vtn_vector_extract_dynamic(struct vtn_builder *b,
651 nir_ssa_def *src,
652 nir_ssa_def *index);
653
654 static struct vtn_ssa_value *
655 vtn_variable_load(struct vtn_builder *b, nir_deref_var *src)
656 {
657 nir_deref *src_tail = get_deref_tail(src);
658 struct vtn_ssa_value *val = _vtn_variable_load(b, src, src_tail);
659
660 if (src_tail->child) {
661 nir_deref_array *vec_deref = nir_deref_as_array(src_tail->child);
662 assert(vec_deref->deref.child == NULL);
663 val->type = vec_deref->deref.type;
664 if (vec_deref->deref_array_type == nir_deref_array_type_direct)
665 val->def = vtn_vector_extract(b, val->def, vec_deref->base_offset);
666 else
667 val->def = vtn_vector_extract_dynamic(b, val->def,
668 vec_deref->indirect.ssa);
669 }
670
671 return val;
672 }
673
674 static nir_ssa_def * vtn_vector_insert(struct vtn_builder *b,
675 nir_ssa_def *src, nir_ssa_def *insert,
676 unsigned index);
677
678 static nir_ssa_def * vtn_vector_insert_dynamic(struct vtn_builder *b,
679 nir_ssa_def *src,
680 nir_ssa_def *insert,
681 nir_ssa_def *index);
682 static void
683 vtn_variable_store(struct vtn_builder *b, struct vtn_ssa_value *src,
684 nir_deref_var *dest)
685 {
686 nir_deref *dest_tail = get_deref_tail(dest);
687 if (dest_tail->child) {
688 struct vtn_ssa_value *val = _vtn_variable_load(b, dest, dest_tail);
689 nir_deref_array *deref = nir_deref_as_array(dest_tail->child);
690 assert(deref->deref.child == NULL);
691 if (deref->deref_array_type == nir_deref_array_type_direct)
692 val->def = vtn_vector_insert(b, val->def, src->def,
693 deref->base_offset);
694 else
695 val->def = vtn_vector_insert_dynamic(b, val->def, src->def,
696 deref->indirect.ssa);
697 _vtn_variable_store(b, dest, dest_tail, val);
698 } else {
699 _vtn_variable_store(b, dest, dest_tail, src);
700 }
701 }
702
703 static void
704 vtn_variable_copy(struct vtn_builder *b, nir_deref_var *src,
705 nir_deref_var *dest)
706 {
707 nir_deref *src_tail = get_deref_tail(src);
708
709 if (src_tail->child) {
710 assert(get_deref_tail(dest)->child);
711 struct vtn_ssa_value *val = vtn_variable_load(b, src);
712 vtn_variable_store(b, val, dest);
713 } else {
714 nir_intrinsic_instr *copy =
715 nir_intrinsic_instr_create(b->shader, nir_intrinsic_copy_var);
716 copy->variables[0] = nir_deref_as_var(nir_copy_deref(copy, &dest->deref));
717 copy->variables[1] = nir_deref_as_var(nir_copy_deref(copy, &src->deref));
718
719 nir_builder_instr_insert(&b->nb, &copy->instr);
720 }
721 }
722
723 static void
724 vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
725 const uint32_t *w, unsigned count)
726 {
727 switch (opcode) {
728 case SpvOpVariable: {
729 const struct glsl_type *type =
730 vtn_value(b, w[1], vtn_value_type_type)->type;
731 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_deref);
732
733 nir_variable *var = ralloc(b->shader, nir_variable);
734
735 var->type = type;
736 var->name = ralloc_strdup(var, val->name);
737
738 switch ((SpvStorageClass)w[3]) {
739 case SpvStorageClassUniformConstant:
740 var->data.mode = nir_var_uniform;
741 var->data.read_only = true;
742 break;
743 case SpvStorageClassInput:
744 var->data.mode = nir_var_shader_in;
745 var->data.read_only = true;
746 break;
747 case SpvStorageClassOutput:
748 var->data.mode = nir_var_shader_out;
749 break;
750 case SpvStorageClassPrivateGlobal:
751 var->data.mode = nir_var_global;
752 break;
753 case SpvStorageClassFunction:
754 var->data.mode = nir_var_local;
755 break;
756 case SpvStorageClassUniform:
757 case SpvStorageClassWorkgroupLocal:
758 case SpvStorageClassWorkgroupGlobal:
759 case SpvStorageClassGeneric:
760 case SpvStorageClassPrivate:
761 case SpvStorageClassAtomicCounter:
762 default:
763 unreachable("Unhandled variable storage class");
764 }
765
766 if (count > 4) {
767 assert(count == 5);
768 var->constant_initializer =
769 vtn_value(b, w[4], vtn_value_type_constant)->constant;
770 }
771
772 if (var->data.mode == nir_var_local) {
773 exec_list_push_tail(&b->impl->locals, &var->node);
774 } else {
775 exec_list_push_tail(&b->shader->globals, &var->node);
776 }
777
778 val->deref = nir_deref_var_create(b->shader, var);
779
780 vtn_foreach_decoration(b, val, var_decoration_cb, var);
781 break;
782 }
783
784 case SpvOpAccessChain:
785 case SpvOpInBoundsAccessChain: {
786 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_deref);
787 nir_deref_var *base = vtn_value(b, w[3], vtn_value_type_deref)->deref;
788 val->deref = nir_deref_as_var(nir_copy_deref(b, &base->deref));
789
790 nir_deref *tail = &val->deref->deref;
791 while (tail->child)
792 tail = tail->child;
793
794 for (unsigned i = 0; i < count - 4; i++) {
795 assert(w[i + 4] < b->value_id_bound);
796 struct vtn_value *idx_val = &b->values[w[i + 4]];
797
798 enum glsl_base_type base_type = glsl_get_base_type(tail->type);
799 switch (base_type) {
800 case GLSL_TYPE_UINT:
801 case GLSL_TYPE_INT:
802 case GLSL_TYPE_FLOAT:
803 case GLSL_TYPE_DOUBLE:
804 case GLSL_TYPE_BOOL:
805 case GLSL_TYPE_ARRAY: {
806 nir_deref_array *deref_arr = nir_deref_array_create(b);
807 if (base_type == GLSL_TYPE_ARRAY) {
808 deref_arr->deref.type = glsl_get_array_element(tail->type);
809 } else if (glsl_type_is_matrix(tail->type)) {
810 deref_arr->deref.type = glsl_get_column_type(tail->type);
811 } else {
812 assert(glsl_type_is_vector(tail->type));
813 deref_arr->deref.type = glsl_scalar_type(base_type);
814 }
815
816 if (idx_val->value_type == vtn_value_type_constant) {
817 unsigned idx = idx_val->constant->value.u[0];
818 deref_arr->deref_array_type = nir_deref_array_type_direct;
819 deref_arr->base_offset = idx;
820 } else {
821 assert(idx_val->value_type == vtn_value_type_ssa);
822 deref_arr->deref_array_type = nir_deref_array_type_indirect;
823 deref_arr->base_offset = 0;
824 deref_arr->indirect =
825 nir_src_for_ssa(vtn_ssa_value(b, w[1])->def);
826 }
827 tail->child = &deref_arr->deref;
828 break;
829 }
830
831 case GLSL_TYPE_STRUCT: {
832 assert(idx_val->value_type == vtn_value_type_constant);
833 unsigned idx = idx_val->constant->value.u[0];
834 nir_deref_struct *deref_struct = nir_deref_struct_create(b, idx);
835 deref_struct->deref.type = glsl_get_struct_field(tail->type, idx);
836 tail->child = &deref_struct->deref;
837 break;
838 }
839 default:
840 unreachable("Invalid type for deref");
841 }
842 tail = tail->child;
843 }
844 break;
845 }
846
847 case SpvOpCopyMemory: {
848 nir_deref_var *dest = vtn_value(b, w[1], vtn_value_type_deref)->deref;
849 nir_deref_var *src = vtn_value(b, w[2], vtn_value_type_deref)->deref;
850
851 vtn_variable_copy(b, src, dest);
852 break;
853 }
854
855 case SpvOpLoad: {
856 nir_deref_var *src = vtn_value(b, w[3], vtn_value_type_deref)->deref;
857 const struct glsl_type *src_type = nir_deref_tail(&src->deref)->type;
858
859 if (glsl_get_base_type(src_type) == GLSL_TYPE_SAMPLER) {
860 vtn_push_value(b, w[2], vtn_value_type_deref)->deref = src;
861 return;
862 }
863
864 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
865 val->ssa = vtn_variable_load(b, src);
866 break;
867 }
868
869 case SpvOpStore: {
870 nir_deref_var *dest = vtn_value(b, w[1], vtn_value_type_deref)->deref;
871 struct vtn_ssa_value *src = vtn_ssa_value(b, w[2]);
872 vtn_variable_store(b, src, dest);
873 break;
874 }
875
876 case SpvOpVariableArray:
877 case SpvOpCopyMemorySized:
878 case SpvOpArrayLength:
879 case SpvOpImagePointer:
880 default:
881 unreachable("Unhandled opcode");
882 }
883 }
884
885 static void
886 vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
887 const uint32_t *w, unsigned count)
888 {
889 unreachable("Unhandled opcode");
890 }
891
892 static nir_tex_src
893 vtn_tex_src(struct vtn_builder *b, unsigned index, nir_tex_src_type type)
894 {
895 nir_tex_src src;
896 src.src = nir_src_for_ssa(vtn_value(b, index, vtn_value_type_ssa)->ssa->def);
897 src.src_type = type;
898 return src;
899 }
900
901 static void
902 vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
903 const uint32_t *w, unsigned count)
904 {
905 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
906 nir_deref_var *sampler = vtn_value(b, w[3], vtn_value_type_deref)->deref;
907
908 nir_tex_src srcs[8]; /* 8 should be enough */
909 nir_tex_src *p = srcs;
910
911 unsigned coord_components = 0;
912 switch (opcode) {
913 case SpvOpTextureSample:
914 case SpvOpTextureSampleDref:
915 case SpvOpTextureSampleLod:
916 case SpvOpTextureSampleProj:
917 case SpvOpTextureSampleGrad:
918 case SpvOpTextureSampleOffset:
919 case SpvOpTextureSampleProjLod:
920 case SpvOpTextureSampleProjGrad:
921 case SpvOpTextureSampleLodOffset:
922 case SpvOpTextureSampleProjOffset:
923 case SpvOpTextureSampleGradOffset:
924 case SpvOpTextureSampleProjLodOffset:
925 case SpvOpTextureSampleProjGradOffset:
926 case SpvOpTextureFetchTexelLod:
927 case SpvOpTextureFetchTexelOffset:
928 case SpvOpTextureFetchSample:
929 case SpvOpTextureFetchTexel:
930 case SpvOpTextureGather:
931 case SpvOpTextureGatherOffset:
932 case SpvOpTextureGatherOffsets:
933 case SpvOpTextureQueryLod: {
934 /* All these types have the coordinate as their first real argument */
935 struct vtn_value *coord = vtn_value(b, w[4], vtn_value_type_ssa);
936 coord_components = glsl_get_vector_elements(coord->type);
937 p->src = nir_src_for_ssa(coord->ssa->def);
938 p->src_type = nir_tex_src_coord;
939 p++;
940 break;
941 }
942 default:
943 break;
944 }
945
946 nir_texop texop;
947 switch (opcode) {
948 case SpvOpTextureSample:
949 texop = nir_texop_tex;
950
951 if (count == 6) {
952 texop = nir_texop_txb;
953 *p++ = vtn_tex_src(b, w[5], nir_tex_src_bias);
954 }
955 break;
956
957 case SpvOpTextureSampleDref:
958 case SpvOpTextureSampleLod:
959 case SpvOpTextureSampleProj:
960 case SpvOpTextureSampleGrad:
961 case SpvOpTextureSampleOffset:
962 case SpvOpTextureSampleProjLod:
963 case SpvOpTextureSampleProjGrad:
964 case SpvOpTextureSampleLodOffset:
965 case SpvOpTextureSampleProjOffset:
966 case SpvOpTextureSampleGradOffset:
967 case SpvOpTextureSampleProjLodOffset:
968 case SpvOpTextureSampleProjGradOffset:
969 case SpvOpTextureFetchTexelLod:
970 case SpvOpTextureFetchTexelOffset:
971 case SpvOpTextureFetchSample:
972 case SpvOpTextureFetchTexel:
973 case SpvOpTextureGather:
974 case SpvOpTextureGatherOffset:
975 case SpvOpTextureGatherOffsets:
976 case SpvOpTextureQuerySizeLod:
977 case SpvOpTextureQuerySize:
978 case SpvOpTextureQueryLod:
979 case SpvOpTextureQueryLevels:
980 case SpvOpTextureQuerySamples:
981 default:
982 unreachable("Unhandled opcode");
983 }
984
985 nir_tex_instr *instr = nir_tex_instr_create(b->shader, p - srcs);
986
987 const struct glsl_type *sampler_type = nir_deref_tail(&sampler->deref)->type;
988 instr->sampler_dim = glsl_get_sampler_dim(sampler_type);
989
990 switch (glsl_get_sampler_result_type(sampler_type)) {
991 case GLSL_TYPE_FLOAT: instr->dest_type = nir_type_float; break;
992 case GLSL_TYPE_INT: instr->dest_type = nir_type_int; break;
993 case GLSL_TYPE_UINT: instr->dest_type = nir_type_unsigned; break;
994 case GLSL_TYPE_BOOL: instr->dest_type = nir_type_bool; break;
995 default:
996 unreachable("Invalid base type for sampler result");
997 }
998
999 instr->op = texop;
1000 memcpy(instr->src, srcs, instr->num_srcs * sizeof(*instr->src));
1001 instr->coord_components = coord_components;
1002 instr->is_array = glsl_sampler_type_is_array(sampler_type);
1003 instr->is_shadow = glsl_sampler_type_is_shadow(sampler_type);
1004
1005 instr->sampler = sampler;
1006
1007 nir_ssa_dest_init(&instr->instr, &instr->dest, 4, NULL);
1008 val->ssa->def = &instr->dest.ssa;
1009 val->ssa->type = val->type;
1010
1011 nir_builder_instr_insert(&b->nb, &instr->instr);
1012 }
1013
1014 static struct vtn_ssa_value *
1015 vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
1016 {
1017 struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
1018 val->type = type;
1019
1020 if (!glsl_type_is_vector_or_scalar(type)) {
1021 unsigned elems = glsl_get_length(type);
1022 val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
1023 for (unsigned i = 0; i < elems; i++) {
1024 const struct glsl_type *child_type;
1025
1026 switch (glsl_get_base_type(type)) {
1027 case GLSL_TYPE_INT:
1028 case GLSL_TYPE_UINT:
1029 case GLSL_TYPE_BOOL:
1030 case GLSL_TYPE_FLOAT:
1031 case GLSL_TYPE_DOUBLE:
1032 child_type = glsl_get_column_type(type);
1033 break;
1034 case GLSL_TYPE_ARRAY:
1035 child_type = glsl_get_array_element(type);
1036 break;
1037 case GLSL_TYPE_STRUCT:
1038 child_type = glsl_get_struct_field(type, i);
1039 break;
1040 default:
1041 unreachable("unkown base type");
1042 }
1043
1044 val->elems[i] = vtn_create_ssa_value(b, child_type);
1045 }
1046 }
1047
1048 return val;
1049 }
1050
1051 static nir_alu_instr *
1052 create_vec(void *mem_ctx, unsigned num_components)
1053 {
1054 nir_op op;
1055 switch (num_components) {
1056 case 1: op = nir_op_fmov; break;
1057 case 2: op = nir_op_vec2; break;
1058 case 3: op = nir_op_vec3; break;
1059 case 4: op = nir_op_vec4; break;
1060 default: unreachable("bad vector size");
1061 }
1062
1063 nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
1064 nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
1065
1066 return vec;
1067 }
1068
1069 static struct vtn_ssa_value *
1070 vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
1071 {
1072 if (src->transposed)
1073 return src->transposed;
1074
1075 struct vtn_ssa_value *dest =
1076 vtn_create_ssa_value(b, glsl_transposed_type(src->type));
1077
1078 for (unsigned i = 0; i < glsl_get_matrix_columns(dest->type); i++) {
1079 nir_alu_instr *vec = create_vec(b, glsl_get_matrix_columns(src->type));
1080 if (glsl_type_is_vector_or_scalar(src->type)) {
1081 vec->src[0].src = nir_src_for_ssa(src->def);
1082 vec->src[0].swizzle[0] = i;
1083 } else {
1084 for (unsigned j = 0; j < glsl_get_matrix_columns(src->type); j++) {
1085 vec->src[j].src = nir_src_for_ssa(src->elems[j]->def);
1086 vec->src[j].swizzle[0] = i;
1087 }
1088 }
1089 nir_builder_instr_insert(&b->nb, &vec->instr);
1090 dest->elems[i]->def = &vec->dest.dest.ssa;
1091 }
1092
1093 dest->transposed = src;
1094
1095 return dest;
1096 }
1097
1098 /*
1099 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
1100 * definition. But for matrix multiplies, we want to do one routine for
1101 * multiplying a matrix by a matrix and then pretend that vectors are matrices
1102 * with one column. So we "wrap" these things, and unwrap the result before we
1103 * send it off.
1104 */
1105
1106 static struct vtn_ssa_value *
1107 vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
1108 {
1109 if (val == NULL)
1110 return NULL;
1111
1112 if (glsl_type_is_matrix(val->type))
1113 return val;
1114
1115 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
1116 dest->type = val->type;
1117 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
1118 dest->elems[0] = val;
1119
1120 return dest;
1121 }
1122
1123 static struct vtn_ssa_value *
1124 vtn_unwrap_matrix(struct vtn_ssa_value *val)
1125 {
1126 if (glsl_type_is_matrix(val->type))
1127 return val;
1128
1129 return val->elems[0];
1130 }
1131
1132 static struct vtn_ssa_value *
1133 vtn_matrix_multiply(struct vtn_builder *b,
1134 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
1135 {
1136
1137 struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0);
1138 struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1);
1139 struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed);
1140 struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed);
1141
1142 unsigned src0_rows = glsl_get_vector_elements(src0->type);
1143 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
1144 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
1145
1146 struct vtn_ssa_value *dest =
1147 vtn_create_ssa_value(b, glsl_matrix_type(glsl_get_base_type(src0->type),
1148 src0_rows, src1_columns));
1149
1150 dest = vtn_wrap_matrix(b, dest);
1151
1152 bool transpose_result = false;
1153 if (src0_transpose && src1_transpose) {
1154 /* transpose(A) * transpose(B) = transpose(B * A) */
1155 src1 = src0_transpose;
1156 src0 = src1_transpose;
1157 src0_transpose = NULL;
1158 src1_transpose = NULL;
1159 transpose_result = true;
1160 }
1161
1162 if (src0_transpose && !src1_transpose &&
1163 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
1164 /* We already have the rows of src0 and the columns of src1 available,
1165 * so we can just take the dot product of each row with each column to
1166 * get the result.
1167 */
1168
1169 for (unsigned i = 0; i < src1_columns; i++) {
1170 nir_alu_instr *vec = create_vec(b, src0_rows);
1171 for (unsigned j = 0; j < src0_rows; j++) {
1172 vec->src[j].src =
1173 nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def,
1174 src1->elems[i]->def));
1175 }
1176
1177 nir_builder_instr_insert(&b->nb, &vec->instr);
1178 dest->elems[i]->def = &vec->dest.dest.ssa;
1179 }
1180 } else {
1181 /* We don't handle the case where src1 is transposed but not src0, since
1182 * the general case only uses individual components of src1 so the
1183 * optimizer should chew through the transpose we emitted for src1.
1184 */
1185
1186 for (unsigned i = 0; i < src1_columns; i++) {
1187 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
1188 dest->elems[i]->def =
1189 nir_fmul(&b->nb, src0->elems[0]->def,
1190 vtn_vector_extract(b, src1->elems[i]->def, 0));
1191 for (unsigned j = 1; j < src0_columns; j++) {
1192 dest->elems[i]->def =
1193 nir_fadd(&b->nb, dest->elems[i]->def,
1194 nir_fmul(&b->nb, src0->elems[j]->def,
1195 vtn_vector_extract(b,
1196 src1->elems[i]->def, j)));
1197 }
1198 }
1199 }
1200
1201 dest = vtn_unwrap_matrix(dest);
1202
1203 if (transpose_result)
1204 dest = vtn_transpose(b, dest);
1205
1206 return dest;
1207 }
1208
1209 static struct vtn_ssa_value *
1210 vtn_mat_times_scalar(struct vtn_builder *b,
1211 struct vtn_ssa_value *mat,
1212 nir_ssa_def *scalar)
1213 {
1214 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
1215 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
1216 if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
1217 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
1218 else
1219 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
1220 }
1221
1222 return dest;
1223 }
1224
1225 static void
1226 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
1227 const uint32_t *w, unsigned count)
1228 {
1229 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
1230 val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
1231
1232 switch (opcode) {
1233 case SpvOpTranspose: {
1234 struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]);
1235 val->ssa = vtn_transpose(b, src);
1236 break;
1237 }
1238
1239 case SpvOpOuterProduct: {
1240 struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
1241 struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
1242
1243 val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1));
1244 break;
1245 }
1246
1247 case SpvOpMatrixTimesScalar: {
1248 struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]);
1249 struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]);
1250
1251 if (mat->transposed) {
1252 val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed,
1253 scalar->def));
1254 } else {
1255 val->ssa = vtn_mat_times_scalar(b, mat, scalar->def);
1256 }
1257 break;
1258 }
1259
1260 case SpvOpVectorTimesMatrix:
1261 case SpvOpMatrixTimesVector:
1262 case SpvOpMatrixTimesMatrix: {
1263 struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
1264 struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
1265
1266 val->ssa = vtn_matrix_multiply(b, src0, src1);
1267 break;
1268 }
1269
1270 default: unreachable("unknown matrix opcode");
1271 }
1272 }
1273
1274 static void
1275 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
1276 const uint32_t *w, unsigned count)
1277 {
1278 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
1279 val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
1280 val->ssa = vtn_create_ssa_value(b, val->type);
1281
1282 /* Collect the various SSA sources */
1283 unsigned num_inputs = count - 3;
1284 nir_ssa_def *src[4];
1285 for (unsigned i = 0; i < num_inputs; i++)
1286 src[i] = vtn_ssa_value(b, w[i + 3])->def;
1287
1288 /* Indicates that the first two arguments should be swapped. This is
1289 * used for implementing greater-than and less-than-or-equal.
1290 */
1291 bool swap = false;
1292
1293 nir_op op;
1294 switch (opcode) {
1295 /* Basic ALU operations */
1296 case SpvOpSNegate: op = nir_op_ineg; break;
1297 case SpvOpFNegate: op = nir_op_fneg; break;
1298 case SpvOpNot: op = nir_op_inot; break;
1299
1300 case SpvOpAny:
1301 switch (src[0]->num_components) {
1302 case 1: op = nir_op_imov; break;
1303 case 2: op = nir_op_bany2; break;
1304 case 3: op = nir_op_bany3; break;
1305 case 4: op = nir_op_bany4; break;
1306 }
1307 break;
1308
1309 case SpvOpAll:
1310 switch (src[0]->num_components) {
1311 case 1: op = nir_op_imov; break;
1312 case 2: op = nir_op_ball2; break;
1313 case 3: op = nir_op_ball3; break;
1314 case 4: op = nir_op_ball4; break;
1315 }
1316 break;
1317
1318 case SpvOpIAdd: op = nir_op_iadd; break;
1319 case SpvOpFAdd: op = nir_op_fadd; break;
1320 case SpvOpISub: op = nir_op_isub; break;
1321 case SpvOpFSub: op = nir_op_fsub; break;
1322 case SpvOpIMul: op = nir_op_imul; break;
1323 case SpvOpFMul: op = nir_op_fmul; break;
1324 case SpvOpUDiv: op = nir_op_udiv; break;
1325 case SpvOpSDiv: op = nir_op_idiv; break;
1326 case SpvOpFDiv: op = nir_op_fdiv; break;
1327 case SpvOpUMod: op = nir_op_umod; break;
1328 case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */
1329 case SpvOpFMod: op = nir_op_fmod; break;
1330
1331 case SpvOpDot:
1332 assert(src[0]->num_components == src[1]->num_components);
1333 switch (src[0]->num_components) {
1334 case 1: op = nir_op_fmul; break;
1335 case 2: op = nir_op_fdot2; break;
1336 case 3: op = nir_op_fdot3; break;
1337 case 4: op = nir_op_fdot4; break;
1338 }
1339 break;
1340
1341 case SpvOpShiftRightLogical: op = nir_op_ushr; break;
1342 case SpvOpShiftRightArithmetic: op = nir_op_ishr; break;
1343 case SpvOpShiftLeftLogical: op = nir_op_ishl; break;
1344 case SpvOpLogicalOr: op = nir_op_ior; break;
1345 case SpvOpLogicalXor: op = nir_op_ixor; break;
1346 case SpvOpLogicalAnd: op = nir_op_iand; break;
1347 case SpvOpBitwiseOr: op = nir_op_ior; break;
1348 case SpvOpBitwiseXor: op = nir_op_ixor; break;
1349 case SpvOpBitwiseAnd: op = nir_op_iand; break;
1350 case SpvOpSelect: op = nir_op_bcsel; break;
1351 case SpvOpIEqual: op = nir_op_ieq; break;
1352
1353 /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
1354 case SpvOpFOrdEqual: op = nir_op_feq; break;
1355 case SpvOpFUnordEqual: op = nir_op_feq; break;
1356 case SpvOpINotEqual: op = nir_op_ine; break;
1357 case SpvOpFOrdNotEqual: op = nir_op_fne; break;
1358 case SpvOpFUnordNotEqual: op = nir_op_fne; break;
1359 case SpvOpULessThan: op = nir_op_ult; break;
1360 case SpvOpSLessThan: op = nir_op_ilt; break;
1361 case SpvOpFOrdLessThan: op = nir_op_flt; break;
1362 case SpvOpFUnordLessThan: op = nir_op_flt; break;
1363 case SpvOpUGreaterThan: op = nir_op_ult; swap = true; break;
1364 case SpvOpSGreaterThan: op = nir_op_ilt; swap = true; break;
1365 case SpvOpFOrdGreaterThan: op = nir_op_flt; swap = true; break;
1366 case SpvOpFUnordGreaterThan: op = nir_op_flt; swap = true; break;
1367 case SpvOpULessThanEqual: op = nir_op_uge; swap = true; break;
1368 case SpvOpSLessThanEqual: op = nir_op_ige; swap = true; break;
1369 case SpvOpFOrdLessThanEqual: op = nir_op_fge; swap = true; break;
1370 case SpvOpFUnordLessThanEqual: op = nir_op_fge; swap = true; break;
1371 case SpvOpUGreaterThanEqual: op = nir_op_uge; break;
1372 case SpvOpSGreaterThanEqual: op = nir_op_ige; break;
1373 case SpvOpFOrdGreaterThanEqual: op = nir_op_fge; break;
1374 case SpvOpFUnordGreaterThanEqual:op = nir_op_fge; break;
1375
1376 /* Conversions: */
1377 case SpvOpConvertFToU: op = nir_op_f2u; break;
1378 case SpvOpConvertFToS: op = nir_op_f2i; break;
1379 case SpvOpConvertSToF: op = nir_op_i2f; break;
1380 case SpvOpConvertUToF: op = nir_op_u2f; break;
1381 case SpvOpBitcast: op = nir_op_imov; break;
1382 case SpvOpUConvert:
1383 case SpvOpSConvert:
1384 op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
1385 break;
1386 case SpvOpFConvert:
1387 op = nir_op_fmov;
1388 break;
1389
1390 /* Derivatives: */
1391 case SpvOpDPdx: op = nir_op_fddx; break;
1392 case SpvOpDPdy: op = nir_op_fddy; break;
1393 case SpvOpDPdxFine: op = nir_op_fddx_fine; break;
1394 case SpvOpDPdyFine: op = nir_op_fddy_fine; break;
1395 case SpvOpDPdxCoarse: op = nir_op_fddx_coarse; break;
1396 case SpvOpDPdyCoarse: op = nir_op_fddy_coarse; break;
1397 case SpvOpFwidth:
1398 val->ssa->def = nir_fadd(&b->nb,
1399 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
1400 nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
1401 return;
1402 case SpvOpFwidthFine:
1403 val->ssa->def = nir_fadd(&b->nb,
1404 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
1405 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
1406 return;
1407 case SpvOpFwidthCoarse:
1408 val->ssa->def = nir_fadd(&b->nb,
1409 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
1410 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
1411 return;
1412
1413 case SpvOpVectorTimesScalar:
1414 /* The builder will take care of splatting for us. */
1415 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
1416 return;
1417
1418 case SpvOpSRem:
1419 case SpvOpFRem:
1420 unreachable("No NIR equivalent");
1421
1422 case SpvOpIsNan:
1423 case SpvOpIsInf:
1424 case SpvOpIsFinite:
1425 case SpvOpIsNormal:
1426 case SpvOpSignBitSet:
1427 case SpvOpLessOrGreater:
1428 case SpvOpOrdered:
1429 case SpvOpUnordered:
1430 default:
1431 unreachable("Unhandled opcode");
1432 }
1433
1434 if (swap) {
1435 nir_ssa_def *tmp = src[0];
1436 src[0] = src[1];
1437 src[1] = tmp;
1438 }
1439
1440 nir_alu_instr *instr = nir_alu_instr_create(b->shader, op);
1441 nir_ssa_dest_init(&instr->instr, &instr->dest.dest,
1442 glsl_get_vector_elements(val->type), val->name);
1443 val->ssa->def = &instr->dest.dest.ssa;
1444
1445 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++)
1446 instr->src[i].src = nir_src_for_ssa(src[i]);
1447
1448 nir_builder_instr_insert(&b->nb, &instr->instr);
1449 }
1450
1451 static nir_ssa_def *
1452 vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
1453 {
1454 unsigned swiz[4] = { index };
1455 return nir_swizzle(&b->nb, src, swiz, 1, true);
1456 }
1457
1458
1459 static nir_ssa_def *
1460 vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
1461 unsigned index)
1462 {
1463 nir_alu_instr *vec = create_vec(b->shader, src->num_components);
1464
1465 for (unsigned i = 0; i < src->num_components; i++) {
1466 if (i == index) {
1467 vec->src[i].src = nir_src_for_ssa(insert);
1468 } else {
1469 vec->src[i].src = nir_src_for_ssa(src);
1470 vec->src[i].swizzle[0] = i;
1471 }
1472 }
1473
1474 nir_builder_instr_insert(&b->nb, &vec->instr);
1475
1476 return &vec->dest.dest.ssa;
1477 }
1478
1479 static nir_ssa_def *
1480 vtn_vector_extract_dynamic(struct vtn_builder *b, nir_ssa_def *src,
1481 nir_ssa_def *index)
1482 {
1483 nir_ssa_def *dest = vtn_vector_extract(b, src, 0);
1484 for (unsigned i = 1; i < src->num_components; i++)
1485 dest = nir_bcsel(&b->nb, nir_ieq(&b->nb, index, nir_imm_int(&b->nb, i)),
1486 vtn_vector_extract(b, src, i), dest);
1487
1488 return dest;
1489 }
1490
1491 static nir_ssa_def *
1492 vtn_vector_insert_dynamic(struct vtn_builder *b, nir_ssa_def *src,
1493 nir_ssa_def *insert, nir_ssa_def *index)
1494 {
1495 nir_ssa_def *dest = vtn_vector_insert(b, src, insert, 0);
1496 for (unsigned i = 1; i < src->num_components; i++)
1497 dest = nir_bcsel(&b->nb, nir_ieq(&b->nb, index, nir_imm_int(&b->nb, i)),
1498 vtn_vector_insert(b, src, insert, i), dest);
1499
1500 return dest;
1501 }
1502
1503 static nir_ssa_def *
1504 vtn_vector_shuffle(struct vtn_builder *b, unsigned num_components,
1505 nir_ssa_def *src0, nir_ssa_def *src1,
1506 const uint32_t *indices)
1507 {
1508 nir_alu_instr *vec = create_vec(b->shader, num_components);
1509
1510 nir_ssa_undef_instr *undef = nir_ssa_undef_instr_create(b->shader, 1);
1511 nir_builder_instr_insert(&b->nb, &undef->instr);
1512
1513 for (unsigned i = 0; i < num_components; i++) {
1514 uint32_t index = indices[i];
1515 if (index == 0xffffffff) {
1516 vec->src[i].src = nir_src_for_ssa(&undef->def);
1517 } else if (index < src0->num_components) {
1518 vec->src[i].src = nir_src_for_ssa(src0);
1519 vec->src[i].swizzle[0] = index;
1520 } else {
1521 vec->src[i].src = nir_src_for_ssa(src1);
1522 vec->src[i].swizzle[0] = index - src0->num_components;
1523 }
1524 }
1525
1526 nir_builder_instr_insert(&b->nb, &vec->instr);
1527
1528 return &vec->dest.dest.ssa;
1529 }
1530
1531 /*
1532 * Concatentates a number of vectors/scalars together to produce a vector
1533 */
1534 static nir_ssa_def *
1535 vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
1536 unsigned num_srcs, nir_ssa_def **srcs)
1537 {
1538 nir_alu_instr *vec = create_vec(b->shader, num_components);
1539
1540 unsigned dest_idx = 0;
1541 for (unsigned i = 0; i < num_srcs; i++) {
1542 nir_ssa_def *src = srcs[i];
1543 for (unsigned j = 0; j < src->num_components; j++) {
1544 vec->src[dest_idx].src = nir_src_for_ssa(src);
1545 vec->src[dest_idx].swizzle[0] = j;
1546 dest_idx++;
1547 }
1548 }
1549
1550 nir_builder_instr_insert(&b->nb, &vec->instr);
1551
1552 return &vec->dest.dest.ssa;
1553 }
1554
1555 static struct vtn_ssa_value *
1556 vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src)
1557 {
1558 struct vtn_ssa_value *dest = rzalloc(mem_ctx, struct vtn_ssa_value);
1559 dest->type = src->type;
1560
1561 if (glsl_type_is_vector_or_scalar(src->type)) {
1562 dest->def = src->def;
1563 } else {
1564 unsigned elems = glsl_get_length(src->type);
1565
1566 dest->elems = ralloc_array(mem_ctx, struct vtn_ssa_value *, elems);
1567 for (unsigned i = 0; i < elems; i++)
1568 dest->elems[i] = vtn_composite_copy(mem_ctx, src->elems[i]);
1569 }
1570
1571 return dest;
1572 }
1573
1574 static struct vtn_ssa_value *
1575 vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src,
1576 struct vtn_ssa_value *insert, const uint32_t *indices,
1577 unsigned num_indices)
1578 {
1579 struct vtn_ssa_value *dest = vtn_composite_copy(b, src);
1580
1581 struct vtn_ssa_value *cur = dest;
1582 unsigned i;
1583 for (i = 0; i < num_indices - 1; i++) {
1584 cur = cur->elems[indices[i]];
1585 }
1586
1587 if (glsl_type_is_vector_or_scalar(cur->type)) {
1588 /* According to the SPIR-V spec, OpCompositeInsert may work down to
1589 * the component granularity. In that case, the last index will be
1590 * the index to insert the scalar into the vector.
1591 */
1592
1593 cur->def = vtn_vector_insert(b, cur->def, insert->def, indices[i]);
1594 } else {
1595 cur->elems[indices[i]] = insert;
1596 }
1597
1598 return dest;
1599 }
1600
1601 static struct vtn_ssa_value *
1602 vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
1603 const uint32_t *indices, unsigned num_indices)
1604 {
1605 struct vtn_ssa_value *cur = src;
1606 for (unsigned i = 0; i < num_indices; i++) {
1607 if (glsl_type_is_vector_or_scalar(cur->type)) {
1608 assert(i == num_indices - 1);
1609 /* According to the SPIR-V spec, OpCompositeExtract may work down to
1610 * the component granularity. The last index will be the index of the
1611 * vector to extract.
1612 */
1613
1614 struct vtn_ssa_value *ret = rzalloc(b, struct vtn_ssa_value);
1615 ret->type = glsl_scalar_type(glsl_get_base_type(cur->type));
1616 ret->def = vtn_vector_extract(b, cur->def, indices[i]);
1617 return ret;
1618 }
1619 }
1620
1621 return cur;
1622 }
1623
1624 static void
1625 vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
1626 const uint32_t *w, unsigned count)
1627 {
1628 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
1629 val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
1630
1631 switch (opcode) {
1632 case SpvOpVectorExtractDynamic:
1633 val->ssa->def = vtn_vector_extract_dynamic(b, vtn_ssa_value(b, w[3])->def,
1634 vtn_ssa_value(b, w[4])->def);
1635 break;
1636
1637 case SpvOpVectorInsertDynamic:
1638 val->ssa->def = vtn_vector_insert_dynamic(b, vtn_ssa_value(b, w[3])->def,
1639 vtn_ssa_value(b, w[4])->def,
1640 vtn_ssa_value(b, w[5])->def);
1641 break;
1642
1643 case SpvOpVectorShuffle:
1644 val->ssa->def = vtn_vector_shuffle(b, glsl_get_vector_elements(val->type),
1645 vtn_ssa_value(b, w[3])->def,
1646 vtn_ssa_value(b, w[4])->def,
1647 w + 5);
1648 break;
1649
1650 case SpvOpCompositeConstruct: {
1651 val->ssa = rzalloc(b, struct vtn_ssa_value);
1652 unsigned elems = count - 3;
1653 if (glsl_type_is_vector_or_scalar(val->type)) {
1654 nir_ssa_def *srcs[4];
1655 for (unsigned i = 0; i < elems; i++)
1656 srcs[i] = vtn_ssa_value(b, w[3 + i])->def;
1657 val->ssa->def =
1658 vtn_vector_construct(b, glsl_get_vector_elements(val->type),
1659 elems, srcs);
1660 } else {
1661 val->ssa->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
1662 for (unsigned i = 0; i < elems; i++)
1663 val->ssa->elems[i] = vtn_ssa_value(b, w[3 + i]);
1664 }
1665 break;
1666 }
1667 case SpvOpCompositeExtract:
1668 val->ssa = vtn_composite_extract(b, vtn_ssa_value(b, w[3]),
1669 w + 4, count - 4);
1670 break;
1671
1672 case SpvOpCompositeInsert:
1673 val->ssa = vtn_composite_insert(b, vtn_ssa_value(b, w[4]),
1674 vtn_ssa_value(b, w[3]),
1675 w + 5, count - 5);
1676 break;
1677
1678 case SpvOpCopyObject:
1679 val->ssa = vtn_composite_copy(b, vtn_ssa_value(b, w[3]));
1680 break;
1681
1682 default:
1683 unreachable("unknown composite operation");
1684 }
1685
1686 val->ssa->type = val->type;
1687 }
1688
1689 static bool
1690 vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
1691 const uint32_t *w, unsigned count)
1692 {
1693 switch (opcode) {
1694 case SpvOpSource:
1695 case SpvOpSourceExtension:
1696 case SpvOpCompileFlag:
1697 case SpvOpExtension:
1698 /* Unhandled, but these are for debug so that's ok. */
1699 break;
1700
1701 case SpvOpExtInstImport:
1702 vtn_handle_extension(b, opcode, w, count);
1703 break;
1704
1705 case SpvOpMemoryModel:
1706 assert(w[1] == SpvAddressingModelLogical);
1707 assert(w[2] == SpvMemoryModelGLSL450);
1708 break;
1709
1710 case SpvOpEntryPoint:
1711 assert(b->entry_point == NULL);
1712 b->entry_point = &b->values[w[2]];
1713 b->execution_model = w[1];
1714 break;
1715
1716 case SpvOpExecutionMode:
1717 unreachable("Execution modes not yet implemented");
1718 break;
1719
1720 case SpvOpString:
1721 vtn_push_value(b, w[1], vtn_value_type_string)->str =
1722 vtn_string_literal(b, &w[2], count - 2);
1723 break;
1724
1725 case SpvOpName:
1726 b->values[w[1]].name = vtn_string_literal(b, &w[2], count - 2);
1727 break;
1728
1729 case SpvOpMemberName:
1730 /* TODO */
1731 break;
1732
1733 case SpvOpLine:
1734 break; /* Ignored for now */
1735
1736 case SpvOpDecorationGroup:
1737 case SpvOpDecorate:
1738 case SpvOpMemberDecorate:
1739 case SpvOpGroupDecorate:
1740 case SpvOpGroupMemberDecorate:
1741 vtn_handle_decoration(b, opcode, w, count);
1742 break;
1743
1744 case SpvOpTypeVoid:
1745 case SpvOpTypeBool:
1746 case SpvOpTypeInt:
1747 case SpvOpTypeFloat:
1748 case SpvOpTypeVector:
1749 case SpvOpTypeMatrix:
1750 case SpvOpTypeSampler:
1751 case SpvOpTypeArray:
1752 case SpvOpTypeRuntimeArray:
1753 case SpvOpTypeStruct:
1754 case SpvOpTypeOpaque:
1755 case SpvOpTypePointer:
1756 case SpvOpTypeFunction:
1757 case SpvOpTypeEvent:
1758 case SpvOpTypeDeviceEvent:
1759 case SpvOpTypeReserveId:
1760 case SpvOpTypeQueue:
1761 case SpvOpTypePipe:
1762 vtn_push_value(b, w[1], vtn_value_type_type)->type =
1763 vtn_handle_type(b, opcode, &w[2], count - 2);
1764 break;
1765
1766 case SpvOpConstantTrue:
1767 case SpvOpConstantFalse:
1768 case SpvOpConstant:
1769 case SpvOpConstantComposite:
1770 case SpvOpConstantSampler:
1771 case SpvOpConstantNullPointer:
1772 case SpvOpConstantNullObject:
1773 case SpvOpSpecConstantTrue:
1774 case SpvOpSpecConstantFalse:
1775 case SpvOpSpecConstant:
1776 case SpvOpSpecConstantComposite:
1777 vtn_handle_constant(b, opcode, w, count);
1778 break;
1779
1780 case SpvOpVariable:
1781 vtn_handle_variables(b, opcode, w, count);
1782 break;
1783
1784 default:
1785 return false; /* End of preamble */
1786 }
1787
1788 return true;
1789 }
1790
1791 static bool
1792 vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode,
1793 const uint32_t *w, unsigned count)
1794 {
1795 switch (opcode) {
1796 case SpvOpFunction: {
1797 assert(b->func == NULL);
1798 b->func = rzalloc(b, struct vtn_function);
1799
1800 const struct glsl_type *result_type =
1801 vtn_value(b, w[1], vtn_value_type_type)->type;
1802 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
1803 const struct glsl_type *func_type =
1804 vtn_value(b, w[4], vtn_value_type_type)->type;
1805
1806 assert(glsl_get_function_return_type(func_type) == result_type);
1807
1808 nir_function *func =
1809 nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
1810
1811 nir_function_overload *overload = nir_function_overload_create(func);
1812 overload->num_params = glsl_get_length(func_type);
1813 overload->params = ralloc_array(overload, nir_parameter,
1814 overload->num_params);
1815 for (unsigned i = 0; i < overload->num_params; i++) {
1816 const struct glsl_function_param *param =
1817 glsl_get_function_param(func_type, i);
1818 overload->params[i].type = param->type;
1819 if (param->in) {
1820 if (param->out) {
1821 overload->params[i].param_type = nir_parameter_inout;
1822 } else {
1823 overload->params[i].param_type = nir_parameter_in;
1824 }
1825 } else {
1826 if (param->out) {
1827 overload->params[i].param_type = nir_parameter_out;
1828 } else {
1829 assert(!"Parameter is neither in nor out");
1830 }
1831 }
1832 }
1833 b->func->overload = overload;
1834 break;
1835 }
1836
1837 case SpvOpFunctionEnd:
1838 b->func = NULL;
1839 break;
1840
1841 case SpvOpFunctionParameter:
1842 break; /* Does nothing */
1843
1844 case SpvOpLabel: {
1845 assert(b->block == NULL);
1846 b->block = rzalloc(b, struct vtn_block);
1847 b->block->label = w;
1848 vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
1849
1850 if (b->func->start_block == NULL) {
1851 /* This is the first block encountered for this function. In this
1852 * case, we set the start block and add it to the list of
1853 * implemented functions that we'll walk later.
1854 */
1855 b->func->start_block = b->block;
1856 exec_list_push_tail(&b->functions, &b->func->node);
1857 }
1858 break;
1859 }
1860
1861 case SpvOpBranch:
1862 case SpvOpBranchConditional:
1863 case SpvOpSwitch:
1864 case SpvOpKill:
1865 case SpvOpReturn:
1866 case SpvOpReturnValue:
1867 case SpvOpUnreachable:
1868 assert(b->block);
1869 b->block->branch = w;
1870 b->block = NULL;
1871 break;
1872
1873 case SpvOpSelectionMerge:
1874 case SpvOpLoopMerge:
1875 assert(b->block && b->block->merge_op == SpvOpNop);
1876 b->block->merge_op = opcode;
1877 b->block->merge_block_id = w[1];
1878 break;
1879
1880 default:
1881 /* Continue on as per normal */
1882 return true;
1883 }
1884
1885 return true;
1886 }
1887
1888 static bool
1889 vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
1890 const uint32_t *w, unsigned count)
1891 {
1892 switch (opcode) {
1893 case SpvOpLabel: {
1894 struct vtn_block *block = vtn_value(b, w[1], vtn_value_type_block)->block;
1895 assert(block->block == NULL);
1896
1897 struct exec_node *list_tail = exec_list_get_tail(b->nb.cf_node_list);
1898 nir_cf_node *tail_node = exec_node_data(nir_cf_node, list_tail, node);
1899 assert(tail_node->type == nir_cf_node_block);
1900 block->block = nir_cf_node_as_block(tail_node);
1901 break;
1902 }
1903
1904 case SpvOpLoopMerge:
1905 case SpvOpSelectionMerge:
1906 /* This is handled by cfg pre-pass and walk_blocks */
1907 break;
1908
1909 case SpvOpUndef:
1910 vtn_push_value(b, w[2], vtn_value_type_undef);
1911 break;
1912
1913 case SpvOpExtInst:
1914 vtn_handle_extension(b, opcode, w, count);
1915 break;
1916
1917 case SpvOpVariable:
1918 case SpvOpVariableArray:
1919 case SpvOpLoad:
1920 case SpvOpStore:
1921 case SpvOpCopyMemory:
1922 case SpvOpCopyMemorySized:
1923 case SpvOpAccessChain:
1924 case SpvOpInBoundsAccessChain:
1925 case SpvOpArrayLength:
1926 case SpvOpImagePointer:
1927 vtn_handle_variables(b, opcode, w, count);
1928 break;
1929
1930 case SpvOpFunctionCall:
1931 vtn_handle_function_call(b, opcode, w, count);
1932 break;
1933
1934 case SpvOpTextureSample:
1935 case SpvOpTextureSampleDref:
1936 case SpvOpTextureSampleLod:
1937 case SpvOpTextureSampleProj:
1938 case SpvOpTextureSampleGrad:
1939 case SpvOpTextureSampleOffset:
1940 case SpvOpTextureSampleProjLod:
1941 case SpvOpTextureSampleProjGrad:
1942 case SpvOpTextureSampleLodOffset:
1943 case SpvOpTextureSampleProjOffset:
1944 case SpvOpTextureSampleGradOffset:
1945 case SpvOpTextureSampleProjLodOffset:
1946 case SpvOpTextureSampleProjGradOffset:
1947 case SpvOpTextureFetchTexelLod:
1948 case SpvOpTextureFetchTexelOffset:
1949 case SpvOpTextureFetchSample:
1950 case SpvOpTextureFetchTexel:
1951 case SpvOpTextureGather:
1952 case SpvOpTextureGatherOffset:
1953 case SpvOpTextureGatherOffsets:
1954 case SpvOpTextureQuerySizeLod:
1955 case SpvOpTextureQuerySize:
1956 case SpvOpTextureQueryLod:
1957 case SpvOpTextureQueryLevels:
1958 case SpvOpTextureQuerySamples:
1959 vtn_handle_texture(b, opcode, w, count);
1960 break;
1961
1962 case SpvOpSNegate:
1963 case SpvOpFNegate:
1964 case SpvOpNot:
1965 case SpvOpAny:
1966 case SpvOpAll:
1967 case SpvOpConvertFToU:
1968 case SpvOpConvertFToS:
1969 case SpvOpConvertSToF:
1970 case SpvOpConvertUToF:
1971 case SpvOpUConvert:
1972 case SpvOpSConvert:
1973 case SpvOpFConvert:
1974 case SpvOpConvertPtrToU:
1975 case SpvOpConvertUToPtr:
1976 case SpvOpPtrCastToGeneric:
1977 case SpvOpGenericCastToPtr:
1978 case SpvOpBitcast:
1979 case SpvOpIsNan:
1980 case SpvOpIsInf:
1981 case SpvOpIsFinite:
1982 case SpvOpIsNormal:
1983 case SpvOpSignBitSet:
1984 case SpvOpLessOrGreater:
1985 case SpvOpOrdered:
1986 case SpvOpUnordered:
1987 case SpvOpIAdd:
1988 case SpvOpFAdd:
1989 case SpvOpISub:
1990 case SpvOpFSub:
1991 case SpvOpIMul:
1992 case SpvOpFMul:
1993 case SpvOpUDiv:
1994 case SpvOpSDiv:
1995 case SpvOpFDiv:
1996 case SpvOpUMod:
1997 case SpvOpSRem:
1998 case SpvOpSMod:
1999 case SpvOpFRem:
2000 case SpvOpFMod:
2001 case SpvOpVectorTimesScalar:
2002 case SpvOpDot:
2003 case SpvOpShiftRightLogical:
2004 case SpvOpShiftRightArithmetic:
2005 case SpvOpShiftLeftLogical:
2006 case SpvOpLogicalOr:
2007 case SpvOpLogicalXor:
2008 case SpvOpLogicalAnd:
2009 case SpvOpBitwiseOr:
2010 case SpvOpBitwiseXor:
2011 case SpvOpBitwiseAnd:
2012 case SpvOpSelect:
2013 case SpvOpIEqual:
2014 case SpvOpFOrdEqual:
2015 case SpvOpFUnordEqual:
2016 case SpvOpINotEqual:
2017 case SpvOpFOrdNotEqual:
2018 case SpvOpFUnordNotEqual:
2019 case SpvOpULessThan:
2020 case SpvOpSLessThan:
2021 case SpvOpFOrdLessThan:
2022 case SpvOpFUnordLessThan:
2023 case SpvOpUGreaterThan:
2024 case SpvOpSGreaterThan:
2025 case SpvOpFOrdGreaterThan:
2026 case SpvOpFUnordGreaterThan:
2027 case SpvOpULessThanEqual:
2028 case SpvOpSLessThanEqual:
2029 case SpvOpFOrdLessThanEqual:
2030 case SpvOpFUnordLessThanEqual:
2031 case SpvOpUGreaterThanEqual:
2032 case SpvOpSGreaterThanEqual:
2033 case SpvOpFOrdGreaterThanEqual:
2034 case SpvOpFUnordGreaterThanEqual:
2035 case SpvOpDPdx:
2036 case SpvOpDPdy:
2037 case SpvOpFwidth:
2038 case SpvOpDPdxFine:
2039 case SpvOpDPdyFine:
2040 case SpvOpFwidthFine:
2041 case SpvOpDPdxCoarse:
2042 case SpvOpDPdyCoarse:
2043 case SpvOpFwidthCoarse:
2044 vtn_handle_alu(b, opcode, w, count);
2045 break;
2046
2047 case SpvOpTranspose:
2048 case SpvOpOuterProduct:
2049 case SpvOpMatrixTimesScalar:
2050 case SpvOpVectorTimesMatrix:
2051 case SpvOpMatrixTimesVector:
2052 case SpvOpMatrixTimesMatrix:
2053 vtn_handle_matrix_alu(b, opcode, w, count);
2054 break;
2055
2056 case SpvOpVectorExtractDynamic:
2057 case SpvOpVectorInsertDynamic:
2058 case SpvOpVectorShuffle:
2059 case SpvOpCompositeConstruct:
2060 case SpvOpCompositeExtract:
2061 case SpvOpCompositeInsert:
2062 case SpvOpCopyObject:
2063 vtn_handle_composite(b, opcode, w, count);
2064 break;
2065
2066 default:
2067 unreachable("Unhandled opcode");
2068 }
2069
2070 return true;
2071 }
2072
2073 static void
2074 vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
2075 struct vtn_block *break_block, struct vtn_block *cont_block,
2076 struct vtn_block *end_block)
2077 {
2078 struct vtn_block *block = start;
2079 while (block != end_block) {
2080 if (block->merge_op == SpvOpLoopMerge) {
2081 /* This is the jump into a loop. */
2082 struct vtn_block *new_cont_block = block;
2083 struct vtn_block *new_break_block =
2084 vtn_value(b, block->merge_block_id, vtn_value_type_block)->block;
2085
2086 nir_loop *loop = nir_loop_create(b->shader);
2087 nir_cf_node_insert_end(b->nb.cf_node_list, &loop->cf_node);
2088
2089 struct exec_list *old_list = b->nb.cf_node_list;
2090
2091 /* Reset the merge_op to prerevent infinite recursion */
2092 block->merge_op = SpvOpNop;
2093
2094 nir_builder_insert_after_cf_list(&b->nb, &loop->body);
2095 vtn_walk_blocks(b, block, new_break_block, new_cont_block, NULL);
2096
2097 nir_builder_insert_after_cf_list(&b->nb, old_list);
2098 block = new_break_block;
2099 continue;
2100 }
2101
2102 const uint32_t *w = block->branch;
2103 SpvOp branch_op = w[0] & SpvOpCodeMask;
2104
2105 b->block = block;
2106 vtn_foreach_instruction(b, block->label, block->branch,
2107 vtn_handle_body_instruction);
2108
2109 switch (branch_op) {
2110 case SpvOpBranch: {
2111 struct vtn_block *branch_block =
2112 vtn_value(b, w[1], vtn_value_type_block)->block;
2113
2114 if (branch_block == break_block) {
2115 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2116 nir_jump_break);
2117 nir_builder_instr_insert(&b->nb, &jump->instr);
2118
2119 return;
2120 } else if (branch_block == cont_block) {
2121 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2122 nir_jump_continue);
2123 nir_builder_instr_insert(&b->nb, &jump->instr);
2124
2125 return;
2126 } else if (branch_block == end_block) {
2127 /* We're branching to the merge block of an if, since for loops
2128 * and functions end_block == NULL, so we're done here.
2129 */
2130 return;
2131 } else {
2132 /* We're branching to another block, and according to the rules,
2133 * we can only branch to another block with one predecessor (so
2134 * we're the only one jumping to it) so we can just process it
2135 * next.
2136 */
2137 block = branch_block;
2138 continue;
2139 }
2140 }
2141
2142 case SpvOpBranchConditional: {
2143 /* Gather up the branch blocks */
2144 struct vtn_block *then_block =
2145 vtn_value(b, w[2], vtn_value_type_block)->block;
2146 struct vtn_block *else_block =
2147 vtn_value(b, w[3], vtn_value_type_block)->block;
2148
2149 nir_if *if_stmt = nir_if_create(b->shader);
2150 if_stmt->condition = nir_src_for_ssa(vtn_ssa_value(b, w[1])->def);
2151 nir_cf_node_insert_end(b->nb.cf_node_list, &if_stmt->cf_node);
2152
2153 if (then_block == break_block) {
2154 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2155 nir_jump_break);
2156 nir_instr_insert_after_cf_list(&if_stmt->then_list,
2157 &jump->instr);
2158 block = else_block;
2159 } else if (else_block == break_block) {
2160 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2161 nir_jump_break);
2162 nir_instr_insert_after_cf_list(&if_stmt->else_list,
2163 &jump->instr);
2164 block = then_block;
2165 } else if (then_block == cont_block) {
2166 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2167 nir_jump_continue);
2168 nir_instr_insert_after_cf_list(&if_stmt->then_list,
2169 &jump->instr);
2170 block = else_block;
2171 } else if (else_block == cont_block) {
2172 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2173 nir_jump_continue);
2174 nir_instr_insert_after_cf_list(&if_stmt->else_list,
2175 &jump->instr);
2176 block = then_block;
2177 } else {
2178 /* According to the rules we're branching to two blocks that don't
2179 * have any other predecessors, so we can handle this as a
2180 * conventional if.
2181 */
2182 assert(block->merge_op == SpvOpSelectionMerge);
2183 struct vtn_block *merge_block =
2184 vtn_value(b, block->merge_block_id, vtn_value_type_block)->block;
2185
2186 struct exec_list *old_list = b->nb.cf_node_list;
2187
2188 nir_builder_insert_after_cf_list(&b->nb, &if_stmt->then_list);
2189 vtn_walk_blocks(b, then_block, break_block, cont_block, merge_block);
2190
2191 nir_builder_insert_after_cf_list(&b->nb, &if_stmt->else_list);
2192 vtn_walk_blocks(b, else_block, break_block, cont_block, merge_block);
2193
2194 nir_builder_insert_after_cf_list(&b->nb, old_list);
2195 block = merge_block;
2196 continue;
2197 }
2198
2199 /* If we got here then we inserted a predicated break or continue
2200 * above and we need to handle the other case. We already set
2201 * `block` above to indicate what block to visit after the
2202 * predicated break.
2203 */
2204
2205 /* It's possible that the other branch is also a break/continue.
2206 * If it is, we handle that here.
2207 */
2208 if (block == break_block) {
2209 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2210 nir_jump_break);
2211 nir_builder_instr_insert(&b->nb, &jump->instr);
2212
2213 return;
2214 } else if (block == cont_block) {
2215 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2216 nir_jump_continue);
2217 nir_builder_instr_insert(&b->nb, &jump->instr);
2218
2219 return;
2220 }
2221
2222 /* If we got here then there was a predicated break/continue but
2223 * the other half of the if has stuff in it. `block` was already
2224 * set above so there is nothing left for us to do.
2225 */
2226 continue;
2227 }
2228
2229 case SpvOpReturn: {
2230 nir_jump_instr *jump = nir_jump_instr_create(b->shader,
2231 nir_jump_return);
2232 nir_builder_instr_insert(&b->nb, &jump->instr);
2233 return;
2234 }
2235
2236 case SpvOpKill: {
2237 nir_intrinsic_instr *discard =
2238 nir_intrinsic_instr_create(b->shader, nir_intrinsic_discard);
2239 nir_builder_instr_insert(&b->nb, &discard->instr);
2240 return;
2241 }
2242
2243 case SpvOpSwitch:
2244 case SpvOpReturnValue:
2245 case SpvOpUnreachable:
2246 default:
2247 unreachable("Unhandled opcode");
2248 }
2249 }
2250 }
2251
2252 nir_shader *
2253 spirv_to_nir(const uint32_t *words, size_t word_count,
2254 const nir_shader_compiler_options *options)
2255 {
2256 const uint32_t *word_end = words + word_count;
2257
2258 /* Handle the SPIR-V header (first 4 dwords) */
2259 assert(word_count > 5);
2260
2261 assert(words[0] == SpvMagicNumber);
2262 assert(words[1] == 99);
2263 /* words[2] == generator magic */
2264 unsigned value_id_bound = words[3];
2265 assert(words[4] == 0);
2266
2267 words+= 5;
2268
2269 nir_shader *shader = nir_shader_create(NULL, options);
2270
2271 /* Initialize the stn_builder object */
2272 struct vtn_builder *b = rzalloc(NULL, struct vtn_builder);
2273 b->shader = shader;
2274 b->value_id_bound = value_id_bound;
2275 b->values = ralloc_array(b, struct vtn_value, value_id_bound);
2276 exec_list_make_empty(&b->functions);
2277
2278 /* Handle all the preamble instructions */
2279 words = vtn_foreach_instruction(b, words, word_end,
2280 vtn_handle_preamble_instruction);
2281
2282 /* Do a very quick CFG analysis pass */
2283 vtn_foreach_instruction(b, words, word_end,
2284 vtn_handle_first_cfg_pass_instruction);
2285
2286 foreach_list_typed(struct vtn_function, func, node, &b->functions) {
2287 b->impl = nir_function_impl_create(func->overload);
2288 b->const_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
2289 _mesa_key_pointer_equal);
2290 nir_builder_init(&b->nb, b->impl);
2291 nir_builder_insert_after_cf_list(&b->nb, &b->impl->body);
2292 vtn_walk_blocks(b, func->start_block, NULL, NULL, NULL);
2293 }
2294
2295 ralloc_free(b);
2296
2297 return shader;
2298 }