896b9504868467bf7cc030aa33bb0d7dbca1e4b2
[mesa.git] / src / compiler / nir / nir_lower_io_to_vector.c
1 /*
2 * Copyright © 2019 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27
28 /** @file nir_lower_io_to_vector.c
29 *
30 * Merges compatible input/output variables residing in different components
31 * of the same location. It's expected that further passes such as
32 * nir_lower_io_to_temporaries will combine loads and stores of the merged
33 * variables, producing vector nir_load_input/nir_store_output instructions
34 * when all is said and done.
35 */
36
37 /* FRAG_RESULT_MAX+1 instead of just FRAG_RESULT_MAX because of how this pass
38 * handles dual source blending */
39 #define MAX_SLOTS MAX2(VARYING_SLOT_TESS_MAX, FRAG_RESULT_MAX+1)
40
41 static unsigned
42 get_slot(const nir_variable *var)
43 {
44 /* This handling of dual-source blending might not be correct when more than
45 * one render target is supported, but it seems no driver supports more than
46 * one. */
47 return var->data.location + var->data.index;
48 }
49
50 static const struct glsl_type *
51 resize_array_vec_type(const struct glsl_type *type, unsigned num_components)
52 {
53 if (glsl_type_is_array(type)) {
54 const struct glsl_type *arr_elem =
55 resize_array_vec_type(glsl_get_array_element(type), num_components);
56 return glsl_array_type(arr_elem, glsl_get_length(type), 0);
57 } else {
58 assert(glsl_type_is_vector_or_scalar(type));
59 return glsl_vector_type(glsl_get_base_type(type), num_components);
60 }
61 }
62
63 static bool
64 variable_can_rewrite(const nir_variable *var)
65 {
66 /* Skip complex types we don't split in the first place */
67 if (!glsl_type_is_vector_or_scalar(glsl_without_array(var->type)))
68 return false;
69
70 /* TODO: add 64/16bit support ? */
71 if (glsl_get_bit_size(glsl_without_array(var->type)) != 32)
72 return false;
73
74 return true;
75 }
76
77 static bool
78 variables_can_merge(nir_shader *shader,
79 const nir_variable *a, const nir_variable *b)
80 {
81 const struct glsl_type *a_type_tail = a->type;
82 const struct glsl_type *b_type_tail = b->type;
83
84 /* They must have the same array structure */
85 while (glsl_type_is_array(a_type_tail)) {
86 if (!glsl_type_is_array(b_type_tail))
87 return false;
88
89 if (glsl_get_length(a_type_tail) != glsl_get_length(b_type_tail))
90 return false;
91
92 a_type_tail = glsl_get_array_element(a_type_tail);
93 b_type_tail = glsl_get_array_element(b_type_tail);
94 }
95
96 if (!glsl_type_is_vector_or_scalar(a_type_tail) ||
97 !glsl_type_is_vector_or_scalar(b_type_tail))
98 return false;
99
100 if (glsl_get_base_type(a->type) != glsl_get_base_type(b->type))
101 return false;
102
103 assert(a->data.mode == b->data.mode);
104 if (shader->info.stage == MESA_SHADER_FRAGMENT &&
105 a->data.mode == nir_var_shader_in &&
106 a->data.interpolation != b->data.interpolation)
107 return false;
108
109 if (shader->info.stage == MESA_SHADER_FRAGMENT &&
110 a->data.mode == nir_var_shader_out &&
111 a->data.index != b->data.index)
112 return false;
113
114 return true;
115 }
116
117 static bool
118 create_new_io_vars(nir_shader *shader, struct exec_list *io_list,
119 nir_variable *old_vars[MAX_SLOTS][4],
120 nir_variable *new_vars[MAX_SLOTS][4])
121 {
122 if (exec_list_is_empty(io_list))
123 return false;
124
125 nir_foreach_variable(var, io_list) {
126 if (variable_can_rewrite(var)) {
127 unsigned frac = var->data.location_frac;
128 old_vars[get_slot(var)][frac] = var;
129 }
130 }
131
132 bool merged_any_vars = false;
133
134 /* We don't handle combining vars of different type e.g. different array
135 * lengths.
136 */
137 for (unsigned loc = 0; loc < MAX_SLOTS; loc++) {
138 unsigned frac = 0;
139 while (frac < 4) {
140 nir_variable *first_var = old_vars[loc][frac];
141 if (!first_var) {
142 frac++;
143 continue;
144 }
145
146 int first = frac;
147 bool found_merge = false;
148
149 while (frac < 4) {
150 nir_variable *var = old_vars[loc][frac];
151 if (!var)
152 break;
153
154 if (var != first_var) {
155 if (!variables_can_merge(shader, first_var, var))
156 break;
157
158 found_merge = true;
159 }
160
161 const unsigned num_components =
162 glsl_get_components(glsl_without_array(var->type));
163
164 /* We had better not have any overlapping vars */
165 for (unsigned i = 1; i < num_components; i++)
166 assert(old_vars[loc][frac + i] == NULL);
167
168 frac += num_components;
169 }
170
171 if (!found_merge)
172 continue;
173
174 merged_any_vars = true;
175
176 nir_variable *var = nir_variable_clone(old_vars[loc][first], shader);
177 var->data.location_frac = first;
178 var->type = resize_array_vec_type(var->type, frac - first);
179
180 nir_shader_add_variable(shader, var);
181 for (unsigned i = first; i < frac; i++)
182 new_vars[loc][i] = var;
183 }
184 }
185
186 return merged_any_vars;
187 }
188
189 static nir_deref_instr *
190 build_array_deref_of_new_var(nir_builder *b, nir_variable *new_var,
191 nir_deref_instr *leader)
192 {
193 if (leader->deref_type == nir_deref_type_var)
194 return nir_build_deref_var(b, new_var);
195
196 nir_deref_instr *parent =
197 build_array_deref_of_new_var(b, new_var, nir_deref_instr_parent(leader));
198
199 return nir_build_deref_follower(b, parent, leader);
200 }
201
202 static bool
203 nir_lower_io_to_vector_impl(nir_function_impl *impl, nir_variable_mode modes)
204 {
205 assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out)));
206
207 nir_builder b;
208 nir_builder_init(&b, impl);
209
210 nir_metadata_require(impl, nir_metadata_dominance);
211
212 nir_shader *shader = impl->function->shader;
213 nir_variable *old_inputs[MAX_SLOTS][4] = {{0}};
214 nir_variable *new_inputs[MAX_SLOTS][4] = {{0}};
215 nir_variable *old_outputs[MAX_SLOTS][4] = {{0}};
216 nir_variable *new_outputs[MAX_SLOTS][4] = {{0}};
217
218 if (modes & nir_var_shader_in) {
219 /* Vertex shaders support overlapping inputs. We don't do those */
220 assert(b.shader->info.stage != MESA_SHADER_VERTEX);
221
222 /* If we don't actually merge any variables, remove that bit from modes
223 * so we don't bother doing extra non-work.
224 */
225 if (!create_new_io_vars(shader, &shader->inputs,
226 old_inputs, new_inputs))
227 modes &= ~nir_var_shader_in;
228 }
229
230 if (modes & nir_var_shader_out) {
231 /* If we don't actually merge any variables, remove that bit from modes
232 * so we don't bother doing extra non-work.
233 */
234 if (!create_new_io_vars(shader, &shader->outputs,
235 old_outputs, new_outputs))
236 modes &= ~nir_var_shader_out;
237 }
238
239 if (!modes)
240 return false;
241
242 bool progress = false;
243
244 /* Actually lower all the IO load/store intrinsics. Load instructions are
245 * lowered to a vector load and an ALU instruction to grab the channels we
246 * want. Outputs are lowered to a write-masked store of the vector output.
247 * For non-TCS outputs, we then run nir_lower_io_to_temporaries at the end
248 * to clean up the partial writes.
249 */
250 nir_foreach_block(block, impl) {
251 nir_foreach_instr_safe(instr, block) {
252 if (instr->type != nir_instr_type_intrinsic)
253 continue;
254
255 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
256
257 switch (intrin->intrinsic) {
258 case nir_intrinsic_load_deref:
259 case nir_intrinsic_interp_deref_at_centroid:
260 case nir_intrinsic_interp_deref_at_sample:
261 case nir_intrinsic_interp_deref_at_offset: {
262 nir_deref_instr *old_deref = nir_src_as_deref(intrin->src[0]);
263 if (!(old_deref->mode & modes))
264 break;
265
266 if (old_deref->mode == nir_var_shader_out)
267 assert(b.shader->info.stage == MESA_SHADER_TESS_CTRL ||
268 b.shader->info.stage == MESA_SHADER_FRAGMENT);
269
270 nir_variable *old_var = nir_deref_instr_get_variable(old_deref);
271
272 const unsigned loc = get_slot(old_var);
273 const unsigned old_frac = old_var->data.location_frac;
274 nir_variable *new_var = old_deref->mode == nir_var_shader_in ?
275 new_inputs[loc][old_frac] :
276 new_outputs[loc][old_frac];
277 if (!new_var)
278 break;
279
280 assert(get_slot(new_var) == loc);
281 const unsigned new_frac = new_var->data.location_frac;
282
283 nir_component_mask_t vec4_comp_mask =
284 ((1 << intrin->num_components) - 1) << old_frac;
285
286 b.cursor = nir_before_instr(&intrin->instr);
287
288 /* Rewrite the load to use the new variable and only select a
289 * portion of the result.
290 */
291 nir_deref_instr *new_deref =
292 build_array_deref_of_new_var(&b, new_var, old_deref);
293 assert(glsl_type_is_vector(new_deref->type));
294 nir_instr_rewrite_src(&intrin->instr, &intrin->src[0],
295 nir_src_for_ssa(&new_deref->dest.ssa));
296
297 intrin->num_components =
298 glsl_get_components(new_deref->type);
299 intrin->dest.ssa.num_components = intrin->num_components;
300
301 b.cursor = nir_after_instr(&intrin->instr);
302
303 nir_ssa_def *new_vec = nir_channels(&b, &intrin->dest.ssa,
304 vec4_comp_mask >> new_frac);
305 nir_ssa_def_rewrite_uses_after(&intrin->dest.ssa,
306 nir_src_for_ssa(new_vec),
307 new_vec->parent_instr);
308
309 progress = true;
310 break;
311 }
312
313 case nir_intrinsic_store_deref: {
314 nir_deref_instr *old_deref = nir_src_as_deref(intrin->src[0]);
315 if (old_deref->mode != nir_var_shader_out)
316 break;
317
318 nir_variable *old_var = nir_deref_instr_get_variable(old_deref);
319
320 const unsigned loc = get_slot(old_var);
321 const unsigned old_frac = old_var->data.location_frac;
322 nir_variable *new_var = new_outputs[loc][old_frac];
323 if (!new_var)
324 break;
325
326 assert(get_slot(new_var) == loc);
327 const unsigned new_frac = new_var->data.location_frac;
328
329 b.cursor = nir_before_instr(&intrin->instr);
330
331 /* Rewrite the store to be a masked store to the new variable */
332 nir_deref_instr *new_deref =
333 build_array_deref_of_new_var(&b, new_var, old_deref);
334 assert(glsl_type_is_vector(new_deref->type));
335 nir_instr_rewrite_src(&intrin->instr, &intrin->src[0],
336 nir_src_for_ssa(&new_deref->dest.ssa));
337
338 intrin->num_components =
339 glsl_get_components(new_deref->type);
340
341 nir_component_mask_t old_wrmask = nir_intrinsic_write_mask(intrin);
342
343 assert(intrin->src[1].is_ssa);
344 nir_ssa_def *old_value = intrin->src[1].ssa;
345 nir_ssa_def *comps[4];
346 for (unsigned c = 0; c < intrin->num_components; c++) {
347 if (new_frac + c >= old_frac &&
348 (old_wrmask & 1 << (new_frac + c - old_frac))) {
349 comps[c] = nir_channel(&b, old_value,
350 new_frac + c - old_frac);
351 } else {
352 comps[c] = nir_ssa_undef(&b, old_value->num_components,
353 old_value->bit_size);
354 }
355 }
356 nir_ssa_def *new_value = nir_vec(&b, comps, intrin->num_components);
357 nir_instr_rewrite_src(&intrin->instr, &intrin->src[1],
358 nir_src_for_ssa(new_value));
359
360 nir_intrinsic_set_write_mask(intrin,
361 old_wrmask << (old_frac - new_frac));
362
363 progress = true;
364 break;
365 }
366
367 default:
368 break;
369 }
370 }
371 }
372
373 if (progress) {
374 nir_metadata_preserve(impl, nir_metadata_block_index |
375 nir_metadata_dominance);
376 }
377
378 return progress;
379 }
380
381 bool
382 nir_lower_io_to_vector(nir_shader *shader, nir_variable_mode modes)
383 {
384 bool progress = false;
385
386 nir_foreach_function(function, shader) {
387 if (function->impl)
388 progress |= nir_lower_io_to_vector_impl(function->impl, modes);
389 }
390
391 return progress;
392 }