From 224ed0e2a82cf5e4067c0a0d0baffa5d3efffc7a Mon Sep 17 00:00:00 2001 From: Bottersnike Date: Fri, 9 Sep 2022 13:44:29 +0100 Subject: [PATCH] Add decopmressed encoding; add decompressed test cases --- bemani/protocol/binary.py | 31 +++++++++++++++++++++++++++---- bemani/protocol/protocol.py | 5 +++++ bemani/protocol/stream.py | 14 ++++++++++++++ bemani/tests/test_protocol.py | 4 +++- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/bemani/protocol/binary.py b/bemani/protocol/binary.py index d941b1a..7aa3abd 100644 --- a/bemani/protocol/binary.py +++ b/bemani/protocol/binary.py @@ -483,7 +483,7 @@ class BinaryEncoder: A class capable of taking a Node tree and encoding it into a binary format. """ - def __init__(self, tree: Node, encoding: str) -> None: + def __init__(self, tree: Node, encoding: str, compressed: bool=True) -> None: """ Initialize the object. @@ -498,6 +498,7 @@ class BinaryEncoder: self.__body: List[int] = [] self.__body_len = 0 self.executed = False + self.compressed = compressed # Generate the characer LUT self.char_lut: Dict[str, int] = {} @@ -512,6 +513,22 @@ class BinaryEncoder: Parameters: name - A string name which should be encoded as a node name """ + if not self.compressed: + encoded = name.encode(self.encoding) + length = len(encoded) + + if length > BinaryEncoding.NAME_MAX_DECOMPRESSED: + raise BinaryEncodingException("Node name length over decompressed limit") + + if length < 64: + self.stream.write_int(length + 0x3f) + else: + length += 0x7fbf + self.stream.write_int((length >> 8) & 0xff) + self.stream.write_int(length & 0xff) + self.stream.write_blob(encoded) + return + def char_to_bin(ch: str) -> str: index = self.char_lut[ch] val = bin(index)[2:] @@ -826,7 +843,7 @@ class BinaryEncoding: else: return None - def encode(self, tree: Node, encoding: Optional[str]=None) -> bytes: + def encode(self, tree: Node, encoding: Optional[str]=None, compressed: bool=True) -> bytes: """ Given a tree of Node objects, encode the data with the current encoding. @@ -852,6 +869,12 @@ class BinaryEncoding: if encoding_magic is None: raise BinaryEncodingException(f"Invalid text encoding {encoding}") - encoder = BinaryEncoder(tree, self.__sanitize_encoding(encoding)) + encoder = BinaryEncoder(tree, self.__sanitize_encoding(encoding), compressed) data = encoder.get_data() - return struct.pack(">BBBB", BinaryEncoding.MAGIC, BinaryEncoding.COMPRESSED_WITH_DATA, encoding_magic, (~encoding_magic & 0xFF)) + data + return struct.pack( + ">BBBB", + BinaryEncoding.MAGIC, + BinaryEncoding.COMPRESSED_WITH_DATA if compressed else BinaryEncoding.DECOMPRESSED_WITH_DATA, + encoding_magic, + (~encoding_magic & 0xFF) + ) + data diff --git a/bemani/protocol/protocol.py b/bemani/protocol/protocol.py index c911d57..f111c0c 100644 --- a/bemani/protocol/protocol.py +++ b/bemani/protocol/protocol.py @@ -23,6 +23,7 @@ class EAmuseProtocol: XML: Final[int] = 1 BINARY: Final[int] = 2 + BINARY_DECOMPRESSED: Final[int] = 3 SHIFT_JIS_LEGACY: Final[str] = "shift-jis-legacy" SHIFT_JIS: Final[str] = "shift-jis" @@ -214,6 +215,10 @@ class EAmuseProtocol: # It's binary, encode it binary = BinaryEncoding() return binary.encode(tree, encoding=text_encoding) + elif packet_encoding == EAmuseProtocol.BINARY_DECOMPRESSED: + # It's binary, encode it + binary = BinaryEncoding() + return binary.encode(tree, encoding=text_encoding, compressed=False) elif packet_encoding == EAmuseProtocol.XML: # It's XML, encode it xml = XmlEncoding() diff --git a/bemani/protocol/stream.py b/bemani/protocol/stream.py index 34580b1..acf3fcf 100644 --- a/bemani/protocol/stream.py +++ b/bemani/protocol/stream.py @@ -128,6 +128,20 @@ class OutputStream: self.__formatted_data = b''.join(self.__data) return self.__formatted_data + def write_blob(self, blob: bytes) -> int: + """ + Write a binary blob of data to the stream + + Parameters: + blob - An blob of data to write. + + Returns: + the number of bytes written + """ + self.__data.append(blob) + self.__data_len += len(blob) + return len(blob) + def write_byte(self, byte: bytes) -> None: """ Write a raw byte to the end of the output stream. diff --git a/bemani/tests/test_protocol.py b/bemani/tests/test_protocol.py index 2977853..4e652eb 100644 --- a/bemani/tests/test_protocol.py +++ b/bemani/tests/test_protocol.py @@ -11,9 +11,11 @@ class TestProtocol(unittest.TestCase): def assertLoopback(self, root: Node) -> None: proto = EAmuseProtocol() - for encoding in [EAmuseProtocol.BINARY, EAmuseProtocol.XML]: + for encoding in [EAmuseProtocol.BINARY, EAmuseProtocol.BINARY_DECOMPRESSED, EAmuseProtocol.XML]: if encoding == EAmuseProtocol.BINARY: loop_name = "binary" + elif encoding == EAmuseProtocol.BINARY_DECOMPRESSED: + loop_name = "decompressed binary" elif encoding == EAmuseProtocol.XML: loop_name = "xml" else: