nir: Add a bit_size to nir_register and nir_ssa_def
[mesa.git] / src / compiler / nir / nir_validate.c
index 0c9d816a384d02b29ed32181bfd24e464f0f40e5..9f18d1c33e4a3cab0854c7b7f53a7f6fba358d73 100644 (file)
@@ -179,9 +179,12 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
    nir_alu_src *src = &instr->src[index];
 
    unsigned num_components;
-   if (src->src.is_ssa)
+   unsigned src_bit_size;
+   if (src->src.is_ssa) {
+      src_bit_size = src->src.ssa->bit_size;
       num_components = src->src.ssa->num_components;
-   else {
+   } else {
+      src_bit_size = src->src.reg.reg->bit_size;
       if (src->src.reg.reg->is_packed)
          num_components = 4; /* can't check anything */
       else
@@ -194,6 +197,24 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
          assert(src->swizzle[i] < num_components);
    }
 
+   nir_alu_type src_type = nir_op_infos[instr->op].input_types[index];
+
+   /* 8-bit float isn't a thing */
+   if (nir_alu_type_get_base_type(src_type) == nir_type_float)
+      assert(src_bit_size == 16 || src_bit_size == 32 || src_bit_size == 64);
+
+   if (nir_alu_type_get_type_size(src_type)) {
+      /* This source has an explicit bit size */
+      assert(nir_alu_type_get_type_size(src_type) == src_bit_size);
+   } else {
+      if (!nir_alu_type_get_type_size(nir_op_infos[instr->op].output_type)) {
+         unsigned dest_bit_size =
+            instr->dest.dest.is_ssa ? instr->dest.dest.ssa.bit_size
+                                    : instr->dest.dest.reg.reg->bit_size;
+         assert(dest_bit_size == src_bit_size);
+      }
+   }
+
    validate_src(&src->src, state);
 }
 
@@ -263,8 +284,10 @@ validate_dest(nir_dest *dest, validate_state *state)
 }
 
 static void
-validate_alu_dest(nir_alu_dest *dest, validate_state *state)
+validate_alu_dest(nir_alu_instr *instr, validate_state *state)
 {
+   nir_alu_dest *dest = &instr->dest;
+
    unsigned dest_size =
       dest->dest.is_ssa ? dest->dest.ssa.num_components
                         : dest->dest.reg.reg->num_components;
@@ -282,6 +305,17 @@ validate_alu_dest(nir_alu_dest *dest, validate_state *state)
    assert(nir_op_infos[alu->op].output_type == nir_type_float ||
           !dest->saturate);
 
+   unsigned bit_size = dest->dest.is_ssa ? dest->dest.ssa.bit_size
+                                         : dest->dest.reg.reg->bit_size;
+   nir_alu_type type = nir_op_infos[instr->op].output_type;
+
+   /* 8-bit float isn't a thing */
+   if (nir_alu_type_get_base_type(type) == nir_type_float)
+      assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
+
+   assert(nir_alu_type_get_type_size(type) == 0 ||
+          nir_alu_type_get_type_size(type) == bit_size);
+
    validate_dest(&dest->dest, state);
 }
 
@@ -294,7 +328,7 @@ validate_alu_instr(nir_alu_instr *instr, validate_state *state)
       validate_alu_src(instr, i, state);
    }
 
-   validate_alu_dest(&instr->dest, state);
+   validate_alu_dest(instr, state);
 }
 
 static void
@@ -1047,6 +1081,11 @@ nir_validate_shader(nir_shader *shader)
      validate_var_decl(var, true, &state);
    }
 
+   exec_list_validate(&shader->shared);
+   nir_foreach_variable(var, &shader->shared) {
+      validate_var_decl(var, true, &state);
+   }
+
    exec_list_validate(&shader->globals);
    nir_foreach_variable(var, &shader->globals) {
      validate_var_decl(var, true, &state);