f45a490c895a0ff79071a3b3570d0e4db744aa27
[mesa.git] / src / gallium / drivers / r600 / sfn / sfn_nir_lower_fs_out_to_vector.cpp
1 /* -*- mesa-c++ -*-
2 *
3 * Copyright (c) 2019 Collabora LTD
4 *
5 * Author: Gert Wollny <gert.wollny@collabora.com>
6 *
7 * Permission is hereby granted, free of charge, to any person obtaining a
8 * copy of this software and associated documentation files (the "Software"),
9 * to deal in the Software without restriction, including without limitation
10 * on the rights to use, copy, modify, merge, publish, distribute, sub
11 * license, and/or sell copies of the Software, and to permit persons to whom
12 * the Software is furnished to do so, subject to the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the next
15 * paragraph) shall be included in all copies or substantial portions of the
16 * Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
21 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
22 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
24 * USE OR OTHER DEALINGS IN THE SOFTWARE.
25 */
26
27 #include "sfn_nir_lower_fs_out_to_vector.h"
28
29 #include "nir_builder.h"
30 #include "nir_deref.h"
31 #include "util/u_math.h"
32
33 #include <set>
34 #include <vector>
35 #include <array>
36 #include <algorithm>
37
38 namespace r600 {
39
40 using std::multiset;
41 using std::vector;
42 using std::array;
43
44 struct nir_intrinsic_instr_less {
45 bool operator () (const nir_intrinsic_instr *lhs, const nir_intrinsic_instr *rhs) const
46 {
47 nir_variable *vlhs = nir_deref_instr_get_variable(nir_src_as_deref(lhs->src[0]));
48 nir_variable *vrhs = nir_deref_instr_get_variable(nir_src_as_deref(rhs->src[0]));
49
50 auto ltype = glsl_get_base_type(vlhs->type);
51 auto rtype = glsl_get_base_type(vrhs->type);
52
53 if (ltype != rtype)
54 return ltype < rtype;
55 return vlhs->data.location < vrhs->data.location;
56 }
57 };
58
59 class NirLowerIOToVector {
60 public:
61 NirLowerIOToVector(int base_slot);
62 bool run(nir_function_impl *shader);
63
64 protected:
65 bool var_can_merge(const nir_variable *lhs, const nir_variable *rhs);
66 bool var_can_rewrite(nir_variable *var) const;
67 void create_new_io_vars(nir_shader *shader);
68 void create_new_io_var(nir_shader *shader, unsigned location, unsigned comps);
69
70 nir_deref_instr *clone_deref_array(nir_builder *b, nir_deref_instr *dst_tail,
71 const nir_deref_instr *src_head);
72
73 bool vectorize_block(nir_builder *b, nir_block *block);
74 bool instr_can_rewrite(nir_instr *instr);
75 bool vec_instr_set_remove(nir_builder *b,nir_instr *instr);
76
77 using InstrSet = multiset<nir_intrinsic_instr *, nir_intrinsic_instr_less>;
78 using InstrSubSet = std::pair<InstrSet::iterator, InstrSet::iterator>;
79
80 bool vec_instr_stack_pop(nir_builder *b, InstrSubSet& ir_set,
81 nir_intrinsic_instr *instr);
82
83 array<array<nir_variable *, 4>, 16> m_vars;
84 InstrSet m_block_io;
85 int m_next_index;
86 private:
87 virtual nir_variable_mode get_io_mode(nir_shader *shader) const = 0;
88 virtual bool instr_can_rewrite_type(nir_intrinsic_instr *intr) const = 0;
89 virtual bool var_can_rewrite_slot(nir_variable *var) const = 0;
90 virtual void create_new_io(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var,
91 nir_ssa_def **srcs, unsigned first_comp, unsigned num_comps) = 0;
92
93 int m_base_slot;
94 };
95
96 class NirLowerFSOutToVector : public NirLowerIOToVector {
97 public:
98 NirLowerFSOutToVector();
99
100 private:
101 nir_variable_mode get_io_mode(nir_shader *shader) const override;
102 bool var_can_rewrite_slot(nir_variable *var) const override;
103 void create_new_io(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var,
104 nir_ssa_def **srcs, unsigned first_comp, unsigned num_comps) override;
105 bool instr_can_rewrite_type(nir_intrinsic_instr *intr) const override;
106
107 nir_ssa_def *create_combined_vector(nir_builder *b, nir_ssa_def **srcs,
108 int first_comp, int num_comp);
109 };
110
111 bool r600_lower_fs_out_to_vector(nir_shader *shader)
112 {
113 NirLowerFSOutToVector processor;
114
115 assert(shader->info.stage == MESA_SHADER_FRAGMENT);
116 bool progress = false;
117
118 nir_foreach_function(function, shader) {
119 if (function->impl)
120 progress |= processor.run(function->impl);
121 }
122 return progress;
123 }
124
125 NirLowerIOToVector::NirLowerIOToVector(int base_slot):
126 m_next_index(0),
127 m_base_slot(base_slot)
128 {
129 for(auto& a : m_vars)
130 for(auto& aa : a)
131 aa = nullptr;
132 }
133
134 bool NirLowerIOToVector::run(nir_function_impl *impl)
135 {
136 nir_builder b;
137 nir_builder_init(&b, impl);
138
139 nir_metadata_require(impl, nir_metadata_dominance);
140 create_new_io_vars(impl->function->shader);
141
142 bool progress = vectorize_block(&b, nir_start_block(impl));
143 if (progress) {
144 nir_metadata_preserve(impl, (nir_metadata )
145 (nir_metadata_block_index |
146 nir_metadata_dominance));
147 }
148 return progress;
149 }
150
151 void NirLowerIOToVector::create_new_io_vars(nir_shader *shader)
152 {
153 nir_variable_mode mode = get_io_mode(shader);
154
155 bool can_rewrite_vars = false;
156 nir_foreach_variable_with_modes(var, shader, mode) {
157 if (var_can_rewrite(var)) {
158 can_rewrite_vars = true;
159 unsigned loc = var->data.location - m_base_slot;
160 m_vars[loc][var->data.location_frac] = var;
161 }
162 }
163
164 if (!can_rewrite_vars)
165 return;
166
167 /* We don't handle combining vars of different type e.g. different array
168 * lengths.
169 */
170 for (unsigned i = 0; i < 16; i++) {
171 unsigned comps = 0;
172
173 for (unsigned j = 0; j < 3; j++) {
174 if (!m_vars[i][j])
175 continue;
176
177 for (unsigned k = j + 1; k < 4; k++) {
178 if (!m_vars[i][k])
179 continue;
180
181 if (!var_can_merge(m_vars[i][j], m_vars[i][k]))
182 continue;
183
184 /* Set comps */
185 for (unsigned n = 0; n < glsl_get_components(m_vars[i][j]->type); ++n)
186 comps |= 1 << (m_vars[i][j]->data.location_frac + n);
187
188 for (unsigned n = 0; n < glsl_get_components(m_vars[i][k]->type); ++n)
189 comps |= 1 << (m_vars[i][k]->data.location_frac + n);
190
191 }
192 }
193 if (comps)
194 create_new_io_var(shader, i, comps);
195 }
196 }
197
198 bool
199 NirLowerIOToVector::var_can_merge(const nir_variable *lhs,
200 const nir_variable *rhs)
201 {
202 return (glsl_get_base_type(lhs->type) == glsl_get_base_type(rhs->type));
203 }
204
205 void
206 NirLowerIOToVector::create_new_io_var(nir_shader *shader,
207 unsigned location, unsigned comps)
208 {
209 unsigned num_comps = util_bitcount(comps);
210 assert(num_comps > 1);
211
212 /* Note: u_bit_scan() strips a component of the comps bitfield here */
213 unsigned first_comp = u_bit_scan(&comps);
214
215 nir_variable *var = nir_variable_clone(m_vars[location][first_comp], shader);
216 var->data.location_frac = first_comp;
217 var->type = glsl_replace_vector_type(var->type, num_comps);
218
219 nir_shader_add_variable(shader, var);
220
221 m_vars[location][first_comp] = var;
222
223 while (comps) {
224 const int comp = u_bit_scan(&comps);
225 if (m_vars[location][comp]) {
226 m_vars[location][comp] = var;
227 }
228 }
229 }
230
231 bool NirLowerIOToVector::var_can_rewrite(nir_variable *var) const
232 {
233 /* Skip complex types we don't split in the first place */
234 if (!glsl_type_is_vector_or_scalar(glsl_without_array(var->type)))
235 return false;
236
237 if (glsl_get_bit_size(glsl_without_array(var->type)) != 32)
238 return false;
239
240 return var_can_rewrite_slot(var);
241 }
242
243 bool
244 NirLowerIOToVector::vectorize_block(nir_builder *b, nir_block *block)
245 {
246 bool progress = false;
247
248 nir_foreach_instr_safe(instr, block) {
249 if (instr_can_rewrite(instr)) {
250 instr->index = m_next_index++;
251 nir_intrinsic_instr *ir = nir_instr_as_intrinsic(instr);
252 m_block_io.insert(ir);
253 }
254 }
255
256 for (unsigned i = 0; i < block->num_dom_children; i++) {
257 nir_block *child = block->dom_children[i];
258 progress |= vectorize_block(b, child);
259 }
260
261 nir_foreach_instr_reverse_safe(instr, block) {
262 progress |= vec_instr_set_remove(b, instr);
263 }
264 m_block_io.clear();
265
266 return progress;
267 }
268
269 bool NirLowerIOToVector::instr_can_rewrite(nir_instr *instr)
270 {
271 if (instr->type != nir_instr_type_intrinsic)
272 return false;
273
274 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
275
276 if (intr->num_components > 3)
277 return false;
278
279 return instr_can_rewrite_type(intr);
280 }
281
282 bool NirLowerIOToVector::vec_instr_set_remove(nir_builder *b,nir_instr *instr)
283 {
284 if (!instr_can_rewrite(instr))
285 return false;
286
287 nir_intrinsic_instr *ir = nir_instr_as_intrinsic(instr);
288 auto entry = m_block_io.equal_range(ir);
289 if (entry.first != m_block_io.end()) {
290 vec_instr_stack_pop(b, entry, ir);
291 }
292 return true;
293 }
294
295 nir_deref_instr *
296 NirLowerIOToVector::clone_deref_array(nir_builder *b, nir_deref_instr *dst_tail,
297 const nir_deref_instr *src_head)
298 {
299 const nir_deref_instr *parent = nir_deref_instr_parent(src_head);
300
301 if (!parent)
302 return dst_tail;
303
304 assert(src_head->deref_type == nir_deref_type_array);
305
306 dst_tail = clone_deref_array(b, dst_tail, parent);
307
308 return nir_build_deref_array(b, dst_tail,
309 nir_ssa_for_src(b, src_head->arr.index, 1));
310 }
311
312 NirLowerFSOutToVector::NirLowerFSOutToVector():
313 NirLowerIOToVector(FRAG_RESULT_COLOR)
314 {
315
316 }
317
318 bool NirLowerFSOutToVector::var_can_rewrite_slot(nir_variable *var) const
319 {
320 return ((var->data.mode == nir_var_shader_out) &&
321 ((var->data.location == FRAG_RESULT_COLOR) ||
322 ((var->data.location >= FRAG_RESULT_DATA0) &&
323 (var->data.location <= FRAG_RESULT_DATA7))));
324 }
325
326 bool NirLowerIOToVector::vec_instr_stack_pop(nir_builder *b, InstrSubSet &ir_set,
327 nir_intrinsic_instr *instr)
328 {
329 vector< nir_intrinsic_instr *> ir_sorted_set(ir_set.first, ir_set.second);
330 std::sort(ir_sorted_set.begin(), ir_sorted_set.end(),
331 [](const nir_intrinsic_instr *lhs, const nir_intrinsic_instr *rhs) {
332 return lhs->instr.index > rhs->instr.index;
333 }
334 );
335
336 nir_intrinsic_instr *intr = *ir_sorted_set.begin();
337 nir_variable *var = nir_deref_instr_get_variable(nir_src_as_deref(intr->src[0]));
338
339 unsigned loc = var->data.location - m_base_slot;
340
341 nir_variable *new_var = m_vars[loc][var->data.location_frac];
342 unsigned num_comps = glsl_get_vector_elements(glsl_without_array(new_var->type));
343 unsigned old_num_comps = glsl_get_vector_elements(glsl_without_array(var->type));
344
345 /* Don't bother walking the stack if this component can't be vectorised. */
346 if (old_num_comps > 3) {
347 return false;
348 }
349
350 if (new_var == var) {
351 return false;
352 }
353
354 b->cursor = nir_after_instr(&intr->instr);
355 nir_ssa_undef_instr *instr_undef =
356 nir_ssa_undef_instr_create(b->shader, 1, 32);
357 nir_builder_instr_insert(b, &instr_undef->instr);
358
359 nir_ssa_def *srcs[4];
360 for (int i = 0; i < 4; i++) {
361 srcs[i] = &instr_undef->def;
362 }
363 srcs[var->data.location_frac] = intr->src[1].ssa;
364
365 for (auto k = ir_sorted_set.begin() + 1; k != ir_sorted_set.end(); ++k) {
366 nir_intrinsic_instr *intr2 = *k;
367 nir_variable *var2 =
368 nir_deref_instr_get_variable(nir_src_as_deref(intr2->src[0]));
369 unsigned loc2 = var->data.location - m_base_slot;
370
371 if (m_vars[loc][var->data.location_frac] !=
372 m_vars[loc2][var2->data.location_frac]) {
373 continue;
374 }
375
376 assert(glsl_get_vector_elements(glsl_without_array(var2->type)) < 4);
377
378 if (srcs[var2->data.location_frac] == &instr_undef->def) {
379 assert(intr2->src[1].is_ssa);
380 assert(intr2->src[1].ssa);
381 srcs[var2->data.location_frac] = intr2->src[1].ssa;
382 }
383 nir_instr_remove(&intr2->instr);
384 }
385
386 create_new_io(b, intr, new_var, srcs, new_var->data.location_frac,
387 num_comps);
388 return true;
389 }
390
391 nir_variable_mode NirLowerFSOutToVector::get_io_mode(nir_shader *shader) const
392 {
393 return nir_var_shader_out;
394 }
395
396 void
397 NirLowerFSOutToVector::create_new_io(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var,
398 nir_ssa_def **srcs, unsigned first_comp, unsigned num_comps)
399 {
400 b->cursor = nir_before_instr(&intr->instr);
401
402 nir_intrinsic_instr *new_intr =
403 nir_intrinsic_instr_create(b->shader, intr->intrinsic);
404 new_intr->num_components = num_comps;
405
406 nir_intrinsic_set_write_mask(new_intr, (1 << num_comps) - 1);
407
408 nir_deref_instr *deref = nir_build_deref_var(b, var);
409 deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
410
411 new_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
412 new_intr->src[1] = nir_src_for_ssa(create_combined_vector(b, srcs, first_comp, num_comps));
413
414 nir_builder_instr_insert(b, &new_intr->instr);
415
416 /* Remove the old store intrinsic */
417 nir_instr_remove(&intr->instr);
418 }
419
420 bool NirLowerFSOutToVector::instr_can_rewrite_type(nir_intrinsic_instr *intr) const
421 {
422 if (intr->intrinsic != nir_intrinsic_store_deref)
423 return false;
424
425 nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
426 if (deref->mode != nir_var_shader_out)
427 return false;
428
429 return var_can_rewrite(nir_deref_instr_get_variable(deref));
430 }
431
432 nir_ssa_def *NirLowerFSOutToVector::create_combined_vector(nir_builder *b, nir_ssa_def **srcs,
433 int first_comp, int num_comp)
434 {
435 nir_op op;
436 switch (num_comp) {
437 case 2: op = nir_op_vec2; break;
438 case 3: op = nir_op_vec3; break;
439 case 4: op = nir_op_vec4; break;
440 default:
441 assert(0 && "combined vector must have 2 to 4 components");
442
443 }
444 nir_alu_instr * instr = nir_alu_instr_create(b->shader, op);
445 instr->exact = b->exact;
446
447 int i = 0;
448 unsigned k = 0;
449 while (i < num_comp) {
450 nir_ssa_def *s = srcs[first_comp + k];
451 for(uint8_t kk = 0; kk < s->num_components && i < num_comp; ++kk) {
452 instr->src[i].src = nir_src_for_ssa(s);
453 instr->src[i].swizzle[0] = kk;
454 ++i;
455 }
456 k += s->num_components;
457 }
458
459 nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_comp, 32, NULL);
460 instr->dest.write_mask = (1 << num_comp) - 1;
461 nir_builder_instr_insert(b, &instr->instr);
462 return &instr->dest.dest.ssa;
463 }
464
465 }