nir/spirv: Add matrix determinants and inverses
authorJason Ekstrand <jason.ekstrand@intel.com>
Sat, 9 Jan 2016 00:02:06 +0000 (16:02 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Sat, 9 Jan 2016 00:02:30 +0000 (16:02 -0800)
src/glsl/nir/spirv/vtn_glsl450.c

index 739e43949541d95d69ad10c426bf70f19cee7307..82cfc8c91a9a9b163b0aebd5203c7a6813ef6b5c 100644 (file)
 #include "vtn_private.h"
 #include "GLSL.std.450.h"
 
+static nir_ssa_def *
+build_mat2_det(nir_builder *b, nir_ssa_def *col[2])
+{
+   unsigned swiz[4] = {1, 0, 0, 0};
+   nir_ssa_def *p = nir_fmul(b, col[0], nir_swizzle(b, col[1], swiz, 2, true));
+   return nir_fsub(b, nir_channel(b, p, 0), nir_channel(b, p, 1));
+}
+
+static nir_ssa_def *
+build_mat3_det(nir_builder *b, nir_ssa_def *col[3])
+{
+   unsigned yzx[4] = {1, 2, 0, 0};
+   unsigned zxy[4] = {2, 0, 1, 0};
+
+   nir_ssa_def *prod0 =
+      nir_fmul(b, col[0],
+               nir_fmul(b, nir_swizzle(b, col[1], yzx, 3, true),
+                           nir_swizzle(b, col[2], zxy, 3, true)));
+   nir_ssa_def *prod1 =
+      nir_fmul(b, col[0],
+               nir_fmul(b, nir_swizzle(b, col[1], zxy, 3, true),
+                           nir_swizzle(b, col[2], yzx, 3, true)));
+
+   nir_ssa_def *diff = nir_fsub(b, prod0, prod1);
+
+   return nir_fadd(b, nir_channel(b, diff, 0),
+                      nir_fadd(b, nir_channel(b, diff, 1),
+                                  nir_channel(b, diff, 2)));
+}
+
+static nir_ssa_def *
+build_mat4_det(nir_builder *b, nir_ssa_def **col)
+{
+   nir_ssa_def *subdet[4];
+   for (unsigned i = 0; i < 4; i++) {
+      unsigned swiz[3];
+      for (unsigned j = 0; j < 4; j++)
+         swiz[j - (j > i)] = j;
+
+      nir_ssa_def *subcol[3];
+      subcol[0] = nir_swizzle(b, col[1], swiz, 3, true);
+      subcol[1] = nir_swizzle(b, col[2], swiz, 3, true);
+      subcol[2] = nir_swizzle(b, col[3], swiz, 3, true);
+
+      subdet[i] = build_mat3_det(b, subcol);
+   }
+
+   nir_ssa_def *prod = nir_fmul(b, col[0], nir_vec(b, subdet, 4));
+
+   return nir_fadd(b, nir_fsub(b, nir_channel(b, prod, 0),
+                                  nir_channel(b, prod, 1)),
+                      nir_fsub(b, nir_channel(b, prod, 2),
+                                  nir_channel(b, prod, 3)));
+}
+
+static nir_ssa_def *
+build_mat_det(struct vtn_builder *b, struct vtn_ssa_value *src)
+{
+   unsigned size = glsl_get_vector_elements(src->type);
+
+   nir_ssa_def *cols[4];
+   for (unsigned i = 0; i < size; i++)
+      cols[i] = src->elems[i]->def;
+
+   switch(size) {
+   case 2: return build_mat2_det(&b->nb, cols);
+   case 3: return build_mat3_det(&b->nb, cols);
+   case 4: return build_mat4_det(&b->nb, cols);
+   default:
+      unreachable("Invalid matrix size");
+   }
+}
+
+/* Computes the determinate of the submatrix given by taking src and
+ * removing the specified row and column.
+ */
+static nir_ssa_def *
+build_mat_subdet(struct nir_builder *b, struct vtn_ssa_value *src,
+                 unsigned size, unsigned row, unsigned col)
+{
+   assert(row < size && col < size);
+   if (size == 2) {
+      return nir_channel(b, src->elems[1 - col]->def, 1 - row);
+   } else {
+      /* Swizzle to get all but the specified row */
+      unsigned swiz[3];
+      for (unsigned j = 0; j < 4; j++)
+         swiz[j - (j > row)] = j;
+
+      /* Grab all but the specified column */
+      nir_ssa_def *subcol[3];
+      for (unsigned j = 0; j < size; j++) {
+         if (j != col) {
+            subcol[j - (j > col)] = nir_swizzle(b, src->elems[j]->def,
+                                                swiz, size - 1, true);
+         }
+      }
+
+      if (size == 3) {
+         return build_mat2_det(b, subcol);
+      } else {
+         assert(size == 4);
+         return build_mat3_det(b, subcol);
+      }
+   }
+}
+
+static struct vtn_ssa_value *
+matrix_inverse(struct vtn_builder *b, struct vtn_ssa_value *src)
+{
+   nir_ssa_def *adj_col[4];
+   unsigned size = glsl_get_vector_elements(src->type);
+
+   /* Build up an adjugate matrix */
+   for (unsigned c = 0; c < size; c++) {
+      nir_ssa_def *elem[4];
+      for (unsigned r = 0; r < size; r++) {
+         elem[r] = build_mat_subdet(&b->nb, src, size, c, r);
+
+         if ((r + c) % 2)
+            elem[r] = nir_fneg(&b->nb, elem[r]);
+      }
+
+      adj_col[c] = nir_vec(&b->nb, elem, size);
+   }
+
+   nir_ssa_def *det_inv = nir_frcp(&b->nb, build_mat_det(b, src));
+
+   struct vtn_ssa_value *val = vtn_create_ssa_value(b, src->type);
+   for (unsigned i = 0; i < size; i++)
+      val->elems[i]->def = nir_fmul(&b->nb, adj_col[i], det_inv);
+
+   return val;
+}
+
 static nir_ssa_def*
 build_length(nir_builder *b, nir_ssa_def *vec)
 {
@@ -309,18 +444,30 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
 
 bool
 vtn_handle_glsl450_instruction(struct vtn_builder *b, uint32_t ext_opcode,
-                               const uint32_t *words, unsigned count)
+                               const uint32_t *w, unsigned count)
 {
    switch ((enum GLSLstd450)ext_opcode) {
-   case GLSLstd450Determinant:
-   case GLSLstd450MatrixInverse:
+   case GLSLstd450Determinant: {
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+      val->ssa = rzalloc(b, struct vtn_ssa_value);
+      val->ssa->type = vtn_value(b, w[1], vtn_value_type_type)->type->type;
+      val->ssa->def = build_mat_det(b, vtn_ssa_value(b, w[5]));
+      break;
+   }
+
+   case GLSLstd450MatrixInverse: {
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+      val->ssa = matrix_inverse(b, vtn_ssa_value(b, w[5]));
+      break;
+   }
+
    case GLSLstd450InterpolateAtCentroid:
    case GLSLstd450InterpolateAtSample:
    case GLSLstd450InterpolateAtOffset:
       unreachable("Unhandled opcode");
 
    default:
-      handle_glsl450_alu(b, (enum GLSLstd450)ext_opcode, words, count);
+      handle_glsl450_alu(b, (enum GLSLstd450)ext_opcode, w, count);
    }
 
    return true;