ac/nir: handle all lowered IO intrinsics
authorMarek Olšák <marek.olsak@amd.com>
Fri, 14 Aug 2020 22:08:20 +0000 (18:08 -0400)
committerMarek Olšák <marek.olsak@amd.com>
Thu, 3 Sep 2020 02:45:38 +0000 (22:45 -0400)
Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6445>

src/amd/llvm/ac_nir_to_llvm.c
src/amd/llvm/ac_shader_abi.h
src/amd/vulkan/radv_nir_to_llvm.c
src/gallium/drivers/radeonsi/si_shader_llvm_tess.c

index 6dc155dd94ae37b711d6e8d41ce928fab15e1f91..ed0cb8008f1a940f2c18a0e838384326787d33b5 100644 (file)
@@ -2269,6 +2269,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
 
        switch (mode) {
        case nir_var_shader_in:
+               /* TODO: remove this after RADV switches to lowered IO */
                if (ctx->stage == MESA_SHADER_TESS_CTRL ||
                    ctx->stage == MESA_SHADER_TESS_EVAL) {
                        return load_tess_varyings(ctx, instr, true);
@@ -2324,6 +2325,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                }
                break;
        case nir_var_shader_out:
+               /* TODO: remove this after RADV switches to lowered IO */
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
                        return load_tess_varyings(ctx, instr, false);
                }
@@ -2444,7 +2446,7 @@ visit_store_var(struct ac_nir_context *ctx,
 
        switch (deref->mode) {
        case nir_var_shader_out:
-
+               /* TODO: remove this after RADV switches to lowered IO */
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
                        LLVMValueRef vertex_index = NULL;
                        LLVMValueRef indir_index = NULL;
@@ -2459,7 +2461,9 @@ visit_store_var(struct ac_nir_context *ctx,
 
                        ctx->abi->store_tcs_outputs(ctx->abi, var,
                                                    vertex_index, indir_index,
-                                                   const_index, src, writemask);
+                                                   const_index, src, writemask,
+                                                   var->data.location_frac,
+                                                   var->data.driver_location);
                        break;
                }
 
@@ -2581,6 +2585,71 @@ visit_store_var(struct ac_nir_context *ctx,
                ac_build_endif(&ctx->ac, 7002);
 }
 
+static void
+visit_store_output(struct ac_nir_context *ctx, nir_intrinsic_instr *instr)
+{
+       if (ctx->ac.postponed_kill) {
+               LLVMValueRef cond = LLVMBuildLoad(ctx->ac.builder,
+                                                  ctx->ac.postponed_kill, "");
+               ac_build_ifcc(&ctx->ac, cond, 7002);
+       }
+
+       unsigned base = nir_intrinsic_base(instr);
+       unsigned writemask = nir_intrinsic_write_mask(instr);
+       unsigned component = nir_intrinsic_component(instr);
+       LLVMValueRef src = ac_to_float(&ctx->ac, get_src(ctx, instr->src[0]));
+       nir_src offset = *nir_get_io_offset_src(instr);
+       LLVMValueRef indir_index = NULL;
+
+       if (nir_src_is_const(offset))
+               assert(nir_src_as_uint(offset) == 0);
+       else
+               indir_index = get_src(ctx, offset);
+
+       switch (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src))) {
+       case 32:
+               break;
+       case 64:
+               writemask = widen_mask(writemask, 2);
+               src = LLVMBuildBitCast(ctx->ac.builder, src,
+                                      LLVMVectorType(ctx->ac.f32, ac_get_llvm_num_components(src) * 2),
+                                      "");
+               break;
+       default:
+               unreachable("unhandled store_output bit size");
+               return;
+       }
+
+       writemask <<= component;
+
+       if (ctx->stage == MESA_SHADER_TESS_CTRL) {
+               nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+               LLVMValueRef vertex_index =
+                               vertex_index_src ? get_src(ctx, *vertex_index_src) : NULL;
+
+               ctx->abi->store_tcs_outputs(ctx->abi, NULL,
+                                           vertex_index, indir_index,
+                                           0, src, writemask,
+                                           component, base * 4);
+               return;
+       }
+
+       /* No indirect indexing is allowed after this point. */
+       assert(!indir_index);
+
+       for (unsigned chan = 0; chan < 8; chan++) {
+               if (!(writemask & (1 << chan)))
+                       continue;
+
+               LLVMValueRef value = ac_llvm_extract_elem(&ctx->ac, src, chan - component);
+               LLVMBuildStore(ctx->ac.builder, value,
+                              ctx->abi->outputs[base * 4 + chan]);
+       }
+
+       if (ctx->ac.postponed_kill)
+               ac_build_endif(&ctx->ac, 7002);
+}
+
 static int image_type_to_components_count(enum glsl_sampler_dim dim, bool array)
 {
        switch (dim) {
@@ -3578,18 +3647,82 @@ static LLVMValueRef load_interpolated_input(struct ac_nir_context *ctx,
        return ac_to_integer(&ctx->ac, ac_build_gather_values(&ctx->ac, values, num_components));
 }
 
-static LLVMValueRef load_input(struct ac_nir_context *ctx,
-                              nir_intrinsic_instr *instr)
+static LLVMValueRef visit_load(struct ac_nir_context *ctx,
+                              nir_intrinsic_instr *instr, bool is_output)
 {
-       unsigned offset_idx = instr->intrinsic == nir_intrinsic_load_input ? 0 : 1;
+       LLVMValueRef values[8];
+       LLVMTypeRef dest_type = get_def_type(ctx, &instr->dest.ssa);
+       LLVMTypeRef component_type;
+       unsigned base = nir_intrinsic_base(instr);
+       unsigned component = nir_intrinsic_component(instr);
+       unsigned count = instr->dest.ssa.num_components *
+                        (instr->dest.ssa.bit_size == 64 ? 2 : 1);
+       nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+       LLVMValueRef vertex_index =
+               vertex_index_src ? get_src(ctx, *vertex_index_src) : NULL;
+       nir_src offset = *nir_get_io_offset_src(instr);
+       LLVMValueRef indir_index = NULL;
 
-       /* We only lower inputs for fragment shaders ATM */
-       ASSERTED nir_const_value *offset = nir_src_as_const_value(instr->src[offset_idx]);
-       assert(offset);
-       assert(offset[0].i32 == 0);
+       if (LLVMGetTypeKind(dest_type) == LLVMVectorTypeKind)
+               component_type = LLVMGetElementType(dest_type);
+       else
+               component_type = dest_type;
 
-       unsigned component = nir_intrinsic_component(instr);
-       unsigned index = nir_intrinsic_base(instr);
+       if (nir_src_is_const(offset))
+               assert(nir_src_as_uint(offset) == 0);
+       else
+               indir_index = get_src(ctx, offset);
+
+       if (ctx->stage == MESA_SHADER_TESS_CTRL ||
+           (ctx->stage == MESA_SHADER_TESS_EVAL && !is_output)) {
+               LLVMValueRef result =
+                       ctx->abi->load_tess_varyings(ctx->abi, component_type,
+                                                    vertex_index, indir_index,
+                                                    0, 0, base * 4,
+                                                    component,
+                                                    instr->num_components,
+                                                    false, false, !is_output);
+               if (instr->dest.ssa.bit_size == 16) {
+                       result = ac_to_integer(&ctx->ac, result);
+                       result = LLVMBuildTrunc(ctx->ac.builder, result, dest_type, "");
+               }
+               return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
+       }
+
+       /* No indirect indexing is allowed after this point. */
+       assert(!indir_index);
+
+       if (ctx->stage == MESA_SHADER_GEOMETRY) {
+               LLVMTypeRef type = LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.bit_size);
+               assert(nir_src_is_const(*vertex_index_src));
+
+               return ctx->abi->load_inputs(ctx->abi, 0, base * 4, component,
+                                            instr->num_components,
+                                            nir_src_as_uint(*vertex_index_src),
+                                            0, type);
+       }
+
+       if (ctx->stage == MESA_SHADER_FRAGMENT && is_output &&
+           nir_intrinsic_io_semantics(instr).fb_fetch_output)
+               return ctx->abi->emit_fbfetch(ctx->abi);
+
+       /* Other non-fragment cases have inputs and outputs in temporaries. */
+       if (ctx->stage != MESA_SHADER_FRAGMENT) {
+               for (unsigned chan = component; chan < count + component; chan++) {
+                       if (is_output) {
+                               values[chan] = LLVMBuildLoad(ctx->ac.builder,
+                                                            ctx->abi->outputs[base * 4 + chan], "");
+                       } else {
+                               values[chan] = ctx->abi->inputs[base * 4 + chan];
+                               if (!values[chan])
+                                       values[chan] = LLVMGetUndef(ctx->ac.i32);
+                       }
+               }
+               LLVMValueRef result = ac_build_varying_gather_values(&ctx->ac, values, count, component);
+               return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
+       }
+
+       /* Fragment shader inputs. */
        unsigned vertex_id = 2; /* P0 */
 
        if (instr->intrinsic == nir_intrinsic_load_input_vertex) {
@@ -3610,18 +3743,11 @@ static LLVMValueRef load_input(struct ac_nir_context *ctx,
                }
        }
 
-       LLVMValueRef attr_number = LLVMConstInt(ctx->ac.i32, index, false);
-       LLVMValueRef values[8];
-
-       /* Each component of a 64-bit value takes up two GL-level channels. */
-       unsigned num_components = instr->dest.ssa.num_components;
-       unsigned bit_size = instr->dest.ssa.bit_size;
-       unsigned channels =
-               bit_size == 64 ? num_components * 2 : num_components;
+       LLVMValueRef attr_number = LLVMConstInt(ctx->ac.i32, base, false);
 
-       for (unsigned chan = 0; chan < channels; chan++) {
+       for (unsigned chan = 0; chan < count; chan++) {
                if (component + chan > 4)
-                       attr_number = LLVMConstInt(ctx->ac.i32, index + 1, false);
+                       attr_number = LLVMConstInt(ctx->ac.i32, base + 1, false);
                LLVMValueRef llvm_chan = LLVMConstInt(ctx->ac.i32, (component + chan) % 4, false);
                values[chan] = ac_build_fs_interp_mov(&ctx->ac,
                                                      LLVMConstInt(ctx->ac.i32, vertex_id, false),
@@ -3630,16 +3756,12 @@ static LLVMValueRef load_input(struct ac_nir_context *ctx,
                                                      ac_get_arg(&ctx->ac, ctx->args->prim_mask));
                values[chan] = LLVMBuildBitCast(ctx->ac.builder, values[chan], ctx->ac.i32, "");
                values[chan] = LLVMBuildTruncOrBitCast(ctx->ac.builder, values[chan],
-                                                      bit_size == 16 ? ctx->ac.i16 : ctx->ac.i32, "");
+                                                      instr->dest.ssa.bit_size == 16 ? ctx->ac.i16
+                                                                                     : ctx->ac.i32, "");
        }
 
-       LLVMValueRef result = ac_build_gather_values(&ctx->ac, values, channels);
-       if (bit_size == 64) {
-               LLVMTypeRef type = num_components == 1 ? ctx->ac.i64 :
-                       LLVMVectorType(ctx->ac.i64, num_components);
-               result = LLVMBuildBitCast(ctx->ac.builder, result, type, "");
-       }
-       return result;
+       LLVMValueRef result = ac_build_gather_values(&ctx->ac, values, count);
+       return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
 }
 
 static void visit_intrinsic(struct ac_nir_context *ctx,
@@ -3836,6 +3958,19 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_store_deref:
                visit_store_var(ctx, instr);
                break;
+       case nir_intrinsic_load_input:
+       case nir_intrinsic_load_input_vertex:
+       case nir_intrinsic_load_per_vertex_input:
+               result = visit_load(ctx, instr, false);
+               break;
+       case nir_intrinsic_load_output:
+       case nir_intrinsic_load_per_vertex_output:
+               result = visit_load(ctx, instr, true);
+               break;
+       case nir_intrinsic_store_output:
+       case nir_intrinsic_store_per_vertex_output:
+               visit_store_output(ctx, instr);
+               break;
        case nir_intrinsic_load_shared:
                result = visit_load_shared(ctx, instr);
                break;
@@ -4003,10 +4138,6 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                                                 instr->dest.ssa.bit_size);
                break;
        }
-       case nir_intrinsic_load_input:
-       case nir_intrinsic_load_input_vertex:
-               result = load_input(ctx, instr);
-               break;
        case nir_intrinsic_emit_vertex:
                ctx->abi->emit_vertex(ctx->abi, nir_intrinsic_stream_id(instr), ctx->abi->outputs);
                break;
@@ -5339,9 +5470,13 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 
        ctx.main_function = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
 
-       nir_foreach_shader_out_variable(variable, nir)
-               ac_handle_shader_output_decl(&ctx.ac, ctx.abi, nir, variable,
-                                            ctx.stage);
+       /* TODO: remove this after RADV switches to lowered IO */
+       if (!nir->info.io_lowered) {
+               nir_foreach_shader_out_variable(variable, nir) {
+                       ac_handle_shader_output_decl(&ctx.ac, ctx.abi, nir, variable,
+                                                    ctx.stage);
+               }
+       }
 
        ctx.defs = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
                                           _mesa_key_pointer_equal);
index 80b1554ea3e0aad077057eaf7327edf1f089bf67..359e9484fc2b651a05010fe4b96e3f7282825583 100644 (file)
@@ -113,7 +113,9 @@ struct ac_shader_abi {
                                  LLVMValueRef param_index,
                                  unsigned const_index,
                                  LLVMValueRef src,
-                                 unsigned writemask);
+                                 unsigned writemask,
+                                 unsigned component,
+                                 unsigned driver_location);
 
        LLVMValueRef (*load_tess_coord)(struct ac_shader_abi *abi);
 
index b962f25e6d6d9c540829154dc7e1d5871f60d9d4..db21ad809b784e424da9b5839300c2979ce65210 100644 (file)
@@ -589,11 +589,12 @@ store_tcs_output(struct ac_shader_abi *abi,
                 LLVMValueRef param_index,
                 unsigned const_index,
                 LLVMValueRef src,
-                unsigned writemask)
+                unsigned writemask,
+                unsigned component,
+                unsigned driver_location)
 {
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
        const unsigned location = var->data.location;
-       unsigned component = var->data.location_frac;
        const bool is_patch = var->data.patch;
        const bool is_compact = var->data.compact;
        LLVMValueRef dw_addr;
index 13bed5f2569757d316b15c0af904daddbde64145..e54b4a8cf18fc57b71451275458a7c045246f291 100644 (file)
@@ -509,12 +509,11 @@ static LLVMValueRef si_nir_load_input_tes(struct ac_shader_abi *abi, LLVMTypeRef
 
 static void si_nir_store_output_tcs(struct ac_shader_abi *abi, const struct nir_variable *var,
                                     LLVMValueRef vertex_index, LLVMValueRef param_index,
-                                    unsigned const_index, LLVMValueRef src, unsigned writemask)
+                                    unsigned const_index, LLVMValueRef src, unsigned writemask,
+                                    unsigned component, unsigned driver_location)
 {
    struct si_shader_context *ctx = si_shader_context_from_abi(abi);
    struct si_shader_info *info = &ctx->shader->selector->info;
-   unsigned component = var->data.location_frac;
-   unsigned driver_location = var->data.driver_location;
    LLVMValueRef dw_addr, stride;
    LLVMValueRef buffer, base, addr;
    LLVMValueRef values[8];