From 8d3c6333619e511b5f103f382d21a9891dd0e794 Mon Sep 17 00:00:00 2001 From: omagdy Date: Tue, 22 Jul 2025 06:08:36 +0300 Subject: feat: Added a feature to proprely parse rdb files and added support for KEYS command --- Cargo.lock | 39 +++ Cargo.toml | 1 + src/lib.rs | 9 +- src/main.rs | 7 +- src/rdb.rs | 687 ++++++++++++++++++++++++++++++++++++++++ src/resp_commands.rs | 88 +++-- tests/test_commands.rs | 16 +- tests/test_parse_bulk_string.rs | 2 +- tests/test_parse_double.rs | 1 - 9 files changed, 811 insertions(+), 39 deletions(-) create mode 100644 src/rdb.rs diff --git a/Cargo.lock b/Cargo.lock index 7aab452..d549015 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "anyhow" version = "1.0.86" @@ -74,6 +83,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes", + "regex", "thiserror", "tokio", ] @@ -207,6 +217,35 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/Cargo.toml b/Cargo.toml index 396319b..aca1ff0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,6 @@ edition = "2021" [dependencies] anyhow = "1.0.59" # error handling bytes = "1.3.0" # helps manage buffers +regex = "1.11.1" thiserror = "1.0.32" # error handling tokio = { version = "1.23.0", features = ["full"] } # async networking diff --git a/src/lib.rs b/src/lib.rs index 060106b..c660fd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ use std::{env, sync::Arc}; #[macro_use] pub mod macros; +pub mod rdb; pub mod resp_commands; pub mod resp_parser; pub mod shared_cache; @@ -15,9 +16,13 @@ pub struct Config { pub type SharedConfig = Arc>; impl Config { - pub fn new() -> Result { + pub fn new() -> Result, String> { let args: Vec = env::args().collect(); + if args.len() == 1 { + return Ok(None); + } + let mut dir = None; let mut dbfilename = None; @@ -44,6 +49,6 @@ impl Config { } } - Ok(Config { dir, dbfilename }) + Ok(Some(Config { dir, dbfilename })) } } diff --git a/src/main.rs b/src/main.rs index b29e2ce..effc1e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,16 +55,17 @@ fn handle_client(mut stream: TcpStream, cache: SharedCache, config: SharedConfig fn main() -> std::io::Result<()> { let listener = TcpListener::bind("127.0.0.1:6379").unwrap(); let cache: SharedCache = Arc::new(Mutex::new(HashMap::new())); - let mut config: SharedConfig = None.into(); + let mut config: SharedConfig = Arc::new(None); spawn_cleanup_thread(cache.clone()); match Config::new() { Ok(conf) => { - config = Arc::new(Some((conf))); + if let Some(conf) = conf { + config = Arc::new(Some(conf)); + } } Err(e) => { - config = Arc::new(None); eprintln!("Error: {}", e); std::process::exit(1); } diff --git a/src/rdb.rs b/src/rdb.rs new file mode 100644 index 0000000..1a4de11 --- /dev/null +++ b/src/rdb.rs @@ -0,0 +1,687 @@ +// Helpful resource +// https://rdb.fnordig.de/file_format.html#zipmap-encoding +#![allow(unused)] + +use std::{ + collections::{HashMap, HashSet}, + fs, io, isize, + path::Path, +}; + +use thiserror::Error; + +use crate::resp_commands::ExpiryOption; + +/// Represents any possible value that a key can hold in Redis. +/// +/// Note: All "string" elements are represented as `Vec` because +/// Redis strings are binary-safe and not guaranteed to be valid UTF-8. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RedisValue { + /// The STRING type. + String(Vec), + + /// The Integer type. + Integer(i64), + + /// The LIST type. An ordered collection of strings. + List(Vec>), + + /// The SET type. An unordered collection of unique strings. + Set(HashSet>), + + /// The HASH type. A collection of field-value pairs. + Hash(HashMap, Vec>), +} + +impl RedisValue { + /// Convert RedisValue to bytes using Redis string encoding format + pub fn to_bytes(&self) -> Vec { + match self { + RedisValue::String(data) => encode_string(data), + RedisValue::Integer(value) => encode_integer(*value), + RedisValue::List(items) => { + let mut result = Vec::new(); + // For lists, we'd typically encode each item separately + // This is a simplified version that just encodes the count + result.extend(encode_length(items.len())); + for item in items { + result.extend(encode_string(item)); + } + result + } + RedisValue::Set(items) => { + let mut result = Vec::new(); + result.extend(encode_length(items.len())); + for item in items { + result.extend(encode_string(item)); + } + result + } + RedisValue::Hash(map) => { + let mut result = Vec::new(); + result.extend(encode_length(map.len())); + for (key, value) in map { + result.extend(encode_string(key)); + result.extend(encode_string(value)); + } + result + } + } + } +} + +/// Encode a string using Redis length encoding + raw bytes +fn encode_string(data: &[u8]) -> Vec { + let mut result = encode_length(data.len()); + result.extend_from_slice(data); + result +} + +/// Encode an integer using Redis special encoding if possible, otherwise as string +fn encode_integer(value: i64) -> Vec { + // Try to use special integer encodings for efficiency + if value >= i8::MIN as i64 && value <= i8::MAX as i64 { + // 8-bit integer encoding: 0b11000000 (0xC0) followed by 1 byte + vec![0xC0, value as u8] + } else if value >= i16::MIN as i64 && value <= i16::MAX as i64 { + // 16-bit integer encoding: 0b11000001 (0xC1) followed by 2 bytes + let bytes = (value as i16).to_be_bytes(); + vec![0xC1, bytes[0], bytes[1]] + } else if value >= i32::MIN as i64 && value <= i32::MAX as i64 { + // 32-bit integer encoding: 0b11000010 (0xC2) followed by 4 bytes + let bytes = (value as i32).to_be_bytes(); + vec![0xC2, bytes[0], bytes[1], bytes[2], bytes[3]] + } else { + // For very large integers, encode as string + let string_repr = value.to_string(); + encode_string(string_repr.as_bytes()) + } +} + +/// Encode length using Redis length encoding format +fn encode_length(len: usize) -> Vec { + if len < 64 { + // 6-bit length (0-63): 0b00xxxxxx + vec![len as u8] + } else if len < 16384 { + // 14-bit length: 0b01xxxxxx xxxxxxxx + let first_byte = 0x40 | ((len >> 8) as u8); + let second_byte = (len & 0xFF) as u8; + vec![first_byte, second_byte] + } else if len <= u32::MAX as usize { + // 32-bit length: 0b10000000 followed by 4 bytes big-endian + let bytes = (len as u32).to_be_bytes(); + vec![0x80, bytes[0], bytes[1], bytes[2], bytes[3]] + } else { + panic!("Length too large for Redis encoding: {}", len); + } +} + +// Custom error type for parsing +#[derive(Debug, PartialEq)] +pub enum ParseError { + InvalidMagicNumber, + InvalidVersion, + UnexpectedEof, + InvalidMetadata, + InvalidLength, +} + +use std::fmt; + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ParseError::*; + let message = match self { + InvalidMagicNumber => "Invalid magic number", + InvalidVersion => "Invalid version", + UnexpectedEof => "Unexpected end of file", + InvalidMetadata => "Invalid metadata", + InvalidLength => "Invalid length", + }; + write!(f, "{}", message) + } +} + +impl std::error::Error for ParseError {} + +// Custom parsing trait that returns bytes consumed +pub trait FromBytes: Sized { + fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError>; +} + +#[derive(Debug, Copy, Clone)] +pub enum ValueType { + String = 0, + List = 1, + Set = 2, + SortedSet = 3, + Hash = 4, + Zipmap = 9, + Ziplist = 10, + Intset = 11, + SortedSetInZiplist = 12, + HashmapInZiplist = 13, + ListInQuicklist = 14, +} + +impl TryFrom for ValueType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(ValueType::String), + 1 => Ok(ValueType::List), + 2 => Ok(ValueType::Set), + 3 => Ok(ValueType::SortedSet), + 4 => Ok(ValueType::Hash), + 9 => Ok(ValueType::Zipmap), + 10 => Ok(ValueType::Ziplist), + 11 => Ok(ValueType::Intset), + 12 => Ok(ValueType::SortedSetInZiplist), + 13 => Ok(ValueType::HashmapInZiplist), + 14 => Ok(ValueType::ListInQuicklist), + _ => Err(()), + } + } +} + +pub struct KeyExpiry { + pub timestamp: u64, + pub unit: ExpiryUnit, +} + +pub enum ExpiryUnit { + Seconds, + Milliseconds, +} + +pub struct DatabaseEntry { + pub expiry: Option, + pub value_type: ValueType, + pub value: RedisValue, +} + +type LengthEncoding = usize; +type BytesConsumed = usize; + +/// Parses a Redis length-encoded integer. +fn parse_length(bytes: &[u8]) -> Result<(usize, usize), ParseError> { + let (first_byte, mut rest) = bytes.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?; + let mut consumed = 1; + + match first_byte[0] >> 6 { + 0b00 => { + // 6-bit length + // 0x3F = 0011 1111 + let len = (first_byte[0] & 0x3F) as usize; + Ok((len, consumed)) + } + 0b01 => { + // 14-bit length + let (second_byte, _) = rest.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?; + consumed += 1; + + // We need to get the first last 6 bits(most right bits) of the first byte and then the + // whole next byte which can be done by a trivial so we mask out the first 6 bits and + // then shift them left by a byte and oring the second byte (6 + 8) = our 14-bit length + let len = (((first_byte[0] & 0x3F) as usize) << 8) | (second_byte[0] as usize); + Ok((len, consumed)) + } + 0b10 => { + // 32-bit length from next 4 bytes + let (len_bytes, _) = rest.split_at_checked(4).ok_or(ParseError::UnexpectedEof)?; + consumed += 4; + + // pretty straight forward just ignore the first byte and interpret the next 4 bytes as a u32 + let len = u32::from_be_bytes(len_bytes.try_into().unwrap()) as usize; + Ok((len, consumed)) + } + 0b11 => { + // Special format, not a length + Err(ParseError::InvalidLength) + } + _ => unreachable!(), + } +} + +fn parse_special_length( + special_type: u8, + bytes: &[u8], + bytes_consumed: &mut usize, +) -> Result<(RedisValue, usize), ParseError> { + match special_type { + 0 => { + let (int_bytes, bytes) = bytes.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?; + *bytes_consumed += 1; + Ok((RedisValue::Integer(int_bytes[0] as i64), *bytes_consumed)) + } + 1 => { + let (int_bytes, bytes) = bytes.split_at_checked(2).ok_or(ParseError::UnexpectedEof)?; + *bytes_consumed += 2; + let value = i16::from_be_bytes([int_bytes[0], int_bytes[1]]) as i64; + Ok((RedisValue::Integer(value), *bytes_consumed)) + } + 2 => { + let (int_bytes, bytes) = bytes.split_at_checked(4).ok_or(ParseError::UnexpectedEof)?; + *bytes_consumed += 4; + let value = + i32::from_be_bytes([int_bytes[0], int_bytes[1], int_bytes[2], int_bytes[3]]) as i64; + Ok((RedisValue::Integer(value), *bytes_consumed)) + } + _ => Err(ParseError::InvalidLength), + } +} + +impl FromBytes for RedisValue { + fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> { + let (length_bytes, rest) = bytes.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?; + let mut bytes_consumed = 1; + + if length_bytes[0] >> 6 == 0b11 { + // Special encoding + let special_type = length_bytes[0] & 0x3F; + return parse_special_length(special_type, rest, &mut bytes_consumed); + } else { + // It's a string, use our new helper + let (len, len_consumed) = parse_length(bytes)?; + let (string_bytes, _) = bytes[len_consumed..] + .split_at_checked(len) + .ok_or(ParseError::UnexpectedEof)?; + + let total_consumed = len_consumed + len; + Ok((RedisValue::String(string_bytes.to_vec()), total_consumed)) + } + } +} + +#[derive(Debug, PartialEq, Default)] +pub struct RDBHeader { + pub magic_number: [u8; 5], + pub version: [u8; 4], +} + +impl FromBytes for RDBHeader { + fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> { + let mut rdb_header = RDBHeader::default(); + + // first 5 bytes should match the magic number + let (magic_number, rest) = bytes.split_at_checked(5).ok_or(ParseError::UnexpectedEof)?; + + if let b"REDIS" = magic_number { + rdb_header.magic_number = *b"REDIS"; + } else { + return Err(ParseError::InvalidMagicNumber); + } + + // The following 4 bytes should match the version number + let (version, _) = rest.split_at_checked(4).ok_or(ParseError::UnexpectedEof)?; + + if let Ok(version_str) = std::str::from_utf8(version) { + if ("0001" <= version_str) && (version_str <= "0011") { + rdb_header.version = version.try_into().unwrap(); + } else { + return Err(ParseError::InvalidVersion); + } + } else { + return Err(ParseError::InvalidVersion); + } + + Ok((rdb_header, 9)) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct RDBMetaData { + pub metadata: HashMap, Vec>, +} + +impl FromBytes for RDBMetaData { + fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> { + let mut metadata = HashMap::new(); + let mut remaining = bytes; + let mut total_consumed = 0; + + // Keep parsing AUX entries until we hit a non-AUX byte + loop { + let (aux_byte, rest) = remaining + .split_at_checked(1) + .ok_or(ParseError::UnexpectedEof)?; + + if aux_byte[0] != RDBFile::AUX { + // Hit a non-AUX byte, we're done with metadata + break; + } + + remaining = rest; + total_consumed += 1; + + // Parse key string + let (key, key_consumed) = RedisValue::from_bytes(remaining)?; + remaining = &remaining[key_consumed..]; + total_consumed += key_consumed; + + // Parse value string + let (value, value_consumed) = RedisValue::from_bytes(remaining)?; + remaining = &remaining[value_consumed..]; + total_consumed += value_consumed; + + let key_data = match key { + RedisValue::String(data) => data, + RedisValue::Integer(data) => data.to_string().as_bytes().to_vec(), + _ => return Err(ParseError::InvalidMetadata), + }; + + let value_data = match value { + RedisValue::String(data) => data, + RedisValue::Integer(data) => data.to_string().as_bytes().to_vec(), + _ => return Err(ParseError::InvalidMetadata), + }; + + metadata.insert(key_data, value_data); + } + + Ok((RDBMetaData { metadata }, total_consumed)) + } +} + +#[derive(Debug, Default)] +pub struct HashTableSizeInfo { + pub hash_table_size: usize, + pub expired_hash_table_size: usize, +} + +pub type DatabaseIndex = usize; + +pub struct RDBDatabase { + pub database_index: DatabaseIndex, + pub size_hints: HashTableSizeInfo, + pub hash_table: HashMap, DatabaseEntry>, +} + +fn parse_db_key_value( + remaining: &mut &[u8], + total_consumed: &mut usize, + expiry: Option, + value_type: ValueType, + hash_table: &mut HashMap, DatabaseEntry>, +) -> Result<(), ParseError> { + // Parse key string + let (key, key_consumed) = RedisValue::from_bytes(remaining)?; + *remaining = &remaining[key_consumed..]; + *total_consumed += key_consumed; + + // Parse value string + let (value, value_consumed) = RedisValue::from_bytes(remaining)?; + *remaining = &remaining[value_consumed..]; + *total_consumed += value_consumed; + + let database_entry = DatabaseEntry { + expiry: expiry, + value_type: ValueType::String, + value: value, + }; + + let key_data = if let RedisValue::String(data) = key { + data + } else { + return Err(ParseError::UnexpectedEof); + }; + hash_table.insert(key_data, database_entry); + + Ok(()) +} + +impl FromBytes for RDBDatabase { + fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> { + let mut hash_table = HashMap::new(); + let mut remaining = bytes; + let mut size_hints = HashTableSizeInfo::default(); + let mut total_consumed = 0; + let mut database_index = 0; + + // Keep parsing db entries until we hit a EOF byte + loop { + let (next_byte, rest) = remaining + .split_at_checked(1) + .ok_or(ParseError::UnexpectedEof)?; + + if next_byte[0] == RDBFile::EOF { + break; + } + + remaining = rest; + total_consumed += 1; + + match next_byte[0] { + RDBFile::SELECT_DB => { + let (index, consumed) = parse_length(remaining)?; + total_consumed += consumed; + remaining = &remaining[consumed..]; + + database_index = index; + } + RDBFile::RESIZE_DB => { + let (len, consumed) = parse_length(remaining)?; + total_consumed += consumed; + remaining = &remaining[consumed..]; + + size_hints.hash_table_size = len; + + let (len, consumed) = parse_length(remaining)?; + total_consumed += consumed; + remaining = &remaining[consumed..]; + + size_hints.expired_hash_table_size = len; + } + RDBFile::EXPIRE_TIME_MS => { + let (timestamp_bytes, rest) = remaining + .split_at_checked(8) + .ok_or(ParseError::UnexpectedEof)?; + + remaining = rest; + total_consumed += 8; + + let timestamp = u64::from_le_bytes( + timestamp_bytes[0..8] + .try_into() + .expect("This should always be atleast 8 bytes"), + ); + + let (value_type_byte, rest) = remaining + .split_at_checked(1) + .ok_or(ParseError::UnexpectedEof)?; + + remaining = rest; + total_consumed += 1; + + let expiry = Some(KeyExpiry { + timestamp, + unit: ExpiryUnit::Milliseconds, + }); + + match ValueType::try_from(value_type_byte[0]).unwrap() { + ValueType::String => { + parse_db_key_value( + &mut remaining, + &mut total_consumed, + expiry, + ValueType::String, + &mut hash_table, + )?; + } + _ => unreachable!(), + } + } + RDBFile::EXPIRE_TIME => { + let (timestamp_bytes, rest) = remaining + .split_at_checked(4) + .ok_or(ParseError::UnexpectedEof)?; + + remaining = rest; + total_consumed += 4; + + let timestamp = u32::from_le_bytes( + timestamp_bytes[0..4] + .try_into() + .expect("This should always be atleast 4 bytes"), + ) as u64; + + let (value_type_byte, rest) = remaining + .split_at_checked(1) + .ok_or(ParseError::UnexpectedEof)?; + + remaining = rest; + total_consumed += 1; + + let expiry = Some(KeyExpiry { + timestamp, + unit: ExpiryUnit::Seconds, + }); + + match ValueType::try_from(value_type_byte[0]).unwrap() { + ValueType::String => { + parse_db_key_value( + &mut remaining, + &mut total_consumed, + expiry, + ValueType::String, + &mut hash_table, + )?; + } + _ => unreachable!(), + } + } + n @ 0..15 => match ValueType::try_from(n as u8).unwrap() { + ValueType::String => { + parse_db_key_value( + &mut remaining, + &mut total_consumed, + None, + ValueType::String, + &mut hash_table, + )?; + } + _ => unreachable!(), + }, + _ => break, + } + } + + Ok(( + RDBDatabase { + database_index, + size_hints, + hash_table, + }, + total_consumed, + )) + } +} + +pub struct RDBFile { + pub header: RDBHeader, + pub metadata: Option, + pub databases: HashMap, + pub checksum: u64, +} + +impl FromBytes for RDBFile { + fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> { + let mut remaining = bytes; + let mut total_consumed = 0; + let mut databases = HashMap::new(); + + // 1. Parse the RDB header ("REDIS" + version) + let (header, consumed) = RDBHeader::from_bytes(remaining)?; + total_consumed += consumed; + remaining = &remaining[consumed..]; + + // 2. Parse metadata (any AUX key-value fields) + // Your RDBMetaData::from_bytes implementation correctly handles this by + // consuming all sequential AUX fields until it hits another opcode. + let (metadata, consumed) = RDBMetaData::from_bytes(remaining)?; + total_consumed += consumed; + remaining = &remaining[consumed..]; + + // 3. Parse database sections + // The provided RDBDatabase::from_bytes is designed to parse the entire + // key-value section, including multiple DB selectors, until the final EOF. + let (database_section, consumed) = RDBDatabase::from_bytes(remaining)?; + total_consumed += consumed; + remaining = &remaining[consumed..]; + databases.insert(database_section.database_index, database_section); + + // 4. Parse the final EOF marker and checksum + // The RDBDatabase parser stops when it sees the EOF marker but doesn't consume it. + // We must consume it here. + let (eof_byte, rest) = remaining + .split_at_checked(1) + .ok_or(ParseError::UnexpectedEof)?; + if eof_byte[0] != Self::EOF { + return Err(ParseError::InvalidMetadata); // Expected EOF marker + } + total_consumed += 1; + remaining = rest; + + // The final 8 bytes of a valid RDB file are the checksum. + if remaining.len() >= 8 { + let (checksum_bytes, _) = remaining + .split_at_checked(8) + .ok_or(ParseError::UnexpectedEof)?; + let checksum = u64::from_le_bytes(checksum_bytes.try_into().unwrap()); + total_consumed += 8; + + let rdb_file = RDBFile { + header, + metadata: Some(metadata), + databases, + checksum, + }; + Ok((rdb_file, total_consumed)) + } else { + // Handle cases where checksum might be missing (older RDB versions or truncated file) + // For simplicity, we'll assume a checksum is always present. + Err(ParseError::UnexpectedEof) + } + } +} + +impl RDBFile { + pub fn read(dir: String, dbfilename: String) -> Result { + let dir = Path::new(&dir); + let file_path = dir.join(dbfilename); + + // Read file to bytes + let bytes = fs::read(&file_path)?; + let (rdb_file, consumed) = RDBFile::from_bytes(&bytes)?; + + // sanity check + assert!(bytes.len() == consumed); + Ok(rdb_file) + } +} + +impl RDBFile { + /// The file starts off with the magic string “REDIS”. This is a quick sanity check to know we are dealing with a redis rdb file. + pub const MAGIC_NUMBER: [u8; 5] = *b"REDIS"; + + /// End of the RDB File + pub const EOF: u8 = 0xFF; + + /// Database Selector + pub const SELECT_DB: u8 = 0xFE; + + /// Expire time in seconds + pub const EXPIRE_TIME: u8 = 0xFD; + + /// Expire time in milliseconds + pub const EXPIRE_TIME_MS: u8 = 0xFC; + + /// Hash table sizes for the main keyspace and expires + pub const RESIZE_DB: u8 = 0xFB; + + /// Auxiliray fields. Arbitrary key-value settings + pub const AUX: u8 = 0xFA; +} diff --git a/src/resp_commands.rs b/src/resp_commands.rs index 37273f8..9d35b1b 100644 --- a/src/resp_commands.rs +++ b/src/resp_commands.rs @@ -1,5 +1,7 @@ +use crate::rdb::RDBFile; +use crate::SharedConfig; use crate::{resp_parser::*, shared_cache::*}; -use crate::{Config, SharedConfig}; +use regex::Regex; use std::time::{SystemTime, UNIX_EPOCH}; #[derive(Debug, Clone)] @@ -108,26 +110,23 @@ fn extract_string(resp: &RespType) -> Option { } } -// Helper function to parse u64 from BulkString -fn parse_u64(resp: &RespType) -> Option { - extract_string(resp)?.parse().ok() -} - pub enum RedisCommands { - PING, - ECHO(String), - GET(String), - SET(SetCommand), - CONFIG_GET(String), + Ping, + Echo(String), + Get(String), + Set(SetCommand), + ConfigGet(String), + Keys(String), Invalid, } impl RedisCommands { pub fn execute(self, cache: SharedCache, config: SharedConfig) -> Vec { + use RedisCommands as RC; match self { - RedisCommands::PING => resp!("PONG"), - RedisCommands::ECHO(echo_string) => resp!(echo_string), - RedisCommands::GET(key) => { + RC::Ping => resp!("PONG"), + RC::Echo(echo_string) => resp!(echo_string), + RC::Get(key) => { let mut cache = cache.lock().unwrap(); match cache.get(&key).cloned() { Some(entry) => { @@ -141,7 +140,7 @@ impl RedisCommands { None => resp!(null), } } - RedisCommands::SET(command) => { + RC::Set(command) => { let mut cache = cache.lock().unwrap(); // Check conditions (NX/XX) @@ -194,7 +193,7 @@ impl RedisCommands { None => return resp!(null), } } - RedisCommands::CONFIG_GET(s) => { + RC::ConfigGet(s) => { use RespType as RT; let config = config.clone(); if let Some(conf) = config.as_ref() { @@ -217,7 +216,44 @@ impl RedisCommands { unreachable!() } } - RedisCommands::Invalid => todo!(), + RC::Keys(query) => { + use RespType as RT; + + let query = query.replace('*', ".*"); + + let cache = cache.lock().unwrap(); + let regex = Regex::new(&query).unwrap(); + let config = config.clone(); + + if let Some(conf) = config.as_ref() { + let dir = conf.dir.clone().unwrap(); + let dbfilename = conf.dbfilename.clone().unwrap(); + let rdb_file = RDBFile::read(dir, dbfilename).unwrap(); + + let hash_table = &rdb_file.databases.get(&0).unwrap().hash_table; + let matching_keys: Vec = hash_table + .keys() + .map(|key| str::from_utf8(key).unwrap()) + .filter_map(|key| { + regex + .is_match(key) + .then(|| RT::BulkString(key.as_bytes().to_vec())) + }) + .collect(); + RT::Array(matching_keys).to_resp_bytes() + } else { + let matching_keys: Vec = cache + .keys() + .filter_map(|key| { + regex + .is_match(key) + .then(|| RT::BulkString(key.as_bytes().to_vec())) + }) + .collect(); + RT::Array(matching_keys).to_resp_bytes() + } + } + RC::Invalid => todo!(), } } } @@ -338,17 +374,17 @@ impl From for RedisCommands { match cmd_name.to_ascii_uppercase().as_str() { "PING" => { if args.next().is_none() { - Self::PING + Self::Ping } else { Self::Invalid } } "ECHO" => match (args.next(), args.next()) { - (Some(echo_string), None) => Self::ECHO(echo_string), + (Some(echo_string), None) => Self::Echo(echo_string), _ => Self::Invalid, }, "GET" => match (args.next(), args.next()) { - (Some(key), None) => Self::GET(key), + (Some(key), None) => Self::Get(key), _ => Self::Invalid, }, "SET" => { @@ -362,15 +398,21 @@ impl From for RedisCommands { let options: Vec = args.collect(); if options.is_empty() { - Self::SET(SetCommand::new(key, value)) + Self::Set(SetCommand::new(key, value)) } else { let parser = SetOptionParser::new(key, value); match parser.parse_options(&options) { - Ok(set_command) => Self::SET(set_command), + Ok(set_command) => Self::Set(set_command), Err(_) => Self::Invalid, } } } + "KEYS" => { + let Some(query) = args.next() else { + return Self::Invalid; + }; + Self::Keys(query) + } "CONFIG" => { let Some(sub_command) = args.next() else { return Self::Invalid; @@ -379,7 +421,7 @@ impl From for RedisCommands { return Self::Invalid; }; if &sub_command.to_uppercase() == &"GET" { - return Self::CONFIG_GET(key); + return Self::ConfigGet(key); } Self::Invalid } diff --git a/tests/test_commands.rs b/tests/test_commands.rs index 1031c8b..e71db38 100644 --- a/tests/test_commands.rs +++ b/tests/test_commands.rs @@ -37,13 +37,13 @@ mod command_parser_tests { #[test] fn test_parse_ping() { let cmd = build_command_from_str_slice(&["PING"]); - assert!(matches!(RedisCommands::from(cmd), RedisCommands::PING)); + assert!(matches!(RedisCommands::from(cmd), RedisCommands::Ping)); } #[test] fn test_parse_ping_case_insensitive() { let cmd = build_command_from_str_slice(&["pInG"]); - assert!(matches!(RedisCommands::from(cmd), RedisCommands::PING)); + assert!(matches!(RedisCommands::from(cmd), RedisCommands::Ping)); } #[test] @@ -56,7 +56,7 @@ mod command_parser_tests { fn test_parse_echo() { let cmd = build_command_from_str_slice(&["ECHO", "hello world"]); match RedisCommands::from(cmd) { - RedisCommands::ECHO(s) => assert_eq!(s, "hello world"), + RedisCommands::Echo(s) => assert_eq!(s, "hello world"), _ => panic!("Expected ECHO command"), } } @@ -71,7 +71,7 @@ mod command_parser_tests { fn test_parse_get() { let cmd = build_command_from_str_slice(&["GET", "mykey"]); match RedisCommands::from(cmd) { - RedisCommands::GET(k) => assert_eq!(k, "mykey"), + RedisCommands::Get(k) => assert_eq!(k, "mykey"), _ => panic!("Expected GET command"), } } @@ -80,7 +80,7 @@ mod command_parser_tests { fn test_parse_simple_set() { let cmd = build_command_from_str_slice(&["SET", "mykey", "myvalue"]); match RedisCommands::from(cmd) { - RedisCommands::SET(c) => { + RedisCommands::Set(c) => { assert_eq!(c.key, "mykey"); assert_eq!(c.value, "myvalue"); assert!(c.condition.is_none() && c.expiry.is_none() && !c.get_old_value); @@ -93,7 +93,7 @@ mod command_parser_tests { fn test_parse_set_with_all_options() { let cmd = build_command_from_str_slice(&["SET", "k", "v", "NX", "PX", "5000", "GET"]); match RedisCommands::from(cmd) { - RedisCommands::SET(c) => { + RedisCommands::Set(c) => { assert!(matches!(c.condition, Some(SetCondition::NotExists))); assert!(matches!(c.expiry, Some(ExpiryOption::Milliseconds(5000)))); assert!(c.get_old_value); @@ -106,7 +106,7 @@ mod command_parser_tests { fn test_parse_set_options_case_insensitive() { let cmd = build_command_from_str_slice(&["set", "k", "v", "nx", "px", "100"]); match RedisCommands::from(cmd) { - RedisCommands::SET(c) => { + RedisCommands::Set(c) => { assert!(matches!(c.condition, Some(SetCondition::NotExists))); assert!(matches!(c.expiry, Some(ExpiryOption::Milliseconds(100)))); } @@ -267,8 +267,6 @@ mod set_command_tests { use codecrafters_redis::resp_commands::{ExpiryOption, SetCommand}; - use super::*; - #[test] fn test_calculate_expiry_seconds() { let cmd = diff --git a/tests/test_parse_bulk_string.rs b/tests/test_parse_bulk_string.rs index 1543262..a247b6e 100644 --- a/tests/test_parse_bulk_string.rs +++ b/tests/test_parse_bulk_string.rs @@ -30,7 +30,7 @@ fn test_valid_bulk_strings() { // large string let large_content = "x".repeat(1000); let large_bulk = format!("$1000\r\n{}\r\n", large_content); - if let RespType::BulkString(bulk) = parse_bulk_strings(large_bulk.as_bytes()).unwrap().0 {} + if let RespType::BulkString(_) = parse_bulk_strings(large_bulk.as_bytes()).unwrap().0 {} assert_eq!( parse_bulk_strings(large_bulk.as_bytes()).unwrap().0, diff --git a/tests/test_parse_double.rs b/tests/test_parse_double.rs index f8fa550..e69de29 100644 --- a/tests/test_parse_double.rs +++ b/tests/test_parse_double.rs @@ -1 +0,0 @@ -use codecrafters_redis::resp_parser::*; -- cgit v1.2.3