From 7493db034f306a4fa5d3925e7025c4610f8eef8a Mon Sep 17 00:00:00 2001 From: Jennifer Taylor Date: Sun, 2 May 2021 03:49:35 +0000 Subject: [PATCH] Add in basic throw, test code generation of mostly the same code as we tested the code graph with. --- bemani/format/afp/decompile.py | 47 +++++- bemani/tests/test_afp_decompile.py | 235 ++++++++++++++++++++++++++++- 2 files changed, 278 insertions(+), 4 deletions(-) diff --git a/bemani/format/afp/decompile.py b/bemani/format/afp/decompile.py index 9a79b73..8a4b972 100644 --- a/bemani/format/afp/decompile.py +++ b/bemani/format/afp/decompile.py @@ -217,6 +217,22 @@ class ReturnStatement(Statement): return [f"{prefix}return {ret};"] +class ThrowStatement(Statement): + # A statement which raises an exception. It appears that there is no + # 'catch' in this version of bytecode so it must be used only as an + # assert. + def __init__(self, exc: Any) -> None: + self.exc = exc + + def __repr__(self) -> str: + exc = value_ref(self.exc, "") + return f"throw {exc}" + + def render(self, prefix: str) -> List[str]: + exc = value_ref(self.exc, prefix) + return [f"{prefix}throw {exc};"] + + class NopStatement(Statement): # A literal no-op. We will get rid of these in an optimizing pass. def __repr__(self) -> str: @@ -2257,6 +2273,11 @@ class ByteCodeDecompiler(VerboseOutput): chunk.actions[i] = ReturnStatement(retval) continue + if action.opcode == AP2Action.THROW: + retval = get_stack() + chunk.actions[i] = ThrowStatement(retval) + continue + if action.opcode == AP2Action.POP: # This is a discard. Let's see if its discarding a function or method # call. If so, that means the return doesn't matter. @@ -2635,6 +2656,7 @@ class ByteCodeDecompiler(VerboseOutput): # Calculate the statements for this chunk, as well as the leftover stack entries and any borrows. self.vprint(f"Evaluating graph of ByteCodeChunk {chunk.id}") new_statements, stack_leftovers, new_borrowed_entries = self.__eval_stack(chunk, stack, offset_map) + borrowed_entries.extend(new_borrowed_entries) # We need to check and see if the last entry is an IfExpr, and hoist it # into a statement here. @@ -2671,7 +2693,17 @@ class ByteCodeDecompiler(VerboseOutput): # The stack for both of these is the leftovers from the previous evaluation as they # rollover. stacks[true_start] = [s for s in stack_leftovers] - true_statements, true_borrowed_entries = self.__eval_chunks_impl(true_start, if_body_chunk.true_chunks, next_chunk_id, stacks, insertables, other_stack_locs, offset_map) + true_statements, true_borrowed_entries = self.__eval_chunks_impl( + true_start, + if_body_chunk.true_chunks, + next_chunk_id, + stacks, + insertables, + other_stack_locs, + offset_map, + ) + borrowed_entries.extend(true_borrowed_entries) + false_statements: List[Statement] = [] if if_body_chunk.false_chunks: self.vprint(f"Evaluating graph of IfBody {if_body_chunk.id} false case") @@ -2682,7 +2714,16 @@ class ByteCodeDecompiler(VerboseOutput): # The stack for both of these is the leftovers from the previous evaluation as they # rollover. stacks[false_start] = [s for s in stack_leftovers] - false_statements, false_borrowed_entries = self.__eval_chunks_impl(false_start, if_body_chunk.false_chunks, next_chunk_id, stacks, insertables, other_stack_locs, offset_map) + false_statements, false_borrowed_entries = self.__eval_chunks_impl( + false_start, + if_body_chunk.false_chunks, + next_chunk_id, + stacks, + insertables, + other_stack_locs, + offset_map, + ) + borrowed_entries.extend(false_borrowed_entries) # Convert this IfExpr to a full-blown IfStatement. new_statements[-1] = IfStatement( @@ -2726,7 +2767,7 @@ class ByteCodeDecompiler(VerboseOutput): break start_id = chunk.next_chunks[0] - return statements, stack + return statements, borrowed_entries def __walk(self, statements: Sequence[Statement], do: Callable[[Statement], Optional[Statement]]) -> List[Statement]: new_statements: List[Statement] = [] diff --git a/bemani/tests/test_afp_decompile.py b/bemani/tests/test_afp_decompile.py index 8d0c195..7291863 100644 --- a/bemani/tests/test_afp_decompile.py +++ b/bemani/tests/test_afp_decompile.py @@ -4,7 +4,7 @@ from typing import Dict, List, Sequence, Tuple, Union from bemani.tests.helpers import ExtendedTestCase from bemani.format.afp.types.ap2 import AP2Action, IfAction, JumpAction, PushAction, Register -from bemani.format.afp.decompile import BitVector, ByteCode, ByteCodeChunk, ControlFlow, ByteCodeDecompiler +from bemani.format.afp.decompile import BitVector, ByteCode, ByteCodeChunk, ControlFlow, ByteCodeDecompiler, Statement class TestAFPBitVector(unittest.TestCase): @@ -517,3 +517,236 @@ class TestAFPControlGraph(ExtendedTestCase): self.assertEqual(self.__equiv(chunks_by_id[1]), ["102: PUSH\n 'b'\nEND_PUSH", "103: END"]) self.assertEqual(self.__equiv(chunks_by_id[2]), ["104: PUSH\n 'a'\nEND_PUSH", "105: END"]) self.assertEqual(self.__equiv(chunks_by_id[3]), []) + + +class TestAFPDecompile(ExtendedTestCase): + # Note that the offsets made up in these test functions are not realistic. Jump/If instructions + # take up more than one opcode, and the end offset might be more than one byte past the last + # action if that action takes up more than one byte. However, from the perspective of the + # decompiler, it doesn't care about accurate sizes, only that the offsets are correct. + + def __make_bytecode(self, actions: Sequence[AP2Action]) -> ByteCode: + return ByteCode( + actions, + actions[-1].offset + 1, + ) + + def __call_decompile(self, bytecode: ByteCode) -> List[Statement]: + # Just create a dummy compiler so we can access the internal method for testing. + bcd = ByteCodeDecompiler(bytecode) + bcd.decompile() + return bcd.statements + + def __equiv(self, statements: List[Statement]) -> List[str]: + return [str(x) for x in statements] + + def test_simple_bytecode(self) -> None: + bytecode = self.__make_bytecode([ + AP2Action(100, AP2Action.STOP), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ['builtin_StopPlaying()']) + + def test_jump_handling(self) -> None: + bytecode = self.__make_bytecode([ + JumpAction(100, 102), + JumpAction(101, 104), + JumpAction(102, 101), + JumpAction(103, 106), + JumpAction(104, 103), + JumpAction(105, 107), + JumpAction(106, 105), + AP2Action(107, AP2Action.STOP), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ['builtin_StopPlaying()']) + + def test_dead_code_elimination_jump(self) -> None: + # Jump case + bytecode = self.__make_bytecode([ + AP2Action(100, AP2Action.STOP), + JumpAction(101, 103), + AP2Action(102, AP2Action.PLAY), + AP2Action(103, AP2Action.STOP), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ['builtin_StopPlaying()', 'builtin_StopPlaying()']) + + def test_dead_code_elimination_return(self) -> None: + # Return case + bytecode = self.__make_bytecode([ + PushAction(100, ["strval"]), + AP2Action(101, AP2Action.RETURN), + AP2Action(102, AP2Action.STOP), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ["return 'strval'"]) + + def test_dead_code_elimination_end(self) -> None: + # Return case + bytecode = self.__make_bytecode([ + AP2Action(100, AP2Action.STOP), + AP2Action(101, AP2Action.END), + AP2Action(102, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ['builtin_StopPlaying()']) + + def test_dead_code_elimination_throw(self) -> None: + # Throw case + bytecode = self.__make_bytecode([ + PushAction(100, ["exception"]), + AP2Action(101, AP2Action.THROW), + AP2Action(102, AP2Action.STOP), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ["throw 'exception'"]) + + def test_if_handling_basic(self) -> None: + # If by itself case. + bytecode = self.__make_bytecode([ + # Beginning of the if statement. + PushAction(100, [True]), + IfAction(101, IfAction.IS_FALSE, 103), + # False case (fall through from if). + AP2Action(102, AP2Action.PLAY), + # Line after the if statement. + AP2Action(103, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ["if (True) {\n builtin_StartPlaying()\n}"]) + + def test_if_handling_basic_jump_to_end(self) -> None: + # If by itself case. + bytecode = self.__make_bytecode([ + # Beginning of the if statement. + PushAction(100, [True]), + IfAction(101, IfAction.IS_FALSE, 103), + # False case (fall through from if). + AP2Action(102, AP2Action.PLAY), + # Some code will jump to the end offset as a way of + # "returning" early from a function. + ]) + statements = self.__call_decompile(bytecode) + + # TODO: The output should be optimized to remove the early return and move the + # start playing section inside the if. + self.assertEqual(self.__equiv(statements), ["if (not True) {\n return\n}", "builtin_StartPlaying()"]) + + def test_if_handling_diamond(self) -> None: + # If true-false diamond case. + bytecode = self.__make_bytecode([ + # Beginning of the if statement. + PushAction(100, [True]), + IfAction(101, IfAction.IS_TRUE, 104), + # False case (fall through from if). + AP2Action(102, AP2Action.STOP), + JumpAction(103, 105), + # True case. + AP2Action(104, AP2Action.PLAY), + # Line after the if statement. + AP2Action(105, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ["if (True) {\n builtin_StartPlaying()\n} else {\n builtin_StopPlaying()\n}"]) + + def test_if_handling_diamond_jump_to_end(self) -> None: + # If true-false diamond case. + bytecode = self.__make_bytecode([ + # Beginning of the if statement. + PushAction(100, [True]), + IfAction(101, IfAction.IS_TRUE, 104), + # False case (fall through from if). + AP2Action(102, AP2Action.STOP), + JumpAction(103, 105), + # True case. + AP2Action(104, AP2Action.PLAY), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: The output should be optimized to remove redundant return statements. + self.assertEqual(self.__equiv(statements), ["if (True) {\n builtin_StartPlaying()\n return\n} else {\n builtin_StopPlaying()\n return\n}"]) + + def test_if_handling_diamond_return_to_end(self) -> None: + # If true-false diamond case but the cases never converge. + bytecode = self.__make_bytecode([ + # Beginning of the if statement. + PushAction(100, [True]), + IfAction(101, IfAction.IS_TRUE, 104), + # False case (fall through from if). + PushAction(102, ['b']), + AP2Action(103, AP2Action.RETURN), + # True case. + PushAction(104, ['a']), + AP2Action(105, AP2Action.RETURN), + ]) + statements = self.__call_decompile(bytecode) + self.assertEqual(self.__equiv(statements), ["if (True) {\n return 'a'\n} else {\n return 'b'\n}"]) + + def test_if_handling_switch(self) -> None: + # Series of ifs (basically a switch statement). + bytecode = self.__make_bytecode([ + # Beginning of the first if statement. + PushAction(100, [Register(0), 1]), + IfAction(101, IfAction.NOT_EQUALS, 104), + # False case (fall through from if). + PushAction(102, ['a']), + JumpAction(103, 113), + + # Beginning of the second if statement. + PushAction(104, [Register(0), 2]), + IfAction(105, IfAction.NOT_EQUALS, 108), + # False case (fall through from if). + PushAction(106, ['b']), + JumpAction(107, 113), + + # Beginning of the third if statement. + PushAction(108, [Register(0), 3]), + IfAction(109, IfAction.NOT_EQUALS, 112), + # False case (fall through from if). + PushAction(110, ['c']), + JumpAction(111, 113), + + # Beginning of default case. + PushAction(112, ['d']), + + # Line after the switch statement. + AP2Action(113, AP2Action.RETURN), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: This should be optimized as an if/elseif/else chunk without so much indentation. + self.assertEqual(self.__equiv(statements), [ + "if (registers[0] != 1) {\n" + " if (registers[0] != 2) {\n" + " if (registers[0] != 3) {\n" + " tempvar_0 = 'd'\n" + " } else {\n" + " tempvar_0 = 'c'\n" + " }\n" + " } else {\n" + " tempvar_0 = 'b'\n" + " }\n" + "} else {\n" + " tempvar_0 = 'a'\n" + "}", + "return tempvar_0" + ]) + + def test_if_handling_diamond_end_both_sides(self) -> None: + # If true-false diamond case but the cases never converge. + bytecode = self.__make_bytecode([ + # Beginning of the if statement. + PushAction(100, [True]), + IfAction(101, IfAction.IS_TRUE, 104), + # False case (fall through from if). + AP2Action(102, AP2Action.STOP), + AP2Action(103, AP2Action.END), + # True case. + AP2Action(104, AP2Action.PLAY), + AP2Action(105, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: The output should be optimized to remove redundant return statements. + self.assertEqual(self.__equiv(statements), ["if (True) {\n builtin_StartPlaying()\n return\n} else {\n builtin_StopPlaying()\n return\n}"])