radv: keep a pointer to a NIR shader into radv_shader_context
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Wed, 28 Aug 2019 15:08:29 +0000 (17:08 +0200)
committerSamuel Pitoiset <samuel.pitoiset@gmail.com>
Fri, 30 Aug 2019 07:33:30 +0000 (09:33 +0200)
This avoids multiple copies for nothing and it's more elegant.

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Dave Airlie <airlied@redhat.com>
src/amd/vulkan/radv_nir_to_llvm.c

index 1c1633a51e77486e391b88bd4f606c163939f76c..047a77d6c9658e8f6396483c6fae19f77ad1bf40 100644 (file)
@@ -49,6 +49,7 @@ struct radv_shader_context {
        struct ac_llvm_context ac;
        const struct radv_nir_compiler_options *options;
        struct radv_shader_variant_info *shader_info;
+       const struct nir_shader *shader;
        struct ac_shader_abi abi;
 
        unsigned max_workgroup_size;
@@ -104,14 +105,7 @@ struct radv_shader_context {
        LLVMValueRef gs_generated_prims[4];
        LLVMValueRef gs_ngg_emit;
        LLVMValueRef gs_ngg_scratch;
-       unsigned gs_max_out_vertices;
-       unsigned gs_output_prim;
 
-       unsigned tes_primitive_mode;
-
-       uint32_t tcs_patch_outputs_read;
-       uint64_t tcs_outputs_read;
-       uint32_t tcs_vertices_per_patch;
        uint32_t tcs_num_inputs;
        uint32_t tcs_num_patches;
        uint32_t max_gsvs_emit_size;
@@ -159,13 +153,13 @@ static unsigned
 get_tcs_num_patches(struct radv_shader_context *ctx)
 {
        unsigned num_tcs_input_cp = ctx->options->key.tcs.input_vertices;
-       unsigned num_tcs_output_cp = ctx->tcs_vertices_per_patch;
+       unsigned num_tcs_output_cp = ctx->shader->info.tess.tcs_vertices_out;
        uint32_t input_vertex_size = ctx->tcs_num_inputs * 16;
        uint32_t input_patch_size = ctx->options->key.tcs.input_vertices * input_vertex_size;
        uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->info.tcs.outputs_written);
        uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->shader_info->info.tcs.patch_outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
-       uint32_t pervertex_output_patch_size = ctx->tcs_vertices_per_patch * output_vertex_size;
+       uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
        uint32_t output_patch_size = pervertex_output_patch_size + num_tcs_patch_outputs * 16;
        unsigned num_patches;
        unsigned hardware_lds_size;
@@ -217,7 +211,7 @@ calculate_tess_lds_size(struct radv_shader_context *ctx)
        unsigned num_patches;
        unsigned lds_size;
 
-       num_tcs_output_cp = ctx->tcs_vertices_per_patch;
+       num_tcs_output_cp = ctx->shader->info.tess.tcs_vertices_out;
        num_tcs_outputs = util_last_bit64(ctx->shader_info->info.tcs.outputs_written);
        num_tcs_patch_outputs = util_last_bit64(ctx->shader_info->info.tcs.patch_outputs_written);
 
@@ -273,7 +267,7 @@ get_tcs_out_patch_stride(struct radv_shader_context *ctx)
        uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->info.tcs.outputs_written);
        uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->shader_info->info.tcs.patch_outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
-       uint32_t pervertex_output_patch_size = ctx->tcs_vertices_per_patch * output_vertex_size;
+       uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
        uint32_t output_patch_size = pervertex_output_patch_size + num_tcs_patch_outputs * 16;
        output_patch_size /= 4;
        return LLVMConstInt(ctx->ac.i32, output_patch_size, false);
@@ -312,7 +306,7 @@ get_tcs_out_patch0_patch_data_offset(struct radv_shader_context *ctx)
 
        uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->info.tcs.outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
-       uint32_t pervertex_output_patch_size = ctx->tcs_vertices_per_patch * output_vertex_size;
+       uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
        unsigned num_patches = ctx->tcs_num_patches;
 
        output_patch0_offset *= num_patches;
@@ -1333,7 +1327,7 @@ static LLVMValueRef get_non_vertex_index_offset(struct radv_shader_context *ctx)
                num_tcs_outputs = ctx->options->key.tes.tcs_num_outputs;
 
        uint32_t output_vertex_size = num_tcs_outputs * 16;
-       uint32_t pervertex_output_patch_size = ctx->tcs_vertices_per_patch * output_vertex_size;
+       uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
 
        return LLVMConstInt(ctx->ac.i32, pervertex_output_patch_size * num_patches, false);
 }
@@ -1343,7 +1337,7 @@ static LLVMValueRef calc_param_stride(struct radv_shader_context *ctx,
 {
        LLVMValueRef param_stride;
        if (vertex_index)
-               param_stride = LLVMConstInt(ctx->ac.i32, ctx->tcs_vertices_per_patch * ctx->tcs_num_patches, false);
+               param_stride = LLVMConstInt(ctx->ac.i32, ctx->shader->info.tess.tcs_vertices_out * ctx->tcs_num_patches, false);
        else
                param_stride = LLVMConstInt(ctx->ac.i32, ctx->tcs_num_patches, false);
        return param_stride;
@@ -1356,7 +1350,7 @@ static LLVMValueRef get_tcs_tes_buffer_address(struct radv_shader_context *ctx,
        LLVMValueRef base_addr;
        LLVMValueRef param_stride, constant16;
        LLVMValueRef rel_patch_id = get_rel_patch_id(ctx);
-       LLVMValueRef vertices_per_patch = LLVMConstInt(ctx->ac.i32, ctx->tcs_vertices_per_patch, false);
+       LLVMValueRef vertices_per_patch = LLVMConstInt(ctx->ac.i32, ctx->shader->info.tess.tcs_vertices_out, false);
        constant16 = LLVMConstInt(ctx->ac.i32, 16, false);
        param_stride = calc_param_stride(ctx, vertex_index);
        if (vertex_index) {
@@ -1503,10 +1497,10 @@ store_tcs_output(struct ac_shader_abi *abi,
        bool store_lds = true;
 
        if (is_patch) {
-               if (!(ctx->tcs_patch_outputs_read & (1U << (location - VARYING_SLOT_PATCH0))))
+               if (!(ctx->shader->info.patch_outputs_read & (1U << (location - VARYING_SLOT_PATCH0))))
                        store_lds = false;
        } else {
-               if (!(ctx->tcs_outputs_read & (1ULL << location)))
+               if (!(ctx->shader->info.outputs_read & (1ULL << location)))
                        store_lds = false;
        }
 
@@ -1771,7 +1765,7 @@ visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addr
         * effects other than emitting vertices.
         */
        can_emit = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, gs_next_vertex,
-                                LLVMConstInt(ctx->ac.i32, ctx->gs_max_out_vertices, false), "");
+                                LLVMConstInt(ctx->ac.i32, ctx->shader->info.gs.vertices_out, false), "");
        ac_build_kill_if_false(&ctx->ac, can_emit);
 
        for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
@@ -1794,7 +1788,7 @@ visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addr
                                                             out_ptr[j], "");
                        LLVMValueRef voffset =
                                LLVMConstInt(ctx->ac.i32, offset *
-                                            ctx->gs_max_out_vertices, false);
+                                            ctx->shader->info.gs.vertices_out, false);
 
                        offset++;
 
@@ -1846,7 +1840,7 @@ load_tess_coord(struct ac_shader_abi *abi)
                ctx->ac.f32_0,
        };
 
-       if (ctx->tes_primitive_mode == GL_TRIANGLES)
+       if (ctx->shader->info.tess.primitive_mode == GL_TRIANGLES)
                coord[2] = LLVMBuildFSub(ctx->ac.builder, ctx->ac.f32_1,
                                        LLVMBuildFAdd(ctx->ac.builder, coord[0], coord[1], ""), "");
 
@@ -3084,7 +3078,7 @@ ngg_gs_vertex_ptr(struct radv_shader_context *ctx, LLVMValueRef vertexidx)
        LLVMValueRef storage = ngg_gs_get_vertex_storage(ctx);
 
        /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
-       unsigned write_stride_2exp = ffs(ctx->gs_max_out_vertices) - 1;
+       unsigned write_stride_2exp = ffs(ctx->shader->info.gs.vertices_out) - 1;
        if (write_stride_2exp) {
                LLVMValueRef row =
                        LLVMBuildLShr(builder, vertexidx,
@@ -3106,7 +3100,7 @@ ngg_gs_emit_vertex_ptr(struct radv_shader_context *ctx, LLVMValueRef gsthread,
        LLVMBuilderRef builder = ctx->ac.builder;
        LLVMValueRef tmp;
 
-       tmp = LLVMConstInt(ctx->ac.i32, ctx->gs_max_out_vertices, false);
+       tmp = LLVMConstInt(ctx->ac.i32, ctx->shader->info.gs.vertices_out, false);
        tmp = LLVMBuildMul(builder, tmp, gsthread, "");
        const LLVMValueRef vertexidx = LLVMBuildAdd(builder, tmp, emitidx, "");
        return ngg_gs_vertex_ptr(ctx, vertexidx);
@@ -3358,7 +3352,7 @@ static void gfx10_ngg_gs_emit_epilogue_1(struct radv_shader_context *ctx)
                const LLVMValueRef vertexidx =
                        LLVMBuildLoad(builder, ctx->gs_next_vertex[stream], "");
                tmp = LLVMBuildICmp(builder, LLVMIntUGE, vertexidx,
-                       LLVMConstInt(ctx->ac.i32, ctx->gs_max_out_vertices, false), "");
+                       LLVMConstInt(ctx->ac.i32, ctx->shader->info.gs.vertices_out, false), "");
                ac_build_ifcc(&ctx->ac, tmp, 5101);
                ac_build_break(&ctx->ac);
                ac_build_endif(&ctx->ac, 5101);
@@ -3381,7 +3375,7 @@ static void gfx10_ngg_gs_emit_epilogue_1(struct radv_shader_context *ctx)
 
 static void gfx10_ngg_gs_emit_epilogue_2(struct radv_shader_context *ctx)
 {
-       const unsigned verts_per_prim = si_conv_gl_prim_to_vertices(ctx->gs_output_prim);
+       const unsigned verts_per_prim = si_conv_gl_prim_to_vertices(ctx->shader->info.gs.output_primitive);
        LLVMBuilderRef builder = ctx->ac.builder;
        LLVMValueRef tmp, tmp2;
 
@@ -3620,7 +3614,7 @@ static void gfx10_ngg_gs_emit_vertex(struct radv_shader_context *ctx,
         */
        const LLVMValueRef can_emit =
                LLVMBuildICmp(builder, LLVMIntULT, vertexidx,
-                             LLVMConstInt(ctx->ac.i32, ctx->gs_max_out_vertices, false), "");
+                             LLVMConstInt(ctx->ac.i32, ctx->shader->info.gs.vertices_out, false), "");
        ac_build_kill_if_false(&ctx->ac, can_emit);
 
        tmp = LLVMBuildAdd(builder, vertexidx, ctx->ac.i32_1, "");
@@ -3666,7 +3660,7 @@ static void gfx10_ngg_gs_emit_vertex(struct radv_shader_context *ctx,
        /* Determine and store whether this vertex completed a primitive. */
        const LLVMValueRef curverts = LLVMBuildLoad(builder, ctx->gs_curprim_verts[stream], "");
 
-       tmp = LLVMConstInt(ctx->ac.i32, si_conv_gl_prim_to_vertices(ctx->gs_output_prim) - 1, false);
+       tmp = LLVMConstInt(ctx->ac.i32, si_conv_gl_prim_to_vertices(ctx->shader->info.gs.output_primitive) - 1, false);
        const LLVMValueRef iscompleteprim =
                LLVMBuildICmp(builder, LLVMIntUGE, curverts, tmp, "");
 
@@ -4080,7 +4074,7 @@ ac_setup_rings(struct radv_shader_context *ctx)
                        if (!num_components)
                                continue;
 
-                       stride = 4 * num_components * ctx->gs_max_out_vertices;
+                       stride = 4 * num_components * ctx->shader->info.gs.vertices_out;
 
                        /* Limit on the stride field for <= GFX7. */
                        assert(stride < (1 << 14));
@@ -4243,6 +4237,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
 
        for(int i = 0; i < shader_count; ++i) {
                ctx.stage = shaders[i]->info.stage;
+               ctx.shader = shaders[i];
                ctx.output_mask = 0;
 
                if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY) {
@@ -4272,28 +4267,21 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                                        "ngg_emit");
                        }
 
-                       ctx.gs_max_out_vertices = shaders[i]->info.gs.vertices_out;
-                       ctx.gs_output_prim = shaders[i]->info.gs.output_primitive;
                        ctx.abi.load_inputs = load_gs_input;
                        ctx.abi.emit_primitive = visit_end_primitive;
                } else if (shaders[i]->info.stage == MESA_SHADER_TESS_CTRL) {
-                       ctx.tcs_outputs_read = shaders[i]->info.outputs_read;
-                       ctx.tcs_patch_outputs_read = shaders[i]->info.patch_outputs_read;
                        ctx.abi.load_tess_varyings = load_tcs_varyings;
                        ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                        ctx.abi.store_tcs_outputs = store_tcs_output;
-                       ctx.tcs_vertices_per_patch = shaders[i]->info.tess.tcs_vertices_out;
                        if (shader_count == 1)
                                ctx.tcs_num_inputs = ctx.options->key.tcs.num_inputs;
                        else
                                ctx.tcs_num_inputs = util_last_bit64(shader_info->info.vs.ls_outputs_written);
                        ctx.tcs_num_patches = get_tcs_num_patches(&ctx);
                } else if (shaders[i]->info.stage == MESA_SHADER_TESS_EVAL) {
-                       ctx.tes_primitive_mode = shaders[i]->info.tess.primitive_mode;
                        ctx.abi.load_tess_varyings = load_tes_input;
                        ctx.abi.load_tess_coord = load_tess_coord;
                        ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
-                       ctx.tcs_vertices_per_patch = shaders[i]->info.tess.tcs_vertices_out;
                        ctx.tcs_num_patches = ctx.options->key.tes.num_patches;
                } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
                        ctx.abi.load_base_vertex = radv_load_base_vertex;
@@ -4645,7 +4633,7 @@ ac_gs_copy_shader_emit(struct radv_shader_context *ctx)
 
                                soffset = LLVMConstInt(ctx->ac.i32,
                                                       offset *
-                                                      ctx->gs_max_out_vertices * 16 * 4, false);
+                                                      ctx->shader->info.gs.vertices_out * 16 * 4, false);
 
                                offset++;
 
@@ -4701,12 +4689,12 @@ radv_compile_gs_copy_shader(struct ac_llvm_compiler *ac_llvm,
 
        ctx.is_gs_copy_shader = true;
        ctx.stage = MESA_SHADER_VERTEX;
+       ctx.shader = geom_shader;
 
        radv_nir_shader_info_pass(geom_shader, options, &shader_info->info);
 
        create_function(&ctx, MESA_SHADER_VERTEX, false, MESA_SHADER_VERTEX);
 
-       ctx.gs_max_out_vertices = geom_shader->info.gs.vertices_out;
        ac_setup_rings(&ctx);
 
        nir_foreach_variable(variable, &geom_shader->outputs) {