(vibe coded) move dxt to rust as well

This commit is contained in:
Will Toohey 2026-04-26 00:10:41 +10:00
parent fb62a41730
commit aac8155bbe
5 changed files with 151 additions and 30 deletions

16
Cargo.lock generated
View File

@ -60,6 +60,7 @@ version = "0.1.0"
dependencies = [
"png",
"pyo3",
"texpresso",
]
[[package]]
@ -68,6 +69,12 @@ version = "0.2.186"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
[[package]]
name = "libm"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
@ -202,6 +209,15 @@ version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
[[package]]
name = "texpresso"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8550677e2259d675a7841cb1403db35f330cc9e58674c8c5caa12dd12c51dc71"
dependencies = [
"libm",
]
[[package]]
name = "unicode-ident"
version = "1.0.24"

View File

@ -11,6 +11,7 @@ path = "rust/lib.rs"
[dependencies]
png = "0.18.1"
pyo3 = { version = "0.28", features = ["extension-module", "abi3-py310"] }
texpresso = "2.0.2"
[profile.release]
lto = "fat"

113
rust/dxt.rs Normal file
View File

@ -0,0 +1,113 @@
//! DXT1 / DXT5 decoders for Konami's byte-swapped texture format.
//!
//! Konami stores standard DXT-compressed pixel data with each 16-bit word's
//! bytes swapped. Standard DDS/PIL/texpresso expect the canonical little-endian
//! layout, so we un-swap before handing the bytes off to texpresso.
use texpresso::Format;
#[derive(Debug)]
pub enum DxtError {
UnknownFormat(String),
SizeMismatch {
expected: usize,
got: usize,
format: &'static str,
},
OddByteCount(usize),
}
impl std::fmt::Display for DxtError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DxtError::UnknownFormat(s) => write!(f, "unknown DXT format: {}", s),
DxtError::SizeMismatch { expected, got, format } => write!(
f,
"{}: expected {} compressed bytes, got {}",
format, expected, got
),
DxtError::OddByteCount(n) => write!(f, "input length {} is not a multiple of 2", n),
}
}
}
impl std::error::Error for DxtError {}
fn parse_format(s: &str) -> Result<(Format, &'static str), DxtError> {
match s {
"dxt1" | "DXT1" | "bc1" => Ok((Format::Bc1, "DXT1")),
"dxt3" | "DXT3" | "bc2" => Ok((Format::Bc2, "DXT3")),
"dxt5" | "DXT5" | "bc3" => Ok((Format::Bc3, "DXT5")),
other => Err(DxtError::UnknownFormat(other.to_string())),
}
}
/// Decode Konami-format DXT data into raw RGBA8 pixels (`width * height * 4`
/// bytes). Pads short input with zero bytes to match the prior Python
/// behaviour, which warns and continues rather than failing on truncated
/// textures.
pub fn decode(
data: &[u8],
width: usize,
height: usize,
format: &str,
) -> Result<Vec<u8>, DxtError> {
let (fmt, fmt_name) = parse_format(format)?;
if data.len() % 2 != 0 {
return Err(DxtError::OddByteCount(data.len()));
}
let expected = fmt.compressed_size(width, height);
// Konami's per-WORD byte swap, with zero padding if the source is short.
let mut swapped = vec![0u8; expected];
let n = data.len().min(expected);
let pairs = n / 2;
for i in 0..pairs {
swapped[i * 2] = data[i * 2 + 1];
swapped[i * 2 + 1] = data[i * 2];
}
if n % 2 != 0 {
// Trailing odd byte (only possible when source is shorter than expected).
swapped[n - 1] = data[n - 1];
}
let _ = fmt_name; // currently only used for error formatting
let mut rgba = vec![0u8; width * height * 4];
fmt.decompress(&swapped, width, height, &mut rgba);
Ok(rgba)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dxt5_smoke() {
// Single 4x4 DXT5 block: alpha 0xFF, RGB midgray.
// Standard DXT5 layout (little-endian within blocks):
// alpha0 alpha1 [6 bytes alpha indices] [color block 8 bytes]
// We Konami-swap it (per u16) before feeding to decode().
let canonical = [
0xFFu8, 0xFF, // both alpha endpoints = 255
0, 0, 0, 0, 0, 0, // alpha indices (all zero → endpoint 0 = 255 everywhere)
0xFF, 0x7F, 0xFF, 0x7F, // both colors = 0x7FFF (ish gray)
0, 0, 0, 0,
];
let mut konami = canonical;
for chunk in konami.chunks_exact_mut(2) {
chunk.swap(0, 1);
}
let rgba = decode(&konami, 4, 4, "dxt5").unwrap();
assert_eq!(rgba.len(), 4 * 4 * 4);
// Alpha should be 0xFF everywhere.
for px in rgba.chunks_exact(4) {
assert_eq!(px[3], 0xFF, "alpha not opaque: {:?}", px);
}
}
#[test]
fn unknown_format_errors() {
assert!(decode(&[0; 16], 4, 4, "garbage").is_err());
}
}

View File

@ -2,6 +2,7 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
mod dxt;
mod lz77;
mod png_enc;
@ -41,11 +42,27 @@ fn py_encode_png<'py>(
Ok(PyBytes::new(py, &out))
}
#[pyfunction]
#[pyo3(name = "decode_dxt")]
fn py_decode_dxt<'py>(
py: Python<'py>,
data: Vec<u8>,
width: usize,
height: usize,
format: &str,
) -> PyResult<Bound<'py, PyBytes>> {
let out = py
.detach(|| dxt::decode(&data, width, height, format))
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(PyBytes::new(py, &out))
}
#[pymodule]
#[pyo3(name = "_native")]
fn _native(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_decompress, m)?)?;
m.add_function(wrap_pyfunction!(py_compress, m)?)?;
m.add_function(wrap_pyfunction!(py_encode_png, m)?)?;
m.add_function(wrap_pyfunction!(py_decode_dxt, m)?)?;
Ok(())
}

View File

@ -1,6 +1,4 @@
from array import array
from io import BytesIO
from struct import pack
from PIL import Image
from tqdm import tqdm
@ -24,20 +22,6 @@ def encode_png(im):
im = im.convert('RGBA')
return _native.encode_png(im.width, im.height, im.tobytes(), im.mode.lower())
# header for a standard DDS with DXT5 compression and RGBA pixels
# gap placed for image height/width insertion
dxt_start = b'DDS |\x00\x00\x00\x07\x10\x00\x00'
dxt_middle = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x04' + \
b'\x00\x00\x00'
dxt_end = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00' + \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
def check_size(ifs_img, data, bytes_per_pixel):
need = ifs_img.img_size[0] * ifs_img.img_size[1] * bytes_per_pixel
if len(data) < need:
@ -60,24 +44,14 @@ def decode_argb4444(ifs_img, data):
return Image.merge('RGBA', (b, g, r, a))
def decode_dxt(ifs_img, data, version):
b = BytesIO()
b.write(dxt_start)
b.write(pack('<2I', ifs_img.img_size[1], ifs_img.img_size[0]))
b.write(dxt_middle)
b.write(version)
b.write(dxt_end)
# the data has swapped endianness for every WORD
swapped = array('H')
swapped.frombytes(data)
swapped.byteswap()
b.write(swapped.tobytes())
return Image.open(b)
rgba = _native.decode_dxt(data, ifs_img.img_size[0], ifs_img.img_size[1], version)
return Image.frombytes('RGBA', ifs_img.img_size, rgba)
def decode_dxt5(ifs_img, data):
return decode_dxt(ifs_img, data, b'DXT5')
return decode_dxt(ifs_img, data, 'dxt5')
def decode_dxt1(ifs_img, data):
return decode_dxt(ifs_img, data, b'DXT1')
return decode_dxt(ifs_img, data, 'dxt1')
image_formats = {