extracting demo JPEG bitstream works
[openpower-isa.git] / src / openpower / test / algorithms / jpeg / svp64_jpeg_decode.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 from pathlib import Path
5 from nmutil.plain_data import plain_data
6
7 RAINBOW_SMILEY = Path(__file__).with_name("rainbow_smiley.jpg").read_bytes()
8
9
10 @plain_data(unsafe_hash=True, frozen=True)
11 class HuffmanTableId:
12 __slots__ = "is_ac", "table_id"
13
14 def __init__(self, is_ac, table_id):
15 # type: (bool, int) -> None
16 self.is_ac = is_ac
17 self.table_id = table_id
18
19 @staticmethod
20 def from_id_byte(id_byte):
21 # type: (int) -> HuffmanTableId
22 return HuffmanTableId(is_ac=id_byte & 0xF0 != 0,
23 table_id=id_byte & 0xF)
24
25
26 @plain_data()
27 class HuffmanTables:
28 __slots__ = "tables",
29
30 def __init__(self, tables=None):
31 # type: (None | dict[HuffmanTableId, dict[str, int]]) -> None
32 if tables is None:
33 tables = {}
34 self.tables = tables
35
36 def add_from_bytes(self, data):
37 # type: (bytes) -> None
38 id_offset = 0
39 counts_offset = 1
40 table_id = HuffmanTableId.from_id_byte(data[id_offset])
41 num_counts = 16
42 offset = counts_offset + num_counts
43 code = 0
44 table = {} # type: dict[str, int]
45 for i in range(num_counts):
46 bit_length = 1 + i
47 count = data[counts_offset + i]
48 code <<= 1
49 for _ in range(count):
50 value = data[offset]
51 offset += 1
52 code_str = bin(code)[2:].rjust(bit_length, "0")
53 table[code_str] = value
54 code += 1
55
56 self.tables[table_id] = table
57
58
59 @plain_data()
60 class ScanComp:
61 __slots__ = "comp_id", "dc_huffman_table_id", "ac_huffman_table_id"
62
63 def __init__(self, comp_id, dc_huffman_table_id, ac_huffman_table_id):
64 # type: (int, int | HuffmanTableId, int | HuffmanTableId) -> None
65 self.comp_id = comp_id
66 if isinstance(dc_huffman_table_id, int):
67 dc_huffman_table_id = HuffmanTableId(is_ac=False,
68 table_id=dc_huffman_table_id)
69 assert not dc_huffman_table_id.is_ac, \
70 "dc huffman table id must be a dc table"
71 if isinstance(ac_huffman_table_id, int):
72 ac_huffman_table_id = HuffmanTableId(is_ac=True,
73 table_id=ac_huffman_table_id)
74 assert ac_huffman_table_id.is_ac, \
75 "ac huffman table id must be an ac table"
76 self.dc_huffman_table_id = dc_huffman_table_id
77 self.ac_huffman_table_id = ac_huffman_table_id
78
79
80 def parse_start_of_scan(data):
81 # type: (bytes) -> list[ScanComp]
82 offset = 0
83 comp_cnt = data[offset]
84 offset += 1
85 retval = []
86 for _ in range(comp_cnt):
87 retval.append(ScanComp(
88 comp_id=data[offset],
89 dc_huffman_table_id=data[offset + 1] >> 4,
90 ac_huffman_table_id=data[offset + 1] & 0xF,
91 ))
92 offset += 2
93 # ignore the rest
94 return retval
95
96
97 @plain_data()
98 class FrameHeaderComp:
99 __slots__ = "comp_id", "h_smpl_fac", "v_smpl_fac", "quant_tbl"
100
101 def __init__(self, comp_id, h_smpl_fac, v_smpl_fac, quant_tbl):
102 # type: (int, int, int, int) -> None
103 self.comp_id = comp_id
104 self.h_smpl_fac = h_smpl_fac
105 self.v_smpl_fac = v_smpl_fac
106 self.quant_tbl = quant_tbl
107
108 @property
109 def repeat(self):
110 return self.h_smpl_fac * self.v_smpl_fac
111
112 @property
113 def mcu_h(self):
114 return 8 * self.h_smpl_fac
115
116 @property
117 def mcu_v(self):
118 return 8 * self.v_smpl_fac
119
120
121 @plain_data()
122 class FrameHeader:
123 __slots__ = "smpl_prec", "img_h", "img_w", "components"
124
125 def __init__(self, smpl_prec, img_h, img_w, components):
126 # type: (int, int, int, dict[int, FrameHeaderComp]) -> None
127 self.smpl_prec = smpl_prec
128 self.img_h = img_h
129 self.img_w = img_w
130 self.components = components
131
132
133 def parse_start_of_frame(marker, data):
134 # type: (int, bytes) -> FrameHeader
135 if marker != 0xC0:
136 raise ValueError("only baseline DCT JPEG encoding supported")
137 offset = 0
138 smpl_prec = data[offset]
139 offset += 1
140 if smpl_prec != 8:
141 raise ValueError(f"unsupported sample-precision {smpl_prec}")
142 img_h = (data[offset] << 8) | data[offset + 1]
143 offset += 2
144 if img_h == 0:
145 raise ValueError("image height not being defined in "
146 "start-of-frame is unsupported")
147 img_w = (data[offset] << 8) | data[offset + 1]
148 offset += 2
149 if img_w == 0:
150 raise ValueError("invalid image width")
151 comp_cnt = data[offset]
152 offset += 1
153 if comp_cnt != 3:
154 raise ValueError("non RGB/YCbCr JPEG not supported")
155 components = {}
156 for _ in range(comp_cnt):
157 comp_id = data[offset]
158 components[comp_id] = FrameHeaderComp(
159 comp_id=comp_id,
160 h_smpl_fac=data[offset + 1] >> 4,
161 v_smpl_fac=data[offset + 1] & 0xF,
162 quant_tbl=data[offset + 2],
163 )
164 offset += 3
165 return FrameHeader(smpl_prec=smpl_prec, img_h=img_h,
166 img_w=img_w, components=components)
167
168
169 @plain_data()
170 class DemoBitstream:
171 __slots__ = ("bitstream", "huffman_tables",
172 "frame_header", "scan_header")
173
174 def __init__(self, bitstream, huffman_tables,
175 frame_header, scan_header):
176 # type: (bytes, HuffmanTables, FrameHeader, list[ScanComp]) -> None
177 self.bitstream = bitstream
178 self.huffman_tables = huffman_tables
179 self.frame_header = frame_header
180 self.scan_header = scan_header
181
182
183 def extract_demo_bitstream(data):
184 # type: (bytes) -> DemoBitstream
185 assert data.startswith(b"\xFF\xD8\xFF"), "not a jpeg"
186 huffman_tables = HuffmanTables()
187 scan_header = []
188 bitstream = []
189 extracted_bitstream = None
190 frame_header = None
191
192 offset = 0
193 while True:
194 chunk_start = offset
195 while True:
196 if data[offset] == 0xFF:
197 if data[offset + 1] == 0:
198 offset += 2
199 else:
200 break
201 else:
202 offset += 1
203 chunk_end = offset
204 if chunk_start != chunk_end:
205 bitstream.append(data[chunk_start:chunk_end])
206 assert data[offset] == 0xFF
207 offset += 1
208 assert data[offset] != 0
209 while data[offset] == 0xFF:
210 offset += 1
211 marker = data[offset]
212 offset += 1
213 if 0xD0 <= marker < 0xD8: # restart marker
214 raise ValueError("restart markers not supported")
215 if marker == 0xD8: # start of image
216 continue
217 if marker == 0xD9: # end of image
218 assert extracted_bitstream is not None, "missing JPEG image data"
219 break
220 segment_size = data[offset] << 8
221 segment_size |= data[offset + 1]
222 assert segment_size >= 2, "invalid marker segment size"
223 segment_data = data[offset + 2:offset + segment_size]
224 assert len(data) >= offset + segment_size, \
225 "file truncated before end of marker segment"
226 offset += segment_size
227 if 0xE0 <= marker <= 0xEF: # APP0 through APP15
228 continue # ignored
229 if marker == 0xDB: # DQT -- define quantization table
230 continue # ignored
231 if marker in (0xC0, 0xC1, 0xC2, 0xC3,
232 0xC5, 0xC6, 0xC7,
233 0xC9, 0xCA, 0xCB,
234 0xCD, 0xCE, 0xCF): # SOF0-15 -- start of frame
235 frame_header = parse_start_of_frame(marker, segment_data)
236 continue
237 if marker == 0xC4: # DHT -- define huffman table
238 huffman_tables.add_from_bytes(segment_data)
239 continue
240 if marker == 0xDA: # SOS -- start of scan
241 if extracted_bitstream is not None:
242 break
243 scan_header = parse_start_of_scan(segment_data)
244 bitstream = extracted_bitstream = []
245 continue
246 raise ValueError(f"unknown marker: 0xFF{marker:02X}: {segment_data}")
247 if frame_header is None:
248 raise ValueError("missing SOF0 marker (0xFF 0xC0)")
249 return DemoBitstream(bitstream=b"".join(extracted_bitstream),
250 huffman_tables=huffman_tables,
251 frame_header=frame_header,
252 scan_header=scan_header)
253
254
255 DEMO_BITSTREAM = extract_demo_bitstream(RAINBOW_SMILEY)
256
257 if __name__ == "__main__":
258 print(DEMO_BITSTREAM)