freedreno/ir3: Implement lowering passes for VS and GS
[mesa.git] / src / freedreno / ir3 / ir3_nir_lower_tess.c
1 /*
2 * Copyright © 2019 Google, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 * SOFTWARE.
22 */
23
24 #include "ir3_nir.h"
25 #include "ir3_compiler.h"
26 #include "compiler/nir/nir_builder.h"
27
28 struct state {
29 struct primitive_map {
30 unsigned loc[32];
31 unsigned size[32];
32 unsigned stride;
33 } map;
34
35 nir_ssa_def *header;
36
37 nir_variable *vertex_count_var;
38 nir_variable *emitted_vertex_var;
39 nir_variable *vertex_flags_var;
40 nir_variable *vertex_flags_out;
41
42 nir_variable *output_vars[32];
43 };
44
45 static nir_ssa_def *
46 bitfield_extract(nir_builder *b, nir_ssa_def *v, uint32_t start, uint32_t mask)
47 {
48 return nir_iand(b, nir_ushr(b, v, nir_imm_int(b, start)),
49 nir_imm_int(b, mask));
50 }
51
52 static nir_ssa_def *
53 build_invocation_id(nir_builder *b, struct state *state)
54 {
55 return bitfield_extract(b, state->header, 11, 31);
56 }
57
58 static nir_ssa_def *
59 build_vertex_id(nir_builder *b, struct state *state)
60 {
61 return bitfield_extract(b, state->header, 6, 31);
62 }
63
64 static nir_ssa_def *
65 build_local_primitive_id(nir_builder *b, struct state *state)
66 {
67 return bitfield_extract(b, state->header, 0, 63);
68 }
69
70 static nir_variable *
71 get_var(struct exec_list *list, int driver_location)
72 {
73 nir_foreach_variable(v, list) {
74 if (v->data.driver_location == driver_location) {
75 return v;
76 }
77 }
78
79 return NULL;
80 }
81
82 static nir_ssa_def *
83 build_local_offset(nir_builder *b, struct state *state,
84 nir_ssa_def *vertex, uint32_t base, nir_ssa_def *offset)
85 {
86 nir_ssa_def *primitive_stride = nir_load_vs_primitive_stride_ir3(b);
87 nir_ssa_def *primitive_offset =
88 nir_imul(b, build_local_primitive_id(b, state), primitive_stride);
89 nir_ssa_def *attr_offset;
90 nir_ssa_def *vertex_stride;
91
92 if (b->shader->info.stage == MESA_SHADER_VERTEX) {
93 vertex_stride = nir_imm_int(b, state->map.stride * 4);
94 attr_offset = nir_imm_int(b, state->map.loc[base] * 4);
95 } else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
96 vertex_stride = nir_load_vs_vertex_stride_ir3(b);
97 attr_offset = nir_load_primitive_location_ir3(b, base);
98 } else {
99 unreachable("bad shader stage");
100 }
101
102 nir_ssa_def *vertex_offset = nir_imul(b, vertex, vertex_stride);
103
104 return nir_iadd(b, nir_iadd(b, primitive_offset, vertex_offset),
105 nir_iadd(b, attr_offset, offset));
106 }
107
108 static nir_intrinsic_instr *
109 replace_intrinsic(nir_builder *b, nir_intrinsic_instr *intr,
110 nir_intrinsic_op op, nir_ssa_def *src0, nir_ssa_def *src1, nir_ssa_def *src2)
111 {
112 nir_intrinsic_instr *new_intr =
113 nir_intrinsic_instr_create(b->shader, op);
114
115 new_intr->src[0] = nir_src_for_ssa(src0);
116 if (src1)
117 new_intr->src[1] = nir_src_for_ssa(src1);
118 if (src2)
119 new_intr->src[2] = nir_src_for_ssa(src2);
120
121 new_intr->num_components = intr->num_components;
122
123 if (nir_intrinsic_infos[op].has_dest)
124 nir_ssa_dest_init(&new_intr->instr, &new_intr->dest,
125 intr->num_components, 32, NULL);
126
127 nir_builder_instr_insert(b, &new_intr->instr);
128
129 if (nir_intrinsic_infos[op].has_dest)
130 nir_ssa_def_rewrite_uses(&intr->dest.ssa, nir_src_for_ssa(&new_intr->dest.ssa));
131
132 nir_instr_remove(&intr->instr);
133
134 return new_intr;
135 }
136
137 static void
138 build_primitive_map(nir_shader *shader, struct primitive_map *map, struct exec_list *list)
139 {
140 nir_foreach_variable(var, list) {
141 switch (var->data.location) {
142 case VARYING_SLOT_TESS_LEVEL_OUTER:
143 case VARYING_SLOT_TESS_LEVEL_INNER:
144 continue;
145 }
146
147 unsigned size = glsl_count_attribute_slots(var->type, false) * 4;
148
149 assert(var->data.driver_location < ARRAY_SIZE(map->size));
150 map->size[var->data.driver_location] =
151 MAX2(map->size[var->data.driver_location], size);
152 }
153
154 unsigned loc = 0;
155 for (uint32_t i = 0; i < ARRAY_SIZE(map->size); i++) {
156 if (map->size[i] == 0)
157 continue;
158 nir_variable *var = get_var(list, i);
159 map->loc[i] = loc;
160 loc += map->size[i];
161
162 if (var->data.patch)
163 map->size[i] = 0;
164 else
165 map->size[i] = map->size[i] / glsl_get_length(var->type);
166 }
167
168 map->stride = loc;
169 }
170
171 static void
172 lower_vs_block(nir_block *block, nir_builder *b, struct state *state)
173 {
174 nir_foreach_instr_safe(instr, block) {
175 if (instr->type != nir_instr_type_intrinsic)
176 continue;
177
178 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
179
180 switch (intr->intrinsic) {
181 case nir_intrinsic_store_output: {
182 // src[] = { value, offset }.
183
184 b->cursor = nir_before_instr(&intr->instr);
185
186 nir_ssa_def *vertex_id = build_vertex_id(b, state);
187 nir_ssa_def *offset = build_local_offset(b, state, vertex_id, nir_intrinsic_base(intr),
188 intr->src[1].ssa);
189 nir_intrinsic_instr *store =
190 nir_intrinsic_instr_create(b->shader, nir_intrinsic_store_shared_ir3);
191
192 nir_intrinsic_set_write_mask(store, MASK(intr->num_components));
193 store->src[0] = nir_src_for_ssa(intr->src[0].ssa);
194 store->src[1] = nir_src_for_ssa(offset);
195
196 store->num_components = intr->num_components;
197
198 nir_builder_instr_insert(b, &store->instr);
199 break;
200 }
201
202 default:
203 break;
204 }
205 }
206 }
207
208 static nir_ssa_def *
209 local_thread_id(nir_builder *b)
210 {
211 return bitfield_extract(b, nir_load_gs_header_ir3(b), 16, 1023);
212 }
213
214 void
215 ir3_nir_lower_vs_to_explicit_io(nir_shader *shader, struct ir3_shader *s)
216 {
217 struct state state = { };
218
219 build_primitive_map(shader, &state.map, &shader->outputs);
220 memcpy(s->output_loc, state.map.loc, sizeof(s->output_loc));
221
222 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
223 assert(impl);
224
225 nir_builder b;
226 nir_builder_init(&b, impl);
227 b.cursor = nir_before_cf_list(&impl->body);
228
229 state.header = nir_load_gs_header_ir3(&b);
230
231 nir_foreach_block_safe(block, impl)
232 lower_vs_block(block, &b, &state);
233
234 nir_metadata_preserve(impl, nir_metadata_block_index |
235 nir_metadata_dominance);
236
237 s->output_size = state.map.stride;
238 }
239
240 static void
241 lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
242 {
243 nir_intrinsic_instr *outputs[32] = {};
244
245 nir_foreach_instr_safe(instr, block) {
246 if (instr->type != nir_instr_type_intrinsic)
247 continue;
248
249 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
250
251 switch (intr->intrinsic) {
252 case nir_intrinsic_store_output: {
253 // src[] = { value, offset }.
254
255 uint32_t loc = nir_intrinsic_base(intr);
256 outputs[loc] = intr;
257 break;
258 }
259
260 case nir_intrinsic_end_primitive: {
261 b->cursor = nir_before_instr(&intr->instr);
262 nir_store_var(b, state->vertex_flags_var, nir_imm_int(b, 4), 0x1);
263 nir_instr_remove(&intr->instr);
264 break;
265 }
266
267 case nir_intrinsic_emit_vertex: {
268
269 /* Load the vertex count */
270 b->cursor = nir_before_instr(&intr->instr);
271 nir_ssa_def *count = nir_load_var(b, state->vertex_count_var);
272
273 nir_push_if(b, nir_ieq(b, count, local_thread_id(b)));
274
275 for (uint32_t i = 0; i < ARRAY_SIZE(outputs); i++) {
276 if (outputs[i]) {
277 nir_store_var(b, state->output_vars[i],
278 outputs[i]->src[0].ssa,
279 (1 << outputs[i]->num_components) - 1);
280
281 nir_instr_remove(&outputs[i]->instr);
282 }
283 outputs[i] = NULL;
284 }
285
286 nir_instr_remove(&intr->instr);
287
288 nir_store_var(b, state->emitted_vertex_var,
289 nir_iadd(b, nir_load_var(b, state->emitted_vertex_var), nir_imm_int(b, 1)), 0x1);
290
291 nir_store_var(b, state->vertex_flags_out,
292 nir_load_var(b, state->vertex_flags_var), 0x1);
293
294 nir_pop_if(b, NULL);
295
296 /* Increment the vertex count by 1 */
297 nir_store_var(b, state->vertex_count_var,
298 nir_iadd(b, count, nir_imm_int(b, 1)), 0x1); /* .x */
299 nir_store_var(b, state->vertex_flags_var, nir_imm_int(b, 0), 0x1);
300
301 break;
302 }
303
304 case nir_intrinsic_load_per_vertex_input: {
305 // src[] = { vertex, offset }.
306
307 b->cursor = nir_before_instr(&intr->instr);
308
309 nir_ssa_def *offset = build_local_offset(b, state,
310 intr->src[0].ssa, // this is typically gl_InvocationID
311 nir_intrinsic_base(intr),
312 intr->src[1].ssa);
313
314 replace_intrinsic(b, intr, nir_intrinsic_load_shared_ir3, offset, NULL, NULL);
315 break;
316 }
317
318 case nir_intrinsic_load_invocation_id: {
319 b->cursor = nir_before_instr(&intr->instr);
320
321 nir_ssa_def *iid = build_invocation_id(b, state);
322 nir_ssa_def_rewrite_uses(&intr->dest.ssa, nir_src_for_ssa(iid));
323 nir_instr_remove(&intr->instr);
324 break;
325 }
326
327 default:
328 break;
329 }
330 }
331 }
332
333 static void
334 emit_store_outputs(nir_builder *b, struct state *state)
335 {
336 /* This also stores the internally added vertex_flags output. */
337
338 for (uint32_t i = 0; i < ARRAY_SIZE(state->output_vars); i++) {
339 if (!state->output_vars[i])
340 continue;
341
342 nir_intrinsic_instr *store =
343 nir_intrinsic_instr_create(b->shader, nir_intrinsic_store_output);
344
345 nir_intrinsic_set_base(store, i);
346 store->src[0] = nir_src_for_ssa(nir_load_var(b, state->output_vars[i]));
347 store->src[1] = nir_src_for_ssa(nir_imm_int(b, 0));
348 store->num_components = store->src[0].ssa->num_components;
349
350 nir_builder_instr_insert(b, &store->instr);
351 }
352 }
353
354 static void
355 clean_up_split_vars(nir_shader *shader, struct exec_list *list)
356 {
357 uint32_t components[32] = {};
358
359 nir_foreach_variable(var, list) {
360 uint32_t mask =
361 ((1 << glsl_get_components(glsl_without_array(var->type))) - 1) << var->data.location_frac;
362 components[var->data.driver_location] |= mask;
363 }
364
365 nir_foreach_variable_safe(var, list) {
366 uint32_t mask =
367 ((1 << glsl_get_components(glsl_without_array(var->type))) - 1) << var->data.location_frac;
368 bool subset =
369 (components[var->data.driver_location] | mask) != mask;
370 if (subset)
371 exec_node_remove(&var->node);
372 }
373 }
374
375 void
376 ir3_nir_lower_gs(nir_shader *shader, struct ir3_shader *s)
377 {
378 struct state state = { };
379
380 if (shader_debug_enabled(shader->info.stage)) {
381 fprintf(stderr, "NIR (before gs lowering):\n");
382 nir_print_shader(shader, stderr);
383 }
384
385 clean_up_split_vars(shader, &shader->inputs);
386 clean_up_split_vars(shader, &shader->outputs);
387
388 build_primitive_map(shader, &state.map, &shader->inputs);
389
390 uint32_t loc = 0;
391 nir_foreach_variable(var, &shader->outputs) {
392 uint32_t end = var->data.driver_location + glsl_count_attribute_slots(var->type, false);
393 loc = MAX2(loc, end);
394 }
395
396 state.vertex_flags_out = nir_variable_create(shader, nir_var_shader_out,
397 glsl_uint_type(), "vertex_flags");
398 state.vertex_flags_out->data.driver_location = loc;
399 state.vertex_flags_out->data.location = VARYING_SLOT_GS_VERTEX_FLAGS_IR3;
400
401 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
402 assert(impl);
403
404 nir_builder b;
405 nir_builder_init(&b, impl);
406 b.cursor = nir_before_cf_list(&impl->body);
407
408 state.header = nir_load_gs_header_ir3(&b);
409
410 nir_foreach_variable(var, &shader->outputs) {
411 state.output_vars[var->data.driver_location] =
412 nir_local_variable_create(impl, var->type,
413 ralloc_asprintf(var, "%s:gs-temp", var->name));
414 }
415
416 state.vertex_count_var =
417 nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
418 state.emitted_vertex_var =
419 nir_local_variable_create(impl, glsl_uint_type(), "emitted_vertex");
420 state.vertex_flags_var =
421 nir_local_variable_create(impl, glsl_uint_type(), "vertex_flags");
422 state.vertex_flags_out = state.output_vars[state.vertex_flags_out->data.driver_location];
423
424 /* initialize to 0 */
425 b.cursor = nir_before_cf_list(&impl->body);
426 nir_store_var(&b, state.vertex_count_var, nir_imm_int(&b, 0), 0x1);
427 nir_store_var(&b, state.emitted_vertex_var, nir_imm_int(&b, 0), 0x1);
428 nir_store_var(&b, state.vertex_flags_var, nir_imm_int(&b, 4), 0x1);
429
430 nir_foreach_block_safe(block, impl)
431 lower_gs_block(block, &b, &state);
432
433 set_foreach(impl->end_block->predecessors, block_entry) {
434 struct nir_block *block = (void *)block_entry->key;
435 b.cursor = nir_after_block_before_jump(block);
436
437 nir_intrinsic_instr *discard_if =
438 nir_intrinsic_instr_create(b.shader, nir_intrinsic_discard_if);
439
440 nir_ssa_def *cond = nir_ieq(&b, nir_load_var(&b, state.emitted_vertex_var), nir_imm_int(&b, 0));
441
442 discard_if->src[0] = nir_src_for_ssa(cond);
443
444 nir_builder_instr_insert(&b, &discard_if->instr);
445
446 emit_store_outputs(&b, &state);
447 }
448
449 nir_metadata_preserve(impl, 0);
450
451 if (shader_debug_enabled(shader->info.stage)) {
452 fprintf(stderr, "NIR (after gs lowering):\n");
453 nir_print_shader(shader, stderr);
454 }
455 }