nir: Assert that nir_lower_io is only called with allowed modes
[mesa.git] / src / compiler / nir / nir_lower_io.c
index 92d2a1f8ba0f89853947c918d7c496d0cc761249..101460afa9732e9e7dc86be7ee67d50a139327bd 100644 (file)
@@ -352,6 +352,13 @@ lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
       }
 
       return nir_vec(b, comp64, intrin->dest.ssa.num_components);
+   } else if (intrin->dest.ssa.bit_size == 1) {
+      /* Booleans are 32-bit */
+      assert(glsl_type_is_boolean(type));
+      return nir_b2b1(&state->builder,
+                      emit_load(state, vertex_index, var, offset, component,
+                                intrin->dest.ssa.num_components, 32,
+                                nir_type_bool32));
    } else {
       return emit_load(state, vertex_index, var, offset, component,
                        intrin->dest.ssa.num_components,
@@ -445,6 +452,14 @@ lower_store(nir_intrinsic_instr *intrin, struct lower_io_state *state,
          write_mask >>= num_comps;
          offset = nir_iadd_imm(b, offset, slot_size);
       }
+   } else if (intrin->dest.ssa.bit_size == 1) {
+      /* Booleans are 32-bit */
+      assert(glsl_type_is_boolean(type));
+      nir_ssa_def *b32_val = nir_b2b32(&state->builder, intrin->src[1].ssa);
+      emit_store(state, b32_val, vertex_index, var, offset,
+                 component, intrin->num_components,
+                 nir_intrinsic_write_mask(intrin),
+                 nir_type_bool32);
    } else {
       emit_store(state, intrin->src[1].ssa, vertex_index, var, offset,
                  component, intrin->num_components,
@@ -611,16 +626,10 @@ nir_lower_io_block(nir_block *block,
       nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
 
       nir_variable_mode mode = deref->mode;
-
+      assert(util_is_power_of_two_nonzero(mode));
       if ((state->modes & mode) == 0)
          continue;
 
-      if (mode != nir_var_shader_in &&
-          mode != nir_var_shader_out &&
-          mode != nir_var_mem_shared &&
-          mode != nir_var_uniform)
-         continue;
-
       nir_variable *var = nir_deref_instr_get_variable(deref);
 
       b->cursor = nir_before_instr(instr);
@@ -708,6 +717,11 @@ nir_lower_io_impl(nir_function_impl *impl,
    state.type_size = type_size;
    state.options = options;
 
+   ASSERTED nir_variable_mode supported_modes =
+      nir_var_shader_in | nir_var_shader_out |
+      nir_var_mem_shared | nir_var_uniform;
+   assert(!(modes & ~supported_modes));
+
    nir_foreach_block(block, impl) {
       progress |= nir_lower_io_block(block, &state);
    }
@@ -769,6 +783,10 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
       assert(addr->num_components == 2);
       return nir_vec2(b, nir_channel(b, addr, 0),
                          nir_iadd(b, nir_channel(b, addr, 1), offset));
+   case nir_address_format_vec2_index_32bit_offset:
+      assert(addr->num_components == 3);
+      return nir_vec3(b, nir_channel(b, addr, 0), nir_channel(b, addr, 1),
+                         nir_iadd(b, nir_channel(b, addr, 2), offset));
    case nir_address_format_logical:
       unreachable("Unsupported address format");
    }
@@ -787,18 +805,30 @@ static nir_ssa_def *
 addr_to_index(nir_builder *b, nir_ssa_def *addr,
               nir_address_format addr_format)
 {
-   assert(addr_format == nir_address_format_32bit_index_offset);
-   assert(addr->num_components == 2);
-   return nir_channel(b, addr, 0);
+   if (addr_format == nir_address_format_32bit_index_offset) {
+      assert(addr->num_components == 2);
+      return nir_channel(b, addr, 0);
+   } else if (addr_format == nir_address_format_vec2_index_32bit_offset) {
+      assert(addr->num_components == 3);
+      return nir_channels(b, addr, 0x3);
+   } else {
+      unreachable("bad address format for index");
+   }
 }
 
 static nir_ssa_def *
 addr_to_offset(nir_builder *b, nir_ssa_def *addr,
                nir_address_format addr_format)
 {
-   assert(addr_format == nir_address_format_32bit_index_offset);
-   assert(addr->num_components == 2);
-   return nir_channel(b, addr, 1);
+   if (addr_format == nir_address_format_32bit_index_offset) {
+      assert(addr->num_components == 2);
+      return nir_channel(b, addr, 1);
+   } else if (addr_format == nir_address_format_vec2_index_32bit_offset) {
+      assert(addr->num_components == 3);
+      return nir_channel(b, addr, 2);
+   } else {
+      unreachable("bad address format for offset");
+   }
 }
 
 /** Returns true if the given address format resolves to a global address */
@@ -826,6 +856,7 @@ addr_to_global(nir_builder *b, nir_ssa_def *addr,
                          nir_u2u64(b, nir_channel(b, addr, 3)));
 
    case nir_address_format_32bit_index_offset:
+   case nir_address_format_vec2_index_32bit_offset:
    case nir_address_format_32bit_offset:
    case nir_address_format_logical:
       unreachable("Cannot get a 64-bit address with this address format");
@@ -896,7 +927,7 @@ build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
       load->src[1] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
    }
 
-   if (mode != nir_var_mem_ubo && mode != nir_var_shader_in && mode != nir_var_mem_shared)
+   if (mode != nir_var_shader_in && mode != nir_var_mem_shared)
       nir_intrinsic_set_access(load, nir_intrinsic_access(intrin));
 
    unsigned bit_size = intrin->dest.ssa.bit_size;
@@ -1269,7 +1300,8 @@ lower_explicit_io_array_length(nir_builder *b, nir_intrinsic_instr *intrin,
    unsigned stride = glsl_get_explicit_stride(deref->type);
    assert(stride > 0);
 
-   assert(addr_format == nir_address_format_32bit_index_offset);
+   assert(addr_format == nir_address_format_32bit_index_offset ||
+          addr_format == nir_address_format_vec2_index_32bit_offset);
    nir_ssa_def *addr = &deref->dest.ssa;
    nir_ssa_def *index = addr_to_index(b, addr, addr_format);
    nir_ssa_def *offset = addr_to_offset(b, addr, addr_format);
@@ -1567,6 +1599,7 @@ nir_address_format_null_value(nir_address_format addr_format)
       [nir_address_format_64bit_global] = {{0}},
       [nir_address_format_64bit_bounded_global] = {{0}},
       [nir_address_format_32bit_index_offset] = {{.u32 = ~0}, {.u32 = ~0}},
+      [nir_address_format_vec2_index_32bit_offset] = {{.u32 = ~0}, {.u32 = ~0}, {.u32 = ~0}},
       [nir_address_format_32bit_offset] = {{.u32 = ~0}},
       [nir_address_format_logical] = {{.u32 = ~0}},
    };
@@ -1584,6 +1617,7 @@ nir_build_addr_ieq(nir_builder *b, nir_ssa_def *addr0, nir_ssa_def *addr1,
    case nir_address_format_64bit_global:
    case nir_address_format_64bit_bounded_global:
    case nir_address_format_32bit_index_offset:
+   case nir_address_format_vec2_index_32bit_offset:
    case nir_address_format_32bit_offset:
       return nir_ball_iequal(b, addr0, addr1);
 
@@ -1616,6 +1650,12 @@ nir_build_addr_isub(nir_builder *b, nir_ssa_def *addr0, nir_ssa_def *addr1,
       /* Assume the same buffer index. */
       return nir_isub(b, nir_channel(b, addr0, 1), nir_channel(b, addr1, 1));
 
+   case nir_address_format_vec2_index_32bit_offset:
+      assert(addr0->num_components == 3);
+      assert(addr1->num_components == 3);
+      /* Assume the same buffer index. */
+      return nir_isub(b, nir_channel(b, addr0, 2), nir_channel(b, addr1, 2));
+
    case nir_address_format_logical:
       unreachable("Unsupported address format");
    }