1use core::hash::BuildHasher;
4use derive_more::{Deref, DerefMut};
5use itertools::Itertools;
6use schnellru::{ByLength, Limiter, RandomState, Unlimited};
7use std::{fmt, hash::Hash};
8
9pub struct LruCache<T: Hash + Eq + fmt::Debug> {
14 limit: u32,
15 inner: LruMap<T, ()>,
16}
17
18impl<T: Hash + Eq + fmt::Debug> LruCache<T> {
19 pub fn new(limit: u32) -> Self {
21 Self { inner: LruMap::new(limit + 1), limit }
24 }
25
26 pub fn insert(&mut self, entry: T) -> bool {
35 let (new_entry, _evicted_val) = self.insert_and_get_evicted(entry);
36 new_entry
37 }
38
39 pub fn insert_and_get_evicted(&mut self, entry: T) -> (bool, Option<T>) {
42 let new = self.inner.peek(&entry).is_none();
43 let evicted =
44 (new && (self.limit as usize) <= self.inner.len()).then(|| self.remove_lru()).flatten();
45 _ = self.inner.get_or_insert(entry, || ());
46
47 (new, evicted)
48 }
49
50 pub fn get(&mut self, entry: &T) -> Option<&T> {
52 let _ = self.inner.get(entry)?;
53 self.iter().next()
54 }
55
56 pub fn find(&self, entry: &T) -> Option<&T> {
63 self.iter().find(|key| *key == entry)
64 }
65
66 #[inline]
71 fn remove_lru(&mut self) -> Option<T> {
72 self.inner.pop_oldest().map(|(k, ())| k)
73 }
74
75 pub fn remove(&mut self, value: &T) -> bool {
77 self.inner.remove(value).is_some()
78 }
79
80 pub fn contains(&self, value: &T) -> bool {
82 self.inner.peek(value).is_some()
83 }
84
85 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
87 self.inner.iter().map(|(k, ())| k)
88 }
89
90 #[allow(dead_code)]
92 pub fn len(&self) -> usize {
93 self.inner.len()
94 }
95
96 #[allow(dead_code)]
98 pub fn is_empty(&self) -> bool {
99 self.inner.is_empty()
100 }
101}
102
103impl<T> Extend<T> for LruCache<T>
104where
105 T: Eq + Hash + fmt::Debug,
106{
107 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
108 for item in iter {
109 _ = self.insert(item);
110 }
111 }
112}
113
114impl<T> fmt::Debug for LruCache<T>
115where
116 T: fmt::Debug + Hash + Eq,
117{
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 let mut debug_struct = f.debug_struct("LruCache");
120
121 debug_struct.field("limit", &self.limit);
122
123 debug_struct.field(
124 "ret %iter",
125 &format_args!("Iter: {{{} }}", self.iter().map(|k| format!(" {k:?}")).format(",")),
126 );
127
128 debug_struct.finish()
129 }
130}
131
132#[derive(Deref, DerefMut, Default)]
134pub struct LruMap<K, V, L = ByLength, S = RandomState>(schnellru::LruMap<K, V, L, S>)
135where
136 K: Hash + PartialEq,
137 L: Limiter<K, V>,
138 S: BuildHasher;
139
140impl<K, V, L, S> fmt::Debug for LruMap<K, V, L, S>
141where
142 K: Hash + PartialEq + fmt::Display,
143 V: fmt::Debug,
144 L: Limiter<K, V> + fmt::Debug,
145 S: BuildHasher,
146{
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 let mut debug_struct = f.debug_struct("LruMap");
149
150 debug_struct.field("limiter", self.limiter());
151
152 debug_struct.field(
153 "ret %iter",
154 &format_args!(
155 "Iter: {{{} }}",
156 self.iter().map(|(k, v)| format!(" {k}: {v:?}")).format(",")
157 ),
158 );
159
160 debug_struct.finish()
161 }
162}
163
164impl<K, V> LruMap<K, V>
165where
166 K: Hash + PartialEq,
167{
168 pub fn new(max_length: u32) -> Self {
170 Self(schnellru::LruMap::new(ByLength::new(max_length)))
171 }
172}
173
174impl<K, V> LruMap<K, V, Unlimited>
175where
176 K: Hash + PartialEq,
177{
178 pub fn new_unlimited() -> Self {
180 Self(schnellru::LruMap::new(Unlimited))
181 }
182}
183
184#[cfg(test)]
185mod test {
186 use super::*;
187 use derive_more::{Constructor, Display};
188 use std::hash::Hasher;
189
190 #[derive(Debug, Hash, PartialEq, Eq, Display, Clone, Copy)]
191 struct Key(i8);
192
193 #[derive(Debug, Eq, Constructor, Clone, Copy)]
194 struct CompoundKey {
195 id: i8,
197 other: i8,
198 }
199
200 impl PartialEq for CompoundKey {
201 fn eq(&self, other: &Self) -> bool {
202 self.id == other.id
203 }
204 }
205
206 impl Hash for CompoundKey {
207 fn hash<H: Hasher>(&self, state: &mut H) {
208 self.id.hash(state)
209 }
210 }
211
212 #[test]
213 fn test_cache_should_insert_into_empty_set() {
214 let mut cache = LruCache::new(5);
215 let entry = "entry";
216 assert!(cache.insert(entry));
217 assert!(cache.contains(&entry));
218 }
219
220 #[test]
221 fn test_cache_should_not_insert_same_value_twice() {
222 let mut cache = LruCache::new(5);
223 let entry = "entry";
224 assert!(cache.insert(entry));
225 assert!(!cache.insert(entry));
226 }
227
228 #[test]
229 fn test_cache_should_remove_oldest_element_when_exceeding_limit() {
230 let mut cache = LruCache::new(2);
231 let old_entry = "old_entry";
232 let new_entry = "new_entry";
233 cache.insert(old_entry);
234 cache.insert("entry");
235 cache.insert(new_entry);
236 assert!(cache.contains(&new_entry));
237 assert!(!cache.contains(&old_entry));
238 }
239
240 #[test]
241 fn test_cache_should_extend_an_array() {
242 let mut cache = LruCache::new(5);
243 let entries = ["some_entry", "another_entry"];
244 cache.extend(entries);
245 for e in entries {
246 assert!(cache.contains(&e));
247 }
248 }
249
250 #[test]
251 #[allow(dead_code)]
252 fn test_debug_impl_lru_map() {
253 #[derive(Debug)]
254 struct Value(i8);
255
256 let mut cache = LruMap::new(2);
257 let key_1 = Key(1);
258 let value_1 = Value(11);
259 cache.insert(key_1, value_1);
260 let key_2 = Key(2);
261 let value_2 = Value(22);
262 cache.insert(key_2, value_2);
263
264 assert_eq!("LruMap { limiter: ByLength { max_length: 2 }, ret %iter: Iter: { 2: Value(22), 1: Value(11) } }", format!("{cache:?}"))
265 }
266
267 #[test]
268 #[allow(dead_code)]
269 fn test_debug_impl_lru_cache() {
270 let mut cache = LruCache::new(2);
271 let key_1 = Key(1);
272 cache.insert(key_1);
273 let key_2 = Key(2);
274 cache.insert(key_2);
275
276 assert_eq!(
277 "LruCache { limit: 2, ret %iter: Iter: { Key(2), Key(1) } }",
278 format!("{cache:?}")
279 )
280 }
281
282 #[test]
283 fn get() {
284 let mut cache = LruCache::new(2);
285 let key_1 = Key(1);
286 cache.insert(key_1);
287 let key_2 = Key(2);
288 cache.insert(key_2);
289
290 _ = cache.get(&key_1);
292
293 assert_eq!(
294 "LruCache { limit: 2, ret %iter: Iter: { Key(1), Key(2) } }",
295 format!("{cache:?}")
296 )
297 }
298
299 #[test]
300 fn get_ty_custom_eq_impl() {
301 let mut cache = LruCache::new(2);
302 let key_1 = CompoundKey::new(1, 11);
303 cache.insert(key_1);
304 let key_2 = CompoundKey::new(2, 22);
305 cache.insert(key_2);
306
307 let key = cache.get(&key_1);
308
309 assert_eq!(key_1.other, key.unwrap().other)
310 }
311
312 #[test]
313 fn peek() {
314 let mut cache = LruCache::new(2);
315 let key_1 = Key(1);
316 cache.insert(key_1);
317 let key_2 = Key(2);
318 cache.insert(key_2);
319
320 _ = cache.find(&key_1);
322
323 assert_eq!(
324 "LruCache { limit: 2, ret %iter: Iter: { Key(2), Key(1) } }",
325 format!("{cache:?}")
326 )
327 }
328
329 #[test]
330 fn peek_ty_custom_eq_impl() {
331 let mut cache = LruCache::new(2);
332 let key_1 = CompoundKey::new(1, 11);
333 cache.insert(key_1);
334 let key_2 = CompoundKey::new(2, 22);
335 cache.insert(key_2);
336
337 let key = cache.find(&key_1);
338
339 assert_eq!(key_1.other, key.unwrap().other)
340 }
341}