1use std::{
4 fs, io,
5 path::{Path, PathBuf},
6};
7
8pub const DB_VERSION_FILE_NAME: &str = "database.version";
10pub const DB_VERSION: u64 = 2;
13
14#[derive(thiserror::Error, Debug)]
16pub enum DatabaseVersionError {
17 #[error("unable to determine the version of the database, file is missing")]
19 MissingFile,
20 #[error("unable to determine the version of the database, file is malformed")]
22 MalformedFile,
23 #[error(
27 "breaking database change detected: your database version (v{version}) \
28 is incompatible with the latest database version (v{DB_VERSION})"
29 )]
30 VersionMismatch {
31 version: u64,
33 },
34 #[error("IO error occurred while reading {path}: {err}")]
36 IORead {
37 err: io::Error,
39 path: PathBuf,
41 },
42}
43
44pub fn check_db_version_file<P: AsRef<Path>>(db_path: P) -> Result<(), DatabaseVersionError> {
49 let version = get_db_version(db_path)?;
50 if version != DB_VERSION {
51 return Err(DatabaseVersionError::VersionMismatch { version })
52 }
53
54 Ok(())
55}
56
57pub fn get_db_version<P: AsRef<Path>>(db_path: P) -> Result<u64, DatabaseVersionError> {
62 let version_file_path = db_version_file_path(db_path);
63 match fs::read_to_string(&version_file_path) {
64 Ok(raw_version) => {
65 Ok(raw_version.parse::<u64>().map_err(|_| DatabaseVersionError::MalformedFile)?)
66 }
67 Err(err) if err.kind() == io::ErrorKind::NotFound => Err(DatabaseVersionError::MissingFile),
68 Err(err) => Err(DatabaseVersionError::IORead { err, path: version_file_path }),
69 }
70}
71
72pub fn create_db_version_file<P: AsRef<Path>>(db_path: P) -> io::Result<()> {
78 fs::write(db_version_file_path(db_path), DB_VERSION.to_string())
79}
80
81pub fn db_version_file_path<P: AsRef<Path>>(db_path: P) -> PathBuf {
83 db_path.as_ref().join(DB_VERSION_FILE_NAME)
84}
85
86#[cfg(test)]
87mod tests {
88 use super::{check_db_version_file, db_version_file_path, DatabaseVersionError};
89 use assert_matches::assert_matches;
90 use std::fs;
91 use tempfile::tempdir;
92
93 #[test]
94 fn missing_file() {
95 let dir = tempdir().unwrap();
96
97 let result = check_db_version_file(&dir);
98 assert_matches!(result, Err(DatabaseVersionError::MissingFile));
99 }
100
101 #[test]
102 fn malformed_file() {
103 let dir = tempdir().unwrap();
104 fs::write(db_version_file_path(&dir), "invalid-version").unwrap();
105
106 let result = check_db_version_file(&dir);
107 assert_matches!(result, Err(DatabaseVersionError::MalformedFile));
108 }
109
110 #[test]
111 fn version_mismatch() {
112 let dir = tempdir().unwrap();
113 fs::write(db_version_file_path(&dir), "0").unwrap();
114
115 let result = check_db_version_file(&dir);
116 assert_matches!(result, Err(DatabaseVersionError::VersionMismatch { version: 0 }));
117 }
118}