Skip to content

Commit

Permalink
host: handle connectRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytill committed Jan 27, 2024
1 parent d4501e0 commit 341d7f5
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 25 deletions.
87 changes: 78 additions & 9 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
mod db;
pub mod message;

use std::path::Path;
use std::{fs, io, path::Path};

use directories::ProjectDirs;
use regex::Regex;
use rusqlite::Connection;
use serde_json::Value;

use message::{
Action, Query, Request, Response, ResponseAction, SaveResponsePayload, SearchResponsePayload,
Version,
Action, ConnectResponsePayload, Query, Request, Response, ResponseAction, SaveResponsePayload,
SearchResponsePayload, Version,
};

#[derive(Debug)]
enum ErrorImpl {
Io(io::Error),
Sqlite(rusqlite::Error),
Semver(semver::Error),
MissingHomeDir,
MissingVersion,
InvalidVersion,
}
Expand All @@ -35,14 +37,22 @@ impl Error {
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self.inner.as_ref() {
ErrorImpl::Io(e) => write!(f, "IO error: {}", e),
ErrorImpl::Sqlite(e) => write!(f, "SQLite error: {}", e),
ErrorImpl::Semver(e) => write!(f, "Semver error: {}", e),
ErrorImpl::MissingHomeDir => write!(f, "Missing home directory"),
ErrorImpl::MissingVersion => write!(f, "Missing version"),
ErrorImpl::InvalidVersion => write!(f, "Invalid version"),
}
}
}

impl From<io::Error> for Error {
fn from(other: io::Error) -> Self {
Self::new(ErrorImpl::Io(other))
}
}

impl From<rusqlite::Error> for Error {
fn from(other: rusqlite::Error) -> Self {
Self::new(ErrorImpl::Sqlite(other))
Expand Down Expand Up @@ -74,8 +84,40 @@ impl From<message::Error> for Error {

impl std::error::Error for Error {}

#[derive(Debug)]
enum Connection {
InMemory(rusqlite::Connection),
Persistent(rusqlite::Connection),
}

impl Connection {
fn inner(&self) -> &rusqlite::Connection {
match self {
Self::InMemory(connection) => connection,
Self::Persistent(connection) => connection,
}
}

fn inner_mut(&mut self) -> &mut rusqlite::Connection {
match self {
Self::InMemory(connection) => connection,
Self::Persistent(connection) => connection,
}
}

fn upgrade(&mut self, db_path: impl AsRef<Path>) -> Result<&mut rusqlite::Connection, Error> {
if let Connection::Persistent(connection) = self {
return Ok(connection);
}
let connection = rusqlite::Connection::open(db_path)?;
let _prev = std::mem::replace(self, Self::Persistent(connection));
let connection = self.inner_mut();
Ok(connection)
}
}

pub struct Context {
connection: rusqlite::Connection,
connection: Connection,
process: Box<dyn Fn(Query) -> String>,
}

Expand All @@ -87,9 +129,10 @@ fn make_process(re: Regex) -> impl Fn(Query) -> String {
}

impl Context {
pub fn new(db_path: impl AsRef<Path>) -> Result<Self, Error> {
let mut connection = Connection::open(db_path)?;
pub fn new() -> Result<Self, Error> {
let mut connection = rusqlite::Connection::open_in_memory()?;
db::init_tables(&mut connection)?;
let connection = Connection::InMemory(connection);
let process_regex = Regex::new(r"\W+").unwrap();
let process = Box::new(make_process(process_regex));
let context = Self {
Expand All @@ -100,13 +143,39 @@ impl Context {
}
}

fn get_project_dirs() -> Result<ProjectDirs, Error> {
ProjectDirs::from("com.github", "henrytill", "noematic")
.ok_or(Error::new(ErrorImpl::MissingHomeDir))
}

pub fn handle_request(context: &mut Context, request: Request) -> Result<Response, Error> {
let version = request.version;
let correlation_id = request.correlation_id;

let connection = context.connection.inner();

match request.action {
Action::ConnectRequest { payload } => {
if payload.persist {
let db_path = {
let project_dirs: ProjectDirs = get_project_dirs()?;
let db_dir = project_dirs.data_dir();
fs::create_dir_all(&db_dir)?;
db_dir.join("db.sqlite3")
};
let connection = context.connection.upgrade(db_path)?;
db::init_tables(connection)?;
}
let payload = ConnectResponsePayload {};
let response = Response {
version,
action: ResponseAction::ConnectResponse { payload },
correlation_id,
};
Ok(response)
}
Action::SaveRequest { payload } => {
db::upsert_site(&context.connection, payload)?;
db::upsert_site(connection, payload)?;
let payload = SaveResponsePayload {};
let action = ResponseAction::SaveResponse { payload };
let response = Response {
Expand All @@ -119,7 +188,7 @@ pub fn handle_request(context: &mut Context, request: Request) -> Result<Respons
Action::SearchRequest { payload } => {
let process = context.process.as_ref();
let query = payload.query.clone();
let results = db::search_sites(&context.connection, payload, process)?;
let results = db::search_sites(connection, payload, process)?;
let payload = SearchResponsePayload { query, results };
let action = ResponseAction::SearchResponse { payload };
let response = Response {
Expand Down
18 changes: 2 additions & 16 deletions host/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{
fmt, fs,
fmt,
io::{self, BufReader, BufWriter, Read, Write},
};

use directories::ProjectDirs;
use serde_json::Value;

use noematic::{
Expand All @@ -18,7 +17,6 @@ enum Error {
Io(io::Error),
Json(serde_json::Error),
Noematic(noematic::Error),
MissingHomeDir,
UnsupportedVersion,
UnsupportedLength,
}
Expand All @@ -29,7 +27,6 @@ impl fmt::Display for Error {
Error::Io(e) => write!(f, "IO error: {}", e),
Error::Json(e) => write!(f, "JSON error: {}", e),
Error::Noematic(e) => write!(f, "{}", e),
Error::MissingHomeDir => write!(f, "Missing home directory"),
Error::UnsupportedVersion => write!(f, "Unsupported version"),
Error::UnsupportedLength => write!(f, "Unsupported length"),
}
Expand Down Expand Up @@ -99,22 +96,11 @@ fn write_response(writer: &mut impl Write, response: Response) -> Result<(), Err
Ok(())
}

fn get_project_dirs() -> Result<ProjectDirs, Error> {
ProjectDirs::from("com.github", "henrytill", "noematic").ok_or(Error::MissingHomeDir)
}

fn main() -> Result<(), Error> {
let mut reader = BufReader::new(io::stdin());
let mut writer = BufWriter::new(io::stdout());

let db_path = {
let project_dirs: ProjectDirs = get_project_dirs()?;
let db_dir = project_dirs.data_dir();
fs::create_dir_all(&db_dir)?;
db_dir.join("db.sqlite3")
};

let mut context = Context::new(db_path)?;
let mut context = Context::new()?;

while let Some(message) = read(&mut reader)? {
let json: Value = serde_json::from_slice(&message)?;
Expand Down
12 changes: 12 additions & 0 deletions host/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,19 @@ pub struct Request {
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "action")]
pub enum Action {
#[serde(rename = "connectRequest")]
ConnectRequest { payload: ConnectPayload },
#[serde(rename = "saveRequest")]
SaveRequest { payload: SavePayload },
#[serde(rename = "searchRequest")]
SearchRequest { payload: SearchPayload },
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ConnectPayload {
pub persist: bool,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SavePayload {
pub url: Url,
Expand All @@ -123,12 +130,17 @@ pub struct Response {
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "action")]
pub enum ResponseAction {
#[serde(rename = "connectResponse")]
ConnectResponse { payload: ConnectResponsePayload },
#[serde(rename = "saveResponse")]
SaveResponse { payload: SaveResponsePayload },
#[serde(rename = "searchResponse")]
SearchResponse { payload: SearchResponsePayload },
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ConnectResponsePayload {}

#[derive(Serialize, Deserialize, Debug)]
pub struct SaveResponsePayload {}

Expand Down

0 comments on commit 341d7f5

Please sign in to comment.