File size: 3,257 Bytes
9b0f4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
from io import BytesIO
from typing import BinaryIO, Iterator, List, Optional, cast

from pdf2zh.pdfexceptions import PDFEOFError, PDFException

logger = logging.getLogger(__name__)


class CorruptDataError(PDFException):
    pass


class LZWDecoder:
    def __init__(self, fp: BinaryIO) -> None:
        self.fp = fp
        self.buff = 0
        self.bpos = 8
        self.nbits = 9
        # NB: self.table stores None only in indices 256 and 257
        self.table: List[Optional[bytes]] = []
        self.prevbuf: Optional[bytes] = None

    def readbits(self, bits: int) -> int:
        v = 0
        while 1:
            # the number of remaining bits we can get from the current buffer.
            r = 8 - self.bpos
            if bits <= r:
                # |-----8-bits-----|
                # |-bpos-|-bits-|  |
                # |      |----r----|
                v = (v << bits) | ((self.buff >> (r - bits)) & ((1 << bits) - 1))
                self.bpos += bits
                break
            else:
                # |-----8-bits-----|
                # |-bpos-|---bits----...
                # |      |----r----|
                v = (v << r) | (self.buff & ((1 << r) - 1))
                bits -= r
                x = self.fp.read(1)
                if not x:
                    raise PDFEOFError
                self.buff = ord(x)
                self.bpos = 0
        return v

    def feed(self, code: int) -> bytes:
        x = b""
        if code == 256:
            self.table = [bytes((c,)) for c in range(256)]  # 0-255
            self.table.append(None)  # 256
            self.table.append(None)  # 257
            self.prevbuf = b""
            self.nbits = 9
        elif code == 257:
            pass
        elif not self.prevbuf:
            x = self.prevbuf = cast(bytes, self.table[code])  # assume not None
        else:
            if code < len(self.table):
                x = cast(bytes, self.table[code])  # assume not None
                self.table.append(self.prevbuf + x[:1])
            elif code == len(self.table):
                self.table.append(self.prevbuf + self.prevbuf[:1])
                x = cast(bytes, self.table[code])
            else:
                raise CorruptDataError
            table_length = len(self.table)
            if table_length == 511:
                self.nbits = 10
            elif table_length == 1023:
                self.nbits = 11
            elif table_length == 2047:
                self.nbits = 12
            self.prevbuf = x
        return x

    def run(self) -> Iterator[bytes]:
        while 1:
            try:
                code = self.readbits(self.nbits)
            except EOFError:
                break
            try:
                x = self.feed(code)
            except CorruptDataError:
                # just ignore corrupt data and stop yielding there
                break
            yield x

            # logger.debug(
            #     "nbits=%d, code=%d, output=%r, table=%r",
            #     self.nbits,
            #     code,
            #     x,
            #     self.table[258:],
            # )


def lzwdecode(data: bytes) -> bytes:
    fp = BytesIO(data)
    s = LZWDecoder(fp).run()
    return b"".join(s)