diff --git a/ext/src/save.rs b/ext/src/save.rs index 5e82648df..4bb2ee3dd 100644 --- a/ext/src/save.rs +++ b/ext/src/save.rs @@ -1,7 +1,7 @@ use std::{ collections::HashSet, fs::File, - io::{BufRead, BufReader, BufWriter, Error, ErrorKind, Read, Write}, + io::{BufRead, BufReader, BufWriter, Cursor, Error, ErrorKind, Read, Write}, path::{Path, PathBuf}, sync::{Arc, Mutex}, }; @@ -300,8 +300,35 @@ impl std::ops::Drop for ChecksumReader { } /// Open the file pointed to by `path` as a `Box`. If the file does not exist, look for -/// compressed versions. -fn open_file(path: PathBuf) -> Option> { +/// compressed versions. If `early_check` is true, we check the checksum before returning the file. +fn open_file(path: PathBuf, early_check: bool) -> Option> { + fn do_early_check(path: PathBuf, mut reader: impl Read) -> Option> { + let mut file_contents = Vec::new(); + let num_bytes = std::io::copy(&mut reader, &mut file_contents) + .unwrap_or_else(|e| panic!("Error when reading from {path:?}: {e}")); + if num_bytes < 4 { + tracing::warn!("File {path:?} is too short to contain a checksum. Deleting file."); + std::fs::remove_file(&path) + .unwrap_or_else(|e| panic!("Error when deleting {path:?}: {e}")); + return None; + } + + let checksum_pos = num_bytes as usize - 4; + let (content_bytes, mut checksum_bytes) = file_contents.split_at(checksum_pos); + let mut adler = adler::Adler32::new(); + adler.write_slice(content_bytes); + let checksum = checksum_bytes.read_u32::().unwrap(); + + if adler.checksum() == checksum { + Some(Box::new(Cursor::new(file_contents))) + } else { + tracing::warn!("Checksum mismatch for {path:?}. Deleting file."); + std::fs::remove_file(&path) + .unwrap_or_else(|e| panic!("Error when deleting {path:?}: {e}")); + None + } + } + // We should try in decreasing order of access speed. match File::open(&path) { Ok(f) => { @@ -316,7 +343,11 @@ fn open_file(path: PathBuf) -> Option> { .unwrap_or_else(|e| panic!("Error when deleting empty file {path:?}: {e}")); return None; } - return Some(Box::new(ChecksumReader::new(reader))); + return if early_check { + do_early_check(path, reader) + } else { + Some(Box::new(ChecksumReader::new(reader))) + }; } Err(e) => { if e.kind() != ErrorKind::NotFound { @@ -331,9 +362,12 @@ fn open_file(path: PathBuf) -> Option> { path.set_extension("zst"); match File::open(&path) { Ok(f) => { - return Some(Box::new(ChecksumReader::new( - zstd::stream::Decoder::new(f).unwrap(), - ))) + let reader = zstd::stream::Decoder::new(f).unwrap(); + return if early_check { + do_early_check(path, reader) + } else { + Some(Box::new(ChecksumReader::new(reader))) + }; } Err(e) => { if e.kind() != ErrorKind::NotFound { @@ -399,6 +433,17 @@ impl SaveFile { Ok(()) } + /// Whether we should load the file in memory and check the checksum before returning it. This + /// only returns false for quasi-inverses because they are our largest files by far. This is a + /// function of `SaveFile` and not just `SaveKind` because we may want to change the behavior + /// depending on the stem or some other heuristic. + fn should_check_early(&self) -> bool { + !matches!( + self.kind, + SaveKind::AugmentationQi | SaveKind::NassauQi | SaveKind::ResQi + ) + } + /// This panics if there is no save dir fn get_save_path(&self, mut dir: PathBuf) -> PathBuf { if let Some(idx) = self.idx { @@ -422,7 +467,7 @@ impl SaveFile { pub fn open_file(&self, dir: PathBuf) -> Option> { let file_path = self.get_save_path(dir); let path_string = file_path.to_string_lossy().into_owned(); - if let Some(mut f) = open_file(file_path) { + if let Some(mut f) = open_file(file_path, self.should_check_early()) { self.validate_header(&mut f).unwrap(); tracing::info!("success open_read: {}", path_string); Some(f) diff --git a/ext/tests/save_load_resolution.rs b/ext/tests/save_load_resolution.rs index 556161514..87a94062a 100644 --- a/ext/tests/save_load_resolution.rs +++ b/ext/tests/save_load_resolution.rs @@ -275,8 +275,7 @@ fn test_load_secondary() { } #[test] -#[should_panic(expected = "Invalid file checksum")] -fn test_checksum() { +fn test_checksum_early() { use std::{ fs::OpenOptions, io::{Seek, SeekFrom, Write}, @@ -300,6 +299,73 @@ fn test_checksum() { file.seek(SeekFrom::Start(41)).unwrap(); file.write_all(&[1]).unwrap(); + // Differentials are checked early for integrity, and silently replaced if they are malformed + construct_standard::("S_2", Some(tempdir.path().into())) + .unwrap() + .compute_through_bidegree(Bidegree::s_t(2, 2)); +} + +#[test] +#[should_panic(expected = "Error when deleting")] +fn test_checksum_early_locked() { + use std::{ + fs::OpenOptions, + io::{Seek, SeekFrom, Write}, + }; + + let tempdir = tempfile::TempDir::new().unwrap(); + + construct_standard::("S_2", Some(tempdir.path().into())) + .unwrap() + .compute_through_bidegree(Bidegree::s_t(2, 2)); + + let mut path = tempdir.path().to_owned(); + path.push("differentials/2_2_differential"); + + let mut file = OpenOptions::new() + .read(true) + .write(true) + .open(path) + .unwrap(); + + file.seek(SeekFrom::Start(41)).unwrap(); + file.write_all(&[1]).unwrap(); + + lock_tempdir(tempdir.path()); + + // This should try to delete the file and panic + construct_standard::("S_2", Some(tempdir.path().into())) + .unwrap() + .compute_through_bidegree(Bidegree::s_t(2, 2)); +} + +#[test] +#[should_panic(expected = "Invalid file checksum")] +fn test_checksum_late() { + use std::{ + fs::OpenOptions, + io::{Seek, SeekFrom, Write}, + }; + + let tempdir = tempfile::TempDir::new().unwrap(); + + construct_standard::("S_2", Some(tempdir.path().into())) + .unwrap() + .compute_through_bidegree(Bidegree::s_t(2, 2)); + + let mut path = tempdir.path().to_owned(); + path.push("res_qis/1_2_res_qi"); + + let mut file = OpenOptions::new() + .read(true) + .write(true) + .open(path) + .unwrap(); + + file.seek(SeekFrom::Start(41)).unwrap(); + file.write_all(&[1]).unwrap(); + + // Quasi-inverses are checked after using them, and we panic if the check fails construct_standard::("S_2", Some(tempdir.path().into())) .unwrap() .compute_through_bidegree(Bidegree::s_t(2, 2));