freedreno/ir3: refactor NIR IR handling
[mesa.git] / src / gallium / drivers / freedreno / ir3 / ir3_compiler_nir.c
index eb24120c4c747c8c6885c8dadc86df2b3feb7e34..0a25d5252a11b15f5a8f1962e57561479430871a 100644 (file)
 #include "util/u_string.h"
 #include "util/u_memory.h"
 #include "util/u_inlines.h"
-#include "tgsi/tgsi_lowering.h"
-#include "tgsi/tgsi_strings.h"
-
-#include "nir/tgsi_to_nir.h"
 
 #include "freedreno_util.h"
 
@@ -123,97 +119,10 @@ struct ir3_compile {
 static struct ir3_instruction * create_immed(struct ir3_block *block, uint32_t val);
 static struct ir3_block * get_block(struct ir3_compile *ctx, nir_block *nblock);
 
-static struct nir_shader *to_nir(struct ir3_compile *ctx,
-               const struct tgsi_token *tokens, struct ir3_shader_variant *so)
-{
-       static const nir_shader_compiler_options options = {
-                       .lower_fpow = true,
-                       .lower_fsat = true,
-                       .lower_scmp = true,
-                       .lower_flrp = true,
-                       .lower_ffract = true,
-                       .native_integers = true,
-       };
-       struct nir_lower_tex_options tex_options = {
-                       .lower_rect = 0,
-       };
-       bool progress;
-
-       switch (so->type) {
-       case SHADER_FRAGMENT:
-       case SHADER_COMPUTE:
-               tex_options.saturate_s = so->key.fsaturate_s;
-               tex_options.saturate_t = so->key.fsaturate_t;
-               tex_options.saturate_r = so->key.fsaturate_r;
-               break;
-       case SHADER_VERTEX:
-               tex_options.saturate_s = so->key.vsaturate_s;
-               tex_options.saturate_t = so->key.vsaturate_t;
-               tex_options.saturate_r = so->key.vsaturate_r;
-               break;
-       }
-
-       if (ctx->compiler->gpu_id >= 400) {
-               /* a4xx seems to have *no* sam.p */
-               tex_options.lower_txp = ~0;  /* lower all txp */
-       } else {
-               /* a3xx just needs to avoid sam.p for 3d tex */
-               tex_options.lower_txp = (1 << GLSL_SAMPLER_DIM_3D);
-       }
-
-       struct nir_shader *s = tgsi_to_nir(tokens, &options);
-
-       if (fd_mesa_debug & FD_DBG_OPTMSGS) {
-               debug_printf("----------------------\n");
-               nir_print_shader(s, stdout);
-               debug_printf("----------------------\n");
-       }
-
-       nir_opt_global_to_local(s);
-       nir_convert_to_ssa(s);
-       if (s->stage == MESA_SHADER_VERTEX) {
-               nir_lower_clip_vs(s, so->key.ucp_enables);
-       } else if (s->stage == MESA_SHADER_FRAGMENT) {
-               nir_lower_clip_fs(s, so->key.ucp_enables);
-       }
-       nir_lower_tex(s, &tex_options);
-       if (so->key.color_two_side)
-               nir_lower_two_sided_color(s);
-       nir_lower_idiv(s);
-       nir_lower_load_const_to_scalar(s);
-
-       do {
-               progress = false;
-
-               nir_lower_vars_to_ssa(s);
-               nir_lower_alu_to_scalar(s);
-               nir_lower_phis_to_scalar(s);
-
-               progress |= nir_copy_prop(s);
-               progress |= nir_opt_dce(s);
-               progress |= nir_opt_cse(s);
-               progress |= ir3_nir_lower_if_else(s);
-               progress |= nir_opt_algebraic(s);
-               progress |= nir_opt_constant_folding(s);
-
-       } while (progress);
-
-       nir_remove_dead_variables(s);
-       nir_validate_shader(s);
-
-       if (fd_mesa_debug & FD_DBG_OPTMSGS) {
-               debug_printf("----------------------\n");
-               nir_print_shader(s, stdout);
-               debug_printf("----------------------\n");
-       }
-
-       return s;
-}
 
 static struct ir3_compile *
 compile_init(struct ir3_compiler *compiler,
-               struct ir3_shader_variant *so,
-               const struct tgsi_token *tokens)
+               struct ir3_shader_variant *so)
 {
        struct ir3_compile *ctx = rzalloc(NULL, struct ir3_compile);
 
@@ -236,12 +145,31 @@ compile_init(struct ir3_compiler *compiler,
                        _mesa_hash_pointer, _mesa_key_pointer_equal);
        ctx->var_ht = _mesa_hash_table_create(ctx,
                        _mesa_hash_pointer, _mesa_key_pointer_equal);
-       ctx->addr_ht = _mesa_hash_table_create(ctx,
-                       _mesa_hash_pointer, _mesa_key_pointer_equal);
        ctx->block_ht = _mesa_hash_table_create(ctx,
                        _mesa_hash_pointer, _mesa_key_pointer_equal);
 
-       ctx->s = to_nir(ctx, tokens, so);
+       /* TODO: maybe generate some sort of bitmask of what key
+        * lowers vs what shader has (ie. no need to lower
+        * texture clamp lowering if no texture sample instrs)..
+        * although should be done further up the stack to avoid
+        * creating duplicate variants..
+        */
+
+       if (ir3_key_lowers_nir(&so->key)) {
+               nir_shader *s = nir_shader_clone(ctx, so->shader->nir);
+               ctx->s = ir3_optimize_nir(so->shader, s, &so->key);
+       } else {
+               /* fast-path for shader key that lowers nothing in NIR: */
+               ctx->s = so->shader->nir;
+       }
+
+       if (fd_mesa_debug & FD_DBG_DISASM) {
+               DBG("dump nir%dv%d: type=%d, k={bp=%u,cts=%u,hp=%u}",
+                       so->shader->id, so->id, so->type,
+                       so->key.binning_pass, so->key.color_two_side,
+                       so->key.half_precision);
+               nir_print_shader(ctx->s, stdout);
+       }
 
        so->first_driver_param = so->first_immediate = ctx->s->num_uniforms;
 
@@ -583,12 +511,17 @@ static struct ir3_instruction *
 get_addr(struct ir3_compile *ctx, struct ir3_instruction *src)
 {
        struct ir3_instruction *addr;
-       struct hash_entry *entry;
-       entry = _mesa_hash_table_search(ctx->addr_ht, src);
-       if (entry)
-               return entry->data;
 
-       /* TODO do we need to cache per block? */
+       if (!ctx->addr_ht) {
+               ctx->addr_ht = _mesa_hash_table_create(ctx,
+                               _mesa_hash_pointer, _mesa_key_pointer_equal);
+       } else {
+               struct hash_entry *entry;
+               entry = _mesa_hash_table_search(ctx->addr_ht, src);
+               if (entry)
+                       return entry->data;
+       }
+
        addr = create_addr(ctx->block, src);
        _mesa_hash_table_insert(ctx->addr_ht, src, addr);
 
@@ -1215,6 +1148,7 @@ emit_intrinsic_load_ubo(struct ir3_compile *ctx, nir_intrinsic_instr *intr,
 {
        struct ir3_block *b = ctx->block;
        struct ir3_instruction *addr, *src0, *src1;
+       nir_const_value *const_offset;
        /* UBO addresses are the first driver params: */
        unsigned ubo = regid(ctx->so->first_driver_param + IR3_UBOS_OFF, 0);
        unsigned off = intr->const_index[0];
@@ -1228,7 +1162,10 @@ emit_intrinsic_load_ubo(struct ir3_compile *ctx, nir_intrinsic_instr *intr,
                addr = create_uniform_indirect(ctx, ubo, get_addr(ctx, src0));
        }
 
-       if (intr->intrinsic == nir_intrinsic_load_ubo_indirect) {
+       const_offset = nir_src_as_const_value(intr->src[1]);
+       if (const_offset) {
+               off += const_offset->u[0];
+       } else {
                /* For load_ubo_indirect, second src is indirect offset: */
                src1 = get_src(ctx, &intr->src[1])[0];
 
@@ -1257,7 +1194,7 @@ emit_intrinsic_load_ubo(struct ir3_compile *ctx, nir_intrinsic_instr *intr,
 
 /* handles array reads: */
 static void
-emit_intrinisic_load_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr,
+emit_intrinsic_load_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr,
                struct ir3_instruction **dst)
 {
        nir_deref_var *dvar = intr->variables[0];
@@ -1298,7 +1235,7 @@ emit_intrinisic_load_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr,
 
 /* handles array writes: */
 static void
-emit_intrinisic_store_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
+emit_intrinsic_store_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
 {
        nir_deref_var *dvar = intr->variables[0];
        nir_deref_array *darr = nir_deref_as_array(dvar->deref.child);
@@ -1314,6 +1251,10 @@ emit_intrinisic_store_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
        case nir_deref_array_type_direct:
                /* direct access does not require anything special: */
                for (int i = 0; i < intr->num_components; i++) {
+                       /* ttn doesn't generate partial writemasks */
+                       assert(intr->const_index[0] ==
+                              (1 << intr->num_components) - 1);
+
                        unsigned n = darr->base_offset * 4 + i;
                        compile_assert(ctx, n < arr->length);
                        arr->arr[n] = src[i];
@@ -1326,6 +1267,10 @@ emit_intrinisic_store_var(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
                struct ir3_instruction *addr =
                                get_addr(ctx, get_src(ctx, &darr->indirect)[0]);
                for (int i = 0; i < intr->num_components; i++) {
+                       /* ttn doesn't generate partial writemasks */
+                       assert(intr->const_index[0] ==
+                              (1 << intr->num_components) - 1);
+
                        struct ir3_instruction *store;
                        unsigned n = darr->base_offset * 4 + i;
                        compile_assert(ctx, n < arr->length);
@@ -1385,12 +1330,13 @@ static void add_sysval_input(struct ir3_compile *ctx, gl_system_value slot,
 }
 
 static void
-emit_intrinisic(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
+emit_intrinsic(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
 {
        const nir_intrinsic_info *info = &nir_intrinsic_infos[intr->intrinsic];
        struct ir3_instruction **dst, **src;
        struct ir3_block *b = ctx->block;
        unsigned idx = intr->const_index[0];
+       nir_const_value *const_offset;
 
        if (info->has_dest) {
                dst = get_dst(ctx, &intr->dest, intr->num_components);
@@ -1400,52 +1346,62 @@ emit_intrinisic(struct ir3_compile *ctx, nir_intrinsic_instr *intr)
 
        switch (intr->intrinsic) {
        case nir_intrinsic_load_uniform:
-               for (int i = 0; i < intr->num_components; i++) {
-                       unsigned n = idx * 4 + i;
-                       dst[i] = create_uniform(ctx, n);
-               }
-               break;
-       case nir_intrinsic_load_uniform_indirect:
-               src = get_src(ctx, &intr->src[0]);
-               for (int i = 0; i < intr->num_components; i++) {
-                       unsigned n = idx * 4 + i;
-                       dst[i] = create_uniform_indirect(ctx, n,
-                                       get_addr(ctx, src[0]));
+               const_offset = nir_src_as_const_value(intr->src[0]);
+               if (const_offset) {
+                       idx += const_offset->u[0];
+                       for (int i = 0; i < intr->num_components; i++) {
+                               unsigned n = idx * 4 + i;
+                               dst[i] = create_uniform(ctx, n);
+                       }
+               } else {
+                       src = get_src(ctx, &intr->src[0]);
+                       for (int i = 0; i < intr->num_components; i++) {
+                               unsigned n = idx * 4 + i;
+                               dst[i] = create_uniform_indirect(ctx, n,
+                                               get_addr(ctx, src[0]));
+                       }
+                       /* NOTE: if relative addressing is used, we set
+                        * constlen in the compiler (to worst-case value)
+                        * since we don't know in the assembler what the max
+                        * addr reg value can be:
+                        */
+                       ctx->so->constlen = ctx->s->num_uniforms;
                }
-               /* NOTE: if relative addressing is used, we set constlen in
-                * the compiler (to worst-case value) since we don't know in
-                * the assembler what the max addr reg value can be:
-                */
-               ctx->so->constlen = ctx->s->num_uniforms;
                break;
        case nir_intrinsic_load_ubo:
-       case nir_intrinsic_load_ubo_indirect:
                emit_intrinsic_load_ubo(ctx, intr, dst);
                break;
        case nir_intrinsic_load_input:
-               for (int i = 0; i < intr->num_components; i++) {
-                       unsigned n = idx * 4 + i;
-                       dst[i] = ctx->ir->inputs[n];
-               }
-               break;
-       case nir_intrinsic_load_input_indirect:
-               src = get_src(ctx, &intr->src[0]);
-               struct ir3_instruction *collect =
-                               create_collect(b, ctx->ir->inputs, ctx->ir->ninputs);
-               struct ir3_instruction *addr = get_addr(ctx, src[0]);
-               for (int i = 0; i < intr->num_components; i++) {
-                       unsigned n = idx * 4 + i;
-                       dst[i] = create_indirect_load(ctx, ctx->ir->ninputs,
-                                       n, addr, collect);
+               const_offset = nir_src_as_const_value(intr->src[0]);
+               if (const_offset) {
+                       idx += const_offset->u[0];
+                       for (int i = 0; i < intr->num_components; i++) {
+                               unsigned n = idx * 4 + i;
+                               dst[i] = ctx->ir->inputs[n];
+                       }
+               } else {
+                       src = get_src(ctx, &intr->src[0]);
+                       struct ir3_instruction *collect =
+                                       create_collect(b, ctx->ir->inputs, ctx->ir->ninputs);
+                       struct ir3_instruction *addr = get_addr(ctx, src[0]);
+                       for (int i = 0; i < intr->num_components; i++) {
+                               unsigned n = idx * 4 + i;
+                               dst[i] = create_indirect_load(ctx, ctx->ir->ninputs,
+                                               n, addr, collect);
+                       }
                }
                break;
        case nir_intrinsic_load_var:
-               emit_intrinisic_load_var(ctx, intr, dst);
+               emit_intrinsic_load_var(ctx, intr, dst);
                break;
        case nir_intrinsic_store_var:
-               emit_intrinisic_store_var(ctx, intr);
+               emit_intrinsic_store_var(ctx, intr);
                break;
        case nir_intrinsic_store_output:
+               const_offset = nir_src_as_const_value(intr->src[1]);
+               compile_assert(ctx, const_offset != NULL);
+               idx += const_offset->u[0];
+
                src = get_src(ctx, &intr->src[0]);
                for (int i = 0; i < intr->num_components; i++) {
                        unsigned n = idx * 4 + i;
@@ -1909,7 +1865,7 @@ emit_instr(struct ir3_compile *ctx, nir_instr *instr)
                emit_alu(ctx, nir_instr_as_alu(instr));
                break;
        case nir_instr_type_intrinsic:
-               emit_intrinisic(ctx, nir_instr_as_intrinsic(instr));
+               emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
                break;
        case nir_instr_type_load_const:
                emit_load_const(ctx, nir_instr_as_load_const(instr));
@@ -1928,8 +1884,6 @@ emit_instr(struct ir3_compile *ctx, nir_instr *instr)
                case nir_texop_query_levels:
                        emit_tex_query_levels(ctx, tex);
                        break;
-               case nir_texop_samples_identical:
-                       unreachable("nir_texop_samples_identical");
                default:
                        emit_tex(ctx, tex);
                        break;
@@ -1980,6 +1934,10 @@ emit_block(struct ir3_compile *ctx, nir_block *nblock)
        ctx->block = block;
        list_addtail(&block->node, &ctx->ir->block_list);
 
+       /* re-emit addr register in each block if needed: */
+       _mesa_hash_table_destroy(ctx->addr_ht, NULL);
+       ctx->addr_ht = NULL;
+
        nir_foreach_instr(nblock, instr) {
                emit_instr(ctx, instr);
                if (ctx->error)
@@ -2321,10 +2279,10 @@ emit_instructions(struct ir3_compile *ctx)
        nir_function_impl *fxn = NULL;
 
        /* Find the main function: */
-       nir_foreach_overload(ctx->s, overload) {
-               compile_assert(ctx, strcmp(overload->function->name, "main") == 0);
-               compile_assert(ctx, overload->impl);
-               fxn = overload->impl;
+       nir_foreach_function(ctx->s, function) {
+               compile_assert(ctx, strcmp(function->name, "main") == 0);
+               compile_assert(ctx, function->impl);
+               fxn = function->impl;
                break;
        }
 
@@ -2469,7 +2427,7 @@ ir3_compile_shader_nir(struct ir3_compiler *compiler,
 
        assert(!so->ir);
 
-       ctx = compile_init(compiler, so, so->shader->tokens);
+       ctx = compile_init(compiler, so);
        if (!ctx) {
                DBG("INIT failed!");
                ret = -1;