nir: copy intrinsic type when lowering load input/uniform and store output
[mesa.git] / src / compiler / nir / nir_lower_io.c
index 1a7e00bf96afa657ceb3d843480aa09ae5a4dfc2..b0197399bdf6f060f6751ca637b42ce0b5111244 100644 (file)
@@ -204,7 +204,7 @@ get_io_offset(nir_builder *b, nir_deref_instr *deref,
 static nir_intrinsic_instr *
 lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
            nir_ssa_def *vertex_index, nir_variable *var, nir_ssa_def *offset,
-           unsigned component)
+           unsigned component, const struct glsl_type *type)
 {
    const nir_shader *nir = state->builder.shader;
    nir_variable_mode mode = var->data.mode;
@@ -261,6 +261,10 @@ lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
       nir_intrinsic_set_range(load,
                               state->type_size(var->type, var->data.bindless));
 
+   if (load->intrinsic == nir_intrinsic_load_input ||
+       load->intrinsic == nir_intrinsic_load_uniform)
+      nir_intrinsic_set_type(load, nir_get_nir_type_for_glsl_type(type));
+
    if (vertex_index) {
       load->src[0] = nir_src_for_ssa(vertex_index);
       load->src[1] = nir_src_for_ssa(offset);
@@ -277,7 +281,7 @@ lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
 static nir_intrinsic_instr *
 lower_store(nir_intrinsic_instr *intrin, struct lower_io_state *state,
             nir_ssa_def *vertex_index, nir_variable *var, nir_ssa_def *offset,
-            unsigned component)
+            unsigned component, const struct glsl_type *type)
 {
    nir_variable_mode mode = var->data.mode;
 
@@ -301,6 +305,9 @@ lower_store(nir_intrinsic_instr *intrin, struct lower_io_state *state,
    if (mode == nir_var_shader_out)
       nir_intrinsic_set_component(store, component);
 
+   if (store->intrinsic == nir_intrinsic_store_output)
+      nir_intrinsic_set_type(store, nir_get_nir_type_for_glsl_type(type));
+
    nir_intrinsic_set_write_mask(store, nir_intrinsic_write_mask(intrin));
 
    if (vertex_index)
@@ -356,13 +363,14 @@ lower_atomic(nir_intrinsic_instr *intrin, struct lower_io_state *state,
 
 static nir_intrinsic_instr *
 lower_interpolate_at(nir_intrinsic_instr *intrin, struct lower_io_state *state,
-                     nir_variable *var, nir_ssa_def *offset, unsigned component)
+                     nir_variable *var, nir_ssa_def *offset, unsigned component,
+                     const struct glsl_type *type)
 {
    assert(var->data.mode == nir_var_shader_in);
 
    /* Ignore interpolateAt() for flat variables - flat is flat. */
    if (var->data.interpolation == INTERP_MODE_FLAT)
-      return lower_load(intrin, state, NULL, var, offset, component);
+      return lower_load(intrin, state, NULL, var, offset, component, type);
 
    nir_intrinsic_op bary_op;
    switch (intrin->intrinsic) {
@@ -485,12 +493,12 @@ nir_lower_io_block(nir_block *block,
       switch (intrin->intrinsic) {
       case nir_intrinsic_load_deref:
          replacement = lower_load(intrin, state, vertex_index, var, offset,
-                                  component_offset);
+                                  component_offset, deref->type);
          break;
 
       case nir_intrinsic_store_deref:
          replacement = lower_store(intrin, state, vertex_index, var, offset,
-                                   component_offset);
+                                   component_offset, deref->type);
          break;
 
       case nir_intrinsic_deref_atomic_add:
@@ -516,7 +524,7 @@ nir_lower_io_block(nir_block *block,
       case nir_intrinsic_interp_deref_at_offset:
          assert(vertex_index == NULL);
          replacement = lower_interpolate_at(intrin, state, var, offset,
-                                            component_offset);
+                                            component_offset, deref->type);
          break;
 
       default:
@@ -604,6 +612,7 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
    switch (addr_format) {
    case nir_address_format_32bit_global:
    case nir_address_format_64bit_global:
+   case nir_address_format_32bit_offset:
       assert(addr->num_components == 1);
       return nir_iadd(b, addr, offset);
 
@@ -675,6 +684,7 @@ addr_to_global(nir_builder *b, nir_ssa_def *addr,
                          nir_u2u64(b, nir_channel(b, addr, 3)));
 
    case nir_address_format_32bit_index_offset:
+   case nir_address_format_32bit_offset:
    case nir_address_format_logical:
       unreachable("Cannot get a 64-bit address with this address format");
    }
@@ -1228,3 +1238,23 @@ nir_get_io_vertex_index_src(nir_intrinsic_instr *instr)
       return NULL;
    }
 }
+
+/**
+ * Return the numeric constant that identify a NULL pointer for each address
+ * format.
+ */
+const nir_const_value *
+nir_address_format_null_value(nir_address_format addr_format)
+{
+   const static nir_const_value null_values[][NIR_MAX_VEC_COMPONENTS] = {
+      [nir_address_format_32bit_global] = {{0}},
+      [nir_address_format_64bit_global] = {{0}},
+      [nir_address_format_64bit_bounded_global] = {{0}},
+      [nir_address_format_32bit_index_offset] = {{.u32 = ~0}, {.u32 = ~0}},
+      [nir_address_format_32bit_offset] = {{.u32 = ~0}},
+      [nir_address_format_logical] = {{.u32 = ~0}},
+   };
+
+   assert(addr_format < ARRAY_SIZE(null_values));
+   return null_values[addr_format];
+}