(vibe coded) rust lz77 using pyo3

This commit is contained in:
Will Toohey 2026-04-25 22:57:42 +10:00
parent fbe5ee17ec
commit d3d377e35c
8 changed files with 616 additions and 122 deletions

3
.gitignore vendored
View File

@ -8,3 +8,6 @@ ifstools.egg-info/
venv/
ifstools.spec
/ifstools-*/
target/
*.so
*.dll

133
Cargo.lock generated Normal file
View File

@ -0,0 +1,133 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "ifstools-native"
version = "0.1.0"
dependencies = [
"pyo3",
]
[[package]]
name = "libc"
version = "0.2.186"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
[[package]]
name = "once_cell"
version = "1.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]]
name = "proc-macro2"
version = "1.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
dependencies = [
"unicode-ident",
]
[[package]]
name = "pyo3"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12"
dependencies = [
"libc",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
]
[[package]]
name = "pyo3-build-config"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e"
dependencies = [
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn",
]
[[package]]
name = "quote"
version = "1.0.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
dependencies = [
"proc-macro2",
]
[[package]]
name = "syn"
version = "2.0.117"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "target-lexicon"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
[[package]]
name = "unicode-ident"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"

16
Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "ifstools-native"
version = "0.1.0"
edition = "2021"
[lib]
name = "_lz77_native"
crate-type = ["cdylib", "rlib"]
path = "rust/lib.rs"
[dependencies]
pyo3 = { version = "0.28", features = ["extension-module", "abi3-py310"] }
[profile.release]
lto = "fat"
codegen-units = 1

View File

@ -1,7 +1,10 @@
[build-system]
requires = ["maturin>=1.7,<2"]
build-backend = "maturin"
[project]
name = "ifstools"
version = "1.5"
version = "2.0"
description = "Extractor/repacker for Konmai IFS files"
readme = "README.md"
authors = [
@ -21,6 +24,7 @@ Homepage = "https://github.com/mon/ifstools/"
[project.scripts]
ifstools = "ifstools:main"
[build-system]
requires = ["setuptools>=78"]
build-backend = "setuptools.build_meta"
[tool.maturin]
python-source = "src"
module-name = "ifstools.handlers._lz77_native"
features = ["pyo3/extension-module"]

34
rust/lib.rs Normal file
View File

@ -0,0 +1,34 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
mod lz77;
#[pyfunction]
#[pyo3(name = "decompress")]
fn py_decompress<'py>(py: Python<'py>, data: Vec<u8>) -> PyResult<Bound<'py, PyBytes>> {
let out = py
.detach(|| lz77::decompress(&data))
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(PyBytes::new(py, &out))
}
#[pyfunction]
#[pyo3(name = "compress", signature = (data, progress=false))]
fn py_compress<'py>(
py: Python<'py>,
data: Vec<u8>,
progress: bool,
) -> PyResult<Bound<'py, PyBytes>> {
let _ = progress; // Pure-Python signature parity; matcher itself is silent.
let out = py.detach(|| lz77::compress(&data));
Ok(PyBytes::new(py, &out))
}
#[pymodule]
#[pyo3(name = "_lz77_native")]
fn _lz77_native(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_decompress, m)?)?;
m.add_function(wrap_pyfunction!(py_compress, m)?)?;
Ok(())
}

297
rust/lz77.rs Normal file
View File

@ -0,0 +1,297 @@
//! IFS / AVS Konami LZSS (Okumura-style).
//!
//! Format: 4 KB window, 3..18 byte match length, 8-codes-per-flag-byte framing.
//! Match token is a big-endian u16: top 12 bits = back-distance, bottom 4 = (length - 3).
//! Distance == 0 is the EOS sentinel. Distances point into a virtual zero-prefilled
//! window before the start of the stream.
const WINDOW: usize = 0x1000;
const MAX_DIST: usize = WINDOW - 1; // 4095, since 0 is the EOS sentinel
const F: usize = 18; // max match length
const THRESHOLD: usize = 3; // min match length
const HASH_BITS: u32 = 13;
const HASH_SIZE: usize = 1 << HASH_BITS;
const HASH_MASK: u32 = HASH_SIZE as u32 - 1;
const NIL: u32 = u32::MAX;
#[derive(Debug)]
pub enum DecompressError {
Truncated,
}
impl std::fmt::Display for DecompressError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DecompressError::Truncated => write!(f, "truncated lz77 stream"),
}
}
}
impl std::error::Error for DecompressError {}
pub fn decompress(input: &[u8]) -> Result<Vec<u8>, DecompressError> {
let mut out: Vec<u8> = Vec::with_capacity(input.len() * 2);
let mut i = 0usize;
loop {
if i >= input.len() {
return Err(DecompressError::Truncated);
}
let flag = input[i];
i += 1;
for bit in 0..8 {
if (flag >> bit) & 1 == 1 {
if i >= input.len() {
return Err(DecompressError::Truncated);
}
out.push(input[i]);
i += 1;
} else {
if i + 1 >= input.len() {
return Err(DecompressError::Truncated);
}
let w = u16::from_be_bytes([input[i], input[i + 1]]);
i += 2;
let pos = (w >> 4) as usize;
let mut len = (w & 0x0F) as usize + THRESHOLD;
if pos == 0 {
return Ok(out);
}
// References into the virtual zero-prefilled window before stream start.
if pos > out.len() {
let diff = (pos - out.len()).min(len);
out.extend(std::iter::repeat(0u8).take(diff));
len -= diff;
}
// Self-overlapping copy: each output byte may feed the next.
let start = out.len() - pos;
for k in 0..len {
let b = out[start + k];
out.push(b);
}
}
}
}
}
/// Hash a 3-byte prefix into the chain table.
#[inline(always)]
fn hash3(a: u8, b: u8, c: u8) -> usize {
let h = ((a as u32) << 16) | ((b as u32) << 8) | (c as u32);
let h = h.wrapping_mul(0x1e35a7bd);
((h >> (32 - HASH_BITS)) & HASH_MASK) as usize
}
/// Compress a buffer. Output is byte-identical to the reference Python encoder
/// when the matcher decisions agree (greedy longest-match here vs. greedy
/// longest-match there).
pub fn compress(input: &[u8]) -> Vec<u8> {
// Pre-pend a 4 KB zero window so matches at the start can legitimately
// reference into the zero-prefilled history that the decoder synthesises.
let mut buf = vec![0u8; WINDOW];
buf.extend_from_slice(input);
let buf_len = buf.len();
let mut head = vec![NIL; HASH_SIZE];
let mut prev = vec![NIL; WINDOW];
// Seed the hash chain with the zero prefix so the encoder can find matches
// pointing into it. Stop before the input cursor and respect buf_len so we
// never index past the end on tiny inputs.
let seed_limit = (WINDOW - 1).min(buf_len.saturating_sub(2));
for p in 0..seed_limit {
let h = hash3(buf[p], buf[p + 1], buf[p + 2]);
prev[p & (WINDOW - 1)] = head[h];
head[h] = p as u32;
}
let mut out: Vec<u8> = Vec::with_capacity(input.len());
let mut pos = WINDOW;
while pos < buf_len {
let mut flag_byte: u8 = 0;
let mut group: Vec<u8> = Vec::with_capacity(8 * 2);
for bit in 0..8 {
if pos >= buf_len {
// Out of input mid-group: leave the bit as match (0). The
// decoder will read into our trailing 0x00 0x00 0x00 sentinel
// and exit on the position == 0 check before reaching this
// phantom slot.
continue;
}
let (best_len, best_dist) = find_match(&buf, pos, &head, &prev);
if best_len >= THRESHOLD {
let dist = best_dist as u16;
let info: u16 = (dist << 4) | ((best_len - THRESHOLD) as u16);
group.extend_from_slice(&info.to_be_bytes());
// Insert hash entries for every position covered by the match,
// including the start (which we are about to skip past).
for k in 0..best_len {
let p = pos + k;
if p + 2 < buf_len {
let h = hash3(buf[p], buf[p + 1], buf[p + 2]);
prev[p & (WINDOW - 1)] = head[h];
head[h] = p as u32;
}
}
pos += best_len;
} else {
group.push(buf[pos]);
flag_byte |= 1 << bit;
if pos + 2 < buf_len {
let h = hash3(buf[pos], buf[pos + 1], buf[pos + 2]);
prev[pos & (WINDOW - 1)] = head[h];
head[h] = pos as u32;
}
pos += 1;
}
}
out.push(flag_byte);
out.extend_from_slice(&group);
}
// EOS sentinel: flag byte saying "next code is a match", then a 0x0000
// match token (distance == 0 → decoder returns).
out.push(0);
out.push(0);
out.push(0);
out
}
#[inline]
fn find_match(buf: &[u8], pos: usize, head: &[u32], prev: &[u32]) -> (usize, usize) {
let buf_len = buf.len();
if pos + THRESHOLD > buf_len {
return (0, 0);
}
let max_len = (buf_len - pos).min(F);
if max_len < THRESHOLD {
return (0, 0);
}
let h = hash3(buf[pos], buf[pos + 1], buf[pos + 2]);
let mut candidate = head[h];
let limit = pos.saturating_sub(MAX_DIST);
let mut best_len = 0usize;
let mut best_dist = 0usize;
// Bound on chain depth — at N=4096 chains are short anyway, but cap to keep
// worst-case behaviour bounded on highly repetitive input.
let mut chain_remaining: u32 = 4096;
while candidate != NIL {
let cand = candidate as usize;
if cand < limit {
break;
}
chain_remaining -= 1;
// Quick reject: if the byte at best_len doesn't match, skip.
if best_len > 0 && buf[cand + best_len] != buf[pos + best_len] {
// fall through to chain step
} else {
// Compare bytes up to max_len.
let mut len = 0usize;
while len < max_len && buf[cand + len] == buf[pos + len] {
len += 1;
}
if len > best_len {
best_len = len;
best_dist = pos - cand;
if best_len >= max_len {
break;
}
}
}
if chain_remaining == 0 {
break;
}
candidate = prev[cand & (WINDOW - 1)];
}
(best_len, best_dist)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_small() {
let data = b"hello hello hello world! world world world world".to_vec();
let comp = compress(&data);
let decomp = decompress(&comp).unwrap();
assert_eq!(decomp, data);
}
#[test]
fn roundtrip_empty() {
let comp = compress(&[]);
let decomp = decompress(&comp).unwrap();
assert_eq!(decomp, Vec::<u8>::new());
}
#[test]
fn roundtrip_short() {
let data = b"abc".to_vec();
let comp = compress(&data);
let decomp = decompress(&comp).unwrap();
assert_eq!(decomp, data);
}
#[test]
fn roundtrip_repetitive() {
let data = vec![0xAAu8; 10_000];
let comp = compress(&data);
let decomp = decompress(&comp).unwrap();
assert_eq!(decomp, data);
}
#[test]
fn roundtrip_random_ish() {
// Pseudo-random but reproducible.
let mut data = vec![0u8; 50_000];
let mut x: u32 = 0x12345678;
for b in data.iter_mut() {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*b = (x >> 16) as u8;
}
let comp = compress(&data);
let decomp = decompress(&comp).unwrap();
assert_eq!(decomp, data);
}
#[test]
fn known_test_vector() {
// The decoder fixture from _lz77.py main block.
let test = [
0x88, 0x46, 0x23, 0x20, 0x00, 0x20, 0x47, 0x20, 0x41, 0x00, 0x10, 0xa2, 0x47,
0x01, 0xa0, 0x45, 0x20, 0x44, 0x00, 0x08, 0x45, 0x01, 0x50, 0x79, 0x00, 0xc0,
0x45, 0x20, 0x05, 0x24, 0x13, 0x88, 0x05, 0xb4, 0x02, 0x4a, 0x44, 0xef, 0x03,
0x58, 0x02, 0x8c, 0x09, 0x16, 0x01, 0x48, 0x45, 0x00, 0xbe, 0x00, 0x9e, 0x00,
0x04, 0x01, 0x18, 0x90, 0x00, 0x00,
];
let out = decompress(&test).unwrap();
// Round-trip through our own encoder to confirm the encoder emits a
// legal stream (not necessarily byte-identical bits).
let recomp = compress(&out);
let redec = decompress(&recomp).unwrap();
assert_eq!(redec, out);
}
}

View File

@ -0,0 +1,119 @@
# consistency with py 2/3
from builtins import bytes
from io import BytesIO
from struct import pack, unpack
from tqdm import tqdm
WINDOW_SIZE = 0x1000
WINDOW_MASK = WINDOW_SIZE - 1
THRESHOLD = 3
INPLACE_THRESHOLD = 0xA
LOOK_RANGE = 0x200
MAX_LEN = 0xF + THRESHOLD
def decompress(input):
input = BytesIO(input)
decompressed = bytearray()
while True:
# wrap in bytes for py2
flag = bytes(input.read(1))[0]
for i in range(8):
if (flag >> i) & 1 == 1:
decompressed.append(input.read(1)[0])
else:
w = unpack('>H', input.read(2))[0]
position = (w >> 4)
length = (w & 0x0F) + THRESHOLD
if position == 0:
return bytes(decompressed)
if position > len(decompressed):
diff = position - len(decompressed)
diff = min(diff, length)
decompressed.extend([0]*diff)
length -= diff
# optimise
if -position+length < 0:
decompressed.extend(decompressed[-position:-position+length])
else:
for loop in range(length):
decompressed.append(decompressed[-position])
def match_window(in_data, offset):
'''Find the longest match for the string starting at offset in the preceeding data
'''
window_start = max(offset - WINDOW_MASK, 0)
for n in range(MAX_LEN, THRESHOLD-1, -1):
window_end = min(offset + n, len(in_data))
# we've not got enough data left for a meaningful result
if window_end - offset < THRESHOLD:
return None
str_to_find = in_data[offset:window_end]
idx = in_data.rfind(str_to_find, window_start, window_end-n)
if idx != -1:
code_offset = offset - idx # - 1
code_len = len(str_to_find)
return (code_offset, code_len)
return None
def compress(input, progress = False):
pbar = tqdm(total = len(input), leave = False, unit = 'b', unit_scale = True,
desc = 'Compressing', disable = not progress)
compressed = bytearray()
input = bytes([0]*WINDOW_SIZE) + bytes(input)
input_size = len(input)
current_pos = WINDOW_SIZE
bit = 0
while current_pos < input_size:
flag_byte = 0;
buf = bytearray()
for _ in range(8):
if current_pos >= input_size:
bit = 0;
else:
match = match_window(input, current_pos)
if match:
pos, length = match
info = (pos << 4) | ((length - THRESHOLD) & 0x0F)
buf.extend(pack('>H', info))
bit = 0
current_pos += length
pbar.update(length)
else:
buf.append(input[current_pos])
current_pos += 1
pbar.update(1)
bit = 1
flag_byte = (flag_byte >> 1) | ((bit & 1) << 7)
compressed.append(flag_byte)
compressed.extend(buf)
compressed.append(0)
compressed.append(0)
compressed.append(0)
pbar.close()
return bytes(compressed)
def compress_dummy(input):
input_length = len(input)
compressed = bytearray()
extra_bytes = input_length % 8
for i in range(0, input_length-extra_bytes, 8):
compressed.append(0xFF)
compressed.extend(input[i:i+8])
if extra_bytes > 0:
compressed.append(0xFF >> (8 - extra_bytes))
compressed.extend(input[-extra_bytes:])
compressed.append(0)
compressed.append(0)
compressed.append(0)
return bytes(compressed)

View File

@ -1,119 +1,7 @@
# consistency with py 2/3
from builtins import bytes
from io import BytesIO
from struct import pack, unpack
try:
from ._lz77_native import compress, decompress
except ImportError:
print("WARNING: using native-python LZ77, operations will be slow")
from ._lz77_py import compress, decompress
from tqdm import tqdm
WINDOW_SIZE = 0x1000
WINDOW_MASK = WINDOW_SIZE - 1
THRESHOLD = 3
INPLACE_THRESHOLD = 0xA
LOOK_RANGE = 0x200
MAX_LEN = 0xF + THRESHOLD
def decompress(input):
input = BytesIO(input)
decompressed = bytearray()
while True:
# wrap in bytes for py2
flag = bytes(input.read(1))[0]
for i in range(8):
if (flag >> i) & 1 == 1:
decompressed.append(input.read(1)[0])
else:
w = unpack('>H', input.read(2))[0]
position = (w >> 4)
length = (w & 0x0F) + THRESHOLD
if position == 0:
return bytes(decompressed)
if position > len(decompressed):
diff = position - len(decompressed)
diff = min(diff, length)
decompressed.extend([0]*diff)
length -= diff
# optimise
if -position+length < 0:
decompressed.extend(decompressed[-position:-position+length])
else:
for loop in range(length):
decompressed.append(decompressed[-position])
def match_window(in_data, offset):
'''Find the longest match for the string starting at offset in the preceeding data
'''
window_start = max(offset - WINDOW_MASK, 0)
for n in range(MAX_LEN, THRESHOLD-1, -1):
window_end = min(offset + n, len(in_data))
# we've not got enough data left for a meaningful result
if window_end - offset < THRESHOLD:
return None
str_to_find = in_data[offset:window_end]
idx = in_data.rfind(str_to_find, window_start, window_end-n)
if idx != -1:
code_offset = offset - idx # - 1
code_len = len(str_to_find)
return (code_offset, code_len)
return None
def compress(input, progress = False):
pbar = tqdm(total = len(input), leave = False, unit = 'b', unit_scale = True,
desc = 'Compressing', disable = not progress)
compressed = bytearray()
input = bytes([0]*WINDOW_SIZE) + bytes(input)
input_size = len(input)
current_pos = WINDOW_SIZE
bit = 0
while current_pos < input_size:
flag_byte = 0;
buf = bytearray()
for _ in range(8):
if current_pos >= input_size:
bit = 0;
else:
match = match_window(input, current_pos)
if match:
pos, length = match
info = (pos << 4) | ((length - THRESHOLD) & 0x0F)
buf.extend(pack('>H', info))
bit = 0
current_pos += length
pbar.update(length)
else:
buf.append(input[current_pos])
current_pos += 1
pbar.update(1)
bit = 1
flag_byte = (flag_byte >> 1) | ((bit & 1) << 7)
compressed.append(flag_byte)
compressed.extend(buf)
compressed.append(0)
compressed.append(0)
compressed.append(0)
pbar.close()
return bytes(compressed)
def compress_dummy(input):
input_length = len(input)
compressed = bytearray()
extra_bytes = input_length % 8
for i in range(0, input_length-extra_bytes, 8):
compressed.append(0xFF)
compressed.extend(input[i:i+8])
if extra_bytes > 0:
compressed.append(0xFF >> (8 - extra_bytes))
compressed.extend(input[-extra_bytes:])
compressed.append(0)
compressed.append(0)
compressed.append(0)
return bytes(compressed)
__all__ = ["compress", "decompress"]