aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs9
-rw-r--r--src/main.rs7
-rw-r--r--src/rdb.rs687
-rw-r--r--src/resp_commands.rs88
4 files changed, 763 insertions, 28 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 060106b..c660fd8 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,6 +2,7 @@ use std::{env, sync::Arc};
#[macro_use]
pub mod macros;
+pub mod rdb;
pub mod resp_commands;
pub mod resp_parser;
pub mod shared_cache;
@@ -15,9 +16,13 @@ pub struct Config {
pub type SharedConfig = Arc<Option<Config>>;
impl Config {
- pub fn new() -> Result<Config, String> {
+ pub fn new() -> Result<Option<Config>, String> {
let args: Vec<String> = env::args().collect();
+ if args.len() == 1 {
+ return Ok(None);
+ }
+
let mut dir = None;
let mut dbfilename = None;
@@ -44,6 +49,6 @@ impl Config {
}
}
- Ok(Config { dir, dbfilename })
+ Ok(Some(Config { dir, dbfilename }))
}
}
diff --git a/src/main.rs b/src/main.rs
index b29e2ce..effc1e7 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -55,16 +55,17 @@ fn handle_client(mut stream: TcpStream, cache: SharedCache, config: SharedConfig
fn main() -> std::io::Result<()> {
let listener = TcpListener::bind("127.0.0.1:6379").unwrap();
let cache: SharedCache = Arc::new(Mutex::new(HashMap::new()));
- let mut config: SharedConfig = None.into();
+ let mut config: SharedConfig = Arc::new(None);
spawn_cleanup_thread(cache.clone());
match Config::new() {
Ok(conf) => {
- config = Arc::new(Some((conf)));
+ if let Some(conf) = conf {
+ config = Arc::new(Some(conf));
+ }
}
Err(e) => {
- config = Arc::new(None);
eprintln!("Error: {}", e);
std::process::exit(1);
}
diff --git a/src/rdb.rs b/src/rdb.rs
new file mode 100644
index 0000000..1a4de11
--- /dev/null
+++ b/src/rdb.rs
@@ -0,0 +1,687 @@
+// Helpful resource
+// https://rdb.fnordig.de/file_format.html#zipmap-encoding
+#![allow(unused)]
+
+use std::{
+ collections::{HashMap, HashSet},
+ fs, io, isize,
+ path::Path,
+};
+
+use thiserror::Error;
+
+use crate::resp_commands::ExpiryOption;
+
+/// Represents any possible value that a key can hold in Redis.
+///
+/// Note: All "string" elements are represented as `Vec<u8>` because
+/// Redis strings are binary-safe and not guaranteed to be valid UTF-8.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum RedisValue {
+ /// The STRING type.
+ String(Vec<u8>),
+
+ /// The Integer type.
+ Integer(i64),
+
+ /// The LIST type. An ordered collection of strings.
+ List(Vec<Vec<u8>>),
+
+ /// The SET type. An unordered collection of unique strings.
+ Set(HashSet<Vec<u8>>),
+
+ /// The HASH type. A collection of field-value pairs.
+ Hash(HashMap<Vec<u8>, Vec<u8>>),
+}
+
+impl RedisValue {
+ /// Convert RedisValue to bytes using Redis string encoding format
+ pub fn to_bytes(&self) -> Vec<u8> {
+ match self {
+ RedisValue::String(data) => encode_string(data),
+ RedisValue::Integer(value) => encode_integer(*value),
+ RedisValue::List(items) => {
+ let mut result = Vec::new();
+ // For lists, we'd typically encode each item separately
+ // This is a simplified version that just encodes the count
+ result.extend(encode_length(items.len()));
+ for item in items {
+ result.extend(encode_string(item));
+ }
+ result
+ }
+ RedisValue::Set(items) => {
+ let mut result = Vec::new();
+ result.extend(encode_length(items.len()));
+ for item in items {
+ result.extend(encode_string(item));
+ }
+ result
+ }
+ RedisValue::Hash(map) => {
+ let mut result = Vec::new();
+ result.extend(encode_length(map.len()));
+ for (key, value) in map {
+ result.extend(encode_string(key));
+ result.extend(encode_string(value));
+ }
+ result
+ }
+ }
+ }
+}
+
+/// Encode a string using Redis length encoding + raw bytes
+fn encode_string(data: &[u8]) -> Vec<u8> {
+ let mut result = encode_length(data.len());
+ result.extend_from_slice(data);
+ result
+}
+
+/// Encode an integer using Redis special encoding if possible, otherwise as string
+fn encode_integer(value: i64) -> Vec<u8> {
+ // Try to use special integer encodings for efficiency
+ if value >= i8::MIN as i64 && value <= i8::MAX as i64 {
+ // 8-bit integer encoding: 0b11000000 (0xC0) followed by 1 byte
+ vec![0xC0, value as u8]
+ } else if value >= i16::MIN as i64 && value <= i16::MAX as i64 {
+ // 16-bit integer encoding: 0b11000001 (0xC1) followed by 2 bytes
+ let bytes = (value as i16).to_be_bytes();
+ vec![0xC1, bytes[0], bytes[1]]
+ } else if value >= i32::MIN as i64 && value <= i32::MAX as i64 {
+ // 32-bit integer encoding: 0b11000010 (0xC2) followed by 4 bytes
+ let bytes = (value as i32).to_be_bytes();
+ vec![0xC2, bytes[0], bytes[1], bytes[2], bytes[3]]
+ } else {
+ // For very large integers, encode as string
+ let string_repr = value.to_string();
+ encode_string(string_repr.as_bytes())
+ }
+}
+
+/// Encode length using Redis length encoding format
+fn encode_length(len: usize) -> Vec<u8> {
+ if len < 64 {
+ // 6-bit length (0-63): 0b00xxxxxx
+ vec![len as u8]
+ } else if len < 16384 {
+ // 14-bit length: 0b01xxxxxx xxxxxxxx
+ let first_byte = 0x40 | ((len >> 8) as u8);
+ let second_byte = (len & 0xFF) as u8;
+ vec![first_byte, second_byte]
+ } else if len <= u32::MAX as usize {
+ // 32-bit length: 0b10000000 followed by 4 bytes big-endian
+ let bytes = (len as u32).to_be_bytes();
+ vec![0x80, bytes[0], bytes[1], bytes[2], bytes[3]]
+ } else {
+ panic!("Length too large for Redis encoding: {}", len);
+ }
+}
+
+// Custom error type for parsing
+#[derive(Debug, PartialEq)]
+pub enum ParseError {
+ InvalidMagicNumber,
+ InvalidVersion,
+ UnexpectedEof,
+ InvalidMetadata,
+ InvalidLength,
+}
+
+use std::fmt;
+
+impl fmt::Display for ParseError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ use ParseError::*;
+ let message = match self {
+ InvalidMagicNumber => "Invalid magic number",
+ InvalidVersion => "Invalid version",
+ UnexpectedEof => "Unexpected end of file",
+ InvalidMetadata => "Invalid metadata",
+ InvalidLength => "Invalid length",
+ };
+ write!(f, "{}", message)
+ }
+}
+
+impl std::error::Error for ParseError {}
+
+// Custom parsing trait that returns bytes consumed
+pub trait FromBytes: Sized {
+ fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError>;
+}
+
+#[derive(Debug, Copy, Clone)]
+pub enum ValueType {
+ String = 0,
+ List = 1,
+ Set = 2,
+ SortedSet = 3,
+ Hash = 4,
+ Zipmap = 9,
+ Ziplist = 10,
+ Intset = 11,
+ SortedSetInZiplist = 12,
+ HashmapInZiplist = 13,
+ ListInQuicklist = 14,
+}
+
+impl TryFrom<u8> for ValueType {
+ type Error = ();
+
+ fn try_from(value: u8) -> Result<Self, Self::Error> {
+ match value {
+ 0 => Ok(ValueType::String),
+ 1 => Ok(ValueType::List),
+ 2 => Ok(ValueType::Set),
+ 3 => Ok(ValueType::SortedSet),
+ 4 => Ok(ValueType::Hash),
+ 9 => Ok(ValueType::Zipmap),
+ 10 => Ok(ValueType::Ziplist),
+ 11 => Ok(ValueType::Intset),
+ 12 => Ok(ValueType::SortedSetInZiplist),
+ 13 => Ok(ValueType::HashmapInZiplist),
+ 14 => Ok(ValueType::ListInQuicklist),
+ _ => Err(()),
+ }
+ }
+}
+
+pub struct KeyExpiry {
+ pub timestamp: u64,
+ pub unit: ExpiryUnit,
+}
+
+pub enum ExpiryUnit {
+ Seconds,
+ Milliseconds,
+}
+
+pub struct DatabaseEntry {
+ pub expiry: Option<KeyExpiry>,
+ pub value_type: ValueType,
+ pub value: RedisValue,
+}
+
+type LengthEncoding = usize;
+type BytesConsumed = usize;
+
+/// Parses a Redis length-encoded integer.
+fn parse_length(bytes: &[u8]) -> Result<(usize, usize), ParseError> {
+ let (first_byte, mut rest) = bytes.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?;
+ let mut consumed = 1;
+
+ match first_byte[0] >> 6 {
+ 0b00 => {
+ // 6-bit length
+ // 0x3F = 0011 1111
+ let len = (first_byte[0] & 0x3F) as usize;
+ Ok((len, consumed))
+ }
+ 0b01 => {
+ // 14-bit length
+ let (second_byte, _) = rest.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?;
+ consumed += 1;
+
+ // We need to get the first last 6 bits(most right bits) of the first byte and then the
+ // whole next byte which can be done by a trivial so we mask out the first 6 bits and
+ // then shift them left by a byte and oring the second byte (6 + 8) = our 14-bit length
+ let len = (((first_byte[0] & 0x3F) as usize) << 8) | (second_byte[0] as usize);
+ Ok((len, consumed))
+ }
+ 0b10 => {
+ // 32-bit length from next 4 bytes
+ let (len_bytes, _) = rest.split_at_checked(4).ok_or(ParseError::UnexpectedEof)?;
+ consumed += 4;
+
+ // pretty straight forward just ignore the first byte and interpret the next 4 bytes as a u32
+ let len = u32::from_be_bytes(len_bytes.try_into().unwrap()) as usize;
+ Ok((len, consumed))
+ }
+ 0b11 => {
+ // Special format, not a length
+ Err(ParseError::InvalidLength)
+ }
+ _ => unreachable!(),
+ }
+}
+
+fn parse_special_length(
+ special_type: u8,
+ bytes: &[u8],
+ bytes_consumed: &mut usize,
+) -> Result<(RedisValue, usize), ParseError> {
+ match special_type {
+ 0 => {
+ let (int_bytes, bytes) = bytes.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?;
+ *bytes_consumed += 1;
+ Ok((RedisValue::Integer(int_bytes[0] as i64), *bytes_consumed))
+ }
+ 1 => {
+ let (int_bytes, bytes) = bytes.split_at_checked(2).ok_or(ParseError::UnexpectedEof)?;
+ *bytes_consumed += 2;
+ let value = i16::from_be_bytes([int_bytes[0], int_bytes[1]]) as i64;
+ Ok((RedisValue::Integer(value), *bytes_consumed))
+ }
+ 2 => {
+ let (int_bytes, bytes) = bytes.split_at_checked(4).ok_or(ParseError::UnexpectedEof)?;
+ *bytes_consumed += 4;
+ let value =
+ i32::from_be_bytes([int_bytes[0], int_bytes[1], int_bytes[2], int_bytes[3]]) as i64;
+ Ok((RedisValue::Integer(value), *bytes_consumed))
+ }
+ _ => Err(ParseError::InvalidLength),
+ }
+}
+
+impl FromBytes for RedisValue {
+ fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> {
+ let (length_bytes, rest) = bytes.split_at_checked(1).ok_or(ParseError::UnexpectedEof)?;
+ let mut bytes_consumed = 1;
+
+ if length_bytes[0] >> 6 == 0b11 {
+ // Special encoding
+ let special_type = length_bytes[0] & 0x3F;
+ return parse_special_length(special_type, rest, &mut bytes_consumed);
+ } else {
+ // It's a string, use our new helper
+ let (len, len_consumed) = parse_length(bytes)?;
+ let (string_bytes, _) = bytes[len_consumed..]
+ .split_at_checked(len)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ let total_consumed = len_consumed + len;
+ Ok((RedisValue::String(string_bytes.to_vec()), total_consumed))
+ }
+ }
+}
+
+#[derive(Debug, PartialEq, Default)]
+pub struct RDBHeader {
+ pub magic_number: [u8; 5],
+ pub version: [u8; 4],
+}
+
+impl FromBytes for RDBHeader {
+ fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> {
+ let mut rdb_header = RDBHeader::default();
+
+ // first 5 bytes should match the magic number
+ let (magic_number, rest) = bytes.split_at_checked(5).ok_or(ParseError::UnexpectedEof)?;
+
+ if let b"REDIS" = magic_number {
+ rdb_header.magic_number = *b"REDIS";
+ } else {
+ return Err(ParseError::InvalidMagicNumber);
+ }
+
+ // The following 4 bytes should match the version number
+ let (version, _) = rest.split_at_checked(4).ok_or(ParseError::UnexpectedEof)?;
+
+ if let Ok(version_str) = std::str::from_utf8(version) {
+ if ("0001" <= version_str) && (version_str <= "0011") {
+ rdb_header.version = version.try_into().unwrap();
+ } else {
+ return Err(ParseError::InvalidVersion);
+ }
+ } else {
+ return Err(ParseError::InvalidVersion);
+ }
+
+ Ok((rdb_header, 9))
+ }
+}
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct RDBMetaData {
+ pub metadata: HashMap<Vec<u8>, Vec<u8>>,
+}
+
+impl FromBytes for RDBMetaData {
+ fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> {
+ let mut metadata = HashMap::new();
+ let mut remaining = bytes;
+ let mut total_consumed = 0;
+
+ // Keep parsing AUX entries until we hit a non-AUX byte
+ loop {
+ let (aux_byte, rest) = remaining
+ .split_at_checked(1)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ if aux_byte[0] != RDBFile::AUX {
+ // Hit a non-AUX byte, we're done with metadata
+ break;
+ }
+
+ remaining = rest;
+ total_consumed += 1;
+
+ // Parse key string
+ let (key, key_consumed) = RedisValue::from_bytes(remaining)?;
+ remaining = &remaining[key_consumed..];
+ total_consumed += key_consumed;
+
+ // Parse value string
+ let (value, value_consumed) = RedisValue::from_bytes(remaining)?;
+ remaining = &remaining[value_consumed..];
+ total_consumed += value_consumed;
+
+ let key_data = match key {
+ RedisValue::String(data) => data,
+ RedisValue::Integer(data) => data.to_string().as_bytes().to_vec(),
+ _ => return Err(ParseError::InvalidMetadata),
+ };
+
+ let value_data = match value {
+ RedisValue::String(data) => data,
+ RedisValue::Integer(data) => data.to_string().as_bytes().to_vec(),
+ _ => return Err(ParseError::InvalidMetadata),
+ };
+
+ metadata.insert(key_data, value_data);
+ }
+
+ Ok((RDBMetaData { metadata }, total_consumed))
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct HashTableSizeInfo {
+ pub hash_table_size: usize,
+ pub expired_hash_table_size: usize,
+}
+
+pub type DatabaseIndex = usize;
+
+pub struct RDBDatabase {
+ pub database_index: DatabaseIndex,
+ pub size_hints: HashTableSizeInfo,
+ pub hash_table: HashMap<Vec<u8>, DatabaseEntry>,
+}
+
+fn parse_db_key_value(
+ remaining: &mut &[u8],
+ total_consumed: &mut usize,
+ expiry: Option<KeyExpiry>,
+ value_type: ValueType,
+ hash_table: &mut HashMap<Vec<u8>, DatabaseEntry>,
+) -> Result<(), ParseError> {
+ // Parse key string
+ let (key, key_consumed) = RedisValue::from_bytes(remaining)?;
+ *remaining = &remaining[key_consumed..];
+ *total_consumed += key_consumed;
+
+ // Parse value string
+ let (value, value_consumed) = RedisValue::from_bytes(remaining)?;
+ *remaining = &remaining[value_consumed..];
+ *total_consumed += value_consumed;
+
+ let database_entry = DatabaseEntry {
+ expiry: expiry,
+ value_type: ValueType::String,
+ value: value,
+ };
+
+ let key_data = if let RedisValue::String(data) = key {
+ data
+ } else {
+ return Err(ParseError::UnexpectedEof);
+ };
+ hash_table.insert(key_data, database_entry);
+
+ Ok(())
+}
+
+impl FromBytes for RDBDatabase {
+ fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> {
+ let mut hash_table = HashMap::new();
+ let mut remaining = bytes;
+ let mut size_hints = HashTableSizeInfo::default();
+ let mut total_consumed = 0;
+ let mut database_index = 0;
+
+ // Keep parsing db entries until we hit a EOF byte
+ loop {
+ let (next_byte, rest) = remaining
+ .split_at_checked(1)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ if next_byte[0] == RDBFile::EOF {
+ break;
+ }
+
+ remaining = rest;
+ total_consumed += 1;
+
+ match next_byte[0] {
+ RDBFile::SELECT_DB => {
+ let (index, consumed) = parse_length(remaining)?;
+ total_consumed += consumed;
+ remaining = &remaining[consumed..];
+
+ database_index = index;
+ }
+ RDBFile::RESIZE_DB => {
+ let (len, consumed) = parse_length(remaining)?;
+ total_consumed += consumed;
+ remaining = &remaining[consumed..];
+
+ size_hints.hash_table_size = len;
+
+ let (len, consumed) = parse_length(remaining)?;
+ total_consumed += consumed;
+ remaining = &remaining[consumed..];
+
+ size_hints.expired_hash_table_size = len;
+ }
+ RDBFile::EXPIRE_TIME_MS => {
+ let (timestamp_bytes, rest) = remaining
+ .split_at_checked(8)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ remaining = rest;
+ total_consumed += 8;
+
+ let timestamp = u64::from_le_bytes(
+ timestamp_bytes[0..8]
+ .try_into()
+ .expect("This should always be atleast 8 bytes"),
+ );
+
+ let (value_type_byte, rest) = remaining
+ .split_at_checked(1)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ remaining = rest;
+ total_consumed += 1;
+
+ let expiry = Some(KeyExpiry {
+ timestamp,
+ unit: ExpiryUnit::Milliseconds,
+ });
+
+ match ValueType::try_from(value_type_byte[0]).unwrap() {
+ ValueType::String => {
+ parse_db_key_value(
+ &mut remaining,
+ &mut total_consumed,
+ expiry,
+ ValueType::String,
+ &mut hash_table,
+ )?;
+ }
+ _ => unreachable!(),
+ }
+ }
+ RDBFile::EXPIRE_TIME => {
+ let (timestamp_bytes, rest) = remaining
+ .split_at_checked(4)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ remaining = rest;
+ total_consumed += 4;
+
+ let timestamp = u32::from_le_bytes(
+ timestamp_bytes[0..4]
+ .try_into()
+ .expect("This should always be atleast 4 bytes"),
+ ) as u64;
+
+ let (value_type_byte, rest) = remaining
+ .split_at_checked(1)
+ .ok_or(ParseError::UnexpectedEof)?;
+
+ remaining = rest;
+ total_consumed += 1;
+
+ let expiry = Some(KeyExpiry {
+ timestamp,
+ unit: ExpiryUnit::Seconds,
+ });
+
+ match ValueType::try_from(value_type_byte[0]).unwrap() {
+ ValueType::String => {
+ parse_db_key_value(
+ &mut remaining,
+ &mut total_consumed,
+ expiry,
+ ValueType::String,
+ &mut hash_table,
+ )?;
+ }
+ _ => unreachable!(),
+ }
+ }
+ n @ 0..15 => match ValueType::try_from(n as u8).unwrap() {
+ ValueType::String => {
+ parse_db_key_value(
+ &mut remaining,
+ &mut total_consumed,
+ None,
+ ValueType::String,
+ &mut hash_table,
+ )?;
+ }
+ _ => unreachable!(),
+ },
+ _ => break,
+ }
+ }
+
+ Ok((
+ RDBDatabase {
+ database_index,
+ size_hints,
+ hash_table,
+ },
+ total_consumed,
+ ))
+ }
+}
+
+pub struct RDBFile {
+ pub header: RDBHeader,
+ pub metadata: Option<RDBMetaData>,
+ pub databases: HashMap<DatabaseIndex, RDBDatabase>,
+ pub checksum: u64,
+}
+
+impl FromBytes for RDBFile {
+ fn from_bytes(bytes: &[u8]) -> Result<(Self, usize), ParseError> {
+ let mut remaining = bytes;
+ let mut total_consumed = 0;
+ let mut databases = HashMap::new();
+
+ // 1. Parse the RDB header ("REDIS" + version)
+ let (header, consumed) = RDBHeader::from_bytes(remaining)?;
+ total_consumed += consumed;
+ remaining = &remaining[consumed..];
+
+ // 2. Parse metadata (any AUX key-value fields)
+ // Your RDBMetaData::from_bytes implementation correctly handles this by
+ // consuming all sequential AUX fields until it hits another opcode.
+ let (metadata, consumed) = RDBMetaData::from_bytes(remaining)?;
+ total_consumed += consumed;
+ remaining = &remaining[consumed..];
+
+ // 3. Parse database sections
+ // The provided RDBDatabase::from_bytes is designed to parse the entire
+ // key-value section, including multiple DB selectors, until the final EOF.
+ let (database_section, consumed) = RDBDatabase::from_bytes(remaining)?;
+ total_consumed += consumed;
+ remaining = &remaining[consumed..];
+ databases.insert(database_section.database_index, database_section);
+
+ // 4. Parse the final EOF marker and checksum
+ // The RDBDatabase parser stops when it sees the EOF marker but doesn't consume it.
+ // We must consume it here.
+ let (eof_byte, rest) = remaining
+ .split_at_checked(1)
+ .ok_or(ParseError::UnexpectedEof)?;
+ if eof_byte[0] != Self::EOF {
+ return Err(ParseError::InvalidMetadata); // Expected EOF marker
+ }
+ total_consumed += 1;
+ remaining = rest;
+
+ // The final 8 bytes of a valid RDB file are the checksum.
+ if remaining.len() >= 8 {
+ let (checksum_bytes, _) = remaining
+ .split_at_checked(8)
+ .ok_or(ParseError::UnexpectedEof)?;
+ let checksum = u64::from_le_bytes(checksum_bytes.try_into().unwrap());
+ total_consumed += 8;
+
+ let rdb_file = RDBFile {
+ header,
+ metadata: Some(metadata),
+ databases,
+ checksum,
+ };
+ Ok((rdb_file, total_consumed))
+ } else {
+ // Handle cases where checksum might be missing (older RDB versions or truncated file)
+ // For simplicity, we'll assume a checksum is always present.
+ Err(ParseError::UnexpectedEof)
+ }
+ }
+}
+
+impl RDBFile {
+ pub fn read(dir: String, dbfilename: String) -> Result<Self, anyhow::Error> {
+ let dir = Path::new(&dir);
+ let file_path = dir.join(dbfilename);
+
+ // Read file to bytes
+ let bytes = fs::read(&file_path)?;
+ let (rdb_file, consumed) = RDBFile::from_bytes(&bytes)?;
+
+ // sanity check
+ assert!(bytes.len() == consumed);
+ Ok(rdb_file)
+ }
+}
+
+impl RDBFile {
+ /// The file starts off with the magic string “REDIS”. This is a quick sanity check to know we are dealing with a redis rdb file.
+ pub const MAGIC_NUMBER: [u8; 5] = *b"REDIS";
+
+ /// End of the RDB File
+ pub const EOF: u8 = 0xFF;
+
+ /// Database Selector
+ pub const SELECT_DB: u8 = 0xFE;
+
+ /// Expire time in seconds
+ pub const EXPIRE_TIME: u8 = 0xFD;
+
+ /// Expire time in milliseconds
+ pub const EXPIRE_TIME_MS: u8 = 0xFC;
+
+ /// Hash table sizes for the main keyspace and expires
+ pub const RESIZE_DB: u8 = 0xFB;
+
+ /// Auxiliray fields. Arbitrary key-value settings
+ pub const AUX: u8 = 0xFA;
+}
diff --git a/src/resp_commands.rs b/src/resp_commands.rs
index 37273f8..9d35b1b 100644
--- a/src/resp_commands.rs
+++ b/src/resp_commands.rs
@@ -1,5 +1,7 @@
+use crate::rdb::RDBFile;
+use crate::SharedConfig;
use crate::{resp_parser::*, shared_cache::*};
-use crate::{Config, SharedConfig};
+use regex::Regex;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
@@ -108,26 +110,23 @@ fn extract_string(resp: &RespType) -> Option<String> {
}
}
-// Helper function to parse u64 from BulkString
-fn parse_u64(resp: &RespType) -> Option<u64> {
- extract_string(resp)?.parse().ok()
-}
-
pub enum RedisCommands {
- PING,
- ECHO(String),
- GET(String),
- SET(SetCommand),
- CONFIG_GET(String),
+ Ping,
+ Echo(String),
+ Get(String),
+ Set(SetCommand),
+ ConfigGet(String),
+ Keys(String),
Invalid,
}
impl RedisCommands {
pub fn execute(self, cache: SharedCache, config: SharedConfig) -> Vec<u8> {
+ use RedisCommands as RC;
match self {
- RedisCommands::PING => resp!("PONG"),
- RedisCommands::ECHO(echo_string) => resp!(echo_string),
- RedisCommands::GET(key) => {
+ RC::Ping => resp!("PONG"),
+ RC::Echo(echo_string) => resp!(echo_string),
+ RC::Get(key) => {
let mut cache = cache.lock().unwrap();
match cache.get(&key).cloned() {
Some(entry) => {
@@ -141,7 +140,7 @@ impl RedisCommands {
None => resp!(null),
}
}
- RedisCommands::SET(command) => {
+ RC::Set(command) => {
let mut cache = cache.lock().unwrap();
// Check conditions (NX/XX)
@@ -194,7 +193,7 @@ impl RedisCommands {
None => return resp!(null),
}
}
- RedisCommands::CONFIG_GET(s) => {
+ RC::ConfigGet(s) => {
use RespType as RT;
let config = config.clone();
if let Some(conf) = config.as_ref() {
@@ -217,7 +216,44 @@ impl RedisCommands {
unreachable!()
}
}
- RedisCommands::Invalid => todo!(),
+ RC::Keys(query) => {
+ use RespType as RT;
+
+ let query = query.replace('*', ".*");
+
+ let cache = cache.lock().unwrap();
+ let regex = Regex::new(&query).unwrap();
+ let config = config.clone();
+
+ if let Some(conf) = config.as_ref() {
+ let dir = conf.dir.clone().unwrap();
+ let dbfilename = conf.dbfilename.clone().unwrap();
+ let rdb_file = RDBFile::read(dir, dbfilename).unwrap();
+
+ let hash_table = &rdb_file.databases.get(&0).unwrap().hash_table;
+ let matching_keys: Vec<RT> = hash_table
+ .keys()
+ .map(|key| str::from_utf8(key).unwrap())
+ .filter_map(|key| {
+ regex
+ .is_match(key)
+ .then(|| RT::BulkString(key.as_bytes().to_vec()))
+ })
+ .collect();
+ RT::Array(matching_keys).to_resp_bytes()
+ } else {
+ let matching_keys: Vec<RT> = cache
+ .keys()
+ .filter_map(|key| {
+ regex
+ .is_match(key)
+ .then(|| RT::BulkString(key.as_bytes().to_vec()))
+ })
+ .collect();
+ RT::Array(matching_keys).to_resp_bytes()
+ }
+ }
+ RC::Invalid => todo!(),
}
}
}
@@ -338,17 +374,17 @@ impl From<RespType> for RedisCommands {
match cmd_name.to_ascii_uppercase().as_str() {
"PING" => {
if args.next().is_none() {
- Self::PING
+ Self::Ping
} else {
Self::Invalid
}
}
"ECHO" => match (args.next(), args.next()) {
- (Some(echo_string), None) => Self::ECHO(echo_string),
+ (Some(echo_string), None) => Self::Echo(echo_string),
_ => Self::Invalid,
},
"GET" => match (args.next(), args.next()) {
- (Some(key), None) => Self::GET(key),
+ (Some(key), None) => Self::Get(key),
_ => Self::Invalid,
},
"SET" => {
@@ -362,15 +398,21 @@ impl From<RespType> for RedisCommands {
let options: Vec<String> = args.collect();
if options.is_empty() {
- Self::SET(SetCommand::new(key, value))
+ Self::Set(SetCommand::new(key, value))
} else {
let parser = SetOptionParser::new(key, value);
match parser.parse_options(&options) {
- Ok(set_command) => Self::SET(set_command),
+ Ok(set_command) => Self::Set(set_command),
Err(_) => Self::Invalid,
}
}
}
+ "KEYS" => {
+ let Some(query) = args.next() else {
+ return Self::Invalid;
+ };
+ Self::Keys(query)
+ }
"CONFIG" => {
let Some(sub_command) = args.next() else {
return Self::Invalid;
@@ -379,7 +421,7 @@ impl From<RespType> for RedisCommands {
return Self::Invalid;
};
if &sub_command.to_uppercase() == &"GET" {
- return Self::CONFIG_GET(key);
+ return Self::ConfigGet(key);
}
Self::Invalid
}