nir: allow specifying filter callback in lower_alu_to_scalar
[mesa.git] / src / compiler / nir / nir_lower_alu_to_scalar.c
index b16624bd8aaedc5b3878b99b476b0430bb772371..bcd92908253df51f98bb496bfc6d6184f014c25e 100644 (file)
 #include "nir.h"
 #include "nir_builder.h"
 
+struct alu_to_scalar_data {
+   nir_instr_filter_cb cb;
+   const void *data;
+};
+
 /** @file nir_lower_alu_to_scalar.c
  *
  * Replaces nir_alu_instr operations with more than one channel used in the
@@ -89,9 +94,9 @@ lower_reduction(nir_alu_instr *alu, nir_op chan_op, nir_op merge_op,
 }
 
 static nir_ssa_def *
-lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
+lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
 {
-   BITSET_WORD *lower_set = _state;
+   struct alu_to_scalar_data *data = _data;
    nir_alu_instr *alu = nir_instr_as_alu(instr);
    unsigned num_src = nir_op_infos[alu->op].num_inputs;
    unsigned i, chan;
@@ -102,7 +107,7 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
    b->cursor = nir_before_instr(&alu->instr);
    b->exact = alu->exact;
 
-   if (lower_set && !BITSET_TEST(lower_set, alu->op))
+   if (data->cb && !data->cb(instr, data->data))
       return NULL;
 
 #define LOWER_REDUCTION(name, chan, merge) \
@@ -246,10 +251,15 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
 }
 
 bool
-nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set)
+nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *_data)
 {
+   struct alu_to_scalar_data data = {
+      .cb = cb,
+      .data = _data,
+   };
+
    return nir_shader_lower_instructions(shader,
                                         inst_is_vector_alu,
                                         lower_alu_instr_scalar,
-                                        lower_set);
+                                        &data);
 }