nir: don't try to scalarize unpack_double_2x32
[mesa.git] / src / compiler / nir / nir_validate.c
index d1a90485e7ecfda1c03162cbf9edd72d96d27b1f..9f18d1c33e4a3cab0854c7b7f53a7f6fba358d73 100644 (file)
@@ -179,9 +179,12 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
    nir_alu_src *src = &instr->src[index];
 
    unsigned num_components;
-   if (src->src.is_ssa)
+   unsigned src_bit_size;
+   if (src->src.is_ssa) {
+      src_bit_size = src->src.ssa->bit_size;
       num_components = src->src.ssa->num_components;
-   else {
+   } else {
+      src_bit_size = src->src.reg.reg->bit_size;
       if (src->src.reg.reg->is_packed)
          num_components = 4; /* can't check anything */
       else
@@ -194,6 +197,24 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
          assert(src->swizzle[i] < num_components);
    }
 
+   nir_alu_type src_type = nir_op_infos[instr->op].input_types[index];
+
+   /* 8-bit float isn't a thing */
+   if (nir_alu_type_get_base_type(src_type) == nir_type_float)
+      assert(src_bit_size == 16 || src_bit_size == 32 || src_bit_size == 64);
+
+   if (nir_alu_type_get_type_size(src_type)) {
+      /* This source has an explicit bit size */
+      assert(nir_alu_type_get_type_size(src_type) == src_bit_size);
+   } else {
+      if (!nir_alu_type_get_type_size(nir_op_infos[instr->op].output_type)) {
+         unsigned dest_bit_size =
+            instr->dest.dest.is_ssa ? instr->dest.dest.ssa.bit_size
+                                    : instr->dest.dest.reg.reg->bit_size;
+         assert(dest_bit_size == src_bit_size);
+      }
+   }
+
    validate_src(&src->src, state);
 }
 
@@ -263,8 +284,10 @@ validate_dest(nir_dest *dest, validate_state *state)
 }
 
 static void
-validate_alu_dest(nir_alu_dest *dest, validate_state *state)
+validate_alu_dest(nir_alu_instr *instr, validate_state *state)
 {
+   nir_alu_dest *dest = &instr->dest;
+
    unsigned dest_size =
       dest->dest.is_ssa ? dest->dest.ssa.num_components
                         : dest->dest.reg.reg->num_components;
@@ -282,6 +305,17 @@ validate_alu_dest(nir_alu_dest *dest, validate_state *state)
    assert(nir_op_infos[alu->op].output_type == nir_type_float ||
           !dest->saturate);
 
+   unsigned bit_size = dest->dest.is_ssa ? dest->dest.ssa.bit_size
+                                         : dest->dest.reg.reg->bit_size;
+   nir_alu_type type = nir_op_infos[instr->op].output_type;
+
+   /* 8-bit float isn't a thing */
+   if (nir_alu_type_get_base_type(type) == nir_type_float)
+      assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
+
+   assert(nir_alu_type_get_type_size(type) == 0 ||
+          nir_alu_type_get_type_size(type) == bit_size);
+
    validate_dest(&dest->dest, state);
 }
 
@@ -294,7 +328,7 @@ validate_alu_instr(nir_alu_instr *instr, validate_state *state)
       validate_alu_src(instr, i, state);
    }
 
-   validate_alu_dest(&instr->dest, state);
+   validate_alu_dest(instr, state);
 }
 
 static void
@@ -457,10 +491,12 @@ validate_tex_instr(nir_tex_instr *instr, validate_state *state)
 static void
 validate_call_instr(nir_call_instr *instr, validate_state *state)
 {
-   if (instr->return_deref == NULL)
+   if (instr->return_deref == NULL) {
       assert(glsl_type_is_void(instr->callee->return_type));
-   else
+   } else {
       assert(instr->return_deref->deref.type == instr->callee->return_type);
+      validate_deref_var(instr, instr->return_deref, state);
+   }
 
    assert(instr->num_params == instr->callee->num_params);
 
@@ -468,8 +504,6 @@ validate_call_instr(nir_call_instr *instr, validate_state *state)
       assert(instr->callee->params[i].type == instr->params[i]->deref.type);
       validate_deref_var(instr, instr->params[i], state);
    }
-
-   validate_deref_var(instr, instr->return_deref, state);
 }
 
 static void
@@ -867,7 +901,7 @@ postvalidate_reg_decl(nir_register *reg, validate_state *state)
 static void
 validate_var_decl(nir_variable *var, bool is_global, validate_state *state)
 {
-   assert(is_global != (var->data.mode == nir_var_local));
+   assert(is_global == nir_variable_is_global(var));
 
    /*
     * TODO validate some things ir_validate.cpp does (requires more GLSL type
@@ -936,13 +970,21 @@ validate_function_impl(nir_function_impl *impl, validate_state *state)
    assert(impl->cf_node.parent == NULL);
 
    assert(impl->num_params == impl->function->num_params);
-   for (unsigned i = 0; i < impl->num_params; i++)
+   for (unsigned i = 0; i < impl->num_params; i++) {
       assert(impl->params[i]->type == impl->function->params[i].type);
+      assert(impl->params[i]->data.mode == nir_var_param);
+      assert(impl->params[i]->data.location == i);
+      validate_var_decl(impl->params[i], false, state);
+   }
 
-   if (glsl_type_is_void(impl->function->return_type))
+   if (glsl_type_is_void(impl->function->return_type)) {
       assert(impl->return_var == NULL);
-   else
+   } else {
       assert(impl->return_var->type == impl->function->return_type);
+      assert(impl->return_var->data.mode == nir_var_param);
+      assert(impl->return_var->data.location == -1);
+      validate_var_decl(impl->return_var, false, state);
+   }
 
    assert(exec_list_is_empty(&impl->end_block->instr_list));
    assert(impl->end_block->successors[0] == NULL);
@@ -1039,6 +1081,11 @@ nir_validate_shader(nir_shader *shader)
      validate_var_decl(var, true, &state);
    }
 
+   exec_list_validate(&shader->shared);
+   nir_foreach_variable(var, &shader->shared) {
+      validate_var_decl(var, true, &state);
+   }
+
    exec_list_validate(&shader->globals);
    nir_foreach_variable(var, &shader->globals) {
      validate_var_decl(var, true, &state);