nir: Generalize nir_intrinsic_vote_eq
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 static nir_ssa_def *
32 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
33 {
34 /* We only use this on uvec4 types */
35 assert(value->num_components == 4 && value->bit_size == 32);
36
37 if (bit_size == 32) {
38 return nir_channel(b, value, 0);
39 } else {
40 assert(bit_size == 64);
41 return nir_pack_64_2x32_split(b, nir_channel(b, value, 0),
42 nir_channel(b, value, 1));
43 }
44 }
45
46 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
47 static nir_ssa_def *
48 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
49 unsigned num_components, unsigned bit_size)
50 {
51 assert(value->num_components == 1);
52 assert(value->bit_size == 32 || value->bit_size == 64);
53
54 nir_ssa_def *zero = nir_imm_int(b, 0);
55 if (num_components > 1) {
56 /* SPIR-V uses a uvec4 for ballot values */
57 assert(num_components == 4);
58 assert(bit_size == 32);
59
60 if (value->bit_size == 32) {
61 return nir_vec4(b, value, zero, zero, zero);
62 } else {
63 assert(value->bit_size == 64);
64 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
65 nir_unpack_64_2x32_split_y(b, value),
66 zero, zero);
67 }
68 } else {
69 /* GLSL uses a uint64_t for ballot values */
70 assert(num_components == 1);
71 assert(bit_size == 64);
72
73 if (value->bit_size == 32) {
74 return nir_pack_64_2x32_split(b, value, zero);
75 } else {
76 assert(value->bit_size == 64);
77 return value;
78 }
79 }
80 }
81
82 static nir_ssa_def *
83 lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
84 {
85 /* This is safe to call on scalar things but it would be silly */
86 assert(intrin->dest.ssa.num_components > 1);
87
88 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
89 intrin->num_components);
90 nir_ssa_def *reads[4];
91
92 for (unsigned i = 0; i < intrin->num_components; i++) {
93 nir_intrinsic_instr *chan_intrin =
94 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
95 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
96 1, intrin->dest.ssa.bit_size, NULL);
97 chan_intrin->num_components = 1;
98
99 /* value */
100 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
101 /* invocation */
102 if (intrin->intrinsic == nir_intrinsic_read_invocation)
103 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
104
105 nir_builder_instr_insert(b, &chan_intrin->instr);
106
107 reads[i] = &chan_intrin->dest.ssa;
108 }
109
110 return nir_vec(b, reads, intrin->num_components);
111 }
112
113 static nir_ssa_def *
114 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
115 const nir_lower_subgroups_options *options)
116 {
117 switch (intrin->intrinsic) {
118 case nir_intrinsic_vote_any:
119 case nir_intrinsic_vote_all:
120 if (options->lower_vote_trivial)
121 return nir_ssa_for_src(b, intrin->src[0], 1);
122 break;
123
124 case nir_intrinsic_vote_feq:
125 case nir_intrinsic_vote_ieq:
126 if (options->lower_vote_trivial)
127 return nir_imm_int(b, NIR_TRUE);
128 break;
129
130 case nir_intrinsic_load_subgroup_size:
131 if (options->subgroup_size)
132 return nir_imm_int(b, options->subgroup_size);
133 break;
134
135 case nir_intrinsic_read_invocation:
136 case nir_intrinsic_read_first_invocation:
137 if (options->lower_to_scalar && intrin->num_components > 1)
138 return lower_read_invocation_to_scalar(b, intrin);
139 break;
140
141 case nir_intrinsic_load_subgroup_eq_mask:
142 case nir_intrinsic_load_subgroup_ge_mask:
143 case nir_intrinsic_load_subgroup_gt_mask:
144 case nir_intrinsic_load_subgroup_le_mask:
145 case nir_intrinsic_load_subgroup_lt_mask: {
146 if (!options->lower_subgroup_masks)
147 return NULL;
148
149 /* If either the result or the requested bit size is 64-bits then we
150 * know that we have 64-bit types and using them will probably be more
151 * efficient than messing around with 32-bit shifts and packing.
152 */
153 const unsigned bit_size = MAX2(options->ballot_bit_size,
154 intrin->dest.ssa.bit_size);
155
156 assert(options->subgroup_size <= 64);
157 uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
158
159 nir_ssa_def *count = nir_load_subgroup_invocation(b);
160 nir_ssa_def *val;
161 switch (intrin->intrinsic) {
162 case nir_intrinsic_load_subgroup_eq_mask:
163 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
164 break;
165 case nir_intrinsic_load_subgroup_ge_mask:
166 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
167 nir_imm_intN_t(b, group_mask, bit_size));
168 break;
169 case nir_intrinsic_load_subgroup_gt_mask:
170 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
171 nir_imm_intN_t(b, group_mask, bit_size));
172 break;
173 case nir_intrinsic_load_subgroup_le_mask:
174 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
175 break;
176 case nir_intrinsic_load_subgroup_lt_mask:
177 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
178 break;
179 default:
180 unreachable("you seriously can't tell this is unreachable?");
181 }
182
183 return uint_to_ballot_type(b, val,
184 intrin->dest.ssa.num_components,
185 intrin->dest.ssa.bit_size);
186 }
187
188 case nir_intrinsic_ballot: {
189 if (intrin->dest.ssa.num_components == 1 &&
190 intrin->dest.ssa.bit_size == options->ballot_bit_size)
191 return NULL;
192
193 nir_intrinsic_instr *ballot =
194 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
195 ballot->num_components = 1;
196 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
197 1, options->ballot_bit_size, NULL);
198 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
199 nir_builder_instr_insert(b, &ballot->instr);
200
201 return uint_to_ballot_type(b, &ballot->dest.ssa,
202 intrin->dest.ssa.num_components,
203 intrin->dest.ssa.bit_size);
204 }
205
206 case nir_intrinsic_ballot_bitfield_extract:
207 case nir_intrinsic_ballot_bit_count_reduce:
208 case nir_intrinsic_ballot_find_lsb:
209 case nir_intrinsic_ballot_find_msb: {
210 assert(intrin->src[0].is_ssa);
211 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
212 options->ballot_bit_size);
213 switch (intrin->intrinsic) {
214 case nir_intrinsic_ballot_bitfield_extract:
215 assert(intrin->src[1].is_ssa);
216 return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
217 intrin->src[1].ssa),
218 nir_imm_int(b, 1)));
219 case nir_intrinsic_ballot_bit_count_reduce:
220 return nir_bit_count(b, int_val);
221 case nir_intrinsic_ballot_find_lsb:
222 return nir_find_lsb(b, int_val);
223 case nir_intrinsic_ballot_find_msb:
224 return nir_ufind_msb(b, int_val);
225 default:
226 unreachable("you seriously can't tell this is unreachable?");
227 }
228 }
229
230 case nir_intrinsic_ballot_bit_count_exclusive:
231 case nir_intrinsic_ballot_bit_count_inclusive: {
232 nir_ssa_def *count = nir_load_subgroup_invocation(b);
233 nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size);
234 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
235 const unsigned bits = options->ballot_bit_size;
236 mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count));
237 } else {
238 mask = nir_inot(b, nir_ishl(b, mask, count));
239 }
240
241 assert(intrin->src[0].is_ssa);
242 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
243 options->ballot_bit_size);
244
245 return nir_bit_count(b, nir_iand(b, int_val, mask));
246 }
247
248 case nir_intrinsic_elect: {
249 nir_intrinsic_instr *first =
250 nir_intrinsic_instr_create(b->shader,
251 nir_intrinsic_first_invocation);
252 nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL);
253 nir_builder_instr_insert(b, &first->instr);
254
255 return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
256 }
257
258 default:
259 break;
260 }
261
262 return NULL;
263 }
264
265 static bool
266 lower_subgroups_impl(nir_function_impl *impl,
267 const nir_lower_subgroups_options *options)
268 {
269 nir_builder b;
270 nir_builder_init(&b, impl);
271 bool progress = false;
272
273 nir_foreach_block(block, impl) {
274 nir_foreach_instr_safe(instr, block) {
275 if (instr->type != nir_instr_type_intrinsic)
276 continue;
277
278 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
279 b.cursor = nir_before_instr(instr);
280
281 nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
282 if (!lower)
283 continue;
284
285 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
286 nir_instr_remove(instr);
287 progress = true;
288 }
289 }
290
291 return progress;
292 }
293
294 bool
295 nir_lower_subgroups(nir_shader *shader,
296 const nir_lower_subgroups_options *options)
297 {
298 bool progress = false;
299
300 nir_foreach_function(function, shader) {
301 if (!function->impl)
302 continue;
303
304 if (lower_subgroups_impl(function->impl, options)) {
305 progress = true;
306 nir_metadata_preserve(function->impl, nir_metadata_block_index |
307 nir_metadata_dominance);
308 }
309 }
310
311 return progress;
312 }