ac/nir: Add deref based var loads/stores.
authorBas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Tue, 24 Apr 2018 22:08:39 +0000 (00:08 +0200)
committerJason Ekstrand <jason.ekstrand@intel.com>
Sat, 23 Jun 2018 03:54:03 +0000 (20:54 -0700)
Acked-by: Rob Clark <robdclark@gmail.com>
Acked-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Acked-by: Dave Airlie <airlied@redhat.com>
Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
src/amd/common/ac_nir_to_llvm.c

index d23d3cd5f2e410f1a53dc4884a0c2eb585c71c46..467d1dd19ab567d30552493389628d03ea068f32 100644 (file)
@@ -27,6 +27,7 @@
 #include "ac_binary.h"
 #include "sid.h"
 #include "nir/nir.h"
+#include "nir/nir_deref.h"
 #include "util/bitscan.h"
 #include "util/u_math.h"
 #include "ac_shader_abi.h"
@@ -1627,6 +1628,75 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
                                get_def_type(ctx, &instr->dest.ssa), "");
 }
 
+static void
+get_deref_instr_offset(struct ac_nir_context *ctx, nir_deref_instr *instr,
+                      bool vs_in, unsigned *vertex_index_out,
+                      LLVMValueRef *vertex_index_ref,
+                      unsigned *const_out, LLVMValueRef *indir_out)
+{
+       nir_variable *var = nir_deref_instr_get_variable(instr);
+       nir_deref_path path;
+       unsigned idx_lvl = 1;
+
+       nir_deref_path_init(&path, instr, NULL);
+
+       if (vertex_index_out != NULL || vertex_index_ref != NULL) {
+               if (vertex_index_ref) {
+                       *vertex_index_ref = get_src(ctx, path.path[idx_lvl]->arr.index);
+                       if (vertex_index_out)
+                               *vertex_index_out = 0;
+               } else {
+                       nir_const_value *v = nir_src_as_const_value(path.path[idx_lvl]->arr.index);
+                       assert(v);
+                       *vertex_index_out = v->u32[0];
+               }
+               ++idx_lvl;
+       }
+
+       uint32_t const_offset = 0;
+       LLVMValueRef offset = NULL;
+
+       if (var->data.compact) {
+               assert(instr->deref_type == nir_deref_type_array);
+               nir_const_value *v = nir_src_as_const_value(instr->arr.index);
+               assert(v);
+               const_offset = v->u32[0];
+               goto out;
+       }
+
+       for (; path.path[idx_lvl]; ++idx_lvl) {
+               const struct glsl_type *parent_type = path.path[idx_lvl - 1]->type;
+               if (path.path[idx_lvl]->deref_type == nir_deref_type_struct) {
+                       unsigned index = path.path[idx_lvl]->strct.index;
+
+                       for (unsigned i = 0; i < index; i++) {
+                               const struct glsl_type *ft = glsl_get_struct_field(parent_type, i);
+                               const_offset += glsl_count_attribute_slots(ft, vs_in);
+                       }
+               } else if(path.path[idx_lvl]->deref_type == nir_deref_type_array) {
+                       unsigned size = glsl_count_attribute_slots(path.path[idx_lvl]->type, vs_in);
+                       LLVMValueRef array_off = LLVMBuildMul(ctx->ac.builder, LLVMConstInt(ctx->ac.i32, size, 0),
+                                                             get_src(ctx, path.path[idx_lvl]->arr.index), "");
+                       if (offset)
+                               offset = LLVMBuildAdd(ctx->ac.builder, offset, array_off, "");
+                       else
+                               offset = array_off;
+               } else
+                       unreachable("Uhandled deref type in get_deref_instr_offset");
+       }
+
+out:
+       nir_deref_path_finish(&path);
+
+       if (const_offset && offset)
+               offset = LLVMBuildAdd(ctx->ac.builder, offset,
+                                     LLVMConstInt(ctx->ac.i32, const_offset, 0),
+                                     "");
+
+       *const_out = const_offset;
+       *indir_out = offset;
+}
+
 static void
 get_deref_offset(struct ac_nir_context *ctx, nir_deref_var *deref,
                 bool vs_in, unsigned *vertex_index_out,
@@ -1753,14 +1823,25 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
        LLVMValueRef vertex_index = NULL;
        LLVMValueRef indir_index = NULL;
        unsigned const_index = 0;
-       unsigned location = instr->variables[0]->var->data.location;
-       unsigned driver_location = instr->variables[0]->var->data.driver_location;
-       const bool is_patch =  instr->variables[0]->var->data.patch;
-       const bool is_compact = instr->variables[0]->var->data.compact;
 
-       get_deref_offset(ctx, instr->variables[0],
-                        false, NULL, is_patch ? NULL : &vertex_index,
-                        &const_index, &indir_index);
+       bool uses_deref_chain = instr->intrinsic == nir_intrinsic_load_var;
+       nir_variable *var = uses_deref_chain ? instr->variables[0]->var :
+                            nir_deref_instr_get_variable(nir_instr_as_deref(instr->src[0].ssa->parent_instr));
+
+       unsigned location = var->data.location;
+       unsigned driver_location = var->data.driver_location;
+       const bool is_patch =  var->data.patch;
+       const bool is_compact = var->data.compact;
+
+       if (uses_deref_chain) {
+               get_deref_offset(ctx, instr->variables[0],
+                                false, NULL, is_patch ? NULL : &vertex_index,
+                                &const_index, &indir_index);
+       } else {
+               get_deref_instr_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr),
+                                      false, NULL, is_patch ? NULL : &vertex_index,
+                                      &const_index, &indir_index);
+       }
 
        LLVMTypeRef dest_type = get_def_type(ctx, &instr->dest.ssa);
 
@@ -1773,7 +1854,7 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
        result = ctx->abi->load_tess_varyings(ctx->abi, src_component_type,
                                              vertex_index, indir_index,
                                              const_index, location, driver_location,
-                                             instr->variables[0]->var->data.location_frac,
+                                             var->data.location_frac,
                                              instr->num_components,
                                              is_patch, is_compact, load_inputs);
        return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
@@ -1782,23 +1863,33 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
 static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                                   nir_intrinsic_instr *instr)
 {
+       bool uses_deref_chain = instr->intrinsic == nir_intrinsic_load_var;
+       nir_variable *var = uses_deref_chain ? instr->variables[0]->var :
+                           nir_deref_instr_get_variable(nir_instr_as_deref(instr->src[0].ssa->parent_instr));
+
        LLVMValueRef values[8];
-       int idx = instr->variables[0]->var->data.driver_location;
+       int idx = var->data.driver_location;
        int ve = instr->dest.ssa.num_components;
-       unsigned comp = instr->variables[0]->var->data.location_frac;
+       unsigned comp = var->data.location_frac;
        LLVMValueRef indir_index;
        LLVMValueRef ret;
        unsigned const_index;
-       unsigned stride = instr->variables[0]->var->data.compact ? 1 : 4;
+       unsigned stride = var->data.compact ? 1 : 4;
        bool vs_in = ctx->stage == MESA_SHADER_VERTEX &&
-                    instr->variables[0]->var->data.mode == nir_var_shader_in;
-       get_deref_offset(ctx, instr->variables[0], vs_in, NULL, NULL,
-                                     &const_index, &indir_index);
+                    var->data.mode == nir_var_shader_in;
+
+       if (uses_deref_chain) {
+               get_deref_offset(ctx, instr->variables[0], vs_in, NULL, NULL,
+                                &const_index, &indir_index);
+       } else {
+               get_deref_instr_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr), vs_in, NULL, NULL,
+                                      &const_index, &indir_index);
+       }
 
        if (instr->dest.ssa.bit_size == 64)
                ve *= 2;
 
-       switch (instr->variables[0]->var->data.mode) {
+       switch (var->data.mode) {
        case nir_var_shader_in:
                if (ctx->stage == MESA_SHADER_TESS_CTRL ||
                    ctx->stage == MESA_SHADER_TESS_EVAL) {
@@ -1809,20 +1900,25 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                        LLVMTypeRef type = LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.bit_size);
                        LLVMValueRef indir_index;
                        unsigned const_index, vertex_index;
-                       get_deref_offset(ctx, instr->variables[0],
-                                        false, &vertex_index, NULL,
-                                        &const_index, &indir_index);
+                       if (uses_deref_chain) {
+                               get_deref_offset(ctx, instr->variables[0],
+                                                false, &vertex_index, NULL,
+                                                &const_index, &indir_index);
+                       } else {
+                               get_deref_instr_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr),
+                                                      false, &vertex_index, NULL, &const_index, &indir_index);
+                       }
 
-                       return ctx->abi->load_inputs(ctx->abi, instr->variables[0]->var->data.location,
-                                                    instr->variables[0]->var->data.driver_location,
-                                                    instr->variables[0]->var->data.location_frac,
+                       return ctx->abi->load_inputs(ctx->abi, var->data.location,
+                                                    var->data.driver_location,
+                                                    var->data.location_frac,
                                                     instr->num_components, vertex_index, const_index, type);
                }
 
                for (unsigned chan = comp; chan < ve + comp; chan++) {
                        if (indir_index) {
                                unsigned count = glsl_count_attribute_slots(
-                                               instr->variables[0]->var->type,
+                                               var->type,
                                                ctx->stage == MESA_SHADER_VERTEX);
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
@@ -1840,7 +1936,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                for (unsigned chan = 0; chan < ve; chan++) {
                        if (indir_index) {
                                unsigned count = glsl_count_attribute_slots(
-                                       instr->variables[0]->var->type, false);
+                                       var->type, false);
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                                &ctx->ac, ctx->locals + idx + chan, count,
@@ -1855,8 +1951,9 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                }
                break;
        case nir_var_shared: {
-               LLVMValueRef address = build_gep_for_deref(ctx,
-                                                          instr->variables[0]);
+               LLVMValueRef address = uses_deref_chain ?
+                                          build_gep_for_deref(ctx, instr->variables[0])
+                                        : get_src(ctx, instr->src[0]);
                LLVMValueRef val = LLVMBuildLoad(ctx->ac.builder, address, "");
                return LLVMBuildBitCast(ctx->ac.builder, val,
                                        get_def_type(ctx, &instr->dest.ssa),
@@ -1870,7 +1967,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                for (unsigned chan = comp; chan < ve + comp; chan++) {
                        if (indir_index) {
                                unsigned count = glsl_count_attribute_slots(
-                                               instr->variables[0]->var->type, false);
+                                               var->type, false);
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                                &ctx->ac, ctx->abi->outputs + idx + chan, count,
@@ -1897,15 +1994,25 @@ static void
 visit_store_var(struct ac_nir_context *ctx,
                nir_intrinsic_instr *instr)
 {
+        bool uses_deref_chain = instr->intrinsic == nir_intrinsic_store_var;
+        nir_variable *var = uses_deref_chain ? instr->variables[0]->var :
+                            nir_deref_instr_get_variable(nir_instr_as_deref(instr->src[0].ssa->parent_instr));
+
        LLVMValueRef temp_ptr, value;
-       int idx = instr->variables[0]->var->data.driver_location;
-       unsigned comp = instr->variables[0]->var->data.location_frac;
-       LLVMValueRef src = ac_to_float(&ctx->ac, get_src(ctx, instr->src[0]));
+       int idx = var->data.driver_location;
+       unsigned comp = var->data.location_frac;
+       LLVMValueRef src = ac_to_float(&ctx->ac, get_src(ctx, instr->src[uses_deref_chain ? 0 : 1]));
        int writemask = instr->const_index[0];
        LLVMValueRef indir_index;
        unsigned const_index;
-       get_deref_offset(ctx, instr->variables[0], false,
-                        NULL, NULL, &const_index, &indir_index);
+
+       if (uses_deref_chain) {
+               get_deref_offset(ctx, instr->variables[0], false,
+                                NULL, NULL, &const_index, &indir_index);
+       } else {
+               get_deref_instr_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr), false,
+                                      NULL, NULL, &const_index, &indir_index);
+       }
 
        if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src)) == 64) {
 
@@ -1918,20 +2025,26 @@ visit_store_var(struct ac_nir_context *ctx,
 
        writemask = writemask << comp;
 
-       switch (instr->variables[0]->var->data.mode) {
+       switch (var->data.mode) {
        case nir_var_shader_out:
 
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
                        LLVMValueRef vertex_index = NULL;
                        LLVMValueRef indir_index = NULL;
                        unsigned const_index = 0;
-                       const bool is_patch = instr->variables[0]->var->data.patch;
+                       const bool is_patch = var->data.patch;
 
-                       get_deref_offset(ctx, instr->variables[0],
-                                        false, NULL, is_patch ? NULL : &vertex_index,
-                                        &const_index, &indir_index);
+                       if (uses_deref_chain) {
+                               get_deref_offset(ctx, instr->variables[0],
+                                                false, NULL, is_patch ? NULL : &vertex_index,
+                                                &const_index, &indir_index);
+                       } else {
+                               get_deref_instr_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr),
+                                                false, NULL, is_patch ? NULL : &vertex_index,
+                                                &const_index, &indir_index);
+                       }
 
-                       ctx->abi->store_tcs_outputs(ctx->abi, instr->variables[0]->var,
+                       ctx->abi->store_tcs_outputs(ctx->abi, var,
                                                    vertex_index, indir_index,
                                                    const_index, src, writemask);
                        return;
@@ -1944,11 +2057,11 @@ visit_store_var(struct ac_nir_context *ctx,
 
                        value = ac_llvm_extract_elem(&ctx->ac, src, chan - comp);
 
-                       if (instr->variables[0]->var->data.compact)
+                       if (var->data.compact)
                                stride = 1;
                        if (indir_index) {
                                unsigned count = glsl_count_attribute_slots(
-                                               instr->variables[0]->var->type, false);
+                                               var->type, false);
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                                &ctx->ac, ctx->abi->outputs + idx + chan, count,
@@ -1974,7 +2087,7 @@ visit_store_var(struct ac_nir_context *ctx,
                        value = ac_llvm_extract_elem(&ctx->ac, src, chan);
                        if (indir_index) {
                                unsigned count = glsl_count_attribute_slots(
-                                       instr->variables[0]->var->type, false);
+                                       var->type, false);
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                        &ctx->ac, ctx->locals + idx + chan, count,
@@ -1993,13 +2106,11 @@ visit_store_var(struct ac_nir_context *ctx,
                break;
        case nir_var_shared: {
                int writemask = instr->const_index[0];
-               LLVMValueRef address = build_gep_for_deref(ctx,
-                                                          instr->variables[0]);
-               LLVMValueRef val = get_src(ctx, instr->src[0]);
-               unsigned components =
-                       glsl_get_vector_elements(
-                          nir_deref_tail(&instr->variables[0]->deref)->type);
-               if (writemask == (1 << components) - 1) {
+               LLVMValueRef address = uses_deref_chain ?
+                                            build_gep_for_deref(ctx, instr->variables[0])
+                                          : get_src(ctx, instr->src[0]);
+               LLVMValueRef val = get_src(ctx, instr->src[uses_deref_chain ? 0 : 1]);
+               if (util_is_power_of_two_nonzero(writemask)) {
                        val = LLVMBuildBitCast(
                           ctx->ac.builder, val,
                           LLVMGetElementType(LLVMTypeOf(address)), "");
@@ -3028,9 +3139,11 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_get_buffer_size:
                result = visit_get_buffer_size(ctx, instr);
                break;
+       case nir_intrinsic_load_deref:
        case nir_intrinsic_load_var:
                result = visit_load_var(ctx, instr);
                break;
+       case nir_intrinsic_store_deref:
        case nir_intrinsic_store_var:
                visit_store_var(ctx, instr);
                break;