From 841aab6f50c88d54d1a3a68a311621f948d09126 Mon Sep 17 00:00:00 2001 From: Connor Abbott Date: Thu, 18 Jun 2015 18:52:44 -0700 Subject: [PATCH] matrices matrices matrices --- src/glsl/nir/spirv_to_nir.c | 291 +++++++++++++++++++++++++--- src/glsl/nir/spirv_to_nir_private.h | 5 + 2 files changed, 268 insertions(+), 28 deletions(-) diff --git a/src/glsl/nir/spirv_to_nir.c b/src/glsl/nir/spirv_to_nir.c index e84d7564300..a0a75040771 100644 --- a/src/glsl/nir/spirv_to_nir.c +++ b/src/glsl/nir/spirv_to_nir.c @@ -37,7 +37,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant, if (entry) return entry->data; - struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value); + struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value); val->type = type; switch (glsl_get_base_type(type)) { @@ -63,7 +63,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant, val->elems = ralloc_array(b, struct vtn_ssa_value *, columns); for (unsigned i = 0; i < columns; i++) { - struct vtn_ssa_value *col_val = ralloc(b, struct vtn_ssa_value); + struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value); col_val->type = glsl_get_column_type(val->type); nir_load_const_instr *load = nir_load_const_instr_create(b->shader, rows); @@ -516,6 +516,7 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, case SpvDecorationFPFastMathMode: case SpvDecorationLinkageAttributes: case SpvDecorationSpecId: + break; default: unreachable("Unhandled variable decoration"); } @@ -525,7 +526,7 @@ static struct vtn_ssa_value * _vtn_variable_load(struct vtn_builder *b, nir_deref_var *src_deref, nir_deref *src_deref_tail) { - struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value); + struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value); val->type = src_deref_tail->type; /* The deref tail may contain a deref to select a component of a vector (in @@ -1010,11 +1011,264 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode, nir_builder_instr_insert(&b->nb, &instr->instr); } +static struct vtn_ssa_value * +vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type) +{ + struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value); + val->type = type; + + if (!glsl_type_is_vector_or_scalar(type)) { + unsigned elems = glsl_get_length(type); + val->elems = ralloc_array(b, struct vtn_ssa_value *, elems); + for (unsigned i = 0; i < elems; i++) { + const struct glsl_type *child_type; + + switch (glsl_get_base_type(type)) { + case GLSL_TYPE_INT: + case GLSL_TYPE_UINT: + case GLSL_TYPE_BOOL: + case GLSL_TYPE_FLOAT: + case GLSL_TYPE_DOUBLE: + child_type = glsl_get_column_type(type); + break; + case GLSL_TYPE_ARRAY: + child_type = glsl_get_array_element(type); + break; + case GLSL_TYPE_STRUCT: + child_type = glsl_get_struct_field(type, i); + break; + default: + unreachable("unkown base type"); + } + + val->elems[i] = vtn_create_ssa_value(b, child_type); + } + } + + return val; +} + +static nir_alu_instr * +create_vec(void *mem_ctx, unsigned num_components) +{ + nir_op op; + switch (num_components) { + case 1: op = nir_op_fmov; break; + case 2: op = nir_op_vec2; break; + case 3: op = nir_op_vec3; break; + case 4: op = nir_op_vec4; break; + default: unreachable("bad vector size"); + } + + nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op); + nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL); + + return vec; +} + +static struct vtn_ssa_value * +vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src) +{ + if (src->transposed) + return src->transposed; + + struct vtn_ssa_value *dest = + vtn_create_ssa_value(b, glsl_transposed_type(src->type)); + + for (unsigned i = 0; i < glsl_get_matrix_columns(dest->type); i++) { + nir_alu_instr *vec = create_vec(b, glsl_get_matrix_columns(src->type)); + if (glsl_type_is_vector_or_scalar(src->type)) { + vec->src[0].src = nir_src_for_ssa(src->def); + vec->src[0].swizzle[0] = i; + } else { + for (unsigned j = 0; j < glsl_get_matrix_columns(src->type); j++) { + vec->src[j].src = nir_src_for_ssa(src->elems[j]->def); + vec->src[j].swizzle[0] = i; + } + } + nir_builder_instr_insert(&b->nb, &vec->instr); + dest->elems[i]->def = &vec->dest.dest.ssa; + } + + dest->transposed = src; + + return dest; +} + +/* + * Normally, column vectors in SPIR-V correspond to a single NIR SSA + * definition. But for matrix multiplies, we want to do one routine for + * multiplying a matrix by a matrix and then pretend that vectors are matrices + * with one column. So we "wrap" these things, and unwrap the result before we + * send it off. + */ + +static struct vtn_ssa_value * +vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) +{ + if (val == NULL) + return NULL; + + if (glsl_type_is_matrix(val->type)) + return val; + + struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); + dest->type = val->type; + dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); + dest->elems[0] = val; + + return dest; +} + +static struct vtn_ssa_value * +vtn_unwrap_matrix(struct vtn_ssa_value *val) +{ + if (glsl_type_is_matrix(val->type)) + return val; + + return val->elems[0]; +} + +static struct vtn_ssa_value * +vtn_matrix_multiply(struct vtn_builder *b, + struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) +{ + + struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0); + struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1); + struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed); + struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed); + + unsigned src0_rows = glsl_get_vector_elements(src0->type); + unsigned src0_columns = glsl_get_matrix_columns(src0->type); + unsigned src1_columns = glsl_get_matrix_columns(src1->type); + + struct vtn_ssa_value *dest = + vtn_create_ssa_value(b, glsl_matrix_type(glsl_get_base_type(src0->type), + src0_rows, src1_columns)); + + dest = vtn_wrap_matrix(b, dest); + + bool transpose_result = false; + if (src0_transpose && src1_transpose) { + /* transpose(A) * transpose(B) = transpose(B * A) */ + src1 = src0_transpose; + src0 = src1_transpose; + src0_transpose = NULL; + src1_transpose = NULL; + transpose_result = true; + } + + if (src0_transpose && !src1_transpose && + glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { + /* We already have the rows of src0 and the columns of src1 available, + * so we can just take the dot product of each row with each column to + * get the result. + */ + + for (unsigned i = 0; i < src1_columns; i++) { + nir_alu_instr *vec = create_vec(b, src0_rows); + for (unsigned j = 0; j < src0_rows; j++) { + vec->src[j].src = + nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def, + src1->elems[i]->def)); + } + + nir_builder_instr_insert(&b->nb, &vec->instr); + dest->elems[i]->def = &vec->dest.dest.ssa; + } + } else { + /* We don't handle the case where src1 is transposed but not src0, since + * the general case only uses individual components of src1 so the + * optimizer should chew through the transpose we emitted for src1. + */ + + for (unsigned i = 0; i < src1_columns; i++) { + /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ + dest->elems[i]->def = + nir_fmul(&b->nb, src0->elems[0]->def, + vtn_vector_extract(b, src1->elems[i]->def, 0)); + for (unsigned j = 1; j < src0_columns; j++) { + dest->elems[i]->def = + nir_fadd(&b->nb, dest->elems[i]->def, + nir_fmul(&b->nb, src0->elems[j]->def, + vtn_vector_extract(b, + src1->elems[i]->def, j))); + } + } + } + + dest = vtn_unwrap_matrix(dest); + + if (transpose_result) + dest = vtn_transpose(b, dest); + + return dest; +} + +static struct vtn_ssa_value * +vtn_mat_times_scalar(struct vtn_builder *b, + struct vtn_ssa_value *mat, + nir_ssa_def *scalar) +{ + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); + for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { + if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT) + dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); + else + dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); + } + + return dest; +} + static void vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) { - unreachable("Matrix math not handled"); + struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); + val->type = vtn_value(b, w[1], vtn_value_type_type)->type; + + switch (opcode) { + case SpvOpTranspose: { + struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]); + val->ssa = vtn_transpose(b, src); + break; + } + + case SpvOpOuterProduct: { + struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]); + struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]); + + val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1)); + break; + } + + case SpvOpMatrixTimesScalar: { + struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]); + struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]); + + if (mat->transposed) { + val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed, + scalar->def)); + } else { + val->ssa = vtn_mat_times_scalar(b, mat, scalar->def); + } + break; + } + + case SpvOpVectorTimesMatrix: + case SpvOpMatrixTimesVector: + case SpvOpMatrixTimesMatrix: { + struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]); + struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]); + + val->ssa = vtn_matrix_multiply(b, src0, src1); + break; + } + + default: unreachable("unknown matrix opcode"); + } } static void @@ -1197,29 +1451,10 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, static nir_ssa_def * vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index) { - nir_alu_src alu_src; - alu_src.src = nir_src_for_ssa(src); - alu_src.swizzle[0] = index; - return nir_fmov_alu(&b->nb, alu_src, 1); + unsigned swiz[4] = { index }; + return nir_swizzle(&b->nb, src, swiz, 1, true); } -static nir_alu_instr * -create_vec(void *mem_ctx, unsigned num_components) -{ - nir_op op; - switch (num_components) { - case 1: op = nir_op_fmov; break; - case 2: op = nir_op_vec2; break; - case 3: op = nir_op_vec3; break; - case 4: op = nir_op_vec4; break; - default: unreachable("bad vector size"); - } - - nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op); - nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL); - - return vec; -} static nir_ssa_def * vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert, @@ -1320,7 +1555,7 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components, static struct vtn_ssa_value * vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src) { - struct vtn_ssa_value *dest = ralloc(mem_ctx, struct vtn_ssa_value); + struct vtn_ssa_value *dest = rzalloc(mem_ctx, struct vtn_ssa_value); dest->type = src->type; if (glsl_type_is_vector_or_scalar(src->type)) { @@ -1376,7 +1611,7 @@ vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src, * vector to extract. */ - struct vtn_ssa_value *ret = ralloc(b, struct vtn_ssa_value); + struct vtn_ssa_value *ret = rzalloc(b, struct vtn_ssa_value); ret->type = glsl_scalar_type(glsl_get_base_type(cur->type)); ret->def = vtn_vector_extract(b, cur->def, indices[i]); return ret; @@ -1413,7 +1648,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, break; case SpvOpCompositeConstruct: { - val->ssa = ralloc(b, struct vtn_ssa_value); + val->ssa = rzalloc(b, struct vtn_ssa_value); unsigned elems = count - 3; if (glsl_type_is_vector_or_scalar(val->type)) { nir_ssa_def *srcs[4]; diff --git a/src/glsl/nir/spirv_to_nir_private.h b/src/glsl/nir/spirv_to_nir_private.h index a4760620d46..937c45b08c2 100644 --- a/src/glsl/nir/spirv_to_nir_private.h +++ b/src/glsl/nir/spirv_to_nir_private.h @@ -71,6 +71,11 @@ struct vtn_ssa_value { struct vtn_ssa_value **elems; }; + /* For matrices, a transposed version of the value, or NULL if it hasn't + * been computed + */ + struct vtn_ssa_value *transposed; + const struct glsl_type *type; }; -- 2.30.2