nir: copy intrinsic type when lowering load input/uniform and store output
[mesa.git] / src / compiler / nir / nir_lower_io.c
index 331ecc08324ed88e52d5fdc3f6cecbd2c12e5ede..b0197399bdf6f060f6751ca637b42ce0b5111244 100644 (file)
@@ -204,7 +204,7 @@ get_io_offset(nir_builder *b, nir_deref_instr *deref,
 static nir_intrinsic_instr *
 lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
            nir_ssa_def *vertex_index, nir_variable *var, nir_ssa_def *offset,
-           unsigned component)
+           unsigned component, const struct glsl_type *type)
 {
    const nir_shader *nir = state->builder.shader;
    nir_variable_mode mode = var->data.mode;
@@ -261,6 +261,10 @@ lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
       nir_intrinsic_set_range(load,
                               state->type_size(var->type, var->data.bindless));
 
+   if (load->intrinsic == nir_intrinsic_load_input ||
+       load->intrinsic == nir_intrinsic_load_uniform)
+      nir_intrinsic_set_type(load, nir_get_nir_type_for_glsl_type(type));
+
    if (vertex_index) {
       load->src[0] = nir_src_for_ssa(vertex_index);
       load->src[1] = nir_src_for_ssa(offset);
@@ -277,7 +281,7 @@ lower_load(nir_intrinsic_instr *intrin, struct lower_io_state *state,
 static nir_intrinsic_instr *
 lower_store(nir_intrinsic_instr *intrin, struct lower_io_state *state,
             nir_ssa_def *vertex_index, nir_variable *var, nir_ssa_def *offset,
-            unsigned component)
+            unsigned component, const struct glsl_type *type)
 {
    nir_variable_mode mode = var->data.mode;
 
@@ -301,6 +305,9 @@ lower_store(nir_intrinsic_instr *intrin, struct lower_io_state *state,
    if (mode == nir_var_shader_out)
       nir_intrinsic_set_component(store, component);
 
+   if (store->intrinsic == nir_intrinsic_store_output)
+      nir_intrinsic_set_type(store, nir_get_nir_type_for_glsl_type(type));
+
    nir_intrinsic_set_write_mask(store, nir_intrinsic_write_mask(intrin));
 
    if (vertex_index)
@@ -356,13 +363,14 @@ lower_atomic(nir_intrinsic_instr *intrin, struct lower_io_state *state,
 
 static nir_intrinsic_instr *
 lower_interpolate_at(nir_intrinsic_instr *intrin, struct lower_io_state *state,
-                     nir_variable *var, nir_ssa_def *offset, unsigned component)
+                     nir_variable *var, nir_ssa_def *offset, unsigned component,
+                     const struct glsl_type *type)
 {
    assert(var->data.mode == nir_var_shader_in);
 
    /* Ignore interpolateAt() for flat variables - flat is flat. */
    if (var->data.interpolation == INTERP_MODE_FLAT)
-      return lower_load(intrin, state, NULL, var, offset, component);
+      return lower_load(intrin, state, NULL, var, offset, component, type);
 
    nir_intrinsic_op bary_op;
    switch (intrin->intrinsic) {
@@ -485,12 +493,12 @@ nir_lower_io_block(nir_block *block,
       switch (intrin->intrinsic) {
       case nir_intrinsic_load_deref:
          replacement = lower_load(intrin, state, vertex_index, var, offset,
-                                  component_offset);
+                                  component_offset, deref->type);
          break;
 
       case nir_intrinsic_store_deref:
          replacement = lower_store(intrin, state, vertex_index, var, offset,
-                                   component_offset);
+                                   component_offset, deref->type);
          break;
 
       case nir_intrinsic_deref_atomic_add:
@@ -516,7 +524,7 @@ nir_lower_io_block(nir_block *block,
       case nir_intrinsic_interp_deref_at_offset:
          assert(vertex_index == NULL);
          replacement = lower_interpolate_at(intrin, state, var, offset,
-                                            component_offset);
+                                            component_offset, deref->type);
          break;
 
       default:
@@ -604,6 +612,7 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
    switch (addr_format) {
    case nir_address_format_32bit_global:
    case nir_address_format_64bit_global:
+   case nir_address_format_32bit_offset:
       assert(addr->num_components == 1);
       return nir_iadd(b, addr, offset);
 
@@ -618,6 +627,8 @@ build_addr_iadd(nir_builder *b, nir_ssa_def *addr,
       assert(addr->num_components == 2);
       return nir_vec2(b, nir_channel(b, addr, 0),
                          nir_iadd(b, nir_channel(b, addr, 1), offset));
+   case nir_address_format_logical:
+      unreachable("Unsupported address format");
    }
    unreachable("Invalid address format");
 }
@@ -673,6 +684,8 @@ addr_to_global(nir_builder *b, nir_ssa_def *addr,
                          nir_u2u64(b, nir_channel(b, addr, 3)));
 
    case nir_address_format_32bit_index_offset:
+   case nir_address_format_32bit_offset:
+   case nir_address_format_logical:
       unreachable("Cannot get a 64-bit address with this address format");
    }
 
@@ -754,10 +767,8 @@ build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
        * as to what we can do with an OOB read.  Unfortunately, returning
        * undefined values isn't one of them so we return an actual zero.
        */
-      nir_const_value zero_val;
-      memset(&zero_val, 0, sizeof(zero_val));
-      nir_ssa_def *zero = nir_build_imm(b, load->num_components,
-                                        load->dest.ssa.bit_size, zero_val);
+      nir_ssa_def *zero = nir_imm_zero(b, load->num_components,
+                                          load->dest.ssa.bit_size);
 
       const unsigned load_size =
          (load->dest.ssa.bit_size / 8) * load->num_components;
@@ -898,38 +909,17 @@ build_explicit_io_atomic(nir_builder *b, nir_intrinsic_instr *intrin,
    }
 }
 
-static void
-lower_explicit_io_deref(nir_builder *b, nir_deref_instr *deref,
-                        nir_address_format addr_format)
+nir_ssa_def *
+nir_explicit_io_address_from_deref(nir_builder *b, nir_deref_instr *deref,
+                                   nir_ssa_def *base_addr,
+                                   nir_address_format addr_format)
 {
-   /* Just delete the deref if it's not used.  We can't use
-    * nir_deref_instr_remove_if_unused here because it may remove more than
-    * one deref which could break our list walking since we walk the list
-    * backwards.
-    */
-   assert(list_empty(&deref->dest.ssa.if_uses));
-   if (list_empty(&deref->dest.ssa.uses)) {
-      nir_instr_remove(&deref->instr);
-      return;
-   }
-
-   b->cursor = nir_after_instr(&deref->instr);
-
-   nir_ssa_def *parent_addr = NULL;
-   if (deref->deref_type != nir_deref_type_var) {
-      assert(deref->parent.is_ssa);
-      parent_addr = deref->parent.ssa;
-   }
-
-
-   nir_ssa_def *addr = NULL;
    assert(deref->dest.is_ssa);
    switch (deref->deref_type) {
    case nir_deref_type_var:
       assert(deref->mode == nir_var_shader_in);
-      addr = nir_imm_intN_t(b, deref->var->data.driver_location,
+      return nir_imm_intN_t(b, deref->var->data.driver_location,
                             deref->dest.ssa.bit_size);
-      break;
 
    case nir_deref_type_array: {
       nir_deref_instr *parent = nir_deref_instr_parent(deref);
@@ -943,19 +933,17 @@ lower_explicit_io_deref(nir_builder *b, nir_deref_instr *deref,
       assert(stride > 0);
 
       nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1);
-      index = nir_i2i(b, index, parent_addr->bit_size);
-      addr = build_addr_iadd(b, parent_addr, addr_format,
+      index = nir_i2i(b, index, base_addr->bit_size);
+      return build_addr_iadd(b, base_addr, addr_format,
                                 nir_imul_imm(b, index, stride));
-      break;
    }
 
    case nir_deref_type_ptr_as_array: {
       nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1);
-      index = nir_i2i(b, index, parent_addr->bit_size);
+      index = nir_i2i(b, index, base_addr->bit_size);
       unsigned stride = nir_deref_instr_ptr_as_array_stride(deref);
-      addr = build_addr_iadd(b, parent_addr, addr_format,
+      return build_addr_iadd(b, base_addr, addr_format,
                                 nir_imul_imm(b, index, stride));
-      break;
    }
 
    case nir_deref_type_array_wildcard:
@@ -967,23 +955,22 @@ lower_explicit_io_deref(nir_builder *b, nir_deref_instr *deref,
       int offset = glsl_get_struct_field_offset(parent->type,
                                                 deref->strct.index);
       assert(offset >= 0);
-      addr = build_addr_iadd_imm(b, parent_addr, addr_format, offset);
-      break;
+      return build_addr_iadd_imm(b, base_addr, addr_format, offset);
    }
 
    case nir_deref_type_cast:
       /* Nothing to do here */
-      addr = parent_addr;
-      break;
+      return base_addr;
    }
 
-   nir_instr_remove(&deref->instr);
-   nir_ssa_def_rewrite_uses(&deref->dest.ssa, nir_src_for_ssa(addr));
+   unreachable("Invalid NIR deref type");
 }
 
-static void
-lower_explicit_io_access(nir_builder *b, nir_intrinsic_instr *intrin,
-                         nir_address_format addr_format)
+void
+nir_lower_explicit_io_instr(nir_builder *b,
+                            nir_intrinsic_instr *intrin,
+                            nir_ssa_def *addr,
+                            nir_address_format addr_format)
 {
    b->cursor = nir_after_instr(&intrin->instr);
 
@@ -993,7 +980,6 @@ lower_explicit_io_access(nir_builder *b, nir_intrinsic_instr *intrin,
    assert(vec_stride == 0 || glsl_type_is_vector(deref->type));
    assert(vec_stride == 0 || vec_stride >= scalar_size);
 
-   nir_ssa_def *addr = &deref->dest.ssa;
    if (intrin->intrinsic == nir_intrinsic_load_deref) {
       nir_ssa_def *value;
       if (vec_stride > scalar_size) {
@@ -1037,6 +1023,44 @@ lower_explicit_io_access(nir_builder *b, nir_intrinsic_instr *intrin,
    nir_instr_remove(&intrin->instr);
 }
 
+static void
+lower_explicit_io_deref(nir_builder *b, nir_deref_instr *deref,
+                        nir_address_format addr_format)
+{
+   /* Just delete the deref if it's not used.  We can't use
+    * nir_deref_instr_remove_if_unused here because it may remove more than
+    * one deref which could break our list walking since we walk the list
+    * backwards.
+    */
+   assert(list_empty(&deref->dest.ssa.if_uses));
+   if (list_empty(&deref->dest.ssa.uses)) {
+      nir_instr_remove(&deref->instr);
+      return;
+   }
+
+   b->cursor = nir_after_instr(&deref->instr);
+
+   nir_ssa_def *base_addr = NULL;
+   if (deref->deref_type != nir_deref_type_var) {
+      assert(deref->parent.is_ssa);
+      base_addr = deref->parent.ssa;
+   }
+
+   nir_ssa_def *addr = nir_explicit_io_address_from_deref(b, deref, base_addr,
+                                                          addr_format);
+
+   nir_instr_remove(&deref->instr);
+   nir_ssa_def_rewrite_uses(&deref->dest.ssa, nir_src_for_ssa(addr));
+}
+
+static void
+lower_explicit_io_access(nir_builder *b, nir_intrinsic_instr *intrin,
+                         nir_address_format addr_format)
+{
+   assert(intrin->src[0].is_ssa);
+   nir_lower_explicit_io_instr(b, intrin, intrin->src[0].ssa, addr_format);
+}
+
 static void
 lower_explicit_io_array_length(nir_builder *b, nir_intrinsic_instr *intrin,
                                nir_address_format addr_format)
@@ -1214,3 +1238,23 @@ nir_get_io_vertex_index_src(nir_intrinsic_instr *instr)
       return NULL;
    }
 }
+
+/**
+ * Return the numeric constant that identify a NULL pointer for each address
+ * format.
+ */
+const nir_const_value *
+nir_address_format_null_value(nir_address_format addr_format)
+{
+   const static nir_const_value null_values[][NIR_MAX_VEC_COMPONENTS] = {
+      [nir_address_format_32bit_global] = {{0}},
+      [nir_address_format_64bit_global] = {{0}},
+      [nir_address_format_64bit_bounded_global] = {{0}},
+      [nir_address_format_32bit_index_offset] = {{.u32 = ~0}, {.u32 = ~0}},
+      [nir_address_format_32bit_offset] = {{.u32 = ~0}},
+      [nir_address_format_logical] = {{.u32 = ~0}},
+   };
+
+   assert(addr_format < ARRAY_SIZE(null_values));
+   return null_values[addr_format];
+}