nir/lower_subgroups: Lower ballot intrinsics to the specified bit size
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
32 static nir_ssa_def *
33 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
34 unsigned num_components, unsigned bit_size)
35 {
36 assert(value->num_components == 1);
37 assert(value->bit_size == 32 || value->bit_size == 64);
38
39 nir_ssa_def *zero = nir_imm_int(b, 0);
40 if (num_components > 1) {
41 /* SPIR-V uses a uvec4 for ballot values */
42 assert(num_components == 4);
43 assert(bit_size == 32);
44
45 if (value->bit_size == 32) {
46 return nir_vec4(b, value, zero, zero, zero);
47 } else {
48 assert(value->bit_size == 64);
49 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
50 nir_unpack_64_2x32_split_y(b, value),
51 zero, zero);
52 }
53 } else {
54 /* GLSL uses a uint64_t for ballot values */
55 assert(num_components == 1);
56 assert(bit_size == 64);
57
58 if (value->bit_size == 32) {
59 return nir_pack_64_2x32_split(b, value, zero);
60 } else {
61 assert(value->bit_size == 64);
62 return value;
63 }
64 }
65 }
66
67 static nir_ssa_def *
68 lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
69 {
70 /* This is safe to call on scalar things but it would be silly */
71 assert(intrin->dest.ssa.num_components > 1);
72
73 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
74 intrin->num_components);
75 nir_ssa_def *reads[4];
76
77 for (unsigned i = 0; i < intrin->num_components; i++) {
78 nir_intrinsic_instr *chan_intrin =
79 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
80 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
81 1, intrin->dest.ssa.bit_size, NULL);
82 chan_intrin->num_components = 1;
83
84 /* value */
85 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
86 /* invocation */
87 if (intrin->intrinsic == nir_intrinsic_read_invocation)
88 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
89
90 nir_builder_instr_insert(b, &chan_intrin->instr);
91
92 reads[i] = &chan_intrin->dest.ssa;
93 }
94
95 return nir_vec(b, reads, intrin->num_components);
96 }
97
98 static nir_ssa_def *
99 high_subgroup_mask(nir_builder *b,
100 nir_ssa_def *count,
101 uint64_t base_mask,
102 unsigned bit_size)
103 {
104 /* group_mask could probably be calculated more efficiently but we want to
105 * be sure not to shift by 64 if the subgroup size is 64 because the GLSL
106 * shift operator is undefined in that case. In any case if we were worried
107 * about efficency this should probably be done further down because the
108 * subgroup size is likely to be known at compile time.
109 */
110 nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
111 nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size);
112 nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size);
113 nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift);
114 nir_ssa_def *higher_bits =
115 nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count);
116
117 return nir_iand(b, higher_bits, group_mask);
118 }
119
120 static nir_ssa_def *
121 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
122 const nir_lower_subgroups_options *options)
123 {
124 switch (intrin->intrinsic) {
125 case nir_intrinsic_vote_any:
126 case nir_intrinsic_vote_all:
127 if (options->lower_vote_trivial)
128 return nir_ssa_for_src(b, intrin->src[0], 1);
129 break;
130
131 case nir_intrinsic_vote_eq:
132 if (options->lower_vote_trivial)
133 return nir_imm_int(b, NIR_TRUE);
134 break;
135
136 case nir_intrinsic_read_invocation:
137 case nir_intrinsic_read_first_invocation:
138 if (options->lower_to_scalar && intrin->num_components > 1)
139 return lower_read_invocation_to_scalar(b, intrin);
140 break;
141
142 case nir_intrinsic_load_subgroup_eq_mask:
143 case nir_intrinsic_load_subgroup_ge_mask:
144 case nir_intrinsic_load_subgroup_gt_mask:
145 case nir_intrinsic_load_subgroup_le_mask:
146 case nir_intrinsic_load_subgroup_lt_mask: {
147 if (!options->lower_subgroup_masks)
148 return NULL;
149
150 /* If either the result or the requested bit size is 64-bits then we
151 * know that we have 64-bit types and using them will probably be more
152 * efficient than messing around with 32-bit shifts and packing.
153 */
154 const unsigned bit_size = MAX2(options->ballot_bit_size,
155 intrin->dest.ssa.bit_size);
156
157 nir_ssa_def *count = nir_load_subgroup_invocation(b);
158 nir_ssa_def *val;
159 switch (intrin->intrinsic) {
160 case nir_intrinsic_load_subgroup_eq_mask:
161 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
162 break;
163 case nir_intrinsic_load_subgroup_ge_mask:
164 val = high_subgroup_mask(b, count, ~0ull, bit_size);
165 break;
166 case nir_intrinsic_load_subgroup_gt_mask:
167 val = high_subgroup_mask(b, count, ~1ull, bit_size);
168 break;
169 case nir_intrinsic_load_subgroup_le_mask:
170 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
171 break;
172 case nir_intrinsic_load_subgroup_lt_mask:
173 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
174 break;
175 default:
176 unreachable("you seriously can't tell this is unreachable?");
177 }
178
179 return uint_to_ballot_type(b, val,
180 intrin->dest.ssa.num_components,
181 intrin->dest.ssa.bit_size);
182 }
183
184 case nir_intrinsic_ballot: {
185 if (intrin->dest.ssa.num_components == 1 &&
186 intrin->dest.ssa.bit_size == options->ballot_bit_size)
187 return NULL;
188
189 nir_intrinsic_instr *ballot =
190 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
191 ballot->num_components = 1;
192 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
193 1, options->ballot_bit_size, NULL);
194 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
195 nir_builder_instr_insert(b, &ballot->instr);
196
197 return uint_to_ballot_type(b, &ballot->dest.ssa,
198 intrin->dest.ssa.num_components,
199 intrin->dest.ssa.bit_size);
200 }
201
202 default:
203 break;
204 }
205
206 return NULL;
207 }
208
209 static bool
210 lower_subgroups_impl(nir_function_impl *impl,
211 const nir_lower_subgroups_options *options)
212 {
213 nir_builder b;
214 nir_builder_init(&b, impl);
215 bool progress = false;
216
217 nir_foreach_block(block, impl) {
218 nir_foreach_instr_safe(instr, block) {
219 if (instr->type != nir_instr_type_intrinsic)
220 continue;
221
222 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
223 b.cursor = nir_before_instr(instr);
224
225 nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
226 if (!lower)
227 continue;
228
229 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
230 nir_instr_remove(instr);
231 progress = true;
232 }
233 }
234
235 return progress;
236 }
237
238 bool
239 nir_lower_subgroups(nir_shader *shader,
240 const nir_lower_subgroups_options *options)
241 {
242 bool progress = false;
243
244 nir_foreach_function(function, shader) {
245 if (!function->impl)
246 continue;
247
248 if (lower_subgroups_impl(function->impl, options)) {
249 progress = true;
250 nir_metadata_preserve(function->impl, nir_metadata_block_index |
251 nir_metadata_dominance);
252 }
253 }
254
255 return progress;
256 }