diff --git a/src/commands.rs b/src/commands.rs index 0b5a6b5..234cf01 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -23,9 +23,7 @@ async fn amdctl(ctx: Context<'_>) -> Result<(), Error> { Ok(()) } -/// Every function that is defined *should* be added to the -/// returned vector in get_commands to ensure it is registered (available for the user) -/// when the bot goes online. +/// Returns a vector containg [Poise Commands][`poise::Command`] pub fn get_commands() -> Vec> { vec![amdctl()] } diff --git a/src/graphql/models.rs b/src/graphql/models.rs index bf111d4..2826e3c 100644 --- a/src/graphql/models.rs +++ b/src/graphql/models.rs @@ -25,6 +25,7 @@ pub struct Streak { pub max_streak: i32, } +/// Represents a record of the Member relation in [Root][https://www.github.com/amfoss/root]. #[derive(Clone, Debug, Deserialize)] pub struct Member { #[serde(rename = "memberId")] @@ -35,5 +36,5 @@ pub struct Member { #[serde(rename = "groupId")] pub group_id: u32, #[serde(default)] - pub streak: Vec, + pub streak: Vec, // Note that Root will NOT have multiple Streak elements but it may be an empty list which is why we use a vector here } diff --git a/src/graphql/queries.rs b/src/graphql/queries.rs index 971f770..c8256c3 100644 --- a/src/graphql/queries.rs +++ b/src/graphql/queries.rs @@ -16,10 +16,10 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . */ use crate::graphql::models::{Member, Streak}; -use anyhow::Context; +use anyhow::{anyhow, Context}; -pub async fn fetch_members() -> Result, anyhow::Error> { - let request_url = std::env::var("ROOT_URL").expect("ROOT_URL not found"); +pub async fn fetch_members() -> anyhow::Result> { + let request_url = std::env::var("ROOT_URL").context("ROOT_URL not found in ENV")?; let client = reqwest::Client::new(); let query = r#" @@ -43,6 +43,10 @@ pub async fn fetch_members() -> Result, anyhow::Error> { .await .context("Failed to successfully post request")?; + if !response.status().is_success() { + return Err(anyhow!("Server responded with an error: {:?}", response.status())); + } + let response_json: serde_json::Value = response .json() .await @@ -52,7 +56,7 @@ pub async fn fetch_members() -> Result, anyhow::Error> { .get("data") .and_then(|data| data.get("members")) .and_then(|members| members.as_array()) - .ok_or_else(|| anyhow::anyhow!("Malformed response: 'members' field missing or invalid"))?; + .ok_or_else(|| anyhow::anyhow!("Malformed response: Could not access Members from {}", response_json))?; let members: Vec = serde_json::from_value(serde_json::Value::Array(members.clone())) .context("Failed to parse 'members' into Vec")?; @@ -61,7 +65,7 @@ pub async fn fetch_members() -> Result, anyhow::Error> { } pub async fn increment_streak(member: &mut Member) -> anyhow::Result<()> { - let request_url = std::env::var("ROOT_URL").context("ROOT_URL was not found")?; + let request_url = std::env::var("ROOT_URL").context("ROOT_URL was not found in ENV")?; let client = reqwest::Client::new(); let mutation = format!( @@ -73,12 +77,17 @@ pub async fn increment_streak(member: &mut Member) -> anyhow::Result<()> { }}"#, member.member_id ); + let response = client .post(request_url) .json(&serde_json::json!({"query": mutation})) .send() .await - .context("Root Request failed")?; + .context("Failed to succesfully post query to Root")?; + + if !response.status().is_success() { + return Err(anyhow!("Server responded with an error: {:?}", response.status())); + } // Handle the streak vector if member.streak.is_empty() { @@ -101,7 +110,7 @@ pub async fn increment_streak(member: &mut Member) -> anyhow::Result<()> { } pub async fn reset_streak(member: &mut Member) -> anyhow::Result<()> { - let request_url = std::env::var("ROOT_URL").context("ROOT_URL was not found")?; + let request_url = std::env::var("ROOT_URL").context("ROOT_URL was not found in the ENV")?; let client = reqwest::Client::new(); let mutation = format!( @@ -120,34 +129,41 @@ pub async fn reset_streak(member: &mut Member) -> anyhow::Result<()> { .json(&serde_json::json!({ "query": mutation })) .send() .await - .context("Root Request failed")?; + .context("Failed to succesfully post query to Root")?; + + if !response.status().is_success() { + return Err(anyhow!("Server responded with an error: {:?}", response.status())); + } let response_json: serde_json::Value = response .json() .await .context("Failed to parse response JSON")?; + if let Some(data) = response_json .get("data") .and_then(|data| data.get("resetStreak")) { - let current_streak = data.get("currentStreak").and_then(|v| v.as_i64()).unwrap(); - - let max_streak = data.get("maxStreak").and_then(|v| v.as_i64()).unwrap(); + let current_streak = data.get("currentStreak").and_then(|v| v.as_i64()).ok_or_else(|| anyhow!("current_streak was parsed as None"))? as i32; + let max_streak = data.get("maxStreak").and_then(|v| v.as_i64()).ok_or_else(|| anyhow!("max_streak was parsed as None"))? as i32; // Update the member's streak vector if member.streak.is_empty() { // If the streak vector is empty, initialize it with the returned values member.streak.push(Streak { - current_streak: current_streak as i32, - max_streak: max_streak as i32, + current_streak, + max_streak, }); } else { // Otherwise, update the first streak entry for streak in &mut member.streak { - streak.current_streak = current_streak as i32; - streak.max_streak = max_streak as i32; + streak.current_streak = current_streak; + streak.max_streak = max_streak; } } + } else { + return Err(anyhow!("Failed to access data from {}", response_json)); } + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 6c17bd5..dca4f68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,24 +15,20 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ -/// Stores all the commands for the bot. +/// Contains all the commands for the bot. mod commands; -/// Responsible for queries, models and mutation requests sent to and from -/// [root's](https://www.github.com/amfoss/root) graphql interace. +/// Interact with [Root's](https://www.github.com/amfoss/root) GraphQL interace. mod graphql; -/// Stores Discord IDs that are needed across the bot. +/// Contains Discord IDs that may be needed across the bot. mod ids; -/// This module is a simple cron equivalent. It spawns threads for the regular [`Task`]s that need to be completed. +/// This module is a simple cron equivalent. It spawns threads for the [`Task`]s that need to be completed. mod scheduler; -/// An interface to define a job that needs to be executed regularly, for example checking for status updates daily. +/// A trait to define a job that needs to be executed regularly, for example checking for status updates daily. mod tasks; /// Misc. helper functions that don't really have a place anywhere else. mod utils; -use ids::{ - AI_ROLE_ID, ARCHIVE_ROLE_ID, DEVOPS_ROLE_ID, MOBILE_ROLE_ID, RESEARCH_ROLE_ID, - ROLES_MESSAGE_ID, SYSTEMS_ROLE_ID, WEB_ROLE_ID, -}; +use std::collections::HashMap; use anyhow::Context as _; use poise::{Context as PoiseContext, Framework, FrameworkOptions, PrefixFrameworkOptions}; @@ -41,21 +37,21 @@ use serenity::{ client::{Context as SerenityContext, FullEvent}, model::{gateway::GatewayIntents, id::MessageId}, }; -use std::collections::HashMap; + +use ids::{ + AI_ROLE_ID, ARCHIVE_ROLE_ID, DEVOPS_ROLE_ID, MOBILE_ROLE_ID, RESEARCH_ROLE_ID, + ROLES_MESSAGE_ID, SYSTEMS_ROLE_ID, WEB_ROLE_ID, +}; pub type Error = Box; pub type Context<'a> = PoiseContext<'a, Data, Error>; -/// Runtime allocated storage for the bot. pub struct Data { pub reaction_roles: HashMap, } /// This function is responsible for allocating the necessary fields /// in [`Data`], before it is passed to the bot. -/// -/// Currently, it only needs to store the (emoji, [`RoleId`]) pair used -/// for assigning roles to users who react to a particular message. pub fn initialize_data() -> Data { let mut data = Data { reaction_roles: HashMap::new(), @@ -92,7 +88,6 @@ pub fn initialize_data() -> Data { ), ]; - // Populate reaction_roles map. data.reaction_roles .extend::>(roles.into()); @@ -102,25 +97,20 @@ pub fn initialize_data() -> Data { #[tokio::main] async fn main() -> Result<(), Error> { dotenv::dotenv().ok(); - let discord_token = std::env::var("DISCORD_TOKEN").context("'DISCORD_TOKEN' was not found")?; + let discord_token = std::env::var("DISCORD_TOKEN").context("DISCORD_TOKEN was not found in the ENV")?; let framework = Framework::builder() .options(FrameworkOptions { - // Load bot commands commands: commands::get_commands(), - // Pass the event handler function event_handler: |ctx, event, framework, data| { Box::pin(event_handler(ctx, event, framework, data)) }, - // General bot settings, set to default except for prefix prefix_options: PrefixFrameworkOptions { prefix: Some(String::from("$")), ..Default::default() }, ..Default::default() }) - // This function that's passed to setup() is called just as - // the bot is ready to start. .setup(|ctx, _ready, framework| { Box::pin(async move { poise::builtins::register_globally(ctx, &framework.options().commands).await?; @@ -141,18 +131,12 @@ async fn main() -> Result<(), Error> { .await .context("Failed to create the Serenity client")?; - client.start().await.context("Error running the bot")?; + client.start().await.context("Failed to start the Serenity client")?; Ok(()) } -/// Handles various events from Discord, such as reactions. -/// -/// Current functionality includes: -/// - Adding roles to users based on reactions. -/// - Removing roles from users when their reactions are removed. -/// -/// TODO: Refactor for better readability and modularity. +/// Handles various events from Discord, such as sending messages or adding reactions to messages. async fn event_handler( ctx: &SerenityContext, event: &FullEvent, @@ -160,70 +144,67 @@ async fn event_handler( data: &Data, ) -> Result<(), Error> { match event { - // Handle reactions being added. FullEvent::ReactionAdd { add_reaction } => { - // Check if a role needs to be added i.e check if the reaction was added to [`ROLES_MESSAGE_ID`] if is_relevant_reaction(add_reaction.message_id, &add_reaction.emoji, data) { // This check for a guild_id isn't strictly necessary, since we're already checking // if the reaction was added to the [`ROLES_MESSAGE_ID`] which *should* point to a // message in the server. if let Some(guild_id) = add_reaction.guild_id { - if let Ok(member) = guild_id.member(ctx, add_reaction.user_id.unwrap()).await { - if let Err(e) = member - .add_role( - &ctx.http, - data.reaction_roles + if let Some(user_id) = add_reaction.user_id { + if let Ok(member) = guild_id.member(ctx, user_id).await { + if let Err(e) = member + .add_role( + &ctx.http, + data.reaction_roles .get(&add_reaction.emoji) .expect("Hard coded value verified earlier."), - ) - .await - { - // TODO: Replace with tracing - eprintln!("Error adding role: {:?}", e); + ) + .await + { + // TODO: Replace with tracing + eprintln!("Error adding role: {:?}", e); + } } } } } } - // Handle reactions being removed. FullEvent::ReactionRemove { removed_reaction } => { - // Check if a role needs to be added i.e check if the reaction was added to [`ROLES_MESSAGE_ID`] if is_relevant_reaction(removed_reaction.message_id, &removed_reaction.emoji, data) { // This check for a guild_id isn't strictly necessary, since we're already checking // if the reaction was added to the [`ROLES_MESSAGE_ID`] which *should* point to a // message in the server. if let Some(guild_id) = removed_reaction.guild_id { - if let Ok(member) = guild_id - .member(ctx, removed_reaction.user_id.unwrap()) - .await - { - if let Err(e) = member - .remove_role( - &ctx.http, - *data - .reaction_roles + if let Some(user_id) = removed_reaction.user_id { + if let Ok(member) = guild_id + .member(ctx, user_id) + .await + { + if let Err(e) = member + .remove_role( + &ctx.http, + *data.reaction_roles .get(&removed_reaction.emoji) .expect("Hard coded value verified earlier."), - ) - .await - { - eprintln!("Error removing role: {:?}", e); + ) + .await + { + eprintln!("Error removing role: {:?}", e); + } } } } } } - // Ignore all other events for now. _ => {} } Ok(()) } -/// Helper function to check if a reaction was made to [`ROLES_MESSAGE_ID`] and if -/// [`Data::reaction_roles`] contains a relevant (emoji, role) pair. +/// Helper function to check if a reaction was made to [`ids::ROLES_MESSAGE_ID`] and if [`Data::reaction_roles`] contains a relevant (emoji, role) pair. fn is_relevant_reaction(message_id: MessageId, emoji: &ReactionType, data: &Data) -> bool { message_id == MessageId::new(ROLES_MESSAGE_ID) && data.reaction_roles.contains_key(emoji) } diff --git a/src/scheduler.rs b/src/scheduler.rs index eed470f..918227c 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -20,11 +20,7 @@ use serenity::client::Context as SerenityContext; use tokio::spawn; -/// Spawns a thread for each [`Task`]. -/// -/// [`SerenityContext`] is passed along with it so that they can -/// call any required Serenity functions without creating a new [`serenity::http`] -/// interface with a Discord token. +/// Spawns a sleepy thread for each [`Task`]. pub async fn run_scheduler(ctx: SerenityContext) { let tasks = get_tasks(); @@ -39,6 +35,8 @@ async fn schedule_task(ctx: SerenityContext, task: Box) { let next_run_in = task.run_in(); tokio::time::sleep(next_run_in).await; - task.run(ctx.clone()).await; + if let Err(e) = task.run(ctx.clone()).await { + eprint!("Error running task: {}", e); + } } } diff --git a/src/tasks/mod.rs b/src/tasks/mod.rs index acb5fb6..7a598ec 100644 --- a/src/tasks/mod.rs +++ b/src/tasks/mod.rs @@ -22,6 +22,7 @@ use crate::{tasks::status_update::check_status_updates, utils::time::time_until} use async_trait::async_trait; use serenity::client::Context; use tokio::time::Duration; +use anyhow::Result; /// A [`Task`] is any job that needs to be executed on a regular basis. /// A task has a function [`Task::run_in`] that returns the time till the @@ -31,7 +32,7 @@ use tokio::time::Duration; pub trait Task: Send + Sync { fn name(&self) -> &'static str; fn run_in(&self) -> Duration; - async fn run(&self, ctx: Context); + async fn run(&self, ctx: Context) -> Result<()>; } /// Analogous to [`crate::commands::get_commands`], every task that is defined @@ -50,10 +51,10 @@ impl Task for StatusUpdateCheck { } fn run_in(&self) -> Duration { - time_until(00, 40) + time_until(05, 40) } - async fn run(&self, ctx: Context) { - check_status_updates(ctx).await; + async fn run(&self, ctx: Context) -> anyhow::Result<()> { + check_status_updates(ctx).await } } diff --git a/src/tasks/status_update.rs b/src/tasks/status_update.rs index d1b93a0..ace119c 100644 --- a/src/tasks/status_update.rs +++ b/src/tasks/status_update.rs @@ -15,30 +15,34 @@ GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ +use anyhow::Result; use serenity::all::{ self, ChannelId, Context, CreateEmbed, CreateEmbedAuthor, CreateMessage, Embed, Member as DiscordMember, Message, MessageId, Timestamp, }; use crate::{ - graphql::{queries::{fetch_members, increment_streak, reset_streak}, models:: Member}, + graphql::{ + models::Member, + queries::{fetch_members, increment_streak, reset_streak}, + }, ids::{ GROUP_FOUR_CHANNEL_ID, GROUP_ONE_CHANNEL_ID, GROUP_THREE_CHANNEL_ID, GROUP_TWO_CHANNEL_ID, STATUS_UPDATE_CHANNEL_ID, }, - utils::time::get_five_am_timestamp, }; use std::fs::File; use std::{collections::HashMap, io::Write, str::FromStr}; +use chrono::{Local, Timelike}; use chrono_tz::Asia; -pub async fn check_status_updates(ctx: Context) { +pub async fn check_status_updates(ctx: Context) -> Result<()> { let mut members = match fetch_members().await { Ok(members) => members, Err(e) => { eprintln!("Failed to fetch members from Root. {}", e); - return; + return Err(e); } }; @@ -60,6 +64,8 @@ pub async fn check_status_updates(ctx: Context) { Err(e) => eprintln!("{}", e), _ => (), }; + + Ok(()) } async fn send_and_save_limiting_messages(channel_ids: &Vec, ctx: &Context) { @@ -88,7 +94,12 @@ async fn collect_updates(channel_ids: &Vec, ctx: &Context) -> Vec. */ -use chrono::{DateTime, Datelike, Local, TimeZone}; -use chrono_tz::Tz; -use tokio::time::Duration; +use chrono::{DateTime, Datelike, Local, NaiveDateTime, TimeZone}; +use chrono_tz::{Asia::Kolkata, Tz}; +use std::time::Duration; pub fn time_until(hour: u32, minute: u32) -> Duration { - let now = chrono::Local::now().with_timezone(&chrono_tz::Asia::Kolkata); - let today_run = now.date().and_hms(hour, minute, 0); + let now = Local::now().with_timezone(&Kolkata); + let today_run = Kolkata + .with_ymd_and_hms(now.year(), now.month(), now.day(), hour, minute, 0) + .single() + .expect("Valid datetime must be created"); let next_run = if now < today_run { today_run @@ -29,13 +32,6 @@ pub fn time_until(hour: u32, minute: u32) -> Duration { today_run + chrono::Duration::days(1) }; - let time_until = (next_run - now).to_std().unwrap(); - Duration::from_secs(time_until.as_secs()) -} - -pub fn get_five_am_timestamp(now: DateTime) -> DateTime { - chrono::Local - .ymd(now.year(), now.month(), now.day()) - .and_hms_opt(5, 0, 0) - .expect("Chrono must work.") + let duration = next_run.signed_duration_since(now); + Duration::from_secs(duration.num_seconds().max(0) as u64) }