From: Bas Nieuwenhuizen Date: Thu, 24 Jan 2019 01:04:10 +0000 (+0100) Subject: amd/common: Implement global memory accesses. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=8a159502111fcc0b7bd68851dbd3f33cbb167fe1;p=mesa.git amd/common: Implement global memory accesses. Needed for VK_EXT_buffer_device_address. The pointers are implmemented as i8*, since I could not figure out how to emulate setting struct offsets in LLVM based on the SPIR-V offsets (and more weird stuff like row major matrices). Acked-by: Samuel Pitoiset --- diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c index f78b6b505c0..54559b19f02 100644 --- a/src/amd/common/ac_nir_to_llvm.c +++ b/src/amd/common/ac_nir_to_llvm.c @@ -1878,6 +1878,14 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx, return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, ""); } +static unsigned +type_scalar_size_bytes(const struct glsl_type *type) +{ + assert(glsl_type_is_vector_or_scalar(type) || + glsl_type_is_matrix(type)); + return glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8; +} + static LLVMValueRef visit_load_var(struct ac_nir_context *ctx, nir_intrinsic_instr *instr) { @@ -1892,7 +1900,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx, LLVMValueRef ret; unsigned const_index; unsigned stride = 4; - int mode = nir_var_mem_shared; + int mode = deref->mode; if (var) { bool vs_in = ctx->stage == MESA_SHADER_VERTEX && @@ -1999,6 +2007,32 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx, } } break; + case nir_var_mem_global: { + LLVMValueRef address = get_src(ctx, instr->src[0]); + unsigned explicit_stride = glsl_get_explicit_stride(deref->type); + unsigned natural_stride = type_scalar_size_bytes(deref->type); + unsigned stride = explicit_stride ? explicit_stride : natural_stride; + + LLVMTypeRef result_type = get_def_type(ctx, &instr->dest.ssa); + if (stride != natural_stride) { + LLVMTypeRef ptr_type = LLVMPointerType(LLVMGetElementType(result_type), + LLVMGetPointerAddressSpace(LLVMTypeOf(address))); + address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , ""); + + for (unsigned i = 0; i < instr->dest.ssa.num_components; ++i) { + LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, i * stride / natural_stride, 0); + values[i] = LLVMBuildLoad(ctx->ac.builder, + ac_build_gep_ptr(&ctx->ac, address, offset), ""); + } + return ac_build_gather_values(&ctx->ac, values, instr->dest.ssa.num_components); + } else { + LLVMTypeRef ptr_type = LLVMPointerType(result_type, + LLVMGetPointerAddressSpace(LLVMTypeOf(address))); + address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , ""); + LLVMValueRef val = LLVMBuildLoad(ctx->ac.builder, address, ""); + return val; + } + } default: unreachable("unhandle variable mode"); } @@ -2114,33 +2148,52 @@ visit_store_var(struct ac_nir_context *ctx, } } break; + + case nir_var_mem_global: case nir_var_mem_shared: { int writemask = instr->const_index[0]; LLVMValueRef address = get_src(ctx, instr->src[0]); LLVMValueRef val = get_src(ctx, instr->src[1]); - if (writemask == (1u << ac_get_llvm_num_components(val)) - 1) { - val = LLVMBuildBitCast( - ctx->ac.builder, val, - LLVMGetElementType(LLVMTypeOf(address)), ""); + + unsigned explicit_stride = glsl_get_explicit_stride(deref->type); + unsigned natural_stride = type_scalar_size_bytes(deref->type); + unsigned stride = explicit_stride ? explicit_stride : natural_stride; + + LLVMTypeRef ptr_type = LLVMPointerType(LLVMTypeOf(val), + LLVMGetPointerAddressSpace(LLVMTypeOf(address))); + address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , ""); + + if (writemask == (1u << ac_get_llvm_num_components(val)) - 1 && + stride == natural_stride) { + LLVMTypeRef ptr_type = LLVMPointerType(LLVMTypeOf(val), + LLVMGetPointerAddressSpace(LLVMTypeOf(address))); + address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , ""); + + val = LLVMBuildBitCast(ctx->ac.builder, val, + LLVMGetElementType(LLVMTypeOf(address)), ""); LLVMBuildStore(ctx->ac.builder, val, address); } else { + LLVMTypeRef ptr_type = LLVMPointerType(LLVMGetElementType(LLVMTypeOf(val)), + LLVMGetPointerAddressSpace(LLVMTypeOf(address))); + address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , ""); for (unsigned chan = 0; chan < 4; chan++) { if (!(writemask & (1 << chan))) continue; - LLVMValueRef ptr = - LLVMBuildStructGEP(ctx->ac.builder, - address, chan, ""); + + LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, chan * stride / natural_stride, 0); + + LLVMValueRef ptr = ac_build_gep_ptr(&ctx->ac, address, offset); LLVMValueRef src = ac_llvm_extract_elem(&ctx->ac, val, chan); - src = LLVMBuildBitCast( - ctx->ac.builder, src, - LLVMGetElementType(LLVMTypeOf(ptr)), ""); + src = LLVMBuildBitCast(ctx->ac.builder, src, + LLVMGetElementType(LLVMTypeOf(ptr)), ""); LLVMBuildStore(ctx->ac.builder, src, ptr); } } break; } default: + abort(); break; } } @@ -3899,7 +3952,8 @@ glsl_to_llvm_type(struct ac_llvm_context *ac, static void visit_deref(struct ac_nir_context *ctx, nir_deref_instr *instr) { - if (instr->mode != nir_var_mem_shared) + if (instr->mode != nir_var_mem_shared && + instr->mode != nir_var_mem_global) return; LLVMValueRef result = NULL; @@ -3910,22 +3964,79 @@ static void visit_deref(struct ac_nir_context *ctx, break; } case nir_deref_type_struct: - result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent), - LLVMConstInt(ctx->ac.i32, instr->strct.index, 0)); + if (instr->mode == nir_var_mem_global) { + nir_deref_instr *parent = nir_deref_instr_parent(instr); + uint64_t offset = glsl_get_struct_field_offset(parent->type, + instr->strct.index); + result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), + LLVMConstInt(ctx->ac.i32, offset, 0)); + } else { + result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent), + LLVMConstInt(ctx->ac.i32, instr->strct.index, 0)); + } break; case nir_deref_type_array: - result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent), - get_src(ctx, instr->arr.index)); + if (instr->mode == nir_var_mem_global) { + nir_deref_instr *parent = nir_deref_instr_parent(instr); + unsigned stride = glsl_get_explicit_stride(parent->type); + + if ((glsl_type_is_matrix(parent->type) && + glsl_matrix_type_is_row_major(parent->type)) || + (glsl_type_is_vector(parent->type) && stride == 0)) + stride = type_scalar_size_bytes(parent->type); + + assert(stride > 0); + LLVMValueRef index = get_src(ctx, instr->arr.index); + if (LLVMTypeOf(index) != ctx->ac.i64) + index = LLVMBuildZExt(ctx->ac.builder, index, ctx->ac.i64, ""); + + LLVMValueRef offset = LLVMBuildMul(ctx->ac.builder, index, LLVMConstInt(ctx->ac.i64, stride, 0), ""); + + result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), offset); + } else { + result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent), + get_src(ctx, instr->arr.index)); + } break; case nir_deref_type_ptr_as_array: - result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), - get_src(ctx, instr->arr.index)); + if (instr->mode == nir_var_mem_global) { + unsigned stride = nir_deref_instr_ptr_as_array_stride(instr); + + LLVMValueRef index = get_src(ctx, instr->arr.index); + if (LLVMTypeOf(index) != ctx->ac.i64) + index = LLVMBuildZExt(ctx->ac.builder, index, ctx->ac.i64, ""); + + LLVMValueRef offset = LLVMBuildMul(ctx->ac.builder, index, LLVMConstInt(ctx->ac.i64, stride, 0), ""); + + result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), offset); + } else { + result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), + get_src(ctx, instr->arr.index)); + } break; case nir_deref_type_cast: { result = get_src(ctx, instr->parent); - LLVMTypeRef pointee_type = glsl_to_llvm_type(&ctx->ac, instr->type); - LLVMTypeRef type = LLVMPointerType(pointee_type, AC_ADDR_SPACE_LDS); + /* We can't use the structs from LLVM because the shader + * specifies its own offsets. */ + LLVMTypeRef pointee_type = ctx->ac.i8; + if (instr->mode == nir_var_mem_shared) + pointee_type = glsl_to_llvm_type(&ctx->ac, instr->type); + + unsigned address_space; + + switch(instr->mode) { + case nir_var_mem_shared: + address_space = AC_ADDR_SPACE_LDS; + break; + case nir_var_mem_global: + address_space = AC_ADDR_SPACE_GLOBAL; + break; + default: + unreachable("Unhandled address space"); + } + + LLVMTypeRef type = LLVMPointerType(pointee_type, address_space); if (LLVMTypeOf(result) != type) { if (LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMVectorTypeKind) {