diff --git a/bemani/format/afp/decompile.py b/bemani/format/afp/decompile.py index 8a4b972..6d76a5a 100644 --- a/bemani/format/afp/decompile.py +++ b/bemani/format/afp/decompile.py @@ -94,6 +94,32 @@ class ControlFlow: return f"ControlFlow(beginning={self.beginning}, end={self.end}, next={(', '.join(str(n) for n in self.next_flow)) or 'N/A'}" +class IfResult: + def __init__(self, stmt_id: int, path: bool) -> None: + self.stmt_id = stmt_id + self.path = path + + def makes_tautology(self, other: "IfResult") -> bool: + return self.stmt_id == other.stmt_id and self.path != other.path + + def __repr__(self) -> str: + return f"IfResult(stmt_id={self.stmt_id}, path={self.path})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IfResult): + return NotImplemented + return self.stmt_id == other.stmt_id and self.path == other.path + + def __ne__(self, other: object) -> bool: + if not isinstance(other, IfResult): + return NotImplemented + return not (self.stmt_id == other.stmt_id and self.path == other.path) + + def __hash__(self) -> int: + # Lower bit will be for true/false, upper bits for statement ID. + return (self.stmt_id * 2) + (1 if self.path else 0) + + class ConvertedAction: # An action that has been analyzed and converted to an intermediate representation. pass @@ -1917,6 +1943,115 @@ class ByteCodeDecompiler(VerboseOutput): self.vprint(f"Finished separating if statements out of graph starting at {start_id}") return [c for _, c in chunks_by_id.items()] + def __new_separate_ifs(self, start_id: int, end_id: Optional[int], chunks: Sequence[ArbitraryCodeChunk], offset_map: Dict[int, int]) -> List[ArbitraryCodeChunk]: + # TODO: This algorithm can possibly do better than the original at identifying cases. + # In particular, it handles compound if statements (if x or y) where the previous one + # ends up sticking gotos in. The problem is that it needs to know what if statements + # exist before combining them, and we can't do that until we walk the stack, and the + # stack walking algorithm both a) comes later and b) relies on all ifs being processed. + # So, this stays as a beta for now, and will possibly be integrated at a later time. + chunks_by_id: Dict[int, ArbitraryCodeChunk] = {chunk.id: chunk for chunk in chunks} + chunks_examined: Set[int] = set() + + self.vprint(f"BETA: Separating if statements out of graph starting at {start_id}") + + def walk_children(cur_chunk: ArbitraryCodeChunk, apply_logic: Sequence[IfResult]) -> Dict[int, Set[IfResult]]: + # First, if we have any previous if statements to apply to this chunk, do that now. + self.vprint(f"BETA: Applying {apply_logic} to {cur_chunk.id}") + chunks_to_logic: Dict[int, Set[IfResult]] = {cur_chunk.id: {x for x in apply_logic}} + + # Now, if it is a loop and we haven't already passed over this chunk, recursively + # find if statements inside it as well. + if isinstance(cur_chunk, Loop): + if cur_chunk.id not in chunks_examined: + chunks_examined.add(cur_chunk.id) + + self.vprint(f"BETA: Examining loop {cur_chunk.id} body for if statements...") + cur_chunk.chunks = self.__new_separate_ifs(cur_chunk.id, None, cur_chunk.chunks, offset_map) + self.vprint(f"BETA: Finished examining loop {cur_chunk.id} body for if statements...") + + # Now, see if we need to split logic up or not. + if not cur_chunk.next_chunks: + # We are at the end of our walk. + return chunks_to_logic + + if len(cur_chunk.next_chunks) == 1: + # We only have one child, so follow that link. + next_chunk = cur_chunk.next_chunks[0] + if next_chunk in chunks_by_id: + for cid, logic in walk_children(chunks_by_id[next_chunk], apply_logic).items(): + chunks_to_logic[cid] = {*chunks_to_logic.get(cid, set()), *logic} + return chunks_to_logic + + if not isinstance(cur_chunk, ByteCodeChunk): + # We should only be looking at bytecode chunks at this point, all other + # types should have a single next chunk. + raise Exception(f"Logic error, found converted Loop or If chunk {cur_chunk.id} with multiple successors!") + + if len(cur_chunk.next_chunks) != 2: + # This needs to be an if statement. + raise Exception(f"Logic error, expected 2 successors but got {len(cur_chunk.next_chunks)} in chunk {cur_chunk.id}!") + last_action = cur_chunk.actions[-1] + if not isinstance(last_action, IfAction): + # This needs, again, to be an if statement. + raise Exception("Logic error, only IfActions can have multiple successors in chunk {cur_chunk.id}!") + + # Find the true and false jump points, walk those graphs and assign logical predecessors + # to each of them. + true_jump_point = offset_map[last_action.jump_if_true_offset] + false_jump_points = [n for n in cur_chunk.next_chunks if n != true_jump_point] + if len(false_jump_points) != 1: + raise Exception("Logic error, got more than one false jump point for an if statement!") + false_jump_point = false_jump_points[0] + + if true_jump_point == false_jump_point: + # This should never happen. + raise Exception("Logic error, both true and false jumps are to the same location!") + + self.vprint(f"BETA: Chunk ID {cur_chunk.id} is an if statement with true node {true_jump_point} and false node {false_jump_point}") + + # Walk both halves, assigning the if statement that has to exist to get to each half. + if true_jump_point in chunks_by_id: + for cid, logic in walk_children(chunks_by_id[true_jump_point], [*apply_logic, IfResult(cur_chunk.id, True)]).items(): + chunks_to_logic[cid] = {*chunks_to_logic.get(cid, set()), *logic} + if false_jump_point in chunks_by_id: + for cid, logic in walk_children(chunks_by_id[false_jump_point], [*apply_logic, IfResult(cur_chunk.id, False)]).items(): + chunks_to_logic[cid] = {*chunks_to_logic.get(cid, set()), *logic} + return chunks_to_logic + + # First, walk through and identify how we get to each chunk. + chunks_by_logic = walk_children(chunks_by_id[start_id], []) + self.vprint(f"BETA: List of logics: {chunks_by_logic}") + + # Now, go through each chunk and remove tautologies (where we get to it through a previous + # if statement from both true and false paths, meaning this isn't owned by an if statement). + for cid in chunks_by_logic: + changed: bool = True + while changed: + # Assume we didn't change anything. + changed = False + + # Figure out if there is a tautology existing in this logic. + for path in chunks_by_logic[cid]: + remove: Optional[IfResult] = None + for other in chunks_by_logic[cid]: + if path.makes_tautology(other): + remove = other + break + + if remove: + # We found a tautology, remove both halves. + self.vprint(f"BETA: {path} makes a tautology with {remove}, removing both of them!") + chunks_by_logic[cid].remove(path) + chunks_by_logic[cid].remove(remove) + changed = True + break + + self.vprint(f"BETA: Cleaned up logics: {chunks_by_logic}") + + self.vprint(f"BETA: Finished separating if statements out of graph starting at {start_id}") + return [c for _, c in chunks_by_id.items()] + def __check_graph(self, start_id: int, chunks: Sequence[ArbitraryCodeChunk]) -> List[ArbitraryCodeChunk]: # Recursively go through and verify that all entries to the graph have only one link. # Also, clean up the graph. @@ -2542,7 +2677,7 @@ class ByteCodeDecompiler(VerboseOutput): if isinstance(statement, InsertionLocation): # Convert to any statements we need to insert. if statement.location in insertables: - self.vprint("Inserting temp variable assignments into insertion location {stataement.location}") + self.vprint(f"Inserting temp variable assignments into insertion location {statement.location}") for stmt in insertables[statement.location]: new_statements.append(stmt) else: diff --git a/bemani/tests/helpers.py b/bemani/tests/helpers.py index aba59f5..5a77daa 100644 --- a/bemani/tests/helpers.py +++ b/bemani/tests/helpers.py @@ -1,4 +1,5 @@ # vim: set fileencoding=utf-8 +import sys import unittest from typing import Container, List, Dict, Any @@ -8,6 +9,10 @@ __unittest = True class ExtendedTestCase(unittest.TestCase): + @property + def verbose(self) -> bool: + return ("-v" in sys.argv) or ("--verbose" in sys.argv) + def assertItemsEqual(self, a: Container[Any], b: Container[Any]) -> None: a_items = {x for x in a} b_items = {x for x in b} diff --git a/bemani/tests/test_afp_decompile.py b/bemani/tests/test_afp_decompile.py index 7291863..b9f936a 100644 --- a/bemani/tests/test_afp_decompile.py +++ b/bemani/tests/test_afp_decompile.py @@ -534,7 +534,7 @@ class TestAFPDecompile(ExtendedTestCase): def __call_decompile(self, bytecode: ByteCode) -> List[Statement]: # Just create a dummy compiler so we can access the internal method for testing. bcd = ByteCodeDecompiler(bytecode) - bcd.decompile() + bcd.decompile(verbose=self.verbose) return bcd.statements def __equiv(self, statements: List[Statement]) -> List[str]: @@ -602,7 +602,7 @@ class TestAFPDecompile(ExtendedTestCase): statements = self.__call_decompile(bytecode) self.assertEqual(self.__equiv(statements), ["throw 'exception'"]) - def test_if_handling_basic(self) -> None: + def test_if_handling_basic_flow_to_end(self) -> None: # If by itself case. bytecode = self.__make_bytecode([ # Beginning of the if statement. @@ -750,3 +750,36 @@ class TestAFPDecompile(ExtendedTestCase): # TODO: The output should be optimized to remove redundant return statements. self.assertEqual(self.__equiv(statements), ["if (True) {\n builtin_StartPlaying()\n return\n} else {\n builtin_StopPlaying()\n return\n}"]) + + def test_if_handling_or(self) -> None: + # Two ifs that together make an or (if register == 1 or register == 3) + bytecode = self.__make_bytecode([ + # Beginning of the first if statement. + PushAction(100, [Register(0), 1]), + IfAction(101, IfAction.EQUALS, 104), + # False case (circuit not broken, register is not equal to 1) + PushAction(102, [Register(0), 2]), + IfAction(103, IfAction.NOT_EQUALS, 106), + # This is the true case + AP2Action(104, AP2Action.PLAY), + JumpAction(105, 107), + # This is the false case + AP2Action(106, AP2Action.STOP), + # This is the fall-through after the if. + PushAction(107, ['strval']), + AP2Action(108, AP2Action.RETURN), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: This should be optimized as a compound if statement. + self.assertEqual(self.__equiv(statements), [ + "if (registers[0] != 1) {\n" + " if (registers[0] != 2) {\n" + " builtin_StopPlaying()\n" + " label_4:\n" + " return 'strval'\n" + " }\n" + "}", + "builtin_StartPlaying()", + "goto label_4", + ])