From 4a1877547620a63fe6215d3bafc10cbd69ab8183 Mon Sep 17 00:00:00 2001 From: vance Date: Thu, 5 Jan 2023 18:32:57 -0800 Subject: [PATCH] initial commit --- .gitignore | 4 + Cargo.toml | 22 +++ README.md | 2 + config.toml | 2 + src/main.rs | 439 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 469 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 config.toml create mode 100644 src/main.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e81acc2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/target +/.idea +Cargo.lock +ircd.db \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8dedd08 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "ircd" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +futures= "*" +anyhow = "*" +tokio = { version = "*", features = ["full"] } +tokio-util = { version = "*", features = ["codec"] } +bytes = "*" +tokio-stream = { version = "*", features = ["net"] } +sqlx = { version = "*", features = ["runtime-tokio-rustls", "all-databases"] } +toml = { version = "*" } +serde = { version = "*" } +serde_derive = { version = "*" } + +[profile.release] +lto = "thin" +opt-level = "z" \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..a1d9cc7 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# ircd +a simple rust ircd \ No newline at end of file diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..bb39e07 --- /dev/null +++ b/config.toml @@ -0,0 +1,2 @@ +addr = "127.0.0.1" +port = 6667 \ No newline at end of file diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..7b21e46 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,439 @@ +use std::collections::HashMap; +use std::env; +use std::fs; +use std::net::SocketAddr; +use std::path::Path; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use anyhow::Result; +use bytes::BytesMut; +use futures::{ready, Sink, SinkExt, Stream}; +use serde_derive::Deserialize; +use sqlx::{AnyPool, SqlitePool}; +use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, Lines}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_stream::StreamExt; +use tokio_util::codec::{Decoder, Encoder, Framed, LinesCodec}; + +struct UnregisteredUserState { + nick: Option, + user: Option, + realname: Option, +} + +impl UnregisteredUserState { + fn new() -> Self { + Self { + nick: None, + user: None, + realname: None, + } + } +} + +struct RegisteredUserState { + nick: String, + user: String, + realname: String, +} + +enum UserState { + Unregistered(UnregisteredUserState), + Registered(RegisteredUserState), +} + +struct User { + server_state: Arc, + io: Framed, + state: UserState, + addr: SocketAddr, + channels: HashMap, +} + +impl User { + fn from(server_state: Arc, socket: TcpStream) -> Result { + let addr = socket.peer_addr()?; + Ok(Self { + server_state, + io: Framed::new(socket, MessageCodec::new()), + state: UserState::Unregistered(UnregisteredUserState::new()), + addr, + channels: HashMap::new(), + }) + } +} + +#[derive(Clone, Debug, PartialEq)] +struct Message { + prefix: Option, + command: String, + args: Vec, +} + +impl ToString for Message { + fn to_string(&self) -> String { + let mut line = String::new(); + if let Some(ref prefix) = self.prefix { + line = line + ":" + &prefix + " "; + } + line += &self.command; + for (i, arg) in self.args.iter().enumerate() { + if i == self.args.len() - 1 + && (arg.contains(' ') || arg.contains(':') || arg.is_empty()) + { + line = line + " :" + arg; + } else { + line = line + " " + arg; + } + } + line + } +} + +impl From for Message { + fn from(s: String) -> Self { + let mut prefix = None; + let mut command = String::new(); + let mut args = Vec::new(); + + let words = &mut s.split(' ').enumerate(); + while let Some((index, word)) = words.next() { + if word.chars().next() == Some(':') { + if index == 0 { + prefix = Some(word[1..].to_string()); + } else { + args.push(words.fold(word[1..].to_string(), |s, (i, w)| s + " " + w)); + } + continue; + } + if command.is_empty() { + command = word.to_string(); + } else { + args.push(word.to_string()) + } + } + Self { + prefix, + command, + args, + } + } +} + +struct MessageStream { + lines: Lines, +} + +impl MessageStream { + fn new(io: T) -> MessageStream { + Self { lines: io.lines() } + } +} + +impl Stream for MessageStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + let line = ready!(Pin::new(&mut this.lines).poll_next_line(cx))?; + let line = line.map(|s| Message::from(s)); + Poll::Ready(Ok(line).transpose()) + } +} + +struct MessageSink { + io: Pin>, + send_buffer: Vec, +} + +impl MessageSink { + fn new(io: T) -> MessageSink { + MessageSink { + io: Box::pin(io), + send_buffer: Vec::new(), + } + } +} + +impl Sink for MessageSink { + type Error = anyhow::Error; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_flush(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> std::result::Result<(), Self::Error> { + self.send_buffer + .extend_from_slice(item.to_string().as_bytes()); + self.send_buffer.push('\n' as u8); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = Pin::into_inner(self); + + while !this.send_buffer.is_empty() { + match this.io.as_mut().poll_write(cx, &this.send_buffer) { + Poll::Ready(Ok(n)) => this.send_buffer.drain(0..n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Pending => return Poll::Pending, + }; + } + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_flush(cx) + } +} + +struct MessageCodec { + codec: LinesCodec, +} + +impl MessageCodec { + fn new() -> Self { + Self { + codec: LinesCodec::new(), + } + } +} + +impl Decoder for MessageCodec { + type Item = Message; + type Error = anyhow::Error; + + fn decode( + &mut self, + src: &mut BytesMut, + ) -> std::result::Result, Self::Error> { + match self.codec.decode(src) { + Ok(result) => match result { + Some(s) => return Ok(Some(Message::from(s))), + None => Ok(None), + }, + Err(e) => Err(e.into()), + } + } +} + +impl Encoder for MessageCodec { + type Error = anyhow::Error; + + fn encode( + &mut self, + item: Message, + dst: &mut BytesMut, + ) -> std::result::Result<(), Self::Error> { + match self.codec.encode(item.to_string(), dst) { + Ok(result) => Ok(result), + Err(e) => Err(e.into()), + } + } +} + +struct Channel { + name: String, + topic: Option, + users: HashMap, +} + +#[derive(Deserialize, Debug)] +struct Config { + addr: String, + port: u16, +} + +impl Config { + fn from_toml>(file_path: P) -> Result { + let contents = fs::read_to_string(file_path)?; + Ok(toml::from_str::(&contents)?) + } + fn listen_addr(&self) -> Result { + let listen_addr = format!("{}:{}", self.addr, self.port); + Ok(listen_addr.parse::()?) + } +} + +struct IrcdState { + config: Config, + pool: AnyPool, + users: Mutex>, + channels: Mutex>, +} + +impl IrcdState { + fn new(config: Config, pool: AnyPool) -> Arc { + Arc::new(Self { + config, + pool, + users: Mutex::new(HashMap::new()), + channels: Mutex::new(HashMap::new()), + }) + } +} + +struct Ircd { + state: Arc, +} + +impl Ircd { + fn new(config: Config, pool: AnyPool) -> Self { + Self { + state: IrcdState::new(config, pool), + } + } + async fn start(&mut self) -> Result<()> { + let listener = TcpListener::bind(self.state.config.listen_addr()?).await?; + let mut incoming = TcpListenerStream::new(listener); + while let Some(socket) = incoming.next().await { + let socket = socket?; + let addr = socket.peer_addr()?; + let client = self.accept_user(socket).await?; + tokio::spawn(Self::handle_user(self.state.clone(), client)); + } + Ok(()) + } + async fn accept_user(&self, socket: TcpStream) -> Result { + Ok(User::from(self.state.clone(), socket)?) + } + + async fn handle_user(state: Arc, mut user: User) -> Result<()> { + let addr = user.addr; + println!("New user from: {}", &addr); + while let Ok(Some(message)) = user.io.next().await.transpose() { + Self::process_message(state.clone(), (), message).await? + } + println!("User at {} disconnected", &addr); + Ok(()) + } + + async fn process_message(state: Arc, user: (), message: Message) -> Result<()> { + println!("{:?}", message); + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let config = Config::from_toml("config.toml")?; + let pool = AnyPool::connect(&env::var("DATABASE_URL")?).await?; + let mut server = Ircd::new(config, pool); + server.start().await + //Ok(()) +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::BufReader; + use tokio::io::BufWriter; + use tokio_util::codec::Framed; + + use super::*; + + #[test] + fn message_from_string() { + let mut message = Message::from(":vance!vance@localhost PART #dcc :hey :yall".to_owned()); + assert_eq!( + message, + Message { + prefix: Some("vance!vance@localhost".to_owned()), + command: "PART".to_owned(), + args: vec!["#dcc".to_owned(), "hey :yall".to_owned()], + } + ); + message = Message::from("PING localhost".to_owned()); + assert_eq!( + message, + Message { + prefix: None, + command: "PING".to_owned(), + args: vec!["localhost".to_owned()], + } + ); + } + + #[test] + fn message_to_string() { + let mut message = Message { + prefix: Some("vance!vance@localhost".to_owned()), + command: "PART".to_owned(), + args: vec!["#dcc".to_owned(), "hey :yall".to_owned()], + }; + assert_eq!( + message.to_string(), + ":vance!vance@localhost PART #dcc :hey :yall" + ); + message = Message { + prefix: None, + command: "PING".to_owned(), + args: vec!["localhost".to_owned()], + }; + assert_eq!(message.to_string(), "PING localhost".to_owned()) + } + + #[tokio::test] + async fn message_stream() { + let io = BufReader::new(Cursor::new( + ":vance!vance@localhost PART #dcc :hey :yall".to_string(), + )); + let mut stream = MessageStream::new(io); + while let Some(message) = stream.next().await { + assert_eq!( + message.unwrap(), + Message { + prefix: Some("vance!vance@localhost".to_owned()), + command: "PART".to_owned(), + args: vec!["#dcc".to_owned(), "hey :yall".to_owned()], + } + ) + } + } + + #[tokio::test] + async fn message_sink() { + let io = BufWriter::new(Vec::new()); + let mut sink = MessageSink::new(io); + let mut message = Message { + prefix: Some("vance!vance@localhost".to_owned()), + command: "PART".to_owned(), + args: vec!["#dcc".to_owned(), "hey :yall".to_owned()], + }; + let mut string = message.to_string(); + string.push('\n'); + sink.send(message).await.unwrap(); + assert_eq!(sink.io.buffer(), string.as_bytes()) + } + + #[tokio::test] + async fn message_decoder() { + let mut io = Framed::new( + BufReader::new(BufWriter::new(Cursor::new(Vec::new()))), + MessageCodec::new(), + ); + let mut sent_message = Message { + prefix: Some("vance!vance@localhost".to_owned()), + command: "PART".to_owned(), + args: vec!["#dcc".to_owned(), "hey :yall".to_owned()], + }; + let _ = io.send(sent_message.clone()); + while let Ok(Some(recv_message)) = io.next().await.transpose() { + assert_eq!(sent_message, recv_message); + } + } +}