Skip to content

Commit

Permalink
Add connected role config option
Browse files Browse the repository at this point in the history
  • Loading branch information
WorkingRobot committed May 5, 2024
1 parent 2413ff1 commit 98a08f8
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 23 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions web/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ max_connections_per_user: 3
discord:
redirect_uri: http://localhost:3000/api/v1/oauth/callback
queue_size_dm_threshold: 0
connected_role_id: 1236758675385749534
activity_update_interval: 60
activities:
- type: watching
Expand Down
2 changes: 1 addition & 1 deletion web/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ where

req.extensions_mut().insert(username);

Ok(srv.call(req).await?)
srv.call(req).await
}
.boxed_local()
}
Expand Down
3 changes: 2 additions & 1 deletion web/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde::Deserialize;
use serenity::all::{ActivityData, ActivityType, ChannelId, GuildId};
use serenity::all::{ActivityData, ActivityType, ChannelId, GuildId, RoleId};

#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize)]
pub enum DiscordActivityType {
Expand All @@ -25,6 +25,7 @@ pub struct DiscordConfig {
pub bot_token: String,
pub guild_id: GuildId,
pub log_channel_id: ChannelId,
pub connected_role_id: RoleId,
pub queue_size_dm_threshold: u32,
pub activities: Vec<DiscordActivity>,
pub activity_update_interval: u64,
Expand Down
10 changes: 10 additions & 0 deletions web/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,13 @@ pub async fn get_connection_ids_by_user_id(
.map(|id| DatabaseU64::from(id).0)
.collect())
}

pub async fn does_connection_id_exist(pool: &PgPool, connection_id: u64) -> Result<bool, Error> {
Ok(sqlx::query_scalar!(
r#"SELECT EXISTS(SELECT 1 FROM connections WHERE conn_user_id = $1)"#,
DatabaseU64(connection_id).as_db()
)
.fetch_one(pool)
.await?
.unwrap_or(false))
}
44 changes: 35 additions & 9 deletions web/src/discord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ use tokio::{
task::JoinHandle,
};

const COLOR_SUCCESS: Color = Color::from_rgb(16, 240, 12);
const COLOR_ERROR: Color = Color::from_rgb(235, 96, 94);
const COLOR_IN_QUEUE: Color = Color::BLITZ_BLUE;

#[derive(Clone)]
pub struct DiscordClient {
imp: Arc<DiscordClientImp>,
Expand Down Expand Up @@ -136,7 +140,7 @@ impl DiscordClient {
}

fn http(&self) -> &Http {
&self.imp.http.get().unwrap()
self.imp.http.get().unwrap()
}

fn config(&self) -> &DiscordConfig {
Expand Down Expand Up @@ -194,7 +198,7 @@ impl DiscordClient {
}))
.footer(CreateEmbedFooter::new("At"))
.timestamp(OffsetDateTime::now_utc())
.color(Color::from_rgb(16, 240, 12));
.color(COLOR_SUCCESS);

channel
.send_message(self.http(), CreateMessage::new().embed(embed))
Expand All @@ -208,7 +212,7 @@ impl DiscordClient {
.description("This discord account will now no longer receive queue notifications from me!\n\nNote: You'll still receive notifications for queues from other computers.")
.footer(CreateEmbedFooter::new("At"))
.timestamp(OffsetDateTime::now_utc())
.color(Color::from_rgb(235, 96, 94));
.color(COLOR_ERROR);

channel
.send_message(self.http(), CreateMessage::new().embed(embed))
Expand All @@ -217,6 +221,28 @@ impl DiscordClient {
Ok(())
}

pub async fn mark_user_connected(&self, user_id: UserId) -> Result<(), serenity::Error> {
self.http()
.add_member_role(
self.config().guild_id,
user_id,
self.config().connected_role_id,
Some("User is Connected"),
)
.await
}

pub async fn mark_user_disconnected(&self, user_id: UserId) -> Result<(), serenity::Error> {
self.http()
.remove_member_role(
self.config().guild_id,
user_id,
self.config().connected_role_id,
Some("User is Disconnected"),
)
.await
}

pub async fn send_queue_position(
&self,
user_id: UserId,
Expand Down Expand Up @@ -299,7 +325,7 @@ impl DiscordClient {
.description(format!("You've been logged in successfully! Thanks for using Waitingway!\n\nYour queue size was {}, which was completed in {}.", queue_start_size, format_duration(duration)))
.footer(CreateEmbedFooter::new("At"))
.timestamp(OffsetDateTime::now_utc())
.color(Color::from_rgb(16, 240, 12));
.color(COLOR_SUCCESS);

channel
.send_message(self.http(), CreateMessage::new().embed(embed))
Expand All @@ -326,7 +352,7 @@ impl DiscordClient {
)
.footer(CreateEmbedFooter::new("At"))
.timestamp(OffsetDateTime::now_utc())
.color(Color::from_rgb(235, 96, 94));
.color(COLOR_ERROR);

channel
.send_message(self.http(), CreateMessage::new().embed(embed))
Expand All @@ -350,7 +376,7 @@ impl DiscordClient {
))
.footer(CreateEmbedFooter::new("Last updated"))
.timestamp(now.assume_utc())
.color(Color::BLITZ_BLUE)
.color(COLOR_IN_QUEUE)
}
}

Expand Down Expand Up @@ -381,10 +407,10 @@ fn format_duration(duration: time::Duration) -> String {
let hours = hours % 24;

if days > 0 {
return format!("{}d {:02}:{:02}:{:02}", days, hours, minutes, seconds);
format!("{}d {:02}:{:02}:{:02}", days, hours, minutes, seconds)
} else if hours > 0 {
return format!("{:02}:{:02}:{:02}", hours, minutes, seconds);
format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
} else {
return format!("{:02}:{:02}", minutes, seconds);
format!("{:02}:{:02}", minutes, seconds)
}
}
14 changes: 12 additions & 2 deletions web/src/routes/api/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,26 @@ async fn delete_connection(
let id = id.into_inner();
let resp = db::delete_connection(&pool, *username, id)
.await
.map_err(|e| ErrorInternalServerError(e))?;
.map_err(ErrorInternalServerError)?;

if resp.rows_affected() == 0 {
return Err(ErrorNotFound("Connection not found"));
}


if !db::does_connection_id_exist(&pool, id)
.await
.map_err(ErrorInternalServerError)?
{
discord.mark_user_disconnected(UserId::new(id))
.await
.map_err(ErrorInternalServerError)?;
}

discord
.offboard_user(UserId::new(id))
.await
.map_err(|e| ErrorInternalServerError(e))?;
.map_err(ErrorInternalServerError)?;

Ok(HttpResponse::NoContent().finish())
}
2 changes: 1 addition & 1 deletion web/src/routes/api/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async fn create(

let connections = db::get_connection_ids_by_user_id(&pool, *username)
.await
.map_err(|e| ErrorInternalServerError(e))?;
.map_err(ErrorInternalServerError)?;

let discord = discord.into_inner();
let data = data.into_inner();
Expand Down
18 changes: 9 additions & 9 deletions web/src/routes/api/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async fn redirect(config: web::Data<Config>, username: web::ReqData<Uuid>) -> Re
.insert_header((
header::LOCATION,
oauth::get_redirect_url(&config.discord, *username)
.map_err(|e| ErrorInternalServerError(e))?
.map_err(ErrorInternalServerError)?
.to_string(),
))
.finish())
Expand Down Expand Up @@ -73,9 +73,7 @@ async fn callback(
.map_err(|_| ErrorBadRequest("Invalid state (uuid)"))?;
let token = oauth::exchange_code_for_token(&client, &config.discord, &query.code)
.await
.map_err(|e| match e {
_ => ErrorInternalServerError(e),
})?;
.map_err(ErrorInternalServerError)?;
if !token.token_type.eq_ignore_ascii_case("Bearer") {
// Can't kill the token if it's not a bearer token
return Err(ErrorInternalServerError("Invalid token type"));
Expand All @@ -88,9 +86,7 @@ async fn callback(
}
let identity = oauth::get_discord_identity(&client, &token.access_token)
.await
.map_err(|e| match e {
_ => ErrorInternalServerError(e),
})?;
.map_err(ErrorInternalServerError)?;

let conn_result = db::create_connection(
&pool,
Expand All @@ -106,16 +102,20 @@ async fn callback(
config.max_connections_per_user.into(),
)
.await
.map_err(|e| ErrorInternalServerError(e))?;
.map_err(ErrorInternalServerError)?;

if conn_result.rows_affected() == 0 {
return Err(ErrorBadRequest("You have too many connections already"));
}

discord.mark_user_connected(identity.id.get().into())
.await
.map_err(ErrorInternalServerError)?;

let message = discord
.onboard_user(identity.id, token.access_token)
.await
.map_err(|e| ErrorInternalServerError(e))?;
.map_err(ErrorInternalServerError)?;

Ok(HttpResponse::Found()
.append_header((header::LOCATION, message.link()))
Expand Down

0 comments on commit 98a08f8

Please sign in to comment.