From: Eli Bendersky Date: Mon, 28 Nov 2011 04:25:52 +0000 (+0200) Subject: further optimization of reading a cstring from a stream + added some tests for utils X-Git-Tag: v0.10~63^2 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=a1d6140ab97b1c9aac16a03482a0c21ef81c00df;p=pyelftools.git further optimization of reading a cstring from a stream + added some tests for utils --- diff --git a/elftools/common/utils.py b/elftools/common/utils.py index 2ce6746..7aa1cd9 100644 --- a/elftools/common/utils.py +++ b/elftools/common/utils.py @@ -34,16 +34,29 @@ def struct_parse(struct, stream, stream_pos=None): def parse_cstring_from_stream(stream, stream_pos=None): """ Parse a C-string from the given stream. The string is returned without - the terminating \x00 byte. + the terminating \x00 byte. If the terminating byte wasn't found, None + is returned (the stream is exhausted). If stream_pos is provided, the stream is seeked to this position before the parsing is done. Otherwise, the current position of the stream is used. """ - # I could've just used construct.CString, but this function is 4x faster. - # Since it's needed a lot, I created it as an optimization. if stream_pos is not None: stream.seek(stream_pos) - return ''.join(iter(lambda: stream.read(1), '\x00')) + CHUNKSIZE = 64 + chunks = [] + found = False + while True: + chunk = stream.read(CHUNKSIZE) + end_index = chunk.find('\x00') + if end_index >= 0: + chunks.append(chunk[:end_index]) + found = True + break + else: + chunks.append(chunk) + if len(chunk) < CHUNKSIZE: + break + return ''.join(chunks) if found else None def elf_assert(cond, msg=''): @@ -77,3 +90,4 @@ def preserve_stream_pos(stream): saved_pos = stream.tell() yield stream.seek(saved_pos) + diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..54a09bb --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,56 @@ +import sys, unittest +from cStringIO import StringIO +from random import randint + +sys.path.extend(['.', '..']) +from elftools.common.utils import (parse_cstring_from_stream, + preserve_stream_pos) + + +class Test_parse_cstring_from_stream(unittest.TestCase): + def _make_random_string(self, n): + return ''.join(chr(randint(32, 127)) for i in range(n)) + + def test_small1(self): + sio = StringIO('abcdefgh\x0012345') + self.assertEqual(parse_cstring_from_stream(sio), 'abcdefgh') + self.assertEqual(parse_cstring_from_stream(sio, 2), 'cdefgh') + self.assertEqual(parse_cstring_from_stream(sio, 8), '') + + def test_small2(self): + sio = StringIO('12345\x006789\x00abcdefg\x00iii') + self.assertEqual(parse_cstring_from_stream(sio), '12345') + self.assertEqual(parse_cstring_from_stream(sio, 5), '') + self.assertEqual(parse_cstring_from_stream(sio, 6), '6789') + + def test_large1(self): + text = 'i' * 400 + '\x00' + 'bb' + sio = StringIO(text) + self.assertEqual(parse_cstring_from_stream(sio), 'i' * 400) + self.assertEqual(parse_cstring_from_stream(sio, 150), 'i' * 250) + + def test_large2(self): + text = self._make_random_string(5000) + '\x00' + 'jujajaja' + sio = StringIO(text) + self.assertEqual(parse_cstring_from_stream(sio), text[:5000]) + self.assertEqual(parse_cstring_from_stream(sio, 2348), text[2348:5000]) + + +class Test_preserve_stream_pos(object): + def test_basic(self): + sio = StringIO('abcdef') + with preserve_stream_pos(sio): + sio.seek(4) + self.assertEqual(stream.tell(), 0) + + sio.seek(5) + with preserve_stream_pos(sio): + sio.seek(0) + self.assertEqual(stream.tell(), 5) + + +if __name__ == '__main__': + unittest.main() + + +