+
+import re
+
+type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
+
+def type_has_size(type_):
+ return type_[-1:].isdigit()
+
+def type_size(type_):
+ assert type_has_size(type_)
+ return int(type_split_re.match(type_).group('bits'))
+
+def type_sizes(type_):
+ if type_has_size(type_):
+ return [type_size(type_)]
+ elif type_ == 'float':
+ return [16, 32, 64]
+ else:
+ return [8, 16, 32, 64]
+
+def type_add_size(type_, size):
+ if type_has_size(type_):
+ return type_
+ return type_ + str(size)
+
+def op_bit_sizes(op):
+ sizes = None
+ if not type_has_size(op.output_type):
+ sizes = set(type_sizes(op.output_type))
+
+ for input_type in op.input_types:
+ if not type_has_size(input_type):
+ if sizes is None:
+ sizes = set(type_sizes(input_type))
+ else:
+ sizes = sizes.intersection(set(type_sizes(input_type)))
+
+ return sorted(list(sizes)) if sizes is not None else None
+
+def get_const_field(type_):
+ if type_ == "bool32":
+ return "u32"
+ elif type_ == "float16":
+ return "u16"
+ else:
+ m = type_split_re.match(type_)
+ if not m:
+ raise Exception(str(type_))
+ return m.group('type')[0] + m.group('bits')
+