nir: allow specifying filter callback in lower_alu_to_scalar
authorVasily Khoruzhick <anarsoul@gmail.com>
Fri, 30 Aug 2019 04:14:54 +0000 (21:14 -0700)
committerVasily Khoruzhick <anarsoul@gmail.com>
Fri, 6 Sep 2019 01:51:28 +0000 (01:51 +0000)
Set of opcodes doesn't have enough flexibility in certain cases. E.g.
Utgard PP has vector conditional select operation, but condition is always
scalar. Lowering all the vector selects to scalar increases instruction
number, so we need a way to filter only those ops that can't be handled
in hardware.

Reviewed-by: Qiang Yu <yuq825@gmail.com>
Reviewed-by: Eric Anholt <eric@anholt.net>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Signed-off-by: Vasily Khoruzhick <anarsoul@gmail.com>
16 files changed:
src/amd/vulkan/radv_shader.c
src/broadcom/compiler/nir_to_vir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_alu_to_scalar.c
src/freedreno/ir3/ir3_nir.c
src/gallium/auxiliary/nir/tgsi_to_nir.c
src/gallium/drivers/etnaviv/etnaviv_compiler_nir.c
src/gallium/drivers/freedreno/a2xx/ir2_nir.c
src/gallium/drivers/lima/lima_program.c
src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp
src/gallium/drivers/radeonsi/si_shader_nir.c
src/gallium/drivers/vc4/vc4_program.c
src/intel/compiler/brw_nir.c
src/mesa/state_tracker/st_glsl_to_nir.cpp
src/panfrost/bifrost/bifrost_compile.c
src/panfrost/bifrost/cmdline.c

index 1ab64a6e328d33ac1c43bfcfb9b20880fdcd7749..f90689e85b588e156c93c30589be230d46b04e6a 100644 (file)
@@ -200,7 +200,7 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively,
                NIR_PASS(progress, shader, nir_remove_dead_variables,
                         nir_var_function_temp);
 
-                NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL);
+                NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL, NULL);
                 NIR_PASS_V(shader, nir_lower_phis_to_scalar);
 
                 NIR_PASS(progress, shader, nir_copy_prop);
index 91e95f9ee5ad65dd16d49e199cccdf26c6935d8f..b640dcc341b905b3d73351bd08702edcb6055a5f 100644 (file)
@@ -1382,7 +1382,7 @@ v3d_optimize_nir(struct nir_shader *s)
                 progress = false;
 
                 NIR_PASS_V(s, nir_lower_vars_to_ssa);
-                NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL);
+                NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
                 NIR_PASS(progress, s, nir_lower_phis_to_scalar);
                 NIR_PASS(progress, s, nir_copy_prop);
                 NIR_PASS(progress, s, nir_opt_remove_phis);
index 5149a0e8c013668fc3c4089d1b5fdbdaf63a8491..bad1d6af212bcff23b8408a386b2ed2dfd3f03f6 100644 (file)
@@ -3606,7 +3606,7 @@ bool nir_lower_alu(nir_shader *shader);
 bool nir_lower_flrp(nir_shader *shader, unsigned lowering_mask,
                     bool always_precise, bool have_ffma);
 
-bool nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set);
+bool nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *data);
 bool nir_lower_bool_to_float(nir_shader *shader);
 bool nir_lower_bool_to_int32(nir_shader *shader);
 bool nir_lower_int_to_float(nir_shader *shader);
index b16624bd8aaedc5b3878b99b476b0430bb772371..bcd92908253df51f98bb496bfc6d6184f014c25e 100644 (file)
 #include "nir.h"
 #include "nir_builder.h"
 
+struct alu_to_scalar_data {
+   nir_instr_filter_cb cb;
+   const void *data;
+};
+
 /** @file nir_lower_alu_to_scalar.c
  *
  * Replaces nir_alu_instr operations with more than one channel used in the
@@ -89,9 +94,9 @@ lower_reduction(nir_alu_instr *alu, nir_op chan_op, nir_op merge_op,
 }
 
 static nir_ssa_def *
-lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
+lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
 {
-   BITSET_WORD *lower_set = _state;
+   struct alu_to_scalar_data *data = _data;
    nir_alu_instr *alu = nir_instr_as_alu(instr);
    unsigned num_src = nir_op_infos[alu->op].num_inputs;
    unsigned i, chan;
@@ -102,7 +107,7 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
    b->cursor = nir_before_instr(&alu->instr);
    b->exact = alu->exact;
 
-   if (lower_set && !BITSET_TEST(lower_set, alu->op))
+   if (data->cb && !data->cb(instr, data->data))
       return NULL;
 
 #define LOWER_REDUCTION(name, chan, merge) \
@@ -246,10 +251,15 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
 }
 
 bool
-nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set)
+nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *_data)
 {
+   struct alu_to_scalar_data data = {
+      .cb = cb,
+      .data = _data,
+   };
+
    return nir_shader_lower_instructions(shader,
                                         inst_is_vector_alu,
                                         lower_alu_instr_scalar,
-                                        lower_set);
+                                        &data);
 }
index 05426bc7a0ec9fa551d42bcd9433dc7ba0015951..50a961f2bad56d108f16a94473cf9769b030d0e8 100644 (file)
@@ -124,7 +124,7 @@ ir3_optimize_loop(nir_shader *s)
                OPT_V(s, nir_lower_vars_to_ssa);
                progress |= OPT(s, nir_opt_copy_prop_vars);
                progress |= OPT(s, nir_opt_dead_write_vars);
-               progress |= OPT(s, nir_lower_alu_to_scalar, NULL);
+               progress |= OPT(s, nir_lower_alu_to_scalar, NULL, NULL);
                progress |= OPT(s, nir_lower_phis_to_scalar);
 
                progress |= OPT(s, nir_copy_prop);
index 20d6c0bfb29a3e49f7e51b44944a1e56f57f1866..eae2ef058d27341b48b0fe544686a3232b1d9ab4 100644 (file)
@@ -2559,7 +2559,7 @@ ttn_optimize_nir(nir_shader *nir, bool scalar)
       NIR_PASS_V(nir, nir_lower_vars_to_ssa);
 
       if (scalar) {
-         NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL);
+         NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
          NIR_PASS_V(nir, nir_lower_phis_to_scalar);
       }
 
index dc6756af257078fa13155b28370571dd894ea278..c2bfec14ed21deacbfaa67781a4107760a8838ee 100644 (file)
@@ -208,25 +208,34 @@ etna_lower_io(nir_shader *shader, struct etna_shader_variant *v)
    }
 }
 
-static void
-etna_lower_alu_to_scalar(nir_shader *shader, const struct etna_specs *specs)
+static bool
+etna_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
 {
-   BITSET_DECLARE(scalar_ops, nir_num_opcodes);
-   BITSET_ZERO(scalar_ops);
-
-   BITSET_SET(scalar_ops, nir_op_frsq);
-   BITSET_SET(scalar_ops, nir_op_frcp);
-   BITSET_SET(scalar_ops, nir_op_flog2);
-   BITSET_SET(scalar_ops, nir_op_fexp2);
-   BITSET_SET(scalar_ops, nir_op_fsqrt);
-   BITSET_SET(scalar_ops, nir_op_fcos);
-   BITSET_SET(scalar_ops, nir_op_fsin);
-   BITSET_SET(scalar_ops, nir_op_fdiv);
-
-   if (!specs->has_halti2_instructions)
-      BITSET_SET(scalar_ops, nir_op_fdot2);
-
-   nir_lower_alu_to_scalar(shader, scalar_ops);
+   const struct etna_specs *specs = data;
+
+   if (instr->type != nir_instr_type_alu)
+      return false;
+
+   nir_alu_instr *alu = nir_instr_as_alu(instr);
+   switch (alu->op) {
+   case nir_op_frsq:
+   case nir_op_frcp:
+   case nir_op_flog2:
+   case nir_op_fexp2:
+   case nir_op_fsqrt:
+   case nir_op_fcos:
+   case nir_op_fsin:
+   case nir_op_fdiv:
+      return true;
+   case nir_op_fdot2:
+      if (!specs->has_halti2_instructions)
+         return true;
+      break;
+   default:
+      break;
+   }
+
+   return false;
 }
 
 static void
@@ -607,7 +616,7 @@ etna_compile_shader_nir(struct etna_shader_variant *v)
    OPT_V(s, nir_lower_vars_to_ssa);
    OPT_V(s, nir_lower_indirect_derefs, nir_var_all);
    OPT_V(s, nir_lower_tex, &(struct nir_lower_tex_options) { .lower_txp = ~0u });
-   OPT_V(s, etna_lower_alu_to_scalar, specs);
+   OPT_V(s, nir_lower_alu_to_scalar, etna_alu_to_scalar_filter_cb, specs);
 
    etna_optimize_loop(s);
 
@@ -627,7 +636,7 @@ etna_compile_shader_nir(struct etna_shader_variant *v)
       nir_print_shader(s, stdout);
 
    while( OPT(s, nir_opt_vectorize) );
-   OPT_V(s, etna_lower_alu_to_scalar, specs);
+   OPT_V(s, nir_lower_alu_to_scalar, etna_alu_to_scalar_filter_cb, specs);
 
    NIR_PASS_V(s, nir_remove_dead_variables, nir_var_function_temp);
    NIR_PASS_V(s, nir_opt_algebraic_late);
index 9d36f7092ef1f6d86200adbaa2c0d024c0d5aadf..6915194234db74957d70bdf0e6387e9cfadb25a4 100644 (file)
@@ -1062,6 +1062,29 @@ static void cleanup_binning(struct ir2_context *ctx)
        ir2_optimize_nir(ctx->nir, false);
 }
 
+static bool
+ir2_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
+{
+       if (instr->type != nir_instr_type_alu)
+               return false;
+
+       nir_alu_instr *alu = nir_instr_as_alu(instr);
+       switch (alu->op) {
+       case nir_op_frsq:
+       case nir_op_frcp:
+       case nir_op_flog2:
+       case nir_op_fexp2:
+       case nir_op_fsqrt:
+       case nir_op_fcos:
+       case nir_op_fsin:
+               return true;
+       default:
+               break;
+       }
+
+       return false;
+}
+
 void
 ir2_nir_compile(struct ir2_context *ctx, bool binning)
 {
@@ -1085,17 +1108,7 @@ ir2_nir_compile(struct ir2_context *ctx, bool binning)
        OPT_V(ctx->nir, nir_lower_bool_to_float);
        OPT_V(ctx->nir, nir_lower_to_source_mods, nir_lower_all_source_mods);
 
-       /* TODO: static bitset ? */
-       BITSET_DECLARE(scalar_ops, nir_num_opcodes);
-       BITSET_ZERO(scalar_ops);
-       BITSET_SET(scalar_ops, nir_op_frsq);
-       BITSET_SET(scalar_ops, nir_op_frcp);
-       BITSET_SET(scalar_ops, nir_op_flog2);
-       BITSET_SET(scalar_ops, nir_op_fexp2);
-       BITSET_SET(scalar_ops, nir_op_fsqrt);
-       BITSET_SET(scalar_ops, nir_op_fcos);
-       BITSET_SET(scalar_ops, nir_op_fsin);
-       OPT_V(ctx->nir, nir_lower_alu_to_scalar, scalar_ops);
+       OPT_V(ctx->nir, nir_lower_alu_to_scalar, ir2_alu_to_scalar_filter_cb, NULL);
 
        OPT_V(ctx->nir, nir_lower_locals_to_regs);
 
index b9c4cbc4d5f0437d4bf638b3e1c0392188bb50a9..c0683b886008dbcd7a2492660cbd973143da7431 100644 (file)
@@ -110,7 +110,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
       progress = false;
 
       NIR_PASS_V(s, nir_lower_vars_to_ssa);
-      NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL);
+      NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
       NIR_PASS(progress, s, nir_lower_phis_to_scalar);
       NIR_PASS(progress, s, nir_copy_prop);
       NIR_PASS(progress, s, nir_opt_remove_phis);
@@ -145,24 +145,38 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
    nir_sweep(s);
 }
 
-void
-lima_program_optimize_fs_nir(struct nir_shader *s)
+static bool
+lima_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
 {
-   BITSET_DECLARE(alu_lower, nir_num_opcodes) = {0};
-   bool progress;
-
-   BITSET_SET(alu_lower, nir_op_frcp);
-   BITSET_SET(alu_lower, nir_op_frsq);
-   BITSET_SET(alu_lower, nir_op_flog2);
-   BITSET_SET(alu_lower, nir_op_fexp2);
-   BITSET_SET(alu_lower, nir_op_fsqrt);
-   BITSET_SET(alu_lower, nir_op_fsin);
-   BITSET_SET(alu_lower, nir_op_fcos);
+   if (instr->type != nir_instr_type_alu)
+      return false;
+
+   nir_alu_instr *alu = nir_instr_as_alu(instr);
+   switch (alu->op) {
+   case nir_op_frcp:
+   case nir_op_frsq:
+   case nir_op_flog2:
+   case nir_op_fexp2:
+   case nir_op_fsqrt:
+   case nir_op_fsin:
+   case nir_op_fcos:
    /* nir vec4 fcsel assumes that each component of the condition will be
     * used to select the same component from the two options, but lima
     * can't implement that since we only have 1 component condition */
-   BITSET_SET(alu_lower, nir_op_fcsel);
-   BITSET_SET(alu_lower, nir_op_bcsel);
+   case nir_op_fcsel:
+   case nir_op_bcsel:
+      return true;
+   default:
+      break;
+   }
+
+   return false;
+}
+
+void
+lima_program_optimize_fs_nir(struct nir_shader *s)
+{
+   bool progress;
 
    NIR_PASS_V(s, nir_lower_fragcoord_wtrans);
    NIR_PASS_V(s, nir_lower_io, nir_var_all, type_size, 0);
@@ -178,7 +192,7 @@ lima_program_optimize_fs_nir(struct nir_shader *s)
       progress = false;
 
       NIR_PASS_V(s, nir_lower_vars_to_ssa);
-      NIR_PASS(progress, s, nir_lower_alu_to_scalar, alu_lower);
+      NIR_PASS(progress, s, nir_lower_alu_to_scalar, lima_alu_to_scalar_filter_cb, NULL);
       NIR_PASS(progress, s, nir_lower_phis_to_scalar);
       NIR_PASS(progress, s, nir_copy_prop);
       NIR_PASS(progress, s, nir_opt_remove_phis);
index 378638bf3a41beaf6a875d95532ccd9d02b70b3a..4e86ab8f8ccc56e4d1feaca3a9504e72842b383e 100644 (file)
@@ -3500,7 +3500,7 @@ Converter::run()
    NIR_PASS_V(nir, nir_lower_regs_to_ssa);
    NIR_PASS_V(nir, nir_lower_load_const_to_scalar);
    NIR_PASS_V(nir, nir_lower_vars_to_ssa);
-   NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL);
+   NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
    NIR_PASS_V(nir, nir_lower_phis_to_scalar);
 
    do {
index fdd139141e207613827082271e737f417f865340..4970b01fd733caeeff1a2cb1b5fc078d533ed238 100644 (file)
@@ -817,7 +817,7 @@ si_nir_opts(struct nir_shader *nir)
                NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
                NIR_PASS(progress, nir, nir_opt_dead_write_vars);
 
-               NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL);
+               NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
                NIR_PASS_V(nir, nir_lower_phis_to_scalar);
 
                /* (Constant) copy propagation is needed for txf with offsets. */
index 1d55b87e1efb821b27b1e2fa6397ccd45c11fa69..e5f7aa31b0ebbffe1c6500ded626f4bceaf151a1 100644 (file)
@@ -1530,7 +1530,7 @@ vc4_optimize_nir(struct nir_shader *s)
                 progress = false;
 
                 NIR_PASS_V(s, nir_lower_vars_to_ssa);
-                NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL);
+                NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
                 NIR_PASS(progress, s, nir_lower_phis_to_scalar);
                 NIR_PASS(progress, s, nir_copy_prop);
                 NIR_PASS(progress, s, nir_opt_remove_phis);
index bd5c50165503b04d9d9844ed50e2196aaffb31cc..c710fe46e5de5ec18ce75b2827d0143b5a40d6d8 100644 (file)
@@ -518,7 +518,7 @@ brw_nir_optimize(nir_shader *nir, const struct brw_compiler *compiler,
       OPT(nir_opt_combine_stores, nir_var_all);
 
       if (is_scalar) {
-         OPT(nir_lower_alu_to_scalar, NULL);
+         OPT(nir_lower_alu_to_scalar, NULL, NULL);
       }
 
       OPT(nir_copy_prop);
@@ -654,7 +654,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
    const bool is_scalar = compiler->scalar_stage[nir->info.stage];
 
    if (is_scalar) {
-      OPT(nir_lower_alu_to_scalar, NULL);
+      OPT(nir_lower_alu_to_scalar, NULL, NULL);
    }
 
    if (nir->info.stage == MESA_SHADER_GEOMETRY)
@@ -871,7 +871,7 @@ brw_postprocess_nir(nir_shader *nir, const struct brw_compiler *compiler,
    OPT(brw_nir_lower_conversions);
 
    if (is_scalar)
-      OPT(nir_lower_alu_to_scalar, NULL);
+      OPT(nir_lower_alu_to_scalar, NULL, NULL);
    OPT(nir_lower_to_source_mods, nir_lower_all_source_mods);
    OPT(nir_copy_prop);
    OPT(nir_opt_dce);
index 4585455c85671c8594b22ec040b8a92a3fbaf3cc..b9ce22c4a6e0570713ad38e881aa7cca606e6679 100644 (file)
@@ -247,7 +247,7 @@ st_nir_opts(nir_shader *nir, bool scalar)
       NIR_PASS(progress, nir, nir_opt_dead_write_vars);
 
       if (scalar) {
-         NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL);
+         NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
          NIR_PASS_V(nir, nir_lower_phis_to_scalar);
       }
 
@@ -363,7 +363,7 @@ st_glsl_to_nir(struct st_context *st, struct gl_program *prog,
    NIR_PASS_V(nir, nir_lower_var_copies);
 
    if (is_scalar) {
-     NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL);
+     NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
    }
 
    /* before buffers and vars_to_ssa */
index 5e34b95d30836dd5ceb9f0915cbae1a8b7686ab4..2af36ee86681573b2365532efccc79cc6d20481d 100644 (file)
@@ -57,7 +57,7 @@ optimize_nir(nir_shader *nir)
                 NIR_PASS(progress, nir, nir_opt_constant_folding);
 
                 NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
-                NIR_PASS(progress, nir, nir_lower_alu_to_scalar, NULL);
+                NIR_PASS(progress, nir, nir_lower_alu_to_scalar, NULL, NULL);
                 NIR_PASS(progress, nir, nir_opt_if, true);
 
         } while (progress);
index 6dec34af0e5c544a02a486cea9e69f726351228c..0c495d25d54c9340985d865bc596265f3bbb701a 100644 (file)
@@ -59,7 +59,7 @@ compile_shader(char **argv)
                 NIR_PASS_V(nir[i], nir_split_var_copies);
                 NIR_PASS_V(nir[i], nir_lower_var_copies);
 
-                NIR_PASS_V(nir[i], nir_lower_alu_to_scalar, NULL);
+                NIR_PASS_V(nir[i], nir_lower_alu_to_scalar, NULL, NULL);
 
                 /* before buffers and vars_to_ssa */
                 NIR_PASS_V(nir[i], gl_nir_lower_bindless_images);