zink: simplify gl-to-vulkan lowering
[mesa.git] / src / gallium / drivers / zink / zink_compiler.c
1 /*
2 * Copyright 2018 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * on the rights to use, copy, modify, merge, publish, distribute, sub
8 * license, and/or sell copies of the Software, and to permit persons to whom
9 * the Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 * USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "zink_compiler.h"
25 #include "zink_screen.h"
26 #include "nir_to_spirv/nir_to_spirv.h"
27
28 #include "pipe/p_state.h"
29
30 #include "nir.h"
31 #include "compiler/nir/nir_builder.h"
32
33 #include "nir/tgsi_to_nir.h"
34 #include "tgsi/tgsi_dump.h"
35 #include "tgsi/tgsi_from_mesa.h"
36
37 #include "util/u_memory.h"
38
39 static bool
40 lower_instr(nir_intrinsic_instr *instr, nir_builder *b)
41 {
42 b->cursor = nir_before_instr(&instr->instr);
43
44 if (instr->intrinsic == nir_intrinsic_load_ubo) {
45 nir_ssa_def *old_idx = nir_ssa_for_src(b, instr->src[0], 1);
46 nir_ssa_def *new_idx = nir_iadd(b, old_idx, nir_imm_int(b, 1));
47 nir_instr_rewrite_src(&instr->instr, &instr->src[0],
48 nir_src_for_ssa(new_idx));
49 return true;
50 }
51
52 if (instr->intrinsic == nir_intrinsic_load_uniform) {
53 nir_ssa_def *ubo_idx = nir_imm_int(b, 0);
54 nir_ssa_def *ubo_offset =
55 nir_iadd(b, nir_imm_int(b, nir_intrinsic_base(instr)),
56 nir_ssa_for_src(b, instr->src[0], 1));
57
58 nir_intrinsic_instr *load =
59 nir_intrinsic_instr_create(b->shader, nir_intrinsic_load_ubo);
60 load->num_components = instr->num_components;
61 load->src[0] = nir_src_for_ssa(ubo_idx);
62 load->src[1] = nir_src_for_ssa(ubo_offset);
63 nir_ssa_dest_init(&load->instr, &load->dest,
64 load->num_components, instr->dest.ssa.bit_size,
65 instr->dest.ssa.name);
66 nir_builder_instr_insert(b, &load->instr);
67 nir_ssa_def_rewrite_uses(&instr->dest.ssa, nir_src_for_ssa(&load->dest.ssa));
68
69 nir_instr_remove(&instr->instr);
70 return true;
71 }
72
73 return false;
74 }
75
76 static bool
77 lower_uniforms_to_ubo(nir_shader *shader)
78 {
79 bool progress = false;
80
81 nir_foreach_function(function, shader) {
82 if (function->impl) {
83 nir_builder builder;
84 nir_builder_init(&builder, function->impl);
85 nir_foreach_block(block, function->impl) {
86 nir_foreach_instr_safe(instr, block) {
87 if (instr->type == nir_instr_type_intrinsic)
88 progress |= lower_instr(nir_instr_as_intrinsic(instr),
89 &builder);
90 }
91 }
92
93 nir_metadata_preserve(function->impl, nir_metadata_block_index |
94 nir_metadata_dominance);
95 }
96 }
97
98 if (progress) {
99 assert(shader->num_uniforms > 0);
100 const struct glsl_type *type = glsl_array_type(glsl_vec4_type(),
101 shader->num_uniforms, 0);
102 nir_variable *ubo = nir_variable_create(shader, nir_var_mem_ubo, type,
103 "uniform_0");
104 ubo->data.binding = 0;
105
106 struct glsl_struct_field field = {
107 .type = type,
108 .name = "data",
109 .location = -1,
110 };
111 ubo->interface_type =
112 glsl_interface_type(&field, 1, GLSL_INTERFACE_PACKING_STD430,
113 false, "__ubo0_interface");
114 }
115
116 return progress;
117 }
118
119 static void
120 lower_pos_write(nir_builder *b, struct nir_instr *instr)
121 {
122 if (instr->type != nir_instr_type_intrinsic)
123 return;
124
125 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
126 struct nir_src *src;
127 if (intr->intrinsic == nir_intrinsic_store_output) {
128 if (nir_intrinsic_base(intr) != VARYING_SLOT_POS)
129 return;
130 src = &intr->src[0];
131 } else if (intr->intrinsic == nir_intrinsic_store_deref) {
132 nir_variable *var = nir_intrinsic_get_var(intr, 0);
133 if (var->data.mode != nir_var_shader_out ||
134 var->data.location != VARYING_SLOT_POS)
135 return;
136 src = &intr->src[1];
137 } else
138 return;
139
140 b->cursor = nir_before_instr(&intr->instr);
141
142 nir_ssa_def *pos = nir_ssa_for_src(b, *src, 4);
143 nir_ssa_def *def = nir_vec4(b,
144 nir_channel(b, pos, 0),
145 nir_channel(b, pos, 1),
146 nir_fmul(b,
147 nir_fadd(b,
148 nir_channel(b, pos, 2),
149 nir_channel(b, pos, 3)),
150 nir_imm_float(b, 0.5)),
151 nir_channel(b, pos, 3));
152 nir_instr_rewrite_src(&intr->instr, src, nir_src_for_ssa(def));
153 }
154
155 static void
156 lower_clip_halfz(nir_shader *s)
157 {
158 if (s->info.stage != MESA_SHADER_VERTEX)
159 return;
160
161 nir_foreach_function(function, s) {
162 if (function->impl) {
163 nir_builder b;
164 nir_builder_init(&b, function->impl);
165
166 nir_foreach_block(block, function->impl) {
167 nir_foreach_instr_safe(instr, block) {
168 lower_pos_write(&b, instr);
169 }
170 }
171
172 nir_metadata_preserve(function->impl, nir_metadata_block_index |
173 nir_metadata_dominance);
174 }
175 }
176 }
177
178 static bool
179 lower_discard_if_instr(nir_intrinsic_instr *instr, nir_builder *b)
180 {
181 if (instr->intrinsic == nir_intrinsic_discard_if) {
182 b->cursor = nir_before_instr(&instr->instr);
183
184 nir_if *if_stmt = nir_push_if(b, nir_ssa_for_src(b, instr->src[0], 1));
185 nir_intrinsic_instr *discard =
186 nir_intrinsic_instr_create(b->shader, nir_intrinsic_discard);
187 nir_builder_instr_insert(b, &discard->instr);
188 nir_pop_if(b, if_stmt);
189 nir_instr_remove(&instr->instr);
190 return true;
191 }
192 assert(instr->intrinsic != nir_intrinsic_discard ||
193 nir_block_last_instr(instr->instr.block) == &instr->instr);
194
195 return false;
196 }
197
198 static bool
199 lower_discard_if(nir_shader *shader)
200 {
201 bool progress = false;
202
203 nir_foreach_function(function, shader) {
204 if (function->impl) {
205 nir_builder builder;
206 nir_builder_init(&builder, function->impl);
207 nir_foreach_block(block, function->impl) {
208 nir_foreach_instr_safe(instr, block) {
209 if (instr->type == nir_instr_type_intrinsic)
210 progress |= lower_discard_if_instr(
211 nir_instr_as_intrinsic(instr),
212 &builder);
213 }
214 }
215
216 nir_metadata_preserve(function->impl, nir_metadata_dominance);
217 }
218 }
219
220 return progress;
221 }
222
223 static const struct nir_shader_compiler_options nir_options = {
224 .lower_all_io_to_temps = true,
225 .lower_ffma = true,
226 .lower_flrp32 = true,
227 .lower_fpow = true,
228 .lower_fsat = true,
229 };
230
231 const void *
232 zink_get_compiler_options(struct pipe_screen *screen,
233 enum pipe_shader_ir ir,
234 enum pipe_shader_type shader)
235 {
236 assert(ir == PIPE_SHADER_IR_NIR);
237 return &nir_options;
238 }
239
240 struct nir_shader *
241 zink_tgsi_to_nir(struct pipe_screen *screen, const struct tgsi_token *tokens)
242 {
243 if (zink_debug & ZINK_DEBUG_TGSI) {
244 fprintf(stderr, "TGSI shader:\n---8<---\n");
245 tgsi_dump_to_file(tokens, 0, stderr);
246 fprintf(stderr, "---8<---\n\n");
247 }
248
249 return tgsi_to_nir(tokens, screen);
250 }
251
252 static void
253 optimize_nir(struct nir_shader *s)
254 {
255 bool progress;
256 do {
257 progress = false;
258 NIR_PASS_V(s, nir_lower_vars_to_ssa);
259 NIR_PASS(progress, s, nir_copy_prop);
260 NIR_PASS(progress, s, nir_opt_remove_phis);
261 NIR_PASS(progress, s, nir_opt_dce);
262 NIR_PASS(progress, s, nir_opt_dead_cf);
263 NIR_PASS(progress, s, nir_opt_cse);
264 NIR_PASS(progress, s, nir_opt_peephole_select, 8, true, true);
265 NIR_PASS(progress, s, nir_opt_algebraic);
266 NIR_PASS(progress, s, nir_opt_constant_folding);
267 NIR_PASS(progress, s, nir_opt_undef);
268 } while (progress);
269 }
270
271 static uint32_t
272 zink_binding(enum pipe_shader_type stage, VkDescriptorType type, int index)
273 {
274 if (stage == PIPE_SHADER_COMPUTE) {
275 unreachable("not supported");
276 } else {
277 uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
278 PIPE_MAX_SHADER_SAMPLER_VIEWS);
279
280 switch (type) {
281 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
282 assert(index < PIPE_MAX_CONSTANT_BUFFERS);
283 return stage_offset + index;
284
285 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
286 assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
287 return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
288
289 default:
290 unreachable("unexpected type");
291 }
292 }
293 }
294
295 struct zink_shader *
296 zink_compile_nir(struct zink_screen *screen, struct nir_shader *nir)
297 {
298 struct zink_shader *ret = CALLOC_STRUCT(zink_shader);
299
300 NIR_PASS_V(nir, lower_uniforms_to_ubo);
301 NIR_PASS_V(nir, lower_clip_halfz);
302 NIR_PASS_V(nir, nir_lower_regs_to_ssa);
303 optimize_nir(nir);
304 NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp);
305 NIR_PASS_V(nir, lower_discard_if);
306 NIR_PASS_V(nir, nir_convert_from_ssa, true);
307
308 if (zink_debug & ZINK_DEBUG_NIR) {
309 fprintf(stderr, "NIR shader:\n---8<---\n");
310 nir_print_shader(nir, stderr);
311 fprintf(stderr, "---8<---\n");
312 }
313
314 enum pipe_shader_type stage = pipe_shader_type_from_mesa(nir->info.stage);
315
316 ret->num_bindings = 0;
317 nir_foreach_variable(var, &nir->uniforms) {
318 if (glsl_type_is_sampler(var->type)) {
319 ret->bindings[ret->num_bindings].index = var->data.driver_location;
320 var->data.binding = zink_binding(stage, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, var->data.driver_location);
321 ret->bindings[ret->num_bindings].binding = var->data.binding;
322 ret->bindings[ret->num_bindings].type = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
323 ret->num_bindings++;
324 } else if (var->interface_type) {
325 ret->bindings[ret->num_bindings].index = var->data.binding;
326 var->data.binding = zink_binding(stage, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, var->data.binding);
327 ret->bindings[ret->num_bindings].binding = var->data.binding;
328 ret->bindings[ret->num_bindings].type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
329 ret->num_bindings++;
330 }
331 }
332
333 ret->info = nir->info;
334
335 struct spirv_shader *spirv = nir_to_spirv(nir);
336 assert(spirv);
337
338 if (zink_debug & ZINK_DEBUG_SPIRV) {
339 char buf[256];
340 static int i;
341 snprintf(buf, sizeof(buf), "dump%02d.spv", i++);
342 FILE *fp = fopen(buf, "wb");
343 fwrite(spirv->words, sizeof(uint32_t), spirv->num_words, fp);
344 fclose(fp);
345 fprintf(stderr, "wrote '%s'...\n", buf);
346 }
347
348 VkShaderModuleCreateInfo smci = {};
349 smci.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
350 smci.codeSize = spirv->num_words * sizeof(uint32_t);
351 smci.pCode = spirv->words;
352
353 if (vkCreateShaderModule(screen->dev, &smci, NULL, &ret->shader_module) != VK_SUCCESS)
354 return NULL;
355
356 return ret;
357 }
358
359 void
360 zink_shader_free(struct zink_screen *screen, struct zink_shader *shader)
361 {
362 vkDestroyShaderModule(screen->dev, shader->shader_module, NULL);
363 FREE(shader);
364 }