nir: allow specifying filter callback in lower_alu_to_scalar
[mesa.git] / src / gallium / drivers / lima / lima_program.c
index c22636fc50e665a39c1edabb55b3d8c074be0477..c0683b886008dbcd7a2492660cbd973143da7431 100644 (file)
@@ -110,7 +110,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
       progress = false;
 
       NIR_PASS_V(s, nir_lower_vars_to_ssa);
-      NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL);
+      NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
       NIR_PASS(progress, s, nir_lower_phis_to_scalar);
       NIR_PASS(progress, s, nir_copy_prop);
       NIR_PASS(progress, s, nir_opt_remove_phis);
@@ -145,30 +145,54 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
    nir_sweep(s);
 }
 
+static bool
+lima_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
+{
+   if (instr->type != nir_instr_type_alu)
+      return false;
+
+   nir_alu_instr *alu = nir_instr_as_alu(instr);
+   switch (alu->op) {
+   case nir_op_frcp:
+   case nir_op_frsq:
+   case nir_op_flog2:
+   case nir_op_fexp2:
+   case nir_op_fsqrt:
+   case nir_op_fsin:
+   case nir_op_fcos:
+   /* nir vec4 fcsel assumes that each component of the condition will be
+    * used to select the same component from the two options, but lima
+    * can't implement that since we only have 1 component condition */
+   case nir_op_fcsel:
+   case nir_op_bcsel:
+      return true;
+   default:
+      break;
+   }
+
+   return false;
+}
+
 void
 lima_program_optimize_fs_nir(struct nir_shader *s)
 {
-   BITSET_DECLARE(alu_lower, nir_num_opcodes) = {0};
    bool progress;
 
-   BITSET_SET(alu_lower, nir_op_frcp);
-   BITSET_SET(alu_lower, nir_op_frsq);
-   BITSET_SET(alu_lower, nir_op_flog2);
-   BITSET_SET(alu_lower, nir_op_fexp2);
-   BITSET_SET(alu_lower, nir_op_fsqrt);
-   BITSET_SET(alu_lower, nir_op_fsin);
-   BITSET_SET(alu_lower, nir_op_fcos);
-
    NIR_PASS_V(s, nir_lower_fragcoord_wtrans);
    NIR_PASS_V(s, nir_lower_io, nir_var_all, type_size, 0);
    NIR_PASS_V(s, nir_lower_regs_to_ssa);
    NIR_PASS_V(s, nir_lower_tex, &tex_options);
 
+   do {
+      progress = false;
+      NIR_PASS(progress, s, nir_opt_vectorize);
+   } while (progress);
+
    do {
       progress = false;
 
       NIR_PASS_V(s, nir_lower_vars_to_ssa);
-      NIR_PASS(progress, s, nir_lower_alu_to_scalar, alu_lower);
+      NIR_PASS(progress, s, nir_lower_alu_to_scalar, lima_alu_to_scalar_filter_cb, NULL);
       NIR_PASS(progress, s, nir_lower_phis_to_scalar);
       NIR_PASS(progress, s, nir_copy_prop);
       NIR_PASS(progress, s, nir_opt_remove_phis);
@@ -195,6 +219,9 @@ lima_program_optimize_fs_nir(struct nir_shader *s)
       NIR_PASS(progress, s, nir_opt_algebraic);
    } while (progress);
 
+   /* Must be run after optimization loop */
+   NIR_PASS_V(s, lima_nir_scale_trig);
+
    /* Lower modifiers */
    NIR_PASS_V(s, nir_lower_to_source_mods, nir_lower_all_source_mods);
    NIR_PASS_V(s, nir_copy_prop);
@@ -214,6 +241,7 @@ static void *
 lima_create_fs_state(struct pipe_context *pctx,
                      const struct pipe_shader_state *cso)
 {
+   struct lima_context *ctx = lima_context(pctx);
    struct lima_screen *screen = lima_screen(pctx->screen);
    struct lima_fs_shader_state *so = rzalloc(NULL, struct lima_fs_shader_state);
 
@@ -234,7 +262,7 @@ lima_create_fs_state(struct pipe_context *pctx,
    if (lima_debug & LIMA_DEBUG_PP)
       nir_print_shader(nir, stdout);
 
-   if (!ppir_compile_nir(so, nir, screen->pp_ra)) {
+   if (!ppir_compile_nir(so, nir, screen->pp_ra, &ctx->debug)) {
       ralloc_free(so);
       return NULL;
    }
@@ -299,6 +327,8 @@ lima_update_fs_state(struct lima_context *ctx)
       fs->shader = NULL;
    }
 
+   ctx->pp_max_stack_size = MAX2(ctx->pp_max_stack_size, ctx->fs->stack_size);
+
    return true;
 }
 
@@ -306,6 +336,7 @@ static void *
 lima_create_vs_state(struct pipe_context *pctx,
                      const struct pipe_shader_state *cso)
 {
+   struct lima_context *ctx = lima_context(pctx);
    struct lima_vs_shader_state *so = rzalloc(NULL, struct lima_vs_shader_state);
 
    if (!so)
@@ -325,7 +356,7 @@ lima_create_vs_state(struct pipe_context *pctx,
    if (lima_debug & LIMA_DEBUG_GP)
       nir_print_shader(nir, stdout);
 
-   if (!gpir_compile_nir(so, nir)) {
+   if (!gpir_compile_nir(so, nir, &ctx->debug)) {
       ralloc_free(so);
       return NULL;
    }