ac/nir: implement 8-bit push constant, ssbo and ubo loads
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 6 Dec 2018 13:56:01 +0000 (13:56 +0000)
committerSamuel Pitoiset <samuel.pitoiset@gmail.com>
Thu, 21 Mar 2019 08:02:16 +0000 (09:02 +0100)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
src/amd/common/ac_nir_to_llvm.c

index fd6dbcce530ce9cf73837a54deb4909d547c1392..c088537e1fd8d7a29c786b20a6281552a70a71f7 100644 (file)
@@ -1432,7 +1432,30 @@ static LLVMValueRef visit_load_push_constant(struct ac_nir_context *ctx,
 
        ptr = ac_build_gep0(&ctx->ac, ctx->abi->push_constants, addr);
 
-       if (instr->dest.ssa.bit_size == 16) {
+       if (instr->dest.ssa.bit_size == 8) {
+               unsigned load_dwords = instr->dest.ssa.num_components > 1 ? 2 : 1;
+               LLVMTypeRef vec_type = LLVMVectorType(LLVMInt8TypeInContext(ctx->ac.context), 4 * load_dwords);
+               ptr = ac_cast_ptr(&ctx->ac, ptr, vec_type);
+               LLVMValueRef res = LLVMBuildLoad(ctx->ac.builder, ptr, "");
+
+               LLVMValueRef params[3];
+               if (load_dwords > 1) {
+                       LLVMValueRef res_vec = LLVMBuildBitCast(ctx->ac.builder, res, LLVMVectorType(ctx->ac.i32, 2), "");
+                       params[0] = LLVMBuildExtractElement(ctx->ac.builder, res_vec, LLVMConstInt(ctx->ac.i32, 1, false), "");
+                       params[1] = LLVMBuildExtractElement(ctx->ac.builder, res_vec, LLVMConstInt(ctx->ac.i32, 0, false), "");
+               } else {
+                       res = LLVMBuildBitCast(ctx->ac.builder, res, ctx->ac.i32, "");
+                       params[0] = ctx->ac.i32_0;
+                       params[1] = res;
+               }
+               params[2] = addr;
+               res = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.alignbyte", ctx->ac.i32, params, 3, 0);
+
+               res = LLVMBuildTrunc(ctx->ac.builder, res, LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.num_components * 8), "");
+               if (instr->dest.ssa.num_components > 1)
+                       res = LLVMBuildBitCast(ctx->ac.builder, res, LLVMVectorType(LLVMInt8TypeInContext(ctx->ac.context), instr->dest.ssa.num_components), "");
+               return res;
+       } else if (instr->dest.ssa.bit_size == 16) {
                unsigned load_dwords = instr->dest.ssa.num_components / 2 + 1;
                LLVMTypeRef vec_type = LLVMVectorType(LLVMInt16TypeInContext(ctx->ac.context), 2 * load_dwords);
                ptr = ac_cast_ptr(&ctx->ac, ptr, vec_type);
@@ -1705,13 +1728,21 @@ static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
                LLVMValueRef immoffset = LLVMConstInt(ctx->ac.i32, i * elem_size_bytes, false);
 
                LLVMValueRef ret;
-               if (load_bytes == 2) {
-                       ret = ac_build_tbuffer_load_short(&ctx->ac,
+
+               if (load_bytes == 1) {
+                       ret = ac_build_tbuffer_load_byte(&ctx->ac,
                                                          rsrc,
                                                          offset,
                                                          ctx->ac.i32_0,
                                                          immoffset,
                                                          cache_policy & ac_glc);
+               } else if (load_bytes == 2) {
+                       ret = ac_build_tbuffer_load_short(&ctx->ac,
+                                                        rsrc,
+                                                        offset,
+                                                        ctx->ac.i32_0,
+                                                        immoffset,
+                                                        cache_policy & ac_glc);
                } else {
                        int num_channels = util_next_power_of_two(load_bytes) / 4;
 
@@ -1751,15 +1782,29 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
        if (instr->dest.ssa.bit_size == 64)
                num_components *= 2;
 
-       if (instr->dest.ssa.bit_size == 16) {
+       if (instr->dest.ssa.bit_size == 16 || instr->dest.ssa.bit_size == 8) {
+               unsigned load_bytes = instr->dest.ssa.bit_size / 8;
                LLVMValueRef results[num_components];
                for (unsigned i = 0; i < num_components; ++i) {
-                       results[i] = ac_build_tbuffer_load_short(&ctx->ac,
-                                                                rsrc,
-                                                                offset,
-                                                                ctx->ac.i32_0,
-                                                                LLVMConstInt(ctx->ac.i32, 2 * i, 0),
-                                                                false);
+                       LLVMValueRef immoffset = LLVMConstInt(ctx->ac.i32,
+                                                             load_bytes * i, 0);
+
+                       if (load_bytes == 1) {
+                               results[i] = ac_build_tbuffer_load_byte(&ctx->ac,
+                                                                       rsrc,
+                                                                       offset,
+                                                                       ctx->ac.i32_0,
+                                                                       immoffset,
+                                                                       false);
+                       } else {
+                               assert(load_bytes == 2);
+                               results[i] = ac_build_tbuffer_load_short(&ctx->ac,
+                                                                        rsrc,
+                                                                        offset,
+                                                                        ctx->ac.i32_0,
+                                                                        immoffset,
+                                                                        false);
+                       }
                }
                ret = ac_build_gather_values(&ctx->ac, results, num_components);
        } else {