From: Jason Ekstrand Date: Wed, 7 Nov 2018 18:15:22 +0000 (-0600) Subject: nir/opcodes: Pull in the type helpers from constant_expressions X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=03571a7a6c6d0d4df9409a75fe93c925ec045670;p=mesa.git nir/opcodes: Pull in the type helpers from constant_expressions While we're at it, we rework them a bit to all use regular expressions and assert more. Reviewed-by: Connor Abbott --- diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index 118af9f7818..0cd4ffcf558 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -1,23 +1,8 @@ from __future__ import print_function import re - -type_split_re = re.compile(r'(?P[a-z]+)(?P\d+)') - -def type_has_size(type_): - return type_[-1:].isdigit() - -def type_size(type_): - assert type_has_size(type_) - return int(type_split_re.match(type_).group('bits')) - -def type_sizes(type_): - if type_has_size(type_): - return [type_size(type_)] - elif type_ == 'float': - return [16, 32, 64] - else: - return [8, 16, 32, 64] +from nir_opcodes import opcodes +from nir_opcodes import type_has_size, type_size, type_sizes, type_base_type def type_add_size(type_, size): if type_has_size(type_): @@ -44,10 +29,7 @@ def get_const_field(type_): elif type_ == "float16": return "u16" else: - m = type_split_re.match(type_) - if not m: - raise Exception(str(type_)) - return m.group('type')[0] + m.group('bits') + return type_base_type(type_)[0] + str(type_size(type_)) template = """\ /* @@ -429,7 +411,6 @@ nir_eval_const_opcode(nir_op op, unsigned num_components, } }""" -from nir_opcodes import opcodes from mako.template import Template print(Template(template).render(opcodes=opcodes, type_sizes=type_sizes, diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index 4ef4ecc6f22..d69d09d30ce 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -23,6 +23,7 @@ # Authors: # Connor Abbott (cwabbott0@gmail.com) +import re # Class that represents all the information we have about the opcode # NOTE: this must be kept in sync with nir_op_info @@ -99,6 +100,33 @@ tint64 = "int64" tuint64 = "uint64" tfloat64 = "float64" +_TYPE_SPLIT_RE = re.compile(r'(?Pint|uint|float|bool)(?P\d+)?') + +def type_has_size(type_): + m = _TYPE_SPLIT_RE.match(type_) + assert m is not None, 'Invalid NIR type string: "{}"'.format(type_) + return m.group('bits') is not None + +def type_size(type_): + m = _TYPE_SPLIT_RE.match(type_) + assert m is not None, 'Invalid NIR type string: "{}"'.format(type_) + assert m.group('bits') is not None, \ + 'NIR type string has no bit size: "{}"'.format(type_) + return int(m.group('bits')) + +def type_sizes(type_): + if type_has_size(type_): + return [type_size(type_)] + elif type_ == 'float': + return [16, 32, 64] + else: + return [8, 16, 32, 64] + +def type_base_type(type_): + m = _TYPE_SPLIT_RE.match(type_) + assert m is not None, 'Invalid NIR type string: "{}"'.format(type_) + return m.group('type') + commutative = "commutative " associative = "associative " @@ -175,11 +203,7 @@ for src_t in [tint, tuint, tfloat]: dst_types = [tint, tuint, tfloat] for dst_t in dst_types: - if dst_t == tfloat: - bit_sizes = [16, 32, 64] - else: - bit_sizes = [8, 16, 32, 64] - for bit_size in bit_sizes: + for bit_size in type_sizes(dst_t): if bit_size == 16 and dst_t == tfloat and src_t == tfloat: rnd_modes = ['_rtne', '_rtz', ''] for rnd_mode in rnd_modes: diff --git a/src/compiler/nir/nir_opcodes_c.py b/src/compiler/nir/nir_opcodes_c.py index 8bfcda6d719..9e3c06b8634 100644 --- a/src/compiler/nir/nir_opcodes_c.py +++ b/src/compiler/nir/nir_opcodes_c.py @@ -25,7 +25,7 @@ from __future__ import print_function -from nir_opcodes import opcodes +from nir_opcodes import opcodes, type_sizes from mako.template import Template template = Template(""" @@ -64,12 +64,7 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd % endif % endif switch (dst_bit_size) { -% if dst_t == 'float': -<% bit_sizes = [16, 32, 64] %> -% else: -<% bit_sizes = [8, 16, 32, 64] %> -% endif -% for dst_bits in bit_sizes: +% for dst_bits in type_sizes(dst_t): case ${dst_bits}: % if src_t == 'float' and dst_t == 'float' and dst_bits == 16: switch(rnd) { @@ -137,4 +132,4 @@ const nir_op_info nir_op_infos[nir_num_opcodes] = { }; """) -print(template.render(opcodes=opcodes)) +print(template.render(opcodes=opcodes, type_sizes=type_sizes))