nir/algebraic: trivially enable existing 32-bit patterns for all bit sizes
[mesa.git] / src / compiler / nir / nir_lower_bool_to_float.c
index 52fc55c6b4c78472ed5a04d0145242f040027cdf..32f2ca056b2d5d1e60fb0e65258b48bb617ed574 100644 (file)
@@ -52,18 +52,30 @@ lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
    /* Replacement SSA value */
    nir_ssa_def *rep = NULL;
    switch (alu->op) {
-   case nir_op_b2f32: alu->op = nir_op_fmov; break;
-   case nir_op_b2i32: alu->op = nir_op_fmov; break;
+   case nir_op_mov:
+   case nir_op_vec2:
+   case nir_op_vec3:
+   case nir_op_vec4:
+   case nir_op_vec8:
+   case nir_op_vec16:
+      if (alu->dest.dest.ssa.bit_size != 1)
+         return false;
+      /* These we expect to have booleans but the opcode doesn't change */
+      break;
+
+   case nir_op_b2f32: alu->op = nir_op_mov; break;
+   case nir_op_b2i32: alu->op = nir_op_mov; break;
    case nir_op_f2b1:
    case nir_op_i2b1:
       rep = nir_sne(b, nir_ssa_for_alu_src(b, alu, 0),
                        nir_imm_float(b, 0));
       break;
+   case nir_op_b2b1: alu->op = nir_op_mov; break;
 
    case nir_op_flt: alu->op = nir_op_slt; break;
    case nir_op_fge: alu->op = nir_op_sge; break;
    case nir_op_feq: alu->op = nir_op_seq; break;
-   case nir_op_fne: alu->op = nir_op_sne; break;
+   case nir_op_fneu: alu->op = nir_op_sne; break;
    case nir_op_ilt: alu->op = nir_op_slt; break;
    case nir_op_ige: alu->op = nir_op_sge; break;
    case nir_op_ieq: alu->op = nir_op_seq; break;
@@ -86,7 +98,6 @@ lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
 
    case nir_op_bcsel: alu->op = nir_op_fcsel; break;
 
-   case nir_op_imov: alu->op = nir_op_fmov; break;
    case nir_op_iand: alu->op = nir_op_fmul; break;
    case nir_op_ixor: alu->op = nir_op_sne; break;
    case nir_op_ior: alu->op = nir_op_fmax; break;
@@ -133,9 +144,9 @@ nir_lower_bool_to_float_impl(nir_function_impl *impl)
          case nir_instr_type_load_const: {
             nir_load_const_instr *load = nir_instr_as_load_const(instr);
             if (load->def.bit_size == 1) {
-               nir_const_value value = load->value;
+               nir_const_value *value = load->value;
                for (unsigned i = 0; i < load->def.num_components; i++)
-                  load->value.f32[i] = value.b[i] ? 1.0 : 0.0;
+                  load->value[i].f32 = value[i].b ? 1.0 : 0.0;
                load->def.bit_size = 32;
                progress = true;
             }