matrices matrices matrices
authorConnor Abbott <cwabbott0@gmail.com>
Fri, 19 Jun 2015 01:52:44 +0000 (18:52 -0700)
committerConnor Abbott <cwabbott0@gmail.com>
Fri, 19 Jun 2015 01:52:44 +0000 (18:52 -0700)
src/glsl/nir/spirv_to_nir.c
src/glsl/nir/spirv_to_nir_private.h

index e84d75643006780c965bf3973f95c1946008e1f9..a0a75040771476226b8f1e2ab2142ccced3158ea 100644 (file)
@@ -37,7 +37,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
    if (entry)
       return entry->data;
 
-   struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
+   struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
    val->type = type;
 
    switch (glsl_get_base_type(type)) {
@@ -63,7 +63,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
          val->elems = ralloc_array(b, struct vtn_ssa_value *, columns);
 
          for (unsigned i = 0; i < columns; i++) {
-            struct vtn_ssa_value *col_val = ralloc(b, struct vtn_ssa_value);
+            struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value);
             col_val->type = glsl_get_column_type(val->type);
             nir_load_const_instr *load =
                nir_load_const_instr_create(b->shader, rows);
@@ -516,6 +516,7 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val,
    case SpvDecorationFPFastMathMode:
    case SpvDecorationLinkageAttributes:
    case SpvDecorationSpecId:
+      break;
    default:
       unreachable("Unhandled variable decoration");
    }
@@ -525,7 +526,7 @@ static struct vtn_ssa_value *
 _vtn_variable_load(struct vtn_builder *b,
                    nir_deref_var *src_deref, nir_deref *src_deref_tail)
 {
-   struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
+   struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
    val->type = src_deref_tail->type;
 
    /* The deref tail may contain a deref to select a component of a vector (in
@@ -1010,11 +1011,264 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
    nir_builder_instr_insert(&b->nb, &instr->instr);
 }
 
+static struct vtn_ssa_value *
+vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
+{
+   struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
+   val->type = type;
+   
+   if (!glsl_type_is_vector_or_scalar(type)) {
+      unsigned elems = glsl_get_length(type);
+      val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
+      for (unsigned i = 0; i < elems; i++) {
+         const struct glsl_type *child_type;
+
+         switch (glsl_get_base_type(type)) {
+         case GLSL_TYPE_INT:
+         case GLSL_TYPE_UINT:
+         case GLSL_TYPE_BOOL:
+         case GLSL_TYPE_FLOAT:
+         case GLSL_TYPE_DOUBLE:
+            child_type = glsl_get_column_type(type);
+            break;
+         case GLSL_TYPE_ARRAY:
+            child_type = glsl_get_array_element(type);
+            break;
+         case GLSL_TYPE_STRUCT:
+            child_type = glsl_get_struct_field(type, i);
+            break;
+         default:
+            unreachable("unkown base type");
+         }
+
+         val->elems[i] = vtn_create_ssa_value(b, child_type);
+      }
+   }
+
+   return val;
+}
+
+static nir_alu_instr *
+create_vec(void *mem_ctx, unsigned num_components)
+{
+   nir_op op;
+   switch (num_components) {
+   case 1: op = nir_op_fmov; break;
+   case 2: op = nir_op_vec2; break;
+   case 3: op = nir_op_vec3; break;
+   case 4: op = nir_op_vec4; break;
+   default: unreachable("bad vector size");
+   }
+
+   nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
+   nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
+
+   return vec;
+}
+
+static struct vtn_ssa_value *
+vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
+{
+   if (src->transposed)
+      return src->transposed;
+
+   struct vtn_ssa_value *dest =
+      vtn_create_ssa_value(b, glsl_transposed_type(src->type));
+
+   for (unsigned i = 0; i < glsl_get_matrix_columns(dest->type); i++) {
+      nir_alu_instr *vec = create_vec(b, glsl_get_matrix_columns(src->type));
+      if (glsl_type_is_vector_or_scalar(src->type)) {
+          vec->src[0].src = nir_src_for_ssa(src->def);
+          vec->src[0].swizzle[0] = i;
+      } else {
+         for (unsigned j = 0; j < glsl_get_matrix_columns(src->type); j++) {
+            vec->src[j].src = nir_src_for_ssa(src->elems[j]->def);
+            vec->src[j].swizzle[0] = i;
+         }
+      }
+      nir_builder_instr_insert(&b->nb, &vec->instr);
+      dest->elems[i]->def = &vec->dest.dest.ssa;
+   }
+
+   dest->transposed = src;
+
+   return dest;
+}
+
+/*
+ * Normally, column vectors in SPIR-V correspond to a single NIR SSA
+ * definition. But for matrix multiplies, we want to do one routine for
+ * multiplying a matrix by a matrix and then pretend that vectors are matrices
+ * with one column. So we "wrap" these things, and unwrap the result before we
+ * send it off.
+ */
+
+static struct vtn_ssa_value *
+vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
+{
+   if (val == NULL)
+      return NULL;
+
+   if (glsl_type_is_matrix(val->type))
+      return val;
+
+   struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
+   dest->type = val->type;
+   dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
+   dest->elems[0] = val;
+
+   return dest;
+}
+
+static struct vtn_ssa_value *
+vtn_unwrap_matrix(struct vtn_ssa_value *val)
+{
+   if (glsl_type_is_matrix(val->type))
+         return val;
+
+   return val->elems[0];
+}
+
+static struct vtn_ssa_value *
+vtn_matrix_multiply(struct vtn_builder *b,
+                    struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
+{
+
+   struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0);
+   struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1);
+   struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed);
+   struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed);
+
+   unsigned src0_rows = glsl_get_vector_elements(src0->type);
+   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
+   unsigned src1_columns = glsl_get_matrix_columns(src1->type);
+
+   struct vtn_ssa_value *dest =
+      vtn_create_ssa_value(b, glsl_matrix_type(glsl_get_base_type(src0->type),
+                                               src0_rows, src1_columns));
+
+   dest = vtn_wrap_matrix(b, dest);
+
+   bool transpose_result = false;
+   if (src0_transpose && src1_transpose) {
+      /* transpose(A) * transpose(B) = transpose(B * A) */
+      src1 = src0_transpose;
+      src0 = src1_transpose;
+      src0_transpose = NULL;
+      src1_transpose = NULL;
+      transpose_result = true;
+   }
+
+   if (src0_transpose && !src1_transpose &&
+       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
+      /* We already have the rows of src0 and the columns of src1 available,
+       * so we can just take the dot product of each row with each column to
+       * get the result.
+       */
+
+      for (unsigned i = 0; i < src1_columns; i++) {
+         nir_alu_instr *vec = create_vec(b, src0_rows);
+         for (unsigned j = 0; j < src0_rows; j++) {
+            vec->src[j].src =
+               nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def,
+                                        src1->elems[i]->def));
+         }
+
+         nir_builder_instr_insert(&b->nb, &vec->instr);
+         dest->elems[i]->def = &vec->dest.dest.ssa;
+      }
+   } else {
+      /* We don't handle the case where src1 is transposed but not src0, since
+       * the general case only uses individual components of src1 so the
+       * optimizer should chew through the transpose we emitted for src1.
+       */
+
+      for (unsigned i = 0; i < src1_columns; i++) {
+         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
+         dest->elems[i]->def =
+            nir_fmul(&b->nb, src0->elems[0]->def,
+                     vtn_vector_extract(b, src1->elems[i]->def, 0));
+         for (unsigned j = 1; j < src0_columns; j++) {
+            dest->elems[i]->def =
+               nir_fadd(&b->nb, dest->elems[i]->def,
+                        nir_fmul(&b->nb, src0->elems[j]->def,
+                                 vtn_vector_extract(b,
+                                                    src1->elems[i]->def, j)));
+         }
+      }
+   }
+   
+   dest = vtn_unwrap_matrix(dest);
+
+   if (transpose_result)
+      dest = vtn_transpose(b, dest);
+
+   return dest;
+}
+
+static struct vtn_ssa_value *
+vtn_mat_times_scalar(struct vtn_builder *b,
+                     struct vtn_ssa_value *mat,
+                     nir_ssa_def *scalar)
+{
+   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
+   for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
+      if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
+         dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
+      else
+         dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
+   }
+
+   return dest;
+}
+
 static void
 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
                       const uint32_t *w, unsigned count)
 {
-   unreachable("Matrix math not handled");
+   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+   val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+
+   switch (opcode) {
+   case SpvOpTranspose: {
+      struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]);
+      val->ssa = vtn_transpose(b, src);
+      break;
+   }
+
+   case SpvOpOuterProduct: {
+      struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
+      struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
+
+      val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1));
+      break;
+   }
+
+   case SpvOpMatrixTimesScalar: {
+      struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]);
+      struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]);
+
+      if (mat->transposed) {
+         val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed,
+                                                          scalar->def));
+      } else {
+         val->ssa = vtn_mat_times_scalar(b, mat, scalar->def);
+      }
+      break;
+   }
+
+   case SpvOpVectorTimesMatrix:
+   case SpvOpMatrixTimesVector:
+   case SpvOpMatrixTimesMatrix: {
+      struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
+      struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
+
+      val->ssa = vtn_matrix_multiply(b, src0, src1);
+      break;
+   }
+
+   default: unreachable("unknown matrix opcode");
+   }
 }
 
 static void
@@ -1197,29 +1451,10 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
 static nir_ssa_def *
 vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
 {
-   nir_alu_src alu_src;
-   alu_src.src = nir_src_for_ssa(src);
-   alu_src.swizzle[0] = index;
-   return nir_fmov_alu(&b->nb, alu_src, 1);
+   unsigned swiz[4] = { index };
+   return nir_swizzle(&b->nb, src, swiz, 1, true);
 }
 
-static nir_alu_instr *
-create_vec(void *mem_ctx, unsigned num_components)
-{
-   nir_op op;
-   switch (num_components) {
-   case 1: op = nir_op_fmov; break;
-   case 2: op = nir_op_vec2; break;
-   case 3: op = nir_op_vec3; break;
-   case 4: op = nir_op_vec4; break;
-   default: unreachable("bad vector size");
-   }
-
-   nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
-   nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
-
-   return vec;
-}
 
 static nir_ssa_def *
 vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
@@ -1320,7 +1555,7 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
 static struct vtn_ssa_value *
 vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src)
 {
-   struct vtn_ssa_value *dest = ralloc(mem_ctx, struct vtn_ssa_value);
+   struct vtn_ssa_value *dest = rzalloc(mem_ctx, struct vtn_ssa_value);
    dest->type = src->type;
 
    if (glsl_type_is_vector_or_scalar(src->type)) {
@@ -1376,7 +1611,7 @@ vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
           * vector to extract.
           */
 
-         struct vtn_ssa_value *ret = ralloc(b, struct vtn_ssa_value);
+         struct vtn_ssa_value *ret = rzalloc(b, struct vtn_ssa_value);
          ret->type = glsl_scalar_type(glsl_get_base_type(cur->type));
          ret->def = vtn_vector_extract(b, cur->def, indices[i]);
          return ret;
@@ -1413,7 +1648,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
       break;
 
    case SpvOpCompositeConstruct: {
-      val->ssa = ralloc(b, struct vtn_ssa_value);
+      val->ssa = rzalloc(b, struct vtn_ssa_value);
       unsigned elems = count - 3;
       if (glsl_type_is_vector_or_scalar(val->type)) {
          nir_ssa_def *srcs[4];
index a4760620d468731177e70617b47092107d8fd332..937c45b08c27f3e699bcd8ce2923fa500a06944a 100644 (file)
@@ -71,6 +71,11 @@ struct vtn_ssa_value {
       struct vtn_ssa_value **elems;
    };
 
+   /* For matrices, a transposed version of the value, or NULL if it hasn't
+    * been computed
+    */
+   struct vtn_ssa_value *transposed;
+
    const struct glsl_type *type;
 };