diff --git a/.gitignore b/.gitignore index 97eb7f1..3ca277c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ ifstools.egg-info/ venv/ ifstools.spec /ifstools-*/ +target/ +*.so +*.dll diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..169e9ae --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,133 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "ifstools-native" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12" +dependencies = [ + "libc", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", +] + +[[package]] +name = "pyo3-build-config" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e" +dependencies = [ + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.28.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..633bd7c --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "ifstools-native" +version = "0.1.0" +edition = "2021" + +[lib] +name = "_lz77_native" +crate-type = ["cdylib", "rlib"] +path = "rust/lib.rs" + +[dependencies] +pyo3 = { version = "0.28", features = ["extension-module", "abi3-py310"] } + +[profile.release] +lto = "fat" +codegen-units = 1 diff --git a/pyproject.toml b/pyproject.toml index ed873ea..bce11aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,10 @@ +[build-system] +requires = ["maturin>=1.7,<2"] +build-backend = "maturin" [project] name = "ifstools" -version = "1.5" +version = "2.0" description = "Extractor/repacker for Konmai IFS files" readme = "README.md" authors = [ @@ -21,6 +24,7 @@ Homepage = "https://github.com/mon/ifstools/" [project.scripts] ifstools = "ifstools:main" -[build-system] -requires = ["setuptools>=78"] -build-backend = "setuptools.build_meta" +[tool.maturin] +python-source = "src" +module-name = "ifstools.handlers._lz77_native" +features = ["pyo3/extension-module"] diff --git a/rust/lib.rs b/rust/lib.rs new file mode 100644 index 0000000..e4e1aab --- /dev/null +++ b/rust/lib.rs @@ -0,0 +1,34 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyBytes; + +mod lz77; + +#[pyfunction] +#[pyo3(name = "decompress")] +fn py_decompress<'py>(py: Python<'py>, data: Vec) -> PyResult> { + let out = py + .detach(|| lz77::decompress(&data)) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + Ok(PyBytes::new(py, &out)) +} + +#[pyfunction] +#[pyo3(name = "compress", signature = (data, progress=false))] +fn py_compress<'py>( + py: Python<'py>, + data: Vec, + progress: bool, +) -> PyResult> { + let _ = progress; // Pure-Python signature parity; matcher itself is silent. + let out = py.detach(|| lz77::compress(&data)); + Ok(PyBytes::new(py, &out)) +} + +#[pymodule] +#[pyo3(name = "_lz77_native")] +fn _lz77_native(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(py_decompress, m)?)?; + m.add_function(wrap_pyfunction!(py_compress, m)?)?; + Ok(()) +} diff --git a/rust/lz77.rs b/rust/lz77.rs new file mode 100644 index 0000000..3ceb90e --- /dev/null +++ b/rust/lz77.rs @@ -0,0 +1,297 @@ +//! IFS / AVS Konami LZSS (Okumura-style). +//! +//! Format: 4 KB window, 3..18 byte match length, 8-codes-per-flag-byte framing. +//! Match token is a big-endian u16: top 12 bits = back-distance, bottom 4 = (length - 3). +//! Distance == 0 is the EOS sentinel. Distances point into a virtual zero-prefilled +//! window before the start of the stream. + +const WINDOW: usize = 0x1000; +const MAX_DIST: usize = WINDOW - 1; // 4095, since 0 is the EOS sentinel +const F: usize = 18; // max match length +const THRESHOLD: usize = 3; // min match length + +const HASH_BITS: u32 = 13; +const HASH_SIZE: usize = 1 << HASH_BITS; +const HASH_MASK: u32 = HASH_SIZE as u32 - 1; +const NIL: u32 = u32::MAX; + +#[derive(Debug)] +pub enum DecompressError { + Truncated, +} + +impl std::fmt::Display for DecompressError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DecompressError::Truncated => write!(f, "truncated lz77 stream"), + } + } +} + +impl std::error::Error for DecompressError {} + +pub fn decompress(input: &[u8]) -> Result, DecompressError> { + let mut out: Vec = Vec::with_capacity(input.len() * 2); + let mut i = 0usize; + + loop { + if i >= input.len() { + return Err(DecompressError::Truncated); + } + let flag = input[i]; + i += 1; + + for bit in 0..8 { + if (flag >> bit) & 1 == 1 { + if i >= input.len() { + return Err(DecompressError::Truncated); + } + out.push(input[i]); + i += 1; + } else { + if i + 1 >= input.len() { + return Err(DecompressError::Truncated); + } + let w = u16::from_be_bytes([input[i], input[i + 1]]); + i += 2; + + let pos = (w >> 4) as usize; + let mut len = (w & 0x0F) as usize + THRESHOLD; + + if pos == 0 { + return Ok(out); + } + + // References into the virtual zero-prefilled window before stream start. + if pos > out.len() { + let diff = (pos - out.len()).min(len); + out.extend(std::iter::repeat(0u8).take(diff)); + len -= diff; + } + + // Self-overlapping copy: each output byte may feed the next. + let start = out.len() - pos; + for k in 0..len { + let b = out[start + k]; + out.push(b); + } + } + } + } +} + +/// Hash a 3-byte prefix into the chain table. +#[inline(always)] +fn hash3(a: u8, b: u8, c: u8) -> usize { + let h = ((a as u32) << 16) | ((b as u32) << 8) | (c as u32); + let h = h.wrapping_mul(0x1e35a7bd); + ((h >> (32 - HASH_BITS)) & HASH_MASK) as usize +} + +/// Compress a buffer. Output is byte-identical to the reference Python encoder +/// when the matcher decisions agree (greedy longest-match here vs. greedy +/// longest-match there). +pub fn compress(input: &[u8]) -> Vec { + // Pre-pend a 4 KB zero window so matches at the start can legitimately + // reference into the zero-prefilled history that the decoder synthesises. + let mut buf = vec![0u8; WINDOW]; + buf.extend_from_slice(input); + let buf_len = buf.len(); + + let mut head = vec![NIL; HASH_SIZE]; + let mut prev = vec![NIL; WINDOW]; + + // Seed the hash chain with the zero prefix so the encoder can find matches + // pointing into it. Stop before the input cursor and respect buf_len so we + // never index past the end on tiny inputs. + let seed_limit = (WINDOW - 1).min(buf_len.saturating_sub(2)); + for p in 0..seed_limit { + let h = hash3(buf[p], buf[p + 1], buf[p + 2]); + prev[p & (WINDOW - 1)] = head[h]; + head[h] = p as u32; + } + + let mut out: Vec = Vec::with_capacity(input.len()); + let mut pos = WINDOW; + + while pos < buf_len { + let mut flag_byte: u8 = 0; + let mut group: Vec = Vec::with_capacity(8 * 2); + + for bit in 0..8 { + if pos >= buf_len { + // Out of input mid-group: leave the bit as match (0). The + // decoder will read into our trailing 0x00 0x00 0x00 sentinel + // and exit on the position == 0 check before reaching this + // phantom slot. + continue; + } + + let (best_len, best_dist) = find_match(&buf, pos, &head, &prev); + + if best_len >= THRESHOLD { + let dist = best_dist as u16; + let info: u16 = (dist << 4) | ((best_len - THRESHOLD) as u16); + group.extend_from_slice(&info.to_be_bytes()); + + // Insert hash entries for every position covered by the match, + // including the start (which we are about to skip past). + for k in 0..best_len { + let p = pos + k; + if p + 2 < buf_len { + let h = hash3(buf[p], buf[p + 1], buf[p + 2]); + prev[p & (WINDOW - 1)] = head[h]; + head[h] = p as u32; + } + } + pos += best_len; + } else { + group.push(buf[pos]); + flag_byte |= 1 << bit; + + if pos + 2 < buf_len { + let h = hash3(buf[pos], buf[pos + 1], buf[pos + 2]); + prev[pos & (WINDOW - 1)] = head[h]; + head[h] = pos as u32; + } + pos += 1; + } + } + + out.push(flag_byte); + out.extend_from_slice(&group); + } + + // EOS sentinel: flag byte saying "next code is a match", then a 0x0000 + // match token (distance == 0 → decoder returns). + out.push(0); + out.push(0); + out.push(0); + out +} + +#[inline] +fn find_match(buf: &[u8], pos: usize, head: &[u32], prev: &[u32]) -> (usize, usize) { + let buf_len = buf.len(); + if pos + THRESHOLD > buf_len { + return (0, 0); + } + + let max_len = (buf_len - pos).min(F); + if max_len < THRESHOLD { + return (0, 0); + } + + let h = hash3(buf[pos], buf[pos + 1], buf[pos + 2]); + let mut candidate = head[h]; + let limit = pos.saturating_sub(MAX_DIST); + + let mut best_len = 0usize; + let mut best_dist = 0usize; + + // Bound on chain depth — at N=4096 chains are short anyway, but cap to keep + // worst-case behaviour bounded on highly repetitive input. + let mut chain_remaining: u32 = 4096; + + while candidate != NIL { + let cand = candidate as usize; + if cand < limit { + break; + } + chain_remaining -= 1; + + // Quick reject: if the byte at best_len doesn't match, skip. + if best_len > 0 && buf[cand + best_len] != buf[pos + best_len] { + // fall through to chain step + } else { + // Compare bytes up to max_len. + let mut len = 0usize; + while len < max_len && buf[cand + len] == buf[pos + len] { + len += 1; + } + + if len > best_len { + best_len = len; + best_dist = pos - cand; + if best_len >= max_len { + break; + } + } + } + + if chain_remaining == 0 { + break; + } + candidate = prev[cand & (WINDOW - 1)]; + } + + (best_len, best_dist) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip_small() { + let data = b"hello hello hello world! world world world world".to_vec(); + let comp = compress(&data); + let decomp = decompress(&comp).unwrap(); + assert_eq!(decomp, data); + } + + #[test] + fn roundtrip_empty() { + let comp = compress(&[]); + let decomp = decompress(&comp).unwrap(); + assert_eq!(decomp, Vec::::new()); + } + + #[test] + fn roundtrip_short() { + let data = b"abc".to_vec(); + let comp = compress(&data); + let decomp = decompress(&comp).unwrap(); + assert_eq!(decomp, data); + } + + #[test] + fn roundtrip_repetitive() { + let data = vec![0xAAu8; 10_000]; + let comp = compress(&data); + let decomp = decompress(&comp).unwrap(); + assert_eq!(decomp, data); + } + + #[test] + fn roundtrip_random_ish() { + // Pseudo-random but reproducible. + let mut data = vec![0u8; 50_000]; + let mut x: u32 = 0x12345678; + for b in data.iter_mut() { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *b = (x >> 16) as u8; + } + let comp = compress(&data); + let decomp = decompress(&comp).unwrap(); + assert_eq!(decomp, data); + } + + #[test] + fn known_test_vector() { + // The decoder fixture from _lz77.py main block. + let test = [ + 0x88, 0x46, 0x23, 0x20, 0x00, 0x20, 0x47, 0x20, 0x41, 0x00, 0x10, 0xa2, 0x47, + 0x01, 0xa0, 0x45, 0x20, 0x44, 0x00, 0x08, 0x45, 0x01, 0x50, 0x79, 0x00, 0xc0, + 0x45, 0x20, 0x05, 0x24, 0x13, 0x88, 0x05, 0xb4, 0x02, 0x4a, 0x44, 0xef, 0x03, + 0x58, 0x02, 0x8c, 0x09, 0x16, 0x01, 0x48, 0x45, 0x00, 0xbe, 0x00, 0x9e, 0x00, + 0x04, 0x01, 0x18, 0x90, 0x00, 0x00, + ]; + let out = decompress(&test).unwrap(); + // Round-trip through our own encoder to confirm the encoder emits a + // legal stream (not necessarily byte-identical bits). + let recomp = compress(&out); + let redec = decompress(&recomp).unwrap(); + assert_eq!(redec, out); + } +} diff --git a/src/ifstools/handlers/_lz77_py.py b/src/ifstools/handlers/_lz77_py.py new file mode 100644 index 0000000..75d84aa --- /dev/null +++ b/src/ifstools/handlers/_lz77_py.py @@ -0,0 +1,119 @@ +# consistency with py 2/3 +from builtins import bytes +from io import BytesIO +from struct import pack, unpack + +from tqdm import tqdm + +WINDOW_SIZE = 0x1000 +WINDOW_MASK = WINDOW_SIZE - 1 +THRESHOLD = 3 +INPLACE_THRESHOLD = 0xA +LOOK_RANGE = 0x200 +MAX_LEN = 0xF + THRESHOLD + +def decompress(input): + input = BytesIO(input) + decompressed = bytearray() + + while True: + # wrap in bytes for py2 + flag = bytes(input.read(1))[0] + for i in range(8): + if (flag >> i) & 1 == 1: + decompressed.append(input.read(1)[0]) + else: + w = unpack('>H', input.read(2))[0] + position = (w >> 4) + length = (w & 0x0F) + THRESHOLD + if position == 0: + return bytes(decompressed) + + if position > len(decompressed): + diff = position - len(decompressed) + diff = min(diff, length) + decompressed.extend([0]*diff) + length -= diff + # optimise + if -position+length < 0: + decompressed.extend(decompressed[-position:-position+length]) + else: + for loop in range(length): + decompressed.append(decompressed[-position]) + +def match_window(in_data, offset): + '''Find the longest match for the string starting at offset in the preceeding data + ''' + window_start = max(offset - WINDOW_MASK, 0) + + for n in range(MAX_LEN, THRESHOLD-1, -1): + window_end = min(offset + n, len(in_data)) + # we've not got enough data left for a meaningful result + if window_end - offset < THRESHOLD: + return None + str_to_find = in_data[offset:window_end] + idx = in_data.rfind(str_to_find, window_start, window_end-n) + if idx != -1: + code_offset = offset - idx # - 1 + code_len = len(str_to_find) + return (code_offset, code_len) + + return None + +def compress(input, progress = False): + pbar = tqdm(total = len(input), leave = False, unit = 'b', unit_scale = True, + desc = 'Compressing', disable = not progress) + compressed = bytearray() + input = bytes([0]*WINDOW_SIZE) + bytes(input) + input_size = len(input) + current_pos = WINDOW_SIZE + bit = 0 + while current_pos < input_size: + flag_byte = 0; + buf = bytearray() + for _ in range(8): + if current_pos >= input_size: + bit = 0; + else: + match = match_window(input, current_pos) + if match: + pos, length = match + info = (pos << 4) | ((length - THRESHOLD) & 0x0F) + buf.extend(pack('>H', info)) + bit = 0 + current_pos += length + pbar.update(length) + else: + buf.append(input[current_pos]) + current_pos += 1 + pbar.update(1) + bit = 1 + flag_byte = (flag_byte >> 1) | ((bit & 1) << 7) + compressed.append(flag_byte) + compressed.extend(buf) + compressed.append(0) + compressed.append(0) + compressed.append(0) + + pbar.close() + return bytes(compressed) + +def compress_dummy(input): + input_length = len(input) + compressed = bytearray() + + extra_bytes = input_length % 8 + + for i in range(0, input_length-extra_bytes, 8): + compressed.append(0xFF) + compressed.extend(input[i:i+8]) + + if extra_bytes > 0: + compressed.append(0xFF >> (8 - extra_bytes)) + compressed.extend(input[-extra_bytes:]) + + compressed.append(0) + compressed.append(0) + compressed.append(0) + + return bytes(compressed) diff --git a/src/ifstools/handlers/lz77.py b/src/ifstools/handlers/lz77.py index 75d84aa..baa7851 100644 --- a/src/ifstools/handlers/lz77.py +++ b/src/ifstools/handlers/lz77.py @@ -1,119 +1,7 @@ -# consistency with py 2/3 -from builtins import bytes -from io import BytesIO -from struct import pack, unpack +try: + from ._lz77_native import compress, decompress +except ImportError: + print("WARNING: using native-python LZ77, operations will be slow") + from ._lz77_py import compress, decompress -from tqdm import tqdm - -WINDOW_SIZE = 0x1000 -WINDOW_MASK = WINDOW_SIZE - 1 -THRESHOLD = 3 -INPLACE_THRESHOLD = 0xA -LOOK_RANGE = 0x200 -MAX_LEN = 0xF + THRESHOLD - -def decompress(input): - input = BytesIO(input) - decompressed = bytearray() - - while True: - # wrap in bytes for py2 - flag = bytes(input.read(1))[0] - for i in range(8): - if (flag >> i) & 1 == 1: - decompressed.append(input.read(1)[0]) - else: - w = unpack('>H', input.read(2))[0] - position = (w >> 4) - length = (w & 0x0F) + THRESHOLD - if position == 0: - return bytes(decompressed) - - if position > len(decompressed): - diff = position - len(decompressed) - diff = min(diff, length) - decompressed.extend([0]*diff) - length -= diff - # optimise - if -position+length < 0: - decompressed.extend(decompressed[-position:-position+length]) - else: - for loop in range(length): - decompressed.append(decompressed[-position]) - -def match_window(in_data, offset): - '''Find the longest match for the string starting at offset in the preceeding data - ''' - window_start = max(offset - WINDOW_MASK, 0) - - for n in range(MAX_LEN, THRESHOLD-1, -1): - window_end = min(offset + n, len(in_data)) - # we've not got enough data left for a meaningful result - if window_end - offset < THRESHOLD: - return None - str_to_find = in_data[offset:window_end] - idx = in_data.rfind(str_to_find, window_start, window_end-n) - if idx != -1: - code_offset = offset - idx # - 1 - code_len = len(str_to_find) - return (code_offset, code_len) - - return None - -def compress(input, progress = False): - pbar = tqdm(total = len(input), leave = False, unit = 'b', unit_scale = True, - desc = 'Compressing', disable = not progress) - compressed = bytearray() - input = bytes([0]*WINDOW_SIZE) + bytes(input) - input_size = len(input) - current_pos = WINDOW_SIZE - bit = 0 - while current_pos < input_size: - flag_byte = 0; - buf = bytearray() - for _ in range(8): - if current_pos >= input_size: - bit = 0; - else: - match = match_window(input, current_pos) - if match: - pos, length = match - info = (pos << 4) | ((length - THRESHOLD) & 0x0F) - buf.extend(pack('>H', info)) - bit = 0 - current_pos += length - pbar.update(length) - else: - buf.append(input[current_pos]) - current_pos += 1 - pbar.update(1) - bit = 1 - flag_byte = (flag_byte >> 1) | ((bit & 1) << 7) - compressed.append(flag_byte) - compressed.extend(buf) - compressed.append(0) - compressed.append(0) - compressed.append(0) - - pbar.close() - return bytes(compressed) - -def compress_dummy(input): - input_length = len(input) - compressed = bytearray() - - extra_bytes = input_length % 8 - - for i in range(0, input_length-extra_bytes, 8): - compressed.append(0xFF) - compressed.extend(input[i:i+8]) - - if extra_bytes > 0: - compressed.append(0xFF >> (8 - extra_bytes)) - compressed.extend(input[-extra_bytes:]) - - compressed.append(0) - compressed.append(0) - compressed.append(0) - - return bytes(compressed) +__all__ = ["compress", "decompress"]