pokeheartgold/tools/py_scripts/dump_scrcmds.py

353 lines
14 KiB
Python

import json
import os
import typing
import warnings
import enum
import abc
import re
from typing import Union, BinaryIO, TextIO, Optional
import argparse
def parse_c_header(filename: str, prefix='') -> dict[int, str]:
with open(filename) as fp:
data = fp.read()
pat = re.compile(rf'#define\s+({prefix}\w+)\s+(\d+|0x[0-9a-fA-F]+)\n')
return {int(m[2], 0): m[1] for m in pat.finditer(data)}
class ScriptType(enum.Enum):
normal = 0
special = 1
__MAX__ = 2
@classmethod
def convert(cls, arg: str):
if arg.isnumeric():
return cls(int(arg))
elif arg in cls.__members__:
return cls.__members__[arg]
else:
raise TypeError
class Namespace(argparse.Namespace):
binfile: BinaryIO
scrfile: TextIO
name: str
mode: ScriptType
class ScriptParserBase(abc.ABC):
def __init__(self, raw: bytes, prefix='_EV'):
self.raw = raw
self.prefix = prefix
self.is_parsed = False
@abc.abstractmethod
def parse_all(self):
return NotImplemented
@abc.abstractmethod
def __str__(self):
return NotImplemented
def __repr__(self):
return f'<{self.__class__.__name__}(raw=bytes({len(self.raw)}), prefix={self.prefix!r})>'
class NormalScriptParser(ScriptParserBase):
def __init__(self, raw: bytes, prefix='_EV'):
super().__init__(raw, prefix)
with open(os.path.join(os.path.dirname(__file__), 'scrcmd.json')) as jsonfp:
scrcmds = json.load(jsonfp)
header_path = os.path.join(os.path.dirname(__file__), '../..')
self.constants = {
'var': parse_c_header(os.path.join(header_path, 'include/constants/vars.h'), 'VAR_'),
'flag': parse_c_header(os.path.join(header_path, 'include/constants/flags.h'), 'FLAG_'),
'species': parse_c_header(os.path.join(header_path, 'include/constants/species.h'), 'SPECIES_'),
'item': parse_c_header(os.path.join(header_path, 'include/constants/items.h'), 'ITEM_'),
'move': parse_c_header(os.path.join(header_path, 'include/constants/moves.h'), 'MOVE_'),
'sound': parse_c_header(os.path.join(header_path, 'include/constants/sndseq.h'), 'SEQ_'),
'ribbon': parse_c_header(os.path.join(header_path, 'include/constants/ribbon.h'), 'RIBBON_'),
'stdscr': parse_c_header(os.path.join(header_path, 'include/constants/std_script.h'), 'std_'),
'trainer': parse_c_header(os.path.join(header_path, 'include/constants/trainers.h'), 'TRAINER_')
}
self.commands: list[dict[str, Union[str, int, list[int], dict[str, list[int]]]]] = scrcmds.get('commands', [])
self.commands_d = {x['name']: x for x in self.commands}
self.movement_cmds = scrcmds.get('movement_commands', [])
self.exported = []
self.labels = {}
self.lines = {}
self.movement_scripts = set()
self.header_end = 0
self.pc_history = []
def parse_header(self):
for i in range(0, len(self.raw), 4):
if self.raw[i:i + 2] == b'\x13\xfd':
self.header_end = i + 2
break
self.exported.append(int.from_bytes(self.raw[i:i + 4], 'little') + i + 4)
assert(self.exported[-1] < len(self.raw))
if self.header_end != 4 * len(self.exported) + 2:
raise ValueError('malformatted script file')
self.labels |= {addr: False for addr in self.exported}
def get_arg(self, size: typing.Union[int, str], pc: int) -> tuple[typing.Union[int, str], int]:
if isinstance(size, int):
assert size in [1, 2, 4]
return int.from_bytes(self.raw[pc:pc + size], 'little'), pc + size
match size:
case 'addr' | 'script' | 'movement':
value = int.from_bytes(self.raw[pc:pc + 4], 'little')
pc += 4
value += pc
value &= 0xFFFFFFFF
assert self.header_end <= value < len(self.raw)
if size == 'movement':
self.movement_scripts.add(value)
if value not in self.labels:
self.labels[value] = (size != 'script')
else:
self.labels[value] |= (size != 'script')
return f'{self.prefix}_{value:04X}', pc
case 'condition':
value = self.raw[pc]
pc += 1
if len(self.pc_history) >= 2 and value < 2 and self.lines[self.pc_history[-2]][0] == 'checkflag':
conds = ['FALSE', 'TRUE']
else:
conds = ['lt', 'eq', 'gt', 'le', 'ge', 'ne']
return conds[value], pc
case 'var' | 'flag':
value = int.from_bytes(self.raw[pc:pc + 2], 'little')
pc += 2
return self.constants[size].get(value, value), pc
case 'species' | 'item' | 'move' | 'sound' | 'ribbon' | 'stdscr' | 'trainer':
value = int.from_bytes(self.raw[pc:pc + 2], 'little')
pc += 2
return self.constants['var'].get(value, self.constants[size].get(value, value)), pc
case _:
raise ValueError('unknown arg type: ' + size)
def parse_script(self, pc: int):
if self.labels[pc]:
return
self.pc_history.clear()
while pc < len(self.raw):
if pc in self.labels:
self.labels[pc] = True
self.pc_history.append(pc)
cmd_i = int.from_bytes(self.raw[pc:pc + 2], 'little')
if cmd_i >= len(self.commands):
warnings.warn(f'script parser hit illegal command {cmd_i} at position {pc}')
break
pc += 2
args = []
cmd_struct = self.commands[cmd_i]
name = cmd_struct['name']
arg_sizes = cmd_struct['args']
special = cmd_struct.get('cases')
switch_arg: Optional[int] = cmd_struct.get('switch_arg')
try:
for size in arg_sizes:
arg, pc = self.get_arg(size, pc)
args.append(arg)
if special is not None and switch_arg is not None:
for size in special[str(args[switch_arg])]:
arg, pc = self.get_arg(size, pc)
args.append(arg)
except (ValueError, KeyError):
warnings.warn(f'script parser hit illegal command args to {cmd_i} at position {self.pc_history[-1]} '
f'(arg {len(args)}, last good arg: {None if not args else args[-1]})')
break
self.lines[self.pc_history[-1]] = (name, args, pc)
if cmd_struct.get('is_abs_branch'):
break
def parse_all(self):
self.parse_header()
while not all(self.labels.values()):
for label in sorted(self.labels):
self.parse_script(label)
self.is_parsed = True
return self
def make_gap_internal(self, pc, nextpc):
if pc == nextpc:
return ''
s = ''
if pc in self.movement_scripts:
if pc & 1:
pc += 1
while pc < nextpc:
cmd = int.from_bytes(self.raw[pc:pc + 2], 'little')
is_end = cmd == 254
if cmd < len(self.movement_cmds):
cmd = self.movement_cmds[cmd]
duration = int.from_bytes(self.raw[pc + 2:pc + 4], 'little')
s += f'\t.short {cmd}, {duration}\n'
pc += 4
if is_end:
break
if pc == nextpc:
return s
if pc & 15:
gap = min(16 - (pc & 15), nextpc - pc)
s += '\t.byte ' + ', '.join(map('0x{:02x}'.format, self.raw[pc:pc + gap])) + '\n'
pc += gap
while pc < nextpc:
gap = min(16, nextpc - pc)
s += '\t.byte ' + ', '.join(map('0x{:02x}'.format, self.raw[pc:pc + gap])) + '\n'
pc += gap
return s
def make_gap(self, pc, nextpc):
if pc == nextpc or (nextpc == len(self.raw) and all(x == 0 for x in self.raw[pc:nextpc])):
return ''
s = ''
labels = sorted({x for x in self.labels if pc <= x < nextpc} | {pc, nextpc})
for x, y in zip(labels[:-1], labels[1:]):
if x in labels:
s += f'\n{self.prefix}_{x:04X}:\n'
s += self.make_gap_internal(x, y)
return s
def __str__(self):
if not self.is_parsed:
return repr(self)
s = '#include "constants/scrcmd.h"\n'
s += '\t.include "asm/macros/script.inc"\n\n'
s += '\t.rodata\n\n'
for i, addr in enumerate(self.exported):
s += f'\tscrdef {self.prefix}_{addr:04X} ; {i:03d}\n'
s += '\tscrdef_end\n\n'
if not self.lines:
s += self.make_gap(self.header_end, len(self.raw))
else:
if self.header_end not in self.lines:
s += self.make_gap(self.header_end, min(self.lines))
lines = sorted(self.lines.items())
lines.append((len(self.raw), ('', [], -1)))
for i, (pc, (name, args, nextpc)) in enumerate(lines[:-1]):
if pc != nextpc:
args = list(args)
if pc in self.labels:
s += f'{self.prefix}_{pc:04X}:\n'
if args:
s += f'\t{name} ' + ', '.join(map(str, args)) + '\n'
else:
s += f'\t{name}\n'
if nextpc in self.labels and self.commands_d[name].get('is_abs_branch'):
s += '\n'
if nextpc != lines[i + 1][0]:
s += self.make_gap(nextpc, lines[i + 1][0])
s += '\t.balign 4, 0\n'
return s
def make_header(self):
s = f'#ifndef {self.prefix.upper()}_H_\n'
s += f'#define {self.prefix.upper()}_H_\n\n'
for i, name in enumerate(self.exported):
s += f'#define {name} {i: 3d}\n'
s += f'\n#endif //{self.prefix.upper()}_H_\n'
return s
class SpecialScriptParser(ScriptParserBase):
def __init__(self, raw: bytes, prefix='_EV'):
super().__init__(raw, prefix)
header_path = os.path.join(os.path.dirname(__file__), '../..')
self.vars = parse_c_header(os.path.join(header_path, 'include/constants/vars.h'), 'VAR_')
self.table: list[tuple[int, int, int]] = []
self.init_offset: int = -1
self.init_vars: list[tuple[int, int, int]] = []
def parse_all(self):
i = 0
for i in range(0, len(self.raw), 5):
if self.raw[i] == 0:
break
if self.raw[i] == 1:
self.init_offset = i + 5 + int.from_bytes(self.raw[i + 1:i + 5], 'little')
self.table.append((1, -1, -1))
else:
self.table.append((
self.raw[i],
int.from_bytes(self.raw[i + 1:i + 3], 'little'),
int.from_bytes(self.raw[i + 3:i + 5], 'little')
))
if self.init_offset != -1:
for i in range(self.init_offset, len(self.raw), 6):
if (a := int.from_bytes(self.raw[i:i + 2], 'little')) == 0:
break
self.init_vars.append((
a,
int.from_bytes(self.raw[i + 2:i + 4], 'little'),
int.from_bytes(self.raw[i + 4:i + 6], 'little')
))
i += 2
else:
i += 1
assert ((i + 3) & ~3) == len(self.raw)
self.is_parsed = True
return self
def __str__(self):
if not self.is_parsed:
return repr(self)
s = '#include "constants/scrcmd.h"\n\t.rodata\n\t.option alignment off\n\n'
for kind, val1, val2 in self.table:
if kind == 1:
s += f'\t.byte 1\n\t.word {self.prefix}_{self.init_offset:04X}-.-4\n'
else:
s += f'\t.byte {kind}\n\t.short {val1}, {val2}\n'
s += '\t.byte 0\n\n'
if self.init_offset != -1:
s += f'{self.prefix}_{self.init_offset:04X}:\n'
for flex1, flex2, script in self.init_vars:
s += f'\t.short {self.vars.get(flex1, flex1)}, {self.vars.get(flex2, flex2)}, {script}\n'
s += '\t.short 0\n\n'
s += '\t.balign 4, 0\n'
return s
def main():
parser = argparse.ArgumentParser()
parser.add_argument('binfile', type=argparse.FileType('rb'))
parser.add_argument('scrfile', type=argparse.FileType('w'), nargs='?')
parser.add_argument('name', nargs='?')
parser.add_argument('--mode', type=ScriptType.convert, default=ScriptType.normal)
args = parser.parse_args(namespace=Namespace())
if args.scrfile is None:
scrfname = os.path.splitext(args.binfile.name)[0] + '.s'
args.scrfile = argparse.FileType('w')(scrfname)
if args.name is None:
args.name = os.path.splitext(os.path.basename(args.binfile.name))[0]
if args.mode is ScriptType.normal:
cls = NormalScriptParser
elif args.mode is ScriptType.special:
cls = SpecialScriptParser
else:
raise TypeError(args.mode)
data = args.binfile.read()
while True:
try:
parser = cls(data, args.name).parse_all()
print(parser, file=args.scrfile, end='')
except Exception as e:
if cls is NormalScriptParser:
cls = SpecialScriptParser
else:
print(f'Error with {args.binfile.name}')
raise e
else:
break
if __name__ == '__main__':
main()