nir: Get rid of nir_shader::stage
[mesa.git] / src / intel / compiler / brw_nir_lower_cs_intrinsics.c
1 /*
2 * Copyright (c) 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "brw_nir.h"
25 #include "compiler/nir/nir_builder.h"
26
27 struct lower_intrinsics_state {
28 nir_shader *nir;
29 struct brw_cs_prog_data *prog_data;
30 nir_function_impl *impl;
31 bool progress;
32 nir_builder builder;
33 int thread_local_id_index;
34 };
35
36 static nir_ssa_def *
37 read_thread_local_id(struct lower_intrinsics_state *state)
38 {
39 struct brw_cs_prog_data *prog_data = state->prog_data;
40 nir_builder *b = &state->builder;
41 nir_shader *nir = state->nir;
42 const unsigned *sizes = nir->info.cs.local_size;
43 const unsigned group_size = sizes[0] * sizes[1] * sizes[2];
44
45 /* Some programs have local_size dimensions so small that the thread local
46 * ID will always be 0.
47 */
48 if (group_size <= 8)
49 return nir_imm_int(b, 0);
50
51 if (state->thread_local_id_index == -1) {
52 state->thread_local_id_index = prog_data->base.nr_params;
53 uint32_t *param = brw_stage_prog_data_add_params(&prog_data->base, 1);
54 *param = BRW_PARAM_BUILTIN_THREAD_LOCAL_ID;
55 nir->num_uniforms += 4;
56 }
57 unsigned id_index = state->thread_local_id_index;
58
59 nir_intrinsic_instr *load =
60 nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
61 load->num_components = 1;
62 load->src[0] = nir_src_for_ssa(nir_imm_int(b, 0));
63 nir_ssa_dest_init(&load->instr, &load->dest, 1, 32, NULL);
64 nir_intrinsic_set_base(load, id_index * sizeof(uint32_t));
65 nir_intrinsic_set_range(load, sizeof(uint32_t));
66 nir_builder_instr_insert(b, &load->instr);
67 return &load->dest.ssa;
68 }
69
70 static bool
71 lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
72 nir_block *block)
73 {
74 bool progress = false;
75 nir_builder *b = &state->builder;
76 nir_shader *nir = state->nir;
77
78 nir_foreach_instr_safe(instr, block) {
79 if (instr->type != nir_instr_type_intrinsic)
80 continue;
81
82 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
83
84 b->cursor = nir_after_instr(&intrinsic->instr);
85
86 nir_ssa_def *sysval;
87 switch (intrinsic->intrinsic) {
88 case nir_intrinsic_load_local_invocation_index: {
89 /* We construct the local invocation index from:
90 *
91 * gl_LocalInvocationIndex =
92 * cs_thread_local_id + subgroup_invocation;
93 */
94 nir_ssa_def *thread_local_id = read_thread_local_id(state);
95 nir_ssa_def *channel = nir_load_subgroup_invocation(b);
96 sysval = nir_iadd(b, channel, thread_local_id);
97 break;
98 }
99
100 case nir_intrinsic_load_local_invocation_id: {
101 /* We lower gl_LocalInvocationID from gl_LocalInvocationIndex based
102 * on this formula:
103 *
104 * gl_LocalInvocationID.x =
105 * gl_LocalInvocationIndex % gl_WorkGroupSize.x;
106 * gl_LocalInvocationID.y =
107 * (gl_LocalInvocationIndex / gl_WorkGroupSize.x) %
108 * gl_WorkGroupSize.y;
109 * gl_LocalInvocationID.z =
110 * (gl_LocalInvocationIndex /
111 * (gl_WorkGroupSize.x * gl_WorkGroupSize.y)) %
112 * gl_WorkGroupSize.z;
113 */
114 unsigned *size = nir->info.cs.local_size;
115
116 nir_ssa_def *local_index = nir_load_local_invocation_index(b);
117
118 nir_const_value uvec3;
119 uvec3.u32[0] = 1;
120 uvec3.u32[1] = size[0];
121 uvec3.u32[2] = size[0] * size[1];
122 nir_ssa_def *div_val = nir_build_imm(b, 3, 32, uvec3);
123 uvec3.u32[0] = size[0];
124 uvec3.u32[1] = size[1];
125 uvec3.u32[2] = size[2];
126 nir_ssa_def *mod_val = nir_build_imm(b, 3, 32, uvec3);
127
128 sysval = nir_umod(b, nir_udiv(b, local_index, div_val), mod_val);
129 break;
130 }
131
132 default:
133 continue;
134 }
135
136 nir_ssa_def_rewrite_uses(&intrinsic->dest.ssa, nir_src_for_ssa(sysval));
137 nir_instr_remove(&intrinsic->instr);
138
139 state->progress = true;
140 }
141
142 return progress;
143 }
144
145 static void
146 lower_cs_intrinsics_convert_impl(struct lower_intrinsics_state *state)
147 {
148 nir_builder_init(&state->builder, state->impl);
149
150 nir_foreach_block(block, state->impl) {
151 lower_cs_intrinsics_convert_block(state, block);
152 }
153
154 nir_metadata_preserve(state->impl,
155 nir_metadata_block_index | nir_metadata_dominance);
156 }
157
158 bool
159 brw_nir_lower_cs_intrinsics(nir_shader *nir,
160 struct brw_cs_prog_data *prog_data)
161 {
162 assert(nir->info.stage == MESA_SHADER_COMPUTE);
163
164 bool progress = false;
165 struct lower_intrinsics_state state;
166 memset(&state, 0, sizeof(state));
167 state.nir = nir;
168 state.prog_data = prog_data;
169
170 state.thread_local_id_index = -1;
171
172 do {
173 state.progress = false;
174 nir_foreach_function(function, nir) {
175 if (function->impl) {
176 state.impl = function->impl;
177 lower_cs_intrinsics_convert_impl(&state);
178 }
179 }
180 progress |= state.progress;
181 } while (state.progress);
182
183 return progress;
184 }