initial commit

This commit is contained in:
vance 2023-01-05 18:32:57 -08:00
commit 4a18775476
5 changed files with 469 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
/target
/.idea
Cargo.lock
ircd.db

22
Cargo.toml Normal file
View File

@ -0,0 +1,22 @@
[package]
name = "ircd"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
futures= "*"
anyhow = "*"
tokio = { version = "*", features = ["full"] }
tokio-util = { version = "*", features = ["codec"] }
bytes = "*"
tokio-stream = { version = "*", features = ["net"] }
sqlx = { version = "*", features = ["runtime-tokio-rustls", "all-databases"] }
toml = { version = "*" }
serde = { version = "*" }
serde_derive = { version = "*" }
[profile.release]
lto = "thin"
opt-level = "z"

2
README.md Normal file
View File

@ -0,0 +1,2 @@
# ircd
a simple rust ircd

2
config.toml Normal file
View File

@ -0,0 +1,2 @@
addr = "127.0.0.1"
port = 6667

439
src/main.rs Normal file
View File

@ -0,0 +1,439 @@
use std::collections::HashMap;
use std::env;
use std::fs;
use std::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use anyhow::Result;
use bytes::BytesMut;
use futures::{ready, Sink, SinkExt, Stream};
use serde_derive::Deserialize;
use sqlx::{AnyPool, SqlitePool};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, Lines};
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::StreamExt;
use tokio_util::codec::{Decoder, Encoder, Framed, LinesCodec};
struct UnregisteredUserState {
nick: Option<String>,
user: Option<String>,
realname: Option<String>,
}
impl UnregisteredUserState {
fn new() -> Self {
Self {
nick: None,
user: None,
realname: None,
}
}
}
struct RegisteredUserState {
nick: String,
user: String,
realname: String,
}
enum UserState {
Unregistered(UnregisteredUserState),
Registered(RegisteredUserState),
}
struct User {
server_state: Arc<IrcdState>,
io: Framed<TcpStream, MessageCodec>,
state: UserState,
addr: SocketAddr,
channels: HashMap<String, Channel>,
}
impl User {
fn from(server_state: Arc<IrcdState>, socket: TcpStream) -> Result<Self> {
let addr = socket.peer_addr()?;
Ok(Self {
server_state,
io: Framed::new(socket, MessageCodec::new()),
state: UserState::Unregistered(UnregisteredUserState::new()),
addr,
channels: HashMap::new(),
})
}
}
#[derive(Clone, Debug, PartialEq)]
struct Message {
prefix: Option<String>,
command: String,
args: Vec<String>,
}
impl ToString for Message {
fn to_string(&self) -> String {
let mut line = String::new();
if let Some(ref prefix) = self.prefix {
line = line + ":" + &prefix + " ";
}
line += &self.command;
for (i, arg) in self.args.iter().enumerate() {
if i == self.args.len() - 1
&& (arg.contains(' ') || arg.contains(':') || arg.is_empty())
{
line = line + " :" + arg;
} else {
line = line + " " + arg;
}
}
line
}
}
impl From<String> for Message {
fn from(s: String) -> Self {
let mut prefix = None;
let mut command = String::new();
let mut args = Vec::new();
let words = &mut s.split(' ').enumerate();
while let Some((index, word)) = words.next() {
if word.chars().next() == Some(':') {
if index == 0 {
prefix = Some(word[1..].to_string());
} else {
args.push(words.fold(word[1..].to_string(), |s, (i, w)| s + " " + w));
}
continue;
}
if command.is_empty() {
command = word.to_string();
} else {
args.push(word.to_string())
}
}
Self {
prefix,
command,
args,
}
}
}
struct MessageStream<T: AsyncRead + AsyncBufRead + Unpin> {
lines: Lines<T>,
}
impl<T: AsyncRead + AsyncBufRead + Unpin> MessageStream<T> {
fn new(io: T) -> MessageStream<T> {
Self { lines: io.lines() }
}
}
impl<T: AsyncRead + AsyncBufRead + Unpin> Stream for MessageStream<T> {
type Item = Result<Message>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = Pin::into_inner(self);
let line = ready!(Pin::new(&mut this.lines).poll_next_line(cx))?;
let line = line.map(|s| Message::from(s));
Poll::Ready(Ok(line).transpose())
}
}
struct MessageSink<T: AsyncWrite + Unpin> {
io: Pin<Box<T>>,
send_buffer: Vec<u8>,
}
impl<T: AsyncWrite + Unpin> MessageSink<T> {
fn new(io: T) -> MessageSink<T> {
MessageSink {
io: Box::pin(io),
send_buffer: Vec::new(),
}
}
}
impl<T: AsyncWrite + Unpin> Sink<Message> for MessageSink<T> {
type Error = anyhow::Error;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> std::result::Result<(), Self::Error> {
self.send_buffer
.extend_from_slice(item.to_string().as_bytes());
self.send_buffer.push('\n' as u8);
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
let this = Pin::into_inner(self);
while !this.send_buffer.is_empty() {
match this.io.as_mut().poll_write(cx, &this.send_buffer) {
Poll::Ready(Ok(n)) => this.send_buffer.drain(0..n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
Poll::Pending => return Poll::Pending,
};
}
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.poll_flush(cx)
}
}
struct MessageCodec {
codec: LinesCodec,
}
impl MessageCodec {
fn new() -> Self {
Self {
codec: LinesCodec::new(),
}
}
}
impl Decoder for MessageCodec {
type Item = Message;
type Error = anyhow::Error;
fn decode(
&mut self,
src: &mut BytesMut,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
match self.codec.decode(src) {
Ok(result) => match result {
Some(s) => return Ok(Some(Message::from(s))),
None => Ok(None),
},
Err(e) => Err(e.into()),
}
}
}
impl Encoder<Message> for MessageCodec {
type Error = anyhow::Error;
fn encode(
&mut self,
item: Message,
dst: &mut BytesMut,
) -> std::result::Result<(), Self::Error> {
match self.codec.encode(item.to_string(), dst) {
Ok(result) => Ok(result),
Err(e) => Err(e.into()),
}
}
}
struct Channel {
name: String,
topic: Option<String>,
users: HashMap<String, User>,
}
#[derive(Deserialize, Debug)]
struct Config {
addr: String,
port: u16,
}
impl Config {
fn from_toml<P: AsRef<Path>>(file_path: P) -> Result<Self> {
let contents = fs::read_to_string(file_path)?;
Ok(toml::from_str::<Self>(&contents)?)
}
fn listen_addr(&self) -> Result<SocketAddr> {
let listen_addr = format!("{}:{}", self.addr, self.port);
Ok(listen_addr.parse::<SocketAddr>()?)
}
}
struct IrcdState {
config: Config,
pool: AnyPool,
users: Mutex<HashMap<String, User>>,
channels: Mutex<HashMap<String, Channel>>,
}
impl IrcdState {
fn new(config: Config, pool: AnyPool) -> Arc<Self> {
Arc::new(Self {
config,
pool,
users: Mutex::new(HashMap::new()),
channels: Mutex::new(HashMap::new()),
})
}
}
struct Ircd {
state: Arc<IrcdState>,
}
impl Ircd {
fn new(config: Config, pool: AnyPool) -> Self {
Self {
state: IrcdState::new(config, pool),
}
}
async fn start(&mut self) -> Result<()> {
let listener = TcpListener::bind(self.state.config.listen_addr()?).await?;
let mut incoming = TcpListenerStream::new(listener);
while let Some(socket) = incoming.next().await {
let socket = socket?;
let addr = socket.peer_addr()?;
let client = self.accept_user(socket).await?;
tokio::spawn(Self::handle_user(self.state.clone(), client));
}
Ok(())
}
async fn accept_user(&self, socket: TcpStream) -> Result<User> {
Ok(User::from(self.state.clone(), socket)?)
}
async fn handle_user(state: Arc<IrcdState>, mut user: User) -> Result<()> {
let addr = user.addr;
println!("New user from: {}", &addr);
while let Ok(Some(message)) = user.io.next().await.transpose() {
Self::process_message(state.clone(), (), message).await?
}
println!("User at {} disconnected", &addr);
Ok(())
}
async fn process_message(state: Arc<IrcdState>, user: (), message: Message) -> Result<()> {
println!("{:?}", message);
Ok(())
}
}
#[tokio::main]
async fn main() -> Result<()> {
let config = Config::from_toml("config.toml")?;
let pool = AnyPool::connect(&env::var("DATABASE_URL")?).await?;
let mut server = Ircd::new(config, pool);
server.start().await
//Ok(())
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use tokio::io::BufReader;
use tokio::io::BufWriter;
use tokio_util::codec::Framed;
use super::*;
#[test]
fn message_from_string() {
let mut message = Message::from(":vance!vance@localhost PART #dcc :hey :yall".to_owned());
assert_eq!(
message,
Message {
prefix: Some("vance!vance@localhost".to_owned()),
command: "PART".to_owned(),
args: vec!["#dcc".to_owned(), "hey :yall".to_owned()],
}
);
message = Message::from("PING localhost".to_owned());
assert_eq!(
message,
Message {
prefix: None,
command: "PING".to_owned(),
args: vec!["localhost".to_owned()],
}
);
}
#[test]
fn message_to_string() {
let mut message = Message {
prefix: Some("vance!vance@localhost".to_owned()),
command: "PART".to_owned(),
args: vec!["#dcc".to_owned(), "hey :yall".to_owned()],
};
assert_eq!(
message.to_string(),
":vance!vance@localhost PART #dcc :hey :yall"
);
message = Message {
prefix: None,
command: "PING".to_owned(),
args: vec!["localhost".to_owned()],
};
assert_eq!(message.to_string(), "PING localhost".to_owned())
}
#[tokio::test]
async fn message_stream() {
let io = BufReader::new(Cursor::new(
":vance!vance@localhost PART #dcc :hey :yall".to_string(),
));
let mut stream = MessageStream::new(io);
while let Some(message) = stream.next().await {
assert_eq!(
message.unwrap(),
Message {
prefix: Some("vance!vance@localhost".to_owned()),
command: "PART".to_owned(),
args: vec!["#dcc".to_owned(), "hey :yall".to_owned()],
}
)
}
}
#[tokio::test]
async fn message_sink() {
let io = BufWriter::new(Vec::new());
let mut sink = MessageSink::new(io);
let mut message = Message {
prefix: Some("vance!vance@localhost".to_owned()),
command: "PART".to_owned(),
args: vec!["#dcc".to_owned(), "hey :yall".to_owned()],
};
let mut string = message.to_string();
string.push('\n');
sink.send(message).await.unwrap();
assert_eq!(sink.io.buffer(), string.as_bytes())
}
#[tokio::test]
async fn message_decoder() {
let mut io = Framed::new(
BufReader::new(BufWriter::new(Cursor::new(Vec::new()))),
MessageCodec::new(),
);
let mut sent_message = Message {
prefix: Some("vance!vance@localhost".to_owned()),
command: "PART".to_owned(),
args: vec!["#dcc".to_owned(), "hey :yall".to_owned()],
};
let _ = io.send(sent_message.clone());
while let Ok(Some(recv_message)) = io.next().await.transpose() {
assert_eq!(sent_message, recv_message);
}
}
}