aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
blob: 5fdd6c3423e0dfa0aba2de41c7ddf893ca4ddf5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#![allow(unused_imports)]
use core::time;
use std::{
    collections::HashMap,
    env,
    io::{Read, Write},
    net::{TcpListener, TcpStream},
    sync::{Arc, Mutex},
    thread,
    time::{Duration, SystemTime, UNIX_EPOCH},
};

use codecrafters_redis::{
    rdb::{KeyExpiry, ParseError, RDBFile, RedisValue},
    shared_cache::*,
};
use codecrafters_redis::{resp_commands::RedisCommands, Config};
use codecrafters_redis::{
    resp_parser::{parse, RespType},
    SharedConfig,
};

fn spawn_cleanup_thread(cache: SharedCache) {
    let cache_clone = cache.clone();
    std::thread::spawn(move || {
        loop {
            std::thread::sleep(Duration::from_secs(10)); // Check every 10 seconds

            let mut cache = cache_clone.lock().unwrap();
            let now = SystemTime::now()
                .duration_since(UNIX_EPOCH)
                .unwrap()
                .as_millis() as u64;

            // Remove expired keys
            cache.retain(|_, entry| entry.expires_at.map_or(true, |expiry| now <= expiry));
        }
    });
}

fn handle_client(mut stream: TcpStream, cache: SharedCache, config: SharedConfig) {
    let mut buffer = [0; 512];

    loop {
        let _ = match stream.read(&mut buffer) {
            Ok(0) => return, // connection closed
            Ok(n) => n,
            Err(_) => return, // error occurred
        };

        let parsed_resp = parse(&buffer).unwrap();
        let response = RedisCommands::from(parsed_resp.0).execute(cache.clone(), config.clone());

        // write respose back to the client
        stream.write(&response).unwrap();
    }
}

fn main() -> std::io::Result<()> {
    let cache: SharedCache = Arc::new(Mutex::new(HashMap::new()));
    let mut config: SharedConfig = Arc::new(None);
    let mut port = "6379".to_string();

    match Config::new() {
        Ok(conf) => {
            if let Some(conf) = conf {
                let mut cache = cache.lock().unwrap();
                let dir = conf.dir.clone().unwrap_or("".to_string());
                let dbfilename = conf.dbfilename.clone().unwrap_or("".to_string());
                port = conf.port.clone().unwrap_or("6379".to_string());
                if let Ok(rdb_file) = RDBFile::read(dir, dbfilename) {
                    if let Some(rdb) = rdb_file {
                        let hash_table = &rdb.databases.get(&0).unwrap().hash_table;

                        for (key, db_entry) in hash_table.iter() {
                            let value = match &db_entry.value {
                                RedisValue::String(data) => {
                                    String::from_utf8(data.clone()).unwrap()
                                }
                                RedisValue::Integer(data) => data.to_string(),
                                _ => {
                                    unreachable!()
                                }
                            };
                            let expires_at = if let Some(key_expiry) = &db_entry.expiry {
                                Some(key_expiry.timestamp)
                            } else {
                                None
                            };
                            let cache_entry = CacheEntry { value, expires_at };
                            cache.insert(String::from_utf8(key.clone()).unwrap(), cache_entry);
                        }
                    }
                }
                config = Arc::new(Some(conf));
            }
        }
        Err(e) => {
            eprintln!("Error: {}", e);
            std::process::exit(1);
        }
    }

    let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();

    spawn_cleanup_thread(cache.clone());

    for stream in listener.incoming() {
        match stream {
            Ok(stream) => {
                let cache_clone = cache.clone();
                let config_clone = Arc::clone(&config);
                thread::spawn(|| {
                    handle_client(stream, cache_clone, config_clone);
                });
            }
            Err(e) => {
                eprintln!("Connection failed: {}", e);
            }
        }
    }
    Ok(())
}