amd/common: Implement global memory accesses.
authorBas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Thu, 24 Jan 2019 01:04:10 +0000 (02:04 +0100)
committerBas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Wed, 6 Feb 2019 21:36:11 +0000 (22:36 +0100)
Needed for VK_EXT_buffer_device_address.

The pointers are implmemented as i8*, since I could not figure
out how to emulate setting struct offsets in LLVM based on the
SPIR-V offsets (and more weird stuff like row major matrices).

Acked-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
src/amd/common/ac_nir_to_llvm.c

index f78b6b505c01b536b22a1d7248cce7db06cb88aa..54559b19f0277f9a586fe45b7cdb863aca71f57a 100644 (file)
@@ -1878,6 +1878,14 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
        return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
 }
 
+static unsigned
+type_scalar_size_bytes(const struct glsl_type *type)
+{
+   assert(glsl_type_is_vector_or_scalar(type) ||
+          glsl_type_is_matrix(type));
+   return glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8;
+}
+
 static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                                   nir_intrinsic_instr *instr)
 {
@@ -1892,7 +1900,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
        LLVMValueRef ret;
        unsigned const_index;
        unsigned stride = 4;
-       int mode = nir_var_mem_shared;
+       int mode = deref->mode;
        
        if (var) {
                bool vs_in = ctx->stage == MESA_SHADER_VERTEX &&
@@ -1999,6 +2007,32 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                        }
                }
                break;
+       case nir_var_mem_global:  {
+               LLVMValueRef address = get_src(ctx, instr->src[0]);
+               unsigned explicit_stride = glsl_get_explicit_stride(deref->type);
+               unsigned natural_stride = type_scalar_size_bytes(deref->type);
+               unsigned stride = explicit_stride ? explicit_stride : natural_stride;
+
+               LLVMTypeRef result_type = get_def_type(ctx, &instr->dest.ssa);
+               if (stride != natural_stride) {
+                       LLVMTypeRef ptr_type =  LLVMPointerType(LLVMGetElementType(result_type),
+                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+                       address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
+
+                       for (unsigned i = 0; i < instr->dest.ssa.num_components; ++i) {
+                               LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, i * stride / natural_stride, 0);
+                               values[i] = LLVMBuildLoad(ctx->ac.builder,
+                                                         ac_build_gep_ptr(&ctx->ac, address, offset), "");
+                       }
+                       return ac_build_gather_values(&ctx->ac, values, instr->dest.ssa.num_components);
+               } else {
+                       LLVMTypeRef ptr_type =  LLVMPointerType(result_type,
+                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+                       address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
+                       LLVMValueRef val = LLVMBuildLoad(ctx->ac.builder, address, "");
+                       return val;
+               }
+       }
        default:
                unreachable("unhandle variable mode");
        }
@@ -2114,33 +2148,52 @@ visit_store_var(struct ac_nir_context *ctx,
                        }
                }
                break;
+
+       case nir_var_mem_global:
        case nir_var_mem_shared: {
                int writemask = instr->const_index[0];
                LLVMValueRef address = get_src(ctx, instr->src[0]);
                LLVMValueRef val = get_src(ctx, instr->src[1]);
-               if (writemask == (1u << ac_get_llvm_num_components(val)) - 1) {
-                       val = LLVMBuildBitCast(
-                          ctx->ac.builder, val,
-                          LLVMGetElementType(LLVMTypeOf(address)), "");
+
+               unsigned explicit_stride = glsl_get_explicit_stride(deref->type);
+               unsigned natural_stride = type_scalar_size_bytes(deref->type);
+               unsigned stride = explicit_stride ? explicit_stride : natural_stride;
+
+               LLVMTypeRef ptr_type =  LLVMPointerType(LLVMTypeOf(val),
+                                                       LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+               address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
+
+               if (writemask == (1u << ac_get_llvm_num_components(val)) - 1 &&
+                   stride == natural_stride) {
+                       LLVMTypeRef ptr_type =  LLVMPointerType(LLVMTypeOf(val),
+                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+                       address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
+
+                       val = LLVMBuildBitCast(ctx->ac.builder, val,
+                                              LLVMGetElementType(LLVMTypeOf(address)), "");
                        LLVMBuildStore(ctx->ac.builder, val, address);
                } else {
+                       LLVMTypeRef ptr_type =  LLVMPointerType(LLVMGetElementType(LLVMTypeOf(val)),
+                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+                       address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
                        for (unsigned chan = 0; chan < 4; chan++) {
                                if (!(writemask & (1 << chan)))
                                        continue;
-                               LLVMValueRef ptr =
-                                       LLVMBuildStructGEP(ctx->ac.builder,
-                                                          address, chan, "");
+
+                               LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, chan * stride / natural_stride, 0);
+
+                               LLVMValueRef ptr = ac_build_gep_ptr(&ctx->ac, address, offset);
                                LLVMValueRef src = ac_llvm_extract_elem(&ctx->ac, val,
                                                                        chan);
-                               src = LLVMBuildBitCast(
-                                  ctx->ac.builder, src,
-                                  LLVMGetElementType(LLVMTypeOf(ptr)), "");
+                               src = LLVMBuildBitCast(ctx->ac.builder, src,
+                                                      LLVMGetElementType(LLVMTypeOf(ptr)), "");
                                LLVMBuildStore(ctx->ac.builder, src, ptr);
                        }
                }
                break;
        }
        default:
+               abort();
                break;
        }
 }
@@ -3899,7 +3952,8 @@ glsl_to_llvm_type(struct ac_llvm_context *ac,
 static void visit_deref(struct ac_nir_context *ctx,
                         nir_deref_instr *instr)
 {
-       if (instr->mode != nir_var_mem_shared)
+       if (instr->mode != nir_var_mem_shared &&
+           instr->mode != nir_var_mem_global)
                return;
 
        LLVMValueRef result = NULL;
@@ -3910,22 +3964,79 @@ static void visit_deref(struct ac_nir_context *ctx,
                break;
        }
        case nir_deref_type_struct:
-               result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
-                                      LLVMConstInt(ctx->ac.i32, instr->strct.index, 0));
+               if (instr->mode == nir_var_mem_global) {
+                       nir_deref_instr *parent = nir_deref_instr_parent(instr);
+                       uint64_t offset = glsl_get_struct_field_offset(parent->type,
+                                                                       instr->strct.index);
+                       result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent),
+                                              LLVMConstInt(ctx->ac.i32, offset, 0));
+               } else {
+                       result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
+                                              LLVMConstInt(ctx->ac.i32, instr->strct.index, 0));
+               }
                break;
        case nir_deref_type_array:
-               result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
-                                      get_src(ctx, instr->arr.index));
+               if (instr->mode == nir_var_mem_global) {
+                       nir_deref_instr *parent = nir_deref_instr_parent(instr);
+                       unsigned stride = glsl_get_explicit_stride(parent->type);
+
+                       if ((glsl_type_is_matrix(parent->type) &&
+                            glsl_matrix_type_is_row_major(parent->type)) ||
+                           (glsl_type_is_vector(parent->type) && stride == 0))
+                               stride = type_scalar_size_bytes(parent->type);
+
+                       assert(stride > 0);
+                       LLVMValueRef index = get_src(ctx, instr->arr.index);
+                       if (LLVMTypeOf(index) != ctx->ac.i64)
+                               index = LLVMBuildZExt(ctx->ac.builder, index, ctx->ac.i64, "");
+
+                       LLVMValueRef offset = LLVMBuildMul(ctx->ac.builder, index, LLVMConstInt(ctx->ac.i64, stride, 0), "");
+
+                       result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), offset);
+               } else {
+                       result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
+                                              get_src(ctx, instr->arr.index));
+               }
                break;
        case nir_deref_type_ptr_as_array:
-               result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent),
-                                         get_src(ctx, instr->arr.index));
+               if (instr->mode == nir_var_mem_global) {
+                       unsigned stride = nir_deref_instr_ptr_as_array_stride(instr);
+
+                       LLVMValueRef index = get_src(ctx, instr->arr.index);
+                       if (LLVMTypeOf(index) != ctx->ac.i64)
+                               index = LLVMBuildZExt(ctx->ac.builder, index, ctx->ac.i64, "");
+
+                       LLVMValueRef offset = LLVMBuildMul(ctx->ac.builder, index, LLVMConstInt(ctx->ac.i64, stride, 0), "");
+
+                       result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), offset);
+               } else {
+                       result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent),
+                                              get_src(ctx, instr->arr.index));
+               }
                break;
        case nir_deref_type_cast: {
                result = get_src(ctx, instr->parent);
 
-               LLVMTypeRef pointee_type = glsl_to_llvm_type(&ctx->ac, instr->type);
-               LLVMTypeRef type = LLVMPointerType(pointee_type, AC_ADDR_SPACE_LDS);
+               /* We can't use the structs from LLVM because the shader
+                * specifies its own offsets. */
+               LLVMTypeRef pointee_type = ctx->ac.i8;
+               if (instr->mode == nir_var_mem_shared)
+                       pointee_type = glsl_to_llvm_type(&ctx->ac, instr->type);
+
+               unsigned address_space;
+
+               switch(instr->mode) {
+               case nir_var_mem_shared:
+                       address_space = AC_ADDR_SPACE_LDS;
+                       break;
+               case nir_var_mem_global:
+                       address_space = AC_ADDR_SPACE_GLOBAL;
+                       break;
+               default:
+                       unreachable("Unhandled address space");
+               }
+
+               LLVMTypeRef type = LLVMPointerType(pointee_type, address_space);
 
                if (LLVMTypeOf(result) != type) {
                        if (LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMVectorTypeKind) {