nir: Validate base types on array dereferences
[mesa.git] / src / compiler / nir / nir_validate.c
index 16efcb2356c0832829f6d0e8e5413f2f8dfd4ef2..9bf8c7029012ef26af146b56b1be41bc5fbc0805 100644 (file)
@@ -227,12 +227,9 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
    nir_alu_src *src = &instr->src[index];
 
    unsigned num_components;
-   unsigned src_bit_size;
    if (src->src.is_ssa) {
-      src_bit_size = src->src.ssa->bit_size;
       num_components = src->src.ssa->num_components;
    } else {
-      src_bit_size = src->src.reg.reg->bit_size;
       if (src->src.reg.reg->is_packed)
          num_components = 4; /* can't check anything */
       else
@@ -245,24 +242,6 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
          validate_assert(state, src->swizzle[i] < num_components);
    }
 
-   nir_alu_type src_type = nir_op_infos[instr->op].input_types[index];
-
-   /* 8-bit float isn't a thing */
-   if (nir_alu_type_get_base_type(src_type) == nir_type_float)
-      validate_assert(state, src_bit_size == 16 || src_bit_size == 32 || src_bit_size == 64);
-
-   if (nir_alu_type_get_type_size(src_type)) {
-      /* This source has an explicit bit size */
-      validate_assert(state, nir_alu_type_get_type_size(src_type) == src_bit_size);
-   } else {
-      if (!nir_alu_type_get_type_size(nir_op_infos[instr->op].output_type)) {
-         unsigned dest_bit_size =
-            instr->dest.dest.is_ssa ? instr->dest.dest.ssa.bit_size
-                                    : instr->dest.dest.reg.reg->bit_size;
-         validate_assert(state, dest_bit_size == src_bit_size);
-      }
-   }
-
    validate_src(&src->src, state, 0, 0);
 }
 
@@ -369,17 +348,6 @@ validate_alu_dest(nir_alu_instr *instr, validate_state *state)
            nir_type_float) ||
           !dest->saturate);
 
-   unsigned bit_size = dest->dest.is_ssa ? dest->dest.ssa.bit_size
-                                         : dest->dest.reg.reg->bit_size;
-   nir_alu_type type = nir_op_infos[instr->op].output_type;
-
-   /* 8-bit float isn't a thing */
-   if (nir_alu_type_get_base_type(type) == nir_type_float)
-      validate_assert(state, bit_size == 16 || bit_size == 32 || bit_size == 64);
-
-   validate_assert(state, nir_alu_type_get_type_size(type) == 0 ||
-          nir_alu_type_get_type_size(type) == bit_size);
-
    validate_dest(&dest->dest, state, 0, 0);
 }
 
@@ -388,15 +356,49 @@ validate_alu_instr(nir_alu_instr *instr, validate_state *state)
 {
    validate_assert(state, instr->op < nir_num_opcodes);
 
+   unsigned instr_bit_size = 0;
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
+      nir_alu_type src_type = nir_op_infos[instr->op].input_types[i];
+      unsigned src_bit_size = nir_src_bit_size(instr->src[i].src);
+      if (nir_alu_type_get_type_size(src_type)) {
+         validate_assert(state, src_bit_size == nir_alu_type_get_type_size(src_type));
+      } else if (instr_bit_size) {
+         validate_assert(state, src_bit_size == instr_bit_size);
+      } else {
+         instr_bit_size = src_bit_size;
+      }
+
+      if (nir_alu_type_get_base_type(src_type) == nir_type_float) {
+         /* 8-bit float isn't a thing */
+         validate_assert(state, src_bit_size == 16 || src_bit_size == 32 ||
+                                src_bit_size == 64);
+      }
+
       validate_alu_src(instr, i, state);
    }
 
+   nir_alu_type dest_type = nir_op_infos[instr->op].output_type;
+   unsigned dest_bit_size = nir_dest_bit_size(instr->dest.dest);
+   if (nir_alu_type_get_type_size(dest_type)) {
+      validate_assert(state, dest_bit_size == nir_alu_type_get_type_size(dest_type));
+   } else if (instr_bit_size) {
+      validate_assert(state, dest_bit_size == instr_bit_size);
+   } else {
+      /* The only unsized thing is the destination so it's vacuously valid */
+   }
+
+   if (nir_alu_type_get_base_type(dest_type) == nir_type_float) {
+      /* 8-bit float isn't a thing */
+      validate_assert(state, dest_bit_size == 16 || dest_bit_size == 32 ||
+                             dest_bit_size == 64);
+   }
+
    validate_alu_dest(instr, state);
 }
 
 static void
-validate_deref_chain(nir_deref *deref, validate_state *state)
+validate_deref_chain(nir_deref *deref, nir_variable_mode mode,
+                     validate_state *state)
 {
    validate_assert(state, deref->child == NULL || ralloc_parent(deref->child) == deref);
 
@@ -404,6 +406,19 @@ validate_deref_chain(nir_deref *deref, validate_state *state)
    while (deref != NULL) {
       switch (deref->deref_type) {
       case nir_deref_type_array:
+         if (mode == nir_var_shared) {
+            /* Shared variables have a bit more relaxed rules because we need
+             * to be able to handle array derefs on vectors.  Fortunately,
+             * nir_lower_io handles these just fine.
+             */
+            validate_assert(state, glsl_type_is_array(parent->type) ||
+                                   glsl_type_is_matrix(parent->type) ||
+                                   glsl_type_is_vector(parent->type));
+         } else {
+            /* Most of NIR cannot handle array derefs on vectors */
+            validate_assert(state, glsl_type_is_array(parent->type) ||
+                                   glsl_type_is_matrix(parent->type));
+         }
          validate_assert(state, deref->type == glsl_get_array_element(parent->type));
          if (nir_deref_as_array(deref)->deref_array_type ==
              nir_deref_array_type_indirect)
@@ -450,7 +465,7 @@ validate_deref_var(void *parent_mem_ctx, nir_deref_var *deref, validate_state *s
 
    validate_var_use(deref->var, state);
 
-   validate_deref_chain(&deref->deref, state);
+   validate_deref_chain(&deref->deref, deref->var->data.mode, state);
 }
 
 static void
@@ -972,7 +987,7 @@ validate_var_decl(nir_variable *var, bool is_global, validate_state *state)
       assert(glsl_type_is_array(var->type));
 
       const struct glsl_type *type = glsl_get_array_element(var->type);
-      if (nir_is_per_vertex_io(var, state->shader->stage)) {
+      if (nir_is_per_vertex_io(var, state->shader->info.stage)) {
          assert(glsl_type_is_array(type));
          assert(glsl_type_is_scalar(glsl_get_array_element(type)));
       } else {