radeonsi/nir: don't run si_nir_opts again if there is no change
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index 3d758759af77b28ef5f7e3e43c58d9490a2a9949..210a37a39061158ffae3fa5a968279cb0dbff509 100644 (file)
@@ -38,6 +38,7 @@
 struct ac_nir_context {
        struct ac_llvm_context ac;
        struct ac_shader_abi *abi;
+       const struct ac_shader_args *args;
 
        gl_shader_stage stage;
        shader_info *info;
@@ -1435,16 +1436,22 @@ static LLVMValueRef visit_load_push_constant(struct ac_nir_context *ctx,
                offset += LLVMConstIntGetZExtValue(src0);
                offset /= 4;
 
-               offset -= ctx->abi->base_inline_push_consts;
+               offset -= ctx->args->base_inline_push_consts;
 
-               if (offset + count <= ctx->abi->num_inline_push_consts) {
+               unsigned num_inline_push_consts = ctx->args->num_inline_push_consts;
+               if (offset + count <= num_inline_push_consts) {
+                       LLVMValueRef push_constants[num_inline_push_consts];
+                       for (unsigned i = 0; i < num_inline_push_consts; i++)
+                               push_constants[i] = ac_get_arg(&ctx->ac,
+                                                              ctx->args->inline_push_consts[i]);
                        return ac_build_gather_values(&ctx->ac,
-                                                     ctx->abi->inline_push_consts + offset,
+                                                     push_constants + offset,
                                                      count);
                }
        }
 
-       ptr = LLVMBuildGEP(ctx->ac.builder, ctx->abi->push_constants, &addr, 1, "");
+       ptr = LLVMBuildGEP(ctx->ac.builder,
+                          ac_get_arg(&ctx->ac, ctx->args->push_constants), &addr, 1, "");
 
        if (instr->dest.ssa.bit_size == 8) {
                unsigned load_dwords = instr->dest.ssa.num_components > 1 ? 2 : 1;
@@ -1643,7 +1650,7 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
                        ac_build_buffer_store_dword(&ctx->ac, rsrc, data,
                                                    num_channels, offset,
                                                    ctx->ac.i32_0, 0,
-                                                   cache_policy, false);
+                                                   cache_policy);
                }
        }
 }
@@ -2540,7 +2547,7 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx,
                const struct glsl_type *type = image_deref->type;
                const nir_variable *var = nir_deref_instr_get_variable(image_deref);
                dim = glsl_get_sampler_dim(type);
-               access = var->data.image.access;
+               access = var->data.access;
                is_array = glsl_sampler_type_is_array(type);
        }
 
@@ -2597,7 +2604,7 @@ static void visit_image_store(struct ac_nir_context *ctx,
                const struct glsl_type *type = image_deref->type;
                const nir_variable *var = nir_deref_instr_get_variable(image_deref);
                dim = glsl_get_sampler_dim(type);
-               access = var->data.image.access;
+               access = var->data.access;
                is_array = glsl_sampler_type_is_array(type);
        }
 
@@ -2902,9 +2909,14 @@ visit_load_local_invocation_index(struct ac_nir_context *ctx)
 {
        LLVMValueRef result;
        LLVMValueRef thread_id = ac_get_thread_id(&ctx->ac);
-       result = LLVMBuildAnd(ctx->ac.builder, ctx->abi->tg_size,
+       result = LLVMBuildAnd(ctx->ac.builder,
+                             ac_get_arg(&ctx->ac, ctx->args->tg_size),
                              LLVMConstInt(ctx->ac.i32, 0xfc0, false), "");
 
+       if (ctx->ac.wave_size == 32)
+               result = LLVMBuildLShr(ctx->ac.builder, result,
+                                      LLVMConstInt(ctx->ac.i32, 1, false), "");
+
        return LLVMBuildAdd(ctx->ac.builder, result, thread_id, "");
 }
 
@@ -2913,7 +2925,8 @@ visit_load_subgroup_id(struct ac_nir_context *ctx)
 {
        if (ctx->stage == MESA_SHADER_COMPUTE) {
                LLVMValueRef result;
-               result = LLVMBuildAnd(ctx->ac.builder, ctx->abi->tg_size,
+               result = LLVMBuildAnd(ctx->ac.builder,
+                                     ac_get_arg(&ctx->ac, ctx->args->tg_size),
                                LLVMConstInt(ctx->ac.i32, 0xfc0, false), "");
                return LLVMBuildLShr(ctx->ac.builder, result,  LLVMConstInt(ctx->ac.i32, 6, false), "");
        } else {
@@ -2925,7 +2938,8 @@ static LLVMValueRef
 visit_load_num_subgroups(struct ac_nir_context *ctx)
 {
        if (ctx->stage == MESA_SHADER_COMPUTE) {
-               return LLVMBuildAnd(ctx->ac.builder, ctx->abi->tg_size,
+               return LLVMBuildAnd(ctx->ac.builder,
+                                   ac_get_arg(&ctx->ac, ctx->args->tg_size),
                                    LLVMConstInt(ctx->ac.i32, 0x3f, false), "");
        } else {
                return LLVMConstInt(ctx->ac.i32, 1, false);
@@ -3055,8 +3069,10 @@ static LLVMValueRef load_sample_pos(struct ac_nir_context *ctx)
        LLVMValueRef values[2];
        LLVMValueRef pos[2];
 
-       pos[0] = ac_to_float(&ctx->ac, ctx->abi->frag_pos[0]);
-       pos[1] = ac_to_float(&ctx->ac, ctx->abi->frag_pos[1]);
+       pos[0] = ac_to_float(&ctx->ac,
+                            ac_get_arg(&ctx->ac, ctx->args->frag_pos[0]));
+       pos[1] = ac_to_float(&ctx->ac,
+                            ac_get_arg(&ctx->ac, ctx->args->frag_pos[1]));
 
        values[0] = ac_build_fract(&ctx->ac, pos[0], 32);
        values[1] = ac_build_fract(&ctx->ac, pos[1], 32);
@@ -3073,19 +3089,19 @@ static LLVMValueRef lookup_interp_param(struct ac_nir_context *ctx,
        case INTERP_MODE_SMOOTH:
        case INTERP_MODE_NONE:
                if (location == INTERP_CENTER)
-                       return ctx->abi->persp_center;
+                       return ac_get_arg(&ctx->ac, ctx->args->persp_center);
                else if (location == INTERP_CENTROID)
                        return ctx->abi->persp_centroid;
                else if (location == INTERP_SAMPLE)
-                       return ctx->abi->persp_sample;
+                       return ac_get_arg(&ctx->ac, ctx->args->persp_sample);
                break;
        case INTERP_MODE_NOPERSPECTIVE:
                if (location == INTERP_CENTER)
-                       return ctx->abi->linear_center;
+                       return ac_get_arg(&ctx->ac, ctx->args->linear_center);
                else if (location == INTERP_CENTROID)
                        return ctx->abi->linear_centroid;
                else if (location == INTERP_SAMPLE)
-                       return ctx->abi->linear_sample;
+                       return ac_get_arg(&ctx->ac, ctx->args->linear_sample);
                break;
        }
        return NULL;
@@ -3199,10 +3215,10 @@ static LLVMValueRef load_interpolated_input(struct ac_nir_context *ctx,
                LLVMValueRef llvm_chan = LLVMConstInt(ctx->ac.i32, comp_start + comp, false);
                if (bitsize == 16) {
                        values[comp] = ac_build_fs_interp_f16(&ctx->ac, llvm_chan, attr_number,
-                                                             ctx->abi->prim_mask, i, j);
+                                                             ac_get_arg(&ctx->ac, ctx->args->prim_mask), i, j);
                } else {
                        values[comp] = ac_build_fs_interp(&ctx->ac, llvm_chan, attr_number,
-                                                         ctx->abi->prim_mask, i, j);
+                                                         ac_get_arg(&ctx->ac, ctx->args->prim_mask), i, j);
                }
        }
 
@@ -3230,7 +3246,7 @@ static LLVMValueRef load_flat_input(struct ac_nir_context *ctx,
                                                      LLVMConstInt(ctx->ac.i32, 2, false),
                                                      llvm_chan,
                                                      attr_number,
-                                                     ctx->abi->prim_mask);
+                                                     ac_get_arg(&ctx->ac, ctx->args->prim_mask));
                values[chan] = LLVMBuildBitCast(ctx->ac.builder, values[chan], ctx->ac.i32, "");
                values[chan] = LLVMBuildTruncOrBitCast(ctx->ac.builder, values[chan],
                                                       bit_size == 16 ? ctx->ac.i16 : ctx->ac.i32, "");
@@ -3270,8 +3286,8 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                LLVMValueRef values[3];
 
                for (int i = 0; i < 3; i++) {
-                       values[i] = ctx->abi->workgroup_ids[i] ?
-                                   ctx->abi->workgroup_ids[i] : ctx->ac.i32_0;
+                       values[i] = ctx->args->workgroup_ids[i].used ?
+                                   ac_get_arg(&ctx->ac, ctx->args->workgroup_ids[i]) : ctx->ac.i32_0;
                }
 
                result = ac_build_gather_values(&ctx->ac, values, 3);
@@ -3285,51 +3301,56 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->load_local_group_size(ctx->abi);
                break;
        case nir_intrinsic_load_vertex_id:
-               result = LLVMBuildAdd(ctx->ac.builder, ctx->abi->vertex_id,
-                                     ctx->abi->base_vertex, "");
+               result = LLVMBuildAdd(ctx->ac.builder,
+                                     ac_get_arg(&ctx->ac, ctx->args->vertex_id),
+                                     ac_get_arg(&ctx->ac, ctx->args->base_vertex), "");
                break;
        case nir_intrinsic_load_vertex_id_zero_base: {
                result = ctx->abi->vertex_id;
                break;
        }
        case nir_intrinsic_load_local_invocation_id: {
-               result = ctx->abi->local_invocation_ids;
+               result = ac_get_arg(&ctx->ac, ctx->args->local_invocation_ids);
                break;
        }
        case nir_intrinsic_load_base_instance:
-               result = ctx->abi->start_instance;
+               result = ac_get_arg(&ctx->ac, ctx->args->start_instance);
                break;
        case nir_intrinsic_load_draw_id:
-               result = ctx->abi->draw_id;
+               result = ac_get_arg(&ctx->ac, ctx->args->draw_id);
                break;
        case nir_intrinsic_load_view_index:
-               result = ctx->abi->view_index;
+               result = ac_get_arg(&ctx->ac, ctx->args->view_index);
                break;
        case nir_intrinsic_load_invocation_id:
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-                       result = ac_unpack_param(&ctx->ac, ctx->abi->tcs_rel_ids, 8, 5);
+                       result = ac_unpack_param(&ctx->ac,
+                                                ac_get_arg(&ctx->ac, ctx->args->tcs_rel_ids),
+                                                8, 5);
                } else {
                        if (ctx->ac.chip_class >= GFX10) {
                                result = LLVMBuildAnd(ctx->ac.builder,
-                                                     ctx->abi->gs_invocation_id,
+                                                     ac_get_arg(&ctx->ac, ctx->args->gs_invocation_id),
                                                      LLVMConstInt(ctx->ac.i32, 127, 0), "");
                        } else {
-                               result = ctx->abi->gs_invocation_id;
+                               result = ac_get_arg(&ctx->ac, ctx->args->gs_invocation_id);
                        }
                }
                break;
        case nir_intrinsic_load_primitive_id:
                if (ctx->stage == MESA_SHADER_GEOMETRY) {
-                       result = ctx->abi->gs_prim_id;
+                       result = ac_get_arg(&ctx->ac, ctx->args->gs_prim_id);
                } else if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-                       result = ctx->abi->tcs_patch_id;
+                       result = ac_get_arg(&ctx->ac, ctx->args->tcs_patch_id);
                } else if (ctx->stage == MESA_SHADER_TESS_EVAL) {
-                       result = ctx->abi->tes_patch_id;
+                       result = ac_get_arg(&ctx->ac, ctx->args->tes_patch_id);
                } else
                        fprintf(stderr, "Unknown primitive id intrinsic: %d", ctx->stage);
                break;
        case nir_intrinsic_load_sample_id:
-               result = ac_unpack_param(&ctx->ac, ctx->abi->ancillary, 8, 4);
+               result = ac_unpack_param(&ctx->ac,
+                                        ac_get_arg(&ctx->ac, ctx->args->ancillary),
+                                        8, 4);
                break;
        case nir_intrinsic_load_sample_pos:
                result = load_sample_pos(ctx);
@@ -3339,10 +3360,11 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                break;
        case nir_intrinsic_load_frag_coord: {
                LLVMValueRef values[4] = {
-                       ctx->abi->frag_pos[0],
-                       ctx->abi->frag_pos[1],
-                       ctx->abi->frag_pos[2],
-                       ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, ctx->abi->frag_pos[3])
+                       ac_get_arg(&ctx->ac, ctx->args->frag_pos[0]),
+                       ac_get_arg(&ctx->ac, ctx->args->frag_pos[1]),
+                       ac_get_arg(&ctx->ac, ctx->args->frag_pos[2]),
+                       ac_build_fdiv(&ctx->ac, ctx->ac.f32_1,
+                                     ac_get_arg(&ctx->ac, ctx->args->frag_pos[3]))
                };
                result = ac_to_integer(&ctx->ac,
                                       ac_build_gather_values(&ctx->ac, values, 4));
@@ -3352,7 +3374,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->inputs[ac_llvm_reg_index_soa(VARYING_SLOT_LAYER, 0)];
                break;
        case nir_intrinsic_load_front_face:
-               result = ctx->abi->front_face;
+               result = ac_get_arg(&ctx->ac, ctx->args->front_face);
                break;
        case nir_intrinsic_load_helper_invocation:
                result = ac_build_load_helper_invocation(&ctx->ac);
@@ -3371,7 +3393,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->instance_id;
                break;
        case nir_intrinsic_load_num_work_groups:
-               result = ctx->abi->num_work_groups;
+               result = ac_get_arg(&ctx->ac, ctx->args->num_work_groups);
                break;
        case nir_intrinsic_load_local_invocation_index:
                result = visit_load_local_invocation_index(ctx);
@@ -4710,13 +4732,14 @@ setup_shared(struct ac_nir_context *ctx,
 }
 
 void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
-                     struct nir_shader *nir)
+                     const struct ac_shader_args *args, struct nir_shader *nir)
 {
        struct ac_nir_context ctx = {};
        struct nir_function *func;
 
        ctx.ac = *ac;
        ctx.abi = abi;
+       ctx.args = args;
 
        ctx.stage = nir->info.stage;
        ctx.info = &nir->info;
@@ -4760,17 +4783,19 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
        ralloc_free(ctx.vars);
 }
 
-void
+bool
 ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
 {
+       bool progress = false;
+
        /* Lower large variables to scratch first so that we won't bloat the
         * shader by generating large if ladders for them. We later lower
         * scratch to alloca's, assuming LLVM won't generate VGPR indexing.
         */
-       NIR_PASS_V(nir, nir_lower_vars_to_scratch,
-                  nir_var_function_temp,
-                  256,
-                  glsl_get_natural_size_align_bytes);
+       NIR_PASS(progress, nir, nir_lower_vars_to_scratch,
+                nir_var_function_temp,
+                256,
+                glsl_get_natural_size_align_bytes);
 
        /* While it would be nice not to have this flag, we are constrained
         * by the reality that LLVM 9.0 has buggy VGPR indexing on GFX9.
@@ -4802,7 +4827,8 @@ ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
         */
        indirect_mask |= nir_var_function_temp;
 
-       nir_lower_indirect_derefs(nir, indirect_mask);
+       progress |= nir_lower_indirect_derefs(nir, indirect_mask);
+       return progress;
 }
 
 static unsigned