mirror of
https://github.com/pret/pokemon-reverse-engineering-tools.git
synced 2026-03-21 17:24:42 -05:00
405 lines
14 KiB
Python
405 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
from copy import copy
|
|
import hashlib
|
|
import random
|
|
import json
|
|
|
|
from pokemontools.interval_map import IntervalMap
|
|
from pokemontools.chars import chars, jap_chars
|
|
|
|
from pokemontools.romstr import (
|
|
RomStr,
|
|
AsmList,
|
|
)
|
|
|
|
from pokemontools.item_constants import (
|
|
item_constants,
|
|
find_item_label_by_id,
|
|
generate_item_constants,
|
|
)
|
|
|
|
from pokemontools.pointers import (
|
|
calculate_bank,
|
|
calculate_pointer,
|
|
)
|
|
|
|
from pokemontools.pksv import (
|
|
pksv_gs,
|
|
pksv_crystal,
|
|
)
|
|
|
|
from pokemontools.labels import (
|
|
remove_quoted_text,
|
|
line_has_comment_address,
|
|
line_has_label,
|
|
get_label_from_line,
|
|
)
|
|
|
|
from pokemontools.helpers import (
|
|
grouper,
|
|
index,
|
|
)
|
|
|
|
from pokemontools.crystalparts.old_parsers import (
|
|
old_parse_map_header_at,
|
|
)
|
|
|
|
from pokemontools.crystal import (
|
|
script_parse_table,
|
|
load_rom,
|
|
rom_until,
|
|
direct_load_rom,
|
|
parse_script_engine_script_at,
|
|
parse_text_engine_script_at,
|
|
parse_text_at2,
|
|
find_all_text_pointers_in_script_engine_script,
|
|
SingleByteParam,
|
|
HexByte,
|
|
MultiByteParam,
|
|
PointerLabelParam,
|
|
ItemLabelByte,
|
|
DollarSignByte,
|
|
DecimalParam,
|
|
rom_interval,
|
|
map_names,
|
|
Label,
|
|
scan_for_predefined_labels,
|
|
all_labels,
|
|
write_all_labels,
|
|
parse_map_header_at,
|
|
process_00_subcommands,
|
|
parse_all_map_headers,
|
|
translate_command_byte,
|
|
map_name_cleaner,
|
|
load_map_group_offsets,
|
|
load_asm,
|
|
asm,
|
|
is_valid_address,
|
|
how_many_until,
|
|
get_pokemon_constant_by_id,
|
|
generate_map_constant_labels,
|
|
get_map_constant_label_by_id,
|
|
get_id_for_map_constant_label,
|
|
calculate_pointer_from_bytes_at,
|
|
isolate_incbins,
|
|
process_incbins,
|
|
get_labels_between,
|
|
rom_text_at,
|
|
get_label_for,
|
|
split_incbin_line_into_three,
|
|
reset_incbins,
|
|
parse_rom,
|
|
|
|
# globals
|
|
engine_flags,
|
|
)
|
|
|
|
import pokemontools.wram
|
|
|
|
import unittest
|
|
import mock
|
|
|
|
class BasicTestCase(unittest.TestCase):
|
|
"this is where i cram all of my unit tests together"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
global rom
|
|
cls.rom = direct_load_rom()
|
|
rom = cls.rom
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.rom
|
|
|
|
def test_direct_load_rom(self):
|
|
rom = self.rom
|
|
self.assertEqual(len(rom), 2097152)
|
|
self.failUnless(isinstance(rom, RomStr))
|
|
|
|
def test_load_rom(self):
|
|
rom = load_rom()
|
|
self.assertNotEqual(rom, None)
|
|
rom = load_rom()
|
|
self.assertNotEqual(rom, RomStr(None))
|
|
|
|
def test_load_asm(self):
|
|
asm = load_asm()
|
|
joined_lines = "\n".join(asm)
|
|
self.failUnless("SECTION" in joined_lines)
|
|
self.failUnless("bank" in joined_lines)
|
|
self.failUnless(isinstance(asm, AsmList))
|
|
|
|
def test_rom_file_existence(self):
|
|
"ROM file must exist"
|
|
dirname = os.path.dirname(__file__)
|
|
filenames = os.listdir(os.path.join(os.path.abspath(dirname), "../../"))
|
|
self.failUnless("baserom.gbc" in filenames)
|
|
|
|
def test_rom_md5(self):
|
|
"ROM file must have the correct md5 sum"
|
|
rom = self.rom
|
|
correct = "9f2922b235a5eeb78d65594e82ef5dde"
|
|
md5 = hashlib.md5()
|
|
md5.update(rom)
|
|
md5sum = md5.hexdigest()
|
|
self.assertEqual(md5sum, correct)
|
|
|
|
def test_bizarre_http_presence(self):
|
|
rom_segment = self.rom[0x112116:0x112116+8]
|
|
self.assertEqual(rom_segment, "HTTP/1.0")
|
|
|
|
def test_rom_interval(self):
|
|
address = 0x100
|
|
interval = 10
|
|
correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
|
|
'0xed', '0x66', '0x66', '0xcc', '0xd']
|
|
byte_strings = rom_interval(address, interval, rom=self.rom, strings=True)
|
|
self.assertEqual(byte_strings, correct_strings)
|
|
correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
|
|
ints = rom_interval(address, interval, rom=self.rom, strings=False)
|
|
self.assertEqual(ints, correct_ints)
|
|
|
|
def test_rom_until(self):
|
|
address = 0x1337
|
|
byte = 0x13
|
|
bytes = rom_until(address, byte, rom=self.rom, strings=True)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == '0xd5')
|
|
bytes = rom_until(address, byte, rom=self.rom, strings=False)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == 0xd5)
|
|
|
|
def test_how_many_until(self):
|
|
how_many = how_many_until(chr(0x13), 0x1337, self.rom)
|
|
self.assertEqual(how_many, 3)
|
|
|
|
def test_calculate_pointer_from_bytes_at(self):
|
|
addr1 = calculate_pointer_from_bytes_at(0x100, bank=False)
|
|
self.assertEqual(addr1, 0xc300)
|
|
addr2 = calculate_pointer_from_bytes_at(0x100, bank=True)
|
|
self.assertEqual(addr2, 0x2ec3)
|
|
|
|
def test_rom_text_at(self):
|
|
self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0")
|
|
|
|
class TestRomStr(unittest.TestCase):
|
|
sample_text = "hello world!"
|
|
sample = None
|
|
|
|
def test_rom_interval(self):
|
|
global rom
|
|
load_rom()
|
|
address = 0x100
|
|
interval = 10
|
|
correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
|
|
'0xed', '0x66', '0x66', '0xcc', '0xd']
|
|
byte_strings = rom.interval(address, interval, strings=True)
|
|
self.assertEqual(byte_strings, correct_strings)
|
|
correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
|
|
ints = rom.interval(address, interval, strings=False)
|
|
self.assertEqual(ints, correct_ints)
|
|
|
|
def test_rom_until(self):
|
|
global rom
|
|
load_rom()
|
|
address = 0x1337
|
|
byte = 0x13
|
|
bytes = rom.until(address, byte, strings=True)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == '0xd5')
|
|
bytes = rom.until(address, byte, strings=False)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == 0xd5)
|
|
|
|
class TestAsmList(unittest.TestCase):
|
|
# this test takes a lot of time :(
|
|
def xtest_scan_for_predefined_labels(self):
|
|
# label keys: line_number, bank, label, offset, address
|
|
load_asm()
|
|
all_labels = scan_for_predefined_labels()
|
|
label_names = [x["label"] for x in all_labels]
|
|
self.assertIn("GetFarByte", label_names)
|
|
self.assertIn("AddNTimes", label_names)
|
|
self.assertIn("CheckShininess", label_names)
|
|
|
|
class TestEncodedText(unittest.TestCase):
|
|
"""for testing chars-table encoded text chunks"""
|
|
|
|
def test_process_00_subcommands(self):
|
|
g = process_00_subcommands(0x197186, 0x197186+601, debug=False)
|
|
self.assertEqual(len(g), 42)
|
|
self.assertEqual(len(g[0]), 13)
|
|
self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81])
|
|
|
|
def test_parse_text_at2(self):
|
|
oakspeech = parse_text_at2(0x197186, 601, debug=False)
|
|
self.assertIn("encyclopedia", oakspeech)
|
|
self.assertIn("researcher", oakspeech)
|
|
self.assertIn("dependable", oakspeech)
|
|
|
|
def test_parse_text_engine_script_at(self):
|
|
p = parse_text_engine_script_at(0x197185, debug=False)
|
|
self.assertEqual(len(p.commands), 1)
|
|
self.assertEqual(p.commands[0].to_asm().count("\n"), 40)
|
|
|
|
class TestScript(unittest.TestCase):
|
|
"""for testing parse_script_engine_script_at and script parsing in
|
|
general. Script should be a class."""
|
|
#def test_parse_script_engine_script_at(self):
|
|
# pass # or raise NotImplementedError, bryan_message
|
|
|
|
def test_find_all_text_pointers_in_script_engine_script(self):
|
|
address = 0x197637 # 0x197634
|
|
script = parse_script_engine_script_at(address, debug=False)
|
|
bank = calculate_bank(address)
|
|
r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False)
|
|
results = list(r)
|
|
self.assertIn(0x197661, results)
|
|
|
|
class TestByteParams(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
load_rom()
|
|
cls.address = 10
|
|
cls.sbp = SingleByteParam(address=cls.address)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.sbp
|
|
|
|
def test__init__(self):
|
|
self.assertEqual(self.sbp.size, 1)
|
|
self.assertEqual(self.sbp.address, self.address)
|
|
|
|
def test_parse(self):
|
|
self.sbp.parse()
|
|
self.assertEqual(str(self.sbp.byte), str(45))
|
|
|
|
def test_to_asm(self):
|
|
self.assertEqual(self.sbp.to_asm(), "$2d")
|
|
self.sbp.should_be_decimal = True
|
|
self.assertEqual(self.sbp.to_asm(), str(45))
|
|
|
|
# HexByte and DollarSignByte are the same now
|
|
def test_HexByte_to_asm(self):
|
|
h = HexByte(address=10)
|
|
a = h.to_asm()
|
|
self.assertEqual(a, "$2d")
|
|
|
|
def test_DollarSignByte_to_asm(self):
|
|
d = DollarSignByte(address=10)
|
|
a = d.to_asm()
|
|
self.assertEqual(a, "$2d")
|
|
|
|
def test_ItemLabelByte_to_asm(self):
|
|
i = ItemLabelByte(address=433)
|
|
self.assertEqual(i.byte, 54)
|
|
self.assertEqual(i.to_asm(), "COIN_CASE")
|
|
self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d")
|
|
|
|
def test_DecimalParam_to_asm(self):
|
|
d = DecimalParam(address=10)
|
|
x = d.to_asm()
|
|
self.assertEqual(x, str(0x2d))
|
|
|
|
class TestMultiByteParam(unittest.TestCase):
|
|
def setup_for(self, somecls, byte_size=2, address=443, **kwargs):
|
|
self.rom = load_rom()
|
|
self.cls = somecls(address=address, size=byte_size, **kwargs)
|
|
self.assertEqual(self.cls.address, address)
|
|
self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, rom=self.rom, strings=False))
|
|
self.assertEqual(self.cls.size, byte_size)
|
|
|
|
def test_two_byte_param(self):
|
|
self.setup_for(MultiByteParam, byte_size=2)
|
|
self.assertEqual(self.cls.to_asm(), "$f0c0")
|
|
|
|
def test_three_byte_param(self):
|
|
self.setup_for(MultiByteParam, byte_size=3)
|
|
|
|
def test_PointerLabelParam_no_bank(self):
|
|
self.setup_for(PointerLabelParam, bank=None)
|
|
# assuming no label at this location..
|
|
self.assertEqual(self.cls.to_asm(), "$f0c0")
|
|
global script_parse_table
|
|
script_parse_table[0xf0c0:0xf0c0 + 1] = {"label": "poop", "bank": 0, "line_number": 2}
|
|
self.assertEqual(self.cls.to_asm(), "poop")
|
|
|
|
class TestPostParsing(unittest.TestCase):
|
|
"""tests that must be run after parsing all maps"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.rom = direct_load_rom()
|
|
|
|
pokemontools.wram.wram_labels = {}
|
|
|
|
with mock.patch("pokemontools.crystal.read_engine_flags", return_value={}):
|
|
with mock.patch("pokemontools.crystal.read_event_flags", return_value={}):
|
|
with mock.patch("pokemontools.crystal.setup_wram_labels", return_value={}):
|
|
parse_rom(rom=cls.rom, _skip_wram_labels=True, _parse_map_header_at=old_parse_map_header_at, debug=False)
|
|
|
|
def test_signpost_counts(self):
|
|
self.assertEqual(len(map_names[1][1]["header_new"]["event_header"]["signposts"]), 0)
|
|
self.assertEqual(len(map_names[1][2]["header_new"]["event_header"]["signposts"]), 2)
|
|
self.assertEqual(len(map_names[10][5]["header_new"]["event_header"]["signposts"]), 7)
|
|
|
|
def test_warp_counts(self):
|
|
self.assertEqual(map_names[10][5]["header_new"]["event_header"]["warp_count"], 9)
|
|
self.assertEqual(map_names[18][5]["header_new"]["event_header"]["warp_count"], 3)
|
|
self.assertEqual(map_names[15][1]["header_new"]["event_header"]["warp_count"], 2)
|
|
|
|
def test_map_sizes(self):
|
|
self.assertEqual(map_names[15][1]["header_new"]["second_map_header"]["height"], 18)
|
|
self.assertEqual(map_names[15][1]["header_new"]["second_map_header"]["width"], 10)
|
|
self.assertEqual(map_names[7][1]["header_new"]["second_map_header"]["height"], 4)
|
|
self.assertEqual(map_names[7][1]["header_new"]["second_map_header"]["width"], 4)
|
|
|
|
def test_map_connection_counts(self):
|
|
#print map_names[10][5]
|
|
#print map_names[10][5].keys()
|
|
#print map_names[10][5]["header_new"].keys()
|
|
self.assertEqual(map_names[7][1]["header_new"]["second_map_header"]["connections"], 0)
|
|
self.assertEqual(map_names[10][1]["header_new"]["second_map_header"]["connections"], 12)
|
|
self.assertEqual(map_names[10][2]["header_new"]["second_map_header"]["connections"], 12)
|
|
self.assertEqual(map_names[11][1]["header_new"]["second_map_header"]["connections"], 9) # or 13?
|
|
|
|
def test_second_map_header_address(self):
|
|
self.assertEqual(map_names[11][1]["header_new"]["second_map_header_address"], 0x9509c)
|
|
self.assertEqual(map_names[1][5]["header_new"]["second_map_header_address"], 0x95bd0)
|
|
|
|
def test_event_address(self):
|
|
self.assertEqual(map_names[17][5]["header_new"]["second_map_header"]["event_address"], 0x194d67)
|
|
self.assertEqual(map_names[23][3]["header_new"]["second_map_header"]["event_address"], 0x1a9ec9)
|
|
|
|
def test_people_event_counts(self):
|
|
self.assertEqual(len(map_names[23][3]["header_new"]["event_header"]["people_events"]), 4)
|
|
self.assertEqual(len(map_names[10][3]["header_new"]["event_header"]["people_events"]), 9)
|
|
|
|
class TestMapParsing(unittest.TestCase):
|
|
def xtest_parse_all_map_headers(self):
|
|
global parse_map_header_at, old_parse_map_header_at, counter
|
|
counter = 0
|
|
for k in map_names.keys():
|
|
if "offset" not in map_names[k].keys():
|
|
map_names[k]["offset"] = 0
|
|
temp = parse_map_header_at
|
|
temp2 = old_parse_map_header_at
|
|
def parse_map_header_at(address, map_group=None, map_id=None, debug=False):
|
|
global counter
|
|
counter += 1
|
|
return {}
|
|
old_parse_map_header_at = parse_map_header_at
|
|
parse_all_map_headers(debug=False)
|
|
# parse_all_map_headers is currently doing it 2x
|
|
# because of the new/old map header parsing routines
|
|
self.assertEqual(counter, 388 * 2)
|
|
parse_map_header_at = temp
|
|
old_parse_map_header_at = temp2
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|