1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
4 from pathlib
import Path
5 from nmutil
.plain_data
import plain_data
7 RAINBOW_SMILEY
= Path(__file__
).with_name("rainbow_smiley.jpg").read_bytes()
10 @plain_data(unsafe_hash
=True, frozen
=True)
12 __slots__
= "is_ac", "table_id"
14 def __init__(self
, is_ac
, table_id
):
15 # type: (bool, int) -> None
17 self
.table_id
= table_id
20 def from_id_byte(id_byte
):
21 # type: (int) -> HuffmanTableId
22 return HuffmanTableId(is_ac
=id_byte
& 0xF0 != 0,
23 table_id
=id_byte
& 0xF)
30 def __init__(self
, tables
=None):
31 # type: (None | dict[HuffmanTableId, dict[str, int]]) -> None
36 def add_from_bytes(self
, data
):
37 # type: (bytes) -> None
40 table_id
= HuffmanTableId
.from_id_byte(data
[id_offset
])
42 offset
= counts_offset
+ num_counts
44 table
= {} # type: dict[str, int]
45 for i
in range(num_counts
):
47 count
= data
[counts_offset
+ i
]
49 for _
in range(count
):
52 code_str
= bin(code
)[2:].rjust(bit_length
, "0")
53 table
[code_str
] = value
56 self
.tables
[table_id
] = table
61 __slots__
= "comp_id", "dc_huffman_table_id", "ac_huffman_table_id"
63 def __init__(self
, comp_id
, dc_huffman_table_id
, ac_huffman_table_id
):
64 # type: (int, int | HuffmanTableId, int | HuffmanTableId) -> None
65 self
.comp_id
= comp_id
66 if isinstance(dc_huffman_table_id
, int):
67 dc_huffman_table_id
= HuffmanTableId(is_ac
=False,
68 table_id
=dc_huffman_table_id
)
69 assert not dc_huffman_table_id
.is_ac
, \
70 "dc huffman table id must be a dc table"
71 if isinstance(ac_huffman_table_id
, int):
72 ac_huffman_table_id
= HuffmanTableId(is_ac
=True,
73 table_id
=ac_huffman_table_id
)
74 assert ac_huffman_table_id
.is_ac
, \
75 "ac huffman table id must be an ac table"
76 self
.dc_huffman_table_id
= dc_huffman_table_id
77 self
.ac_huffman_table_id
= ac_huffman_table_id
80 def parse_start_of_scan(data
):
81 # type: (bytes) -> list[ScanComp]
83 comp_cnt
= data
[offset
]
86 for _
in range(comp_cnt
):
87 retval
.append(ScanComp(
89 dc_huffman_table_id
=data
[offset
+ 1] >> 4,
90 ac_huffman_table_id
=data
[offset
+ 1] & 0xF,
98 class FrameHeaderComp
:
99 __slots__
= "comp_id", "h_smpl_fac", "v_smpl_fac", "quant_tbl"
101 def __init__(self
, comp_id
, h_smpl_fac
, v_smpl_fac
, quant_tbl
):
102 # type: (int, int, int, int) -> None
103 self
.comp_id
= comp_id
104 self
.h_smpl_fac
= h_smpl_fac
105 self
.v_smpl_fac
= v_smpl_fac
106 self
.quant_tbl
= quant_tbl
110 return self
.h_smpl_fac
* self
.v_smpl_fac
114 return 8 * self
.h_smpl_fac
118 return 8 * self
.v_smpl_fac
123 __slots__
= "smpl_prec", "img_h", "img_w", "components"
125 def __init__(self
, smpl_prec
, img_h
, img_w
, components
):
126 # type: (int, int, int, dict[int, FrameHeaderComp]) -> None
127 self
.smpl_prec
= smpl_prec
130 self
.components
= components
133 def parse_start_of_frame(marker
, data
):
134 # type: (int, bytes) -> FrameHeader
136 raise ValueError("only baseline DCT JPEG encoding supported")
138 smpl_prec
= data
[offset
]
141 raise ValueError(f
"unsupported sample-precision {smpl_prec}")
142 img_h
= (data
[offset
] << 8) | data
[offset
+ 1]
145 raise ValueError("image height not being defined in "
146 "start-of-frame is unsupported")
147 img_w
= (data
[offset
] << 8) | data
[offset
+ 1]
150 raise ValueError("invalid image width")
151 comp_cnt
= data
[offset
]
154 raise ValueError("non RGB/YCbCr JPEG not supported")
156 for _
in range(comp_cnt
):
157 comp_id
= data
[offset
]
158 components
[comp_id
] = FrameHeaderComp(
160 h_smpl_fac
=data
[offset
+ 1] >> 4,
161 v_smpl_fac
=data
[offset
+ 1] & 0xF,
162 quant_tbl
=data
[offset
+ 2],
165 return FrameHeader(smpl_prec
=smpl_prec
, img_h
=img_h
,
166 img_w
=img_w
, components
=components
)
171 __slots__
= ("bitstream", "huffman_tables",
172 "frame_header", "scan_header")
174 def __init__(self
, bitstream
, huffman_tables
,
175 frame_header
, scan_header
):
176 # type: (bytes, HuffmanTables, FrameHeader, list[ScanComp]) -> None
177 self
.bitstream
= bitstream
178 self
.huffman_tables
= huffman_tables
179 self
.frame_header
= frame_header
180 self
.scan_header
= scan_header
183 def extract_demo_bitstream(data
):
184 # type: (bytes) -> DemoBitstream
185 assert data
.startswith(b
"\xFF\xD8\xFF"), "not a jpeg"
186 huffman_tables
= HuffmanTables()
189 extracted_bitstream
= None
196 if data
[offset
] == 0xFF:
197 if data
[offset
+ 1] == 0:
204 if chunk_start
!= chunk_end
:
205 bitstream
.append(data
[chunk_start
:chunk_end
])
206 assert data
[offset
] == 0xFF
208 assert data
[offset
] != 0
209 while data
[offset
] == 0xFF:
211 marker
= data
[offset
]
213 if 0xD0 <= marker
< 0xD8: # restart marker
214 raise ValueError("restart markers not supported")
215 if marker
== 0xD8: # start of image
217 if marker
== 0xD9: # end of image
218 assert extracted_bitstream
is not None, "missing JPEG image data"
220 segment_size
= data
[offset
] << 8
221 segment_size |
= data
[offset
+ 1]
222 assert segment_size
>= 2, "invalid marker segment size"
223 segment_data
= data
[offset
+ 2:offset
+ segment_size
]
224 assert len(data
) >= offset
+ segment_size
, \
225 "file truncated before end of marker segment"
226 offset
+= segment_size
227 if 0xE0 <= marker
<= 0xEF: # APP0 through APP15
229 if marker
== 0xDB: # DQT -- define quantization table
231 if marker
in (0xC0, 0xC1, 0xC2, 0xC3,
234 0xCD, 0xCE, 0xCF): # SOF0-15 -- start of frame
235 frame_header
= parse_start_of_frame(marker
, segment_data
)
237 if marker
== 0xC4: # DHT -- define huffman table
238 huffman_tables
.add_from_bytes(segment_data
)
240 if marker
== 0xDA: # SOS -- start of scan
241 if extracted_bitstream
is not None:
243 scan_header
= parse_start_of_scan(segment_data
)
244 bitstream
= extracted_bitstream
= []
246 raise ValueError(f
"unknown marker: 0xFF{marker:02X}: {segment_data}")
247 if frame_header
is None:
248 raise ValueError("missing SOF0 marker (0xFF 0xC0)")
249 return DemoBitstream(bitstream
=b
"".join(extracted_bitstream
),
250 huffman_tables
=huffman_tables
,
251 frame_header
=frame_header
,
252 scan_header
=scan_header
)
255 DEMO_BITSTREAM
= extract_demo_bitstream(RAINBOW_SMILEY
)
257 if __name__
== "__main__":
258 print(DEMO_BITSTREAM
)