From b299db6ca372f3c90971ccbd5cf05054db0c8598 Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Mon, 16 Sep 2024 16:11:28 -0400 Subject: [PATCH] fix(oauth): catch network policy violation and rate limit --- src/client.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index 6e73a592..33ea6fe4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,6 +4,7 @@ use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; use hyper::client::HttpConnector; use hyper::header::HeaderValue; +use hyper::StatusCode; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; use libflate::gzip; @@ -60,10 +61,9 @@ pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); pub async fn canonical_path(path: String) -> Result, String> { let res = reddit_head(path.clone(), true).await?; let status = res.status().as_u16(); + let policy_error = res.headers().get(header::RETRY_AFTER).is_some(); match status { - 429 => Err("Too many requests.".to_string()), - // If Reddit responds with a 2xx, then the path is already canonical. 200..=299 => Ok(Some(path)), @@ -94,6 +94,12 @@ pub async fn canonical_path(path: String) -> Result, String> { // as above), return a None. 300..=399 => Ok(None), + // Rate limiting + 429 => Err("Too many requests.".to_string()), + + // Special condition rate limiting - https://github.com/redlib-org/redlib/issues/229 + 403 if policy_error => Err("Too many requests.".to_string()), + _ => Ok( res .headers() @@ -257,6 +263,12 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo .await; }; + // Special condition rate limiting - https://github.com/redlib-org/redlib/issues/229 + if response.status() == StatusCode::FORBIDDEN && response.headers().get("retry-after").unwrap_or(&HeaderValue::from_static("0")).to_str().unwrap_or("0") == "0" { + force_refresh_token().await; + return Err("Rate limit - try refreshing soon".to_string()); + } + match response.headers().get(header::CONTENT_ENCODING) { // Content not compressed. None => Ok(response),