ac/nir: rewrite shared variable handling (v2)
authorConnor Abbott <cwabbott0@gmail.com>
Mon, 26 Jun 2017 22:50:07 +0000 (15:50 -0700)
committerConnor Abbott <cwabbott0@gmail.com>
Mon, 17 Jul 2017 21:16:03 +0000 (14:16 -0700)
Translate the NIR variables directly to LLVM instead of lowering to a
TGSI-style giant array of vec4's and then back to a variable. This
should fix indirect dereferences, make shared variables more tightly
packed, and make LLVM's alias analysis more precise. This should fix an
upcoming Feral title, which has a compute shader that was failing to
compile because the extra padding made us run out of LDS space.

v2: Combine the previous two patches into one, only use this for shared
variables for now until LLVM becomes smarter.

Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen>
Reviewed-by: Nicolai Hähnle <nicolai.haehnle@amd.com>
Tested-by: Alex Smith <asmith@feralinteractive.com>
src/amd/common/ac_nir_to_llvm.c

index 922331090cb09847a31488e56aae11e5c9e2f3e0..9a69066afa209c40f936384fb7097115fbde3ed2 100644 (file)
@@ -65,6 +65,7 @@ struct nir_to_llvm_context {
 
        struct hash_table *defs;
        struct hash_table *phis;
+       struct hash_table *vars;
 
        LLVMValueRef descriptor_sets[AC_UD_MAX_SETS];
        LLVMValueRef ring_offsets;
@@ -154,7 +155,6 @@ struct nir_to_llvm_context {
        LLVMValueRef inputs[RADEON_LLVM_MAX_INPUTS * 4];
        LLVMValueRef outputs[RADEON_LLVM_MAX_OUTPUTS * 4];
 
-       LLVMValueRef shared_memory;
        uint64_t input_mask;
        uint64_t output_mask;
        int num_locals;
@@ -387,23 +387,6 @@ static LLVMTypeRef const_array(LLVMTypeRef elem_type, int num_elements)
                               CONST_ADDR_SPACE);
 }
 
-static LLVMValueRef get_shared_memory_ptr(struct nir_to_llvm_context *ctx,
-                                         int idx,
-                                         LLVMTypeRef type)
-{
-       LLVMValueRef offset;
-       LLVMValueRef ptr;
-       int addr_space;
-
-       offset = LLVMConstInt(ctx->i32, idx * 16, false);
-
-       ptr = ctx->shared_memory;
-       ptr = LLVMBuildGEP(ctx->builder, ptr, &offset, 1, "");
-       addr_space = LLVMGetPointerAddressSpace(LLVMTypeOf(ptr));
-       ptr = LLVMBuildBitCast(ctx->builder, ptr, LLVMPointerType(type, addr_space), "");
-       return ptr;
-}
-
 static LLVMTypeRef to_integer_type_scalar(struct ac_llvm_context *ctx, LLVMTypeRef t)
 {
        if (t == ctx->f16 || t == ctx->i16)
@@ -2905,6 +2888,45 @@ load_gs_input(struct nir_to_llvm_context *ctx,
        return result;
 }
 
+static LLVMValueRef
+build_gep_for_deref(struct nir_to_llvm_context *ctx,
+                   nir_deref_var *deref)
+{
+       struct hash_entry *entry = _mesa_hash_table_search(ctx->vars, deref->var);
+       assert(entry->data);
+       LLVMValueRef val = entry->data;
+       nir_deref *tail = deref->deref.child;
+       while (tail != NULL) {
+               LLVMValueRef offset;
+               switch (tail->deref_type) {
+               case nir_deref_type_array: {
+                       nir_deref_array *array = nir_deref_as_array(tail);
+                       offset = LLVMConstInt(ctx->i32, array->base_offset, 0);
+                       if (array->deref_array_type ==
+                           nir_deref_array_type_indirect) {
+                               offset = LLVMBuildAdd(ctx->builder, offset,
+                                                     get_src(ctx,
+                                                             array->indirect),
+                                                     "");
+                       }
+                       break;
+               }
+               case nir_deref_type_struct: {
+                       nir_deref_struct *deref_struct =
+                               nir_deref_as_struct(tail);
+                       offset = LLVMConstInt(ctx->i32,
+                                             deref_struct->index, 0);
+                       break;
+               }
+               default:
+                       unreachable("bad deref type");
+               }
+               val = ac_build_gep0(&ctx->ac, val, offset);
+               tail = tail->child;
+       }
+       return val;
+}
+
 static LLVMValueRef visit_load_var(struct nir_to_llvm_context *ctx,
                                   nir_intrinsic_instr *instr)
 {
@@ -2966,6 +2988,14 @@ static LLVMValueRef visit_load_var(struct nir_to_llvm_context *ctx,
                        }
                }
                break;
+       case nir_var_shared: {
+               LLVMValueRef address = build_gep_for_deref(ctx,
+                                                          instr->variables[0]);
+               LLVMValueRef val = LLVMBuildLoad(ctx->builder, address, "");
+               return LLVMBuildBitCast(ctx->builder, val,
+                                       get_def_type(ctx, &instr->dest.ssa),
+                                       "");
+       }
        case nir_var_shader_out:
                if (ctx->stage == MESA_SHADER_TESS_CTRL)
                        return load_tcs_output(ctx, instr);
@@ -2988,23 +3018,6 @@ static LLVMValueRef visit_load_var(struct nir_to_llvm_context *ctx,
                        }
                }
                break;
-       case nir_var_shared: {
-               LLVMValueRef ptr = get_shared_memory_ptr(ctx, idx, ctx->i32);
-               LLVMValueRef derived_ptr;
-
-               if (indir_index)
-                       indir_index = LLVMBuildMul(ctx->builder, indir_index, LLVMConstInt(ctx->i32, 4, false), "");
-
-               for (unsigned chan = 0; chan < ve; chan++) {
-                       LLVMValueRef index = LLVMConstInt(ctx->i32, chan, false);
-                       if (indir_index)
-                               index = LLVMBuildAdd(ctx->builder, index, indir_index, "");
-                       derived_ptr = LLVMBuildGEP(ctx->builder, ptr, &index, 1, "");
-
-                       values[chan] = LLVMBuildLoad(ctx->builder, derived_ptr, "");
-               }
-               break;
-       }
        default:
                unreachable("unhandle variable mode");
        }
@@ -3105,24 +3118,32 @@ visit_store_var(struct nir_to_llvm_context *ctx,
                }
                break;
        case nir_var_shared: {
-               LLVMValueRef ptr = get_shared_memory_ptr(ctx, idx, ctx->i32);
-
-               if (indir_index)
-                       indir_index = LLVMBuildMul(ctx->builder, indir_index, LLVMConstInt(ctx->i32, 4, false), "");
-
-               for (unsigned chan = 0; chan < 8; chan++) {
-                       if (!(writemask & (1 << chan)))
-                               continue;
-                       LLVMValueRef index = LLVMConstInt(ctx->i32, chan, false);
-                       LLVMValueRef derived_ptr;
-
-                       if (indir_index)
-                               index = LLVMBuildAdd(ctx->builder, index, indir_index, "");
-
-                       value = llvm_extract_elem(ctx, src, chan);
-                       derived_ptr = LLVMBuildGEP(ctx->builder, ptr, &index, 1, "");
-                       LLVMBuildStore(ctx->builder,
-                                      to_integer(&ctx->ac, value), derived_ptr);
+               int writemask = instr->const_index[0];
+               LLVMValueRef address = build_gep_for_deref(ctx,
+                                                          instr->variables[0]);
+               LLVMValueRef val = get_src(ctx, instr->src[0]);
+               unsigned components =
+                       glsl_get_vector_elements(
+                          nir_deref_tail(&instr->variables[0]->deref)->type);
+               if (writemask == (1 << components) - 1) {
+                       val = LLVMBuildBitCast(
+                          ctx->builder, val,
+                          LLVMGetElementType(LLVMTypeOf(address)), "");
+                       LLVMBuildStore(ctx->builder, val, address);
+               } else {
+                       for (unsigned chan = 0; chan < 4; chan++) {
+                               if (!(writemask & (1 << chan)))
+                                       continue;
+                               LLVMValueRef ptr =
+                                       LLVMBuildStructGEP(ctx->builder,
+                                                          address, chan, "");
+                               LLVMValueRef src = llvm_extract_elem(ctx, val,
+                                                                    chan);
+                               src = LLVMBuildBitCast(
+                                  ctx->builder, src,
+                                  LLVMGetElementType(LLVMTypeOf(ptr)), "");
+                               LLVMBuildStore(ctx->builder, src, ptr);
+                       }
                }
                break;
        }
@@ -3604,9 +3625,8 @@ static LLVMValueRef visit_var_atomic(struct nir_to_llvm_context *ctx,
                                     const nir_intrinsic_instr *instr)
 {
        LLVMValueRef ptr, result;
-       int idx = instr->variables[0]->var->data.driver_location;
        LLVMValueRef src = get_src(ctx, instr->src[0]);
-       ptr = get_shared_memory_ptr(ctx, idx, ctx->i32);
+       ptr = build_gep_for_deref(ctx, instr->variables[0]);
 
        if (instr->intrinsic == nir_intrinsic_var_atomic_comp_swap) {
                LLVMValueRef src1 = get_src(ctx, instr->src[1]);
@@ -5005,6 +5025,68 @@ handle_shader_output_decl(struct nir_to_llvm_context *ctx,
        ctx->output_mask |= mask_attribs;
 }
 
+static LLVMTypeRef
+glsl_base_to_llvm_type(struct nir_to_llvm_context *ctx,
+                      enum glsl_base_type type)
+{
+       switch (type) {
+       case GLSL_TYPE_INT:
+       case GLSL_TYPE_UINT:
+       case GLSL_TYPE_BOOL:
+       case GLSL_TYPE_SUBROUTINE:
+               return ctx->i32;
+       case GLSL_TYPE_FLOAT: /* TODO handle mediump */
+               return ctx->f32;
+       case GLSL_TYPE_INT64:
+       case GLSL_TYPE_UINT64:
+               return ctx->i64;
+       case GLSL_TYPE_DOUBLE:
+               return ctx->f64;
+       default:
+               unreachable("unknown GLSL type");
+       }
+}
+
+static LLVMTypeRef
+glsl_to_llvm_type(struct nir_to_llvm_context *ctx,
+                 const struct glsl_type *type)
+{
+       if (glsl_type_is_scalar(type)) {
+               return glsl_base_to_llvm_type(ctx, glsl_get_base_type(type));
+       }
+
+       if (glsl_type_is_vector(type)) {
+               return LLVMVectorType(
+                  glsl_base_to_llvm_type(ctx, glsl_get_base_type(type)),
+                  glsl_get_vector_elements(type));
+       }
+
+       if (glsl_type_is_matrix(type)) {
+               return LLVMArrayType(
+                  glsl_to_llvm_type(ctx, glsl_get_column_type(type)),
+                  glsl_get_matrix_columns(type));
+       }
+
+       if (glsl_type_is_array(type)) {
+               return LLVMArrayType(
+                  glsl_to_llvm_type(ctx, glsl_get_array_element(type)),
+                  glsl_get_length(type));
+       }
+
+       assert(glsl_type_is_struct(type));
+
+       LLVMTypeRef member_types[glsl_get_length(type)];
+
+       for (unsigned i = 0; i < glsl_get_length(type); i++) {
+               member_types[i] =
+                       glsl_to_llvm_type(ctx,
+                                         glsl_get_struct_field(type, i));
+       }
+
+       return LLVMStructTypeInContext(ctx->context, member_types,
+                                      glsl_get_length(type), false);
+}
+
 static void
 setup_locals(struct nir_to_llvm_context *ctx,
             struct nir_function *func)
@@ -5028,6 +5110,20 @@ setup_locals(struct nir_to_llvm_context *ctx,
        }
 }
 
+static void
+setup_shared(struct nir_to_llvm_context *ctx,
+            struct nir_shader *nir)
+{
+       nir_foreach_variable(variable, &nir->shared) {
+               LLVMValueRef shared =
+                       LLVMAddGlobalInAddressSpace(
+                          ctx->module, glsl_to_llvm_type(ctx, variable->type),
+                          variable->name ? variable->name : "",
+                          LOCAL_ADDR_SPACE);
+               _mesa_hash_table_insert(ctx->vars, variable, shared);
+       }
+}
+
 static LLVMValueRef
 emit_float_saturate(struct ac_llvm_context *ctx, LLVMValueRef v, float lo, float hi)
 {
@@ -5820,15 +5916,6 @@ handle_shader_outputs_post(struct nir_to_llvm_context *ctx)
        }
 }
 
-static void
-handle_shared_compute_var(struct nir_to_llvm_context *ctx,
-                         struct nir_variable *variable, uint32_t *offset, int idx)
-{
-       unsigned size = glsl_count_attribute_slots(variable->type, false);
-       variable->data.driver_location = *offset;
-       *offset += size;
-}
-
 static void ac_llvm_finalize_module(struct nir_to_llvm_context * ctx)
 {
        LLVMPassManagerRef passmgr;
@@ -5985,29 +6072,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
 
        create_function(&ctx);
 
-       if (nir->stage == MESA_SHADER_COMPUTE) {
-               int num_shared = 0;
-               nir_foreach_variable(variable, &nir->shared)
-                       num_shared++;
-               if (num_shared) {
-                       int idx = 0;
-                       uint32_t shared_size = 0;
-                       LLVMValueRef var;
-                       LLVMTypeRef i8p = LLVMPointerType(ctx.i8, LOCAL_ADDR_SPACE);
-                       nir_foreach_variable(variable, &nir->shared) {
-                               handle_shared_compute_var(&ctx, variable, &shared_size, idx);
-                               idx++;
-                       }
-
-                       shared_size *= 16;
-                       var = LLVMAddGlobalInAddressSpace(ctx.module,
-                                                         LLVMArrayType(ctx.i8, shared_size),
-                                                         "compute_lds",
-                                                         LOCAL_ADDR_SPACE);
-                       LLVMSetAlignment(var, 4);
-                       ctx.shared_memory = LLVMBuildBitCast(ctx.builder, var, i8p, "");
-               }
-       } else if (nir->stage == MESA_SHADER_GEOMETRY) {
+       if (nir->stage == MESA_SHADER_GEOMETRY) {
                ctx.gs_next_vertex = ac_build_alloca(&ctx, ctx.i32, "gs_next_vertex");
 
                ctx.gs_max_out_vertices = nir->info.gs.vertices_out;
@@ -6033,11 +6098,16 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
                                           _mesa_key_pointer_equal);
        ctx.phis = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
                                           _mesa_key_pointer_equal);
+       ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                            _mesa_key_pointer_equal);
 
        func = (struct nir_function *)exec_list_get_head(&nir->functions);
 
        setup_locals(&ctx, func);
 
+       if (nir->stage == MESA_SHADER_COMPUTE)
+               setup_shared(&ctx, nir);
+
        visit_cf_list(&ctx, &func->impl->body);
        phi_post_pass(&ctx);
 
@@ -6050,6 +6120,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
        free(ctx.locals);
        ralloc_free(ctx.defs);
        ralloc_free(ctx.phis);
+       ralloc_free(ctx.vars);
 
        if (nir->stage == MESA_SHADER_GEOMETRY) {
                unsigned addclip = ctx.num_output_clips + ctx.num_output_culls > 4;