diff --git a/src/client.rs b/src/client.rs index 7281df13..73d0071b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; use hyper::client::HttpConnector; +use hyper::header::HeaderValue; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; use libflate::gzip; @@ -21,6 +22,7 @@ use crate::server::RequestExt; use crate::utils::format_url; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; +const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com"; pub static CLIENT: Lazy>> = Lazy::new(|| { let https = hyper_rustls::HttpsConnectorBuilder::new() @@ -221,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo if !redirect { return Ok(response); }; - + let location_header = response.headers().get(header::LOCATION); + if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) { + return Err("Reddit response was invalid".to_string()); + } return request( method, - response - .headers() - .get(header::LOCATION) + location_header .map(|val| { // We need to make adjustments to the URI // we get back from Reddit. Namely, we @@ -239,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo // required. // // 2. Percent-encode the path. - let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string(); + let new_path = percent_encode(val.as_bytes(), CONTROLS) + .to_string() + .trim_start_matches(REDDIT_URL_BASE) + .trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE) + .to_string(); format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" }) }) .unwrap_or_default() @@ -298,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo } } Err(e) => { - dbg_msg!("{} {}: {}", method, path, e); + dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e); Err(e.to_string()) } @@ -312,6 +319,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo // Make a request to a Reddit API and parse the JSON response #[cached(size = 100, time = 30, result = true)] pub async fn json(path: String, quarantine: bool) -> Result { + trace!("going to get {path}"); // Closure to quickly build errors let err = |msg: &str, e: String, path: String| -> Result { // eprintln!("{} - {}: {}", url, msg, e);