Skip to content

Commit

Permalink
fix(client): Handle invalid reddit response of base URL location
Browse files Browse the repository at this point in the history
  • Loading branch information
sigaloid committed Jun 29, 2024
1 parent ea87ec3 commit 6ed4c40
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
use hyper::client::HttpConnector;
use hyper::header::HeaderValue;
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
use hyper_rustls::HttpsConnector;
use libflate::gzip;
Expand All @@ -21,6 +22,7 @@ use crate::server::RequestExt;
use crate::utils::format_url;

const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";

pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
let https = hyper_rustls::HttpsConnectorBuilder::new()
Expand Down Expand Up @@ -221,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
if !redirect {
return Ok(response);
};

let location_header = response.headers().get(header::LOCATION);
if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) {
return Err("Reddit response was invalid".to_string());
}
return request(
method,
response
.headers()
.get(header::LOCATION)
location_header
.map(|val| {
// We need to make adjustments to the URI
// we get back from Reddit. Namely, we
Expand All @@ -239,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// required.
//
// 2. Percent-encode the path.
let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string();
let new_path = percent_encode(val.as_bytes(), CONTROLS)
.to_string()
.trim_start_matches(REDDIT_URL_BASE)
.trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE)
.to_string();
format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" })
})
.unwrap_or_default()
Expand Down Expand Up @@ -298,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
}
}
Err(e) => {
dbg_msg!("{} {}: {}", method, path, e);
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);

Err(e.to_string())
}
Expand All @@ -312,6 +319,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// Make a request to a Reddit API and parse the JSON response
#[cached(size = 100, time = 30, result = true)]
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
trace!("going to get {path}");
// Closure to quickly build errors
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
// eprintln!("{} - {}: {}", url, msg, e);
Expand Down

0 comments on commit 6ed4c40

Please sign in to comment.