diff --git a/src/webserver/oidc.rs b/src/webserver/oidc.rs index a1bec14b..db7e180f 100644 --- a/src/webserver/oidc.rs +++ b/src/webserver/oidc.rs @@ -155,16 +155,6 @@ fn get_app_host(config: &AppConfig) -> String { host } -fn build_absolute_uri(app_host: &str, relative_path: &str, scheme: &str) -> anyhow::Result { - let mut base_url = Url::parse(&format!("{scheme}://{app_host}")) - .with_context(|| format!("Failed to parse app_host: {app_host}"))?; - base_url.set_path(""); - let absolute_url = base_url - .join(relative_path) - .with_context(|| format!("Failed to join path {relative_path}"))?; - Ok(absolute_url.to_string()) -} - pub struct ClientWithTime { client: OidcClient, end_session_endpoint: Option, @@ -246,6 +236,29 @@ impl OidcState { .map_err(|e| anyhow::anyhow!("Could not verify the ID token: {e}"))?; Ok(claims) } + + /// Builds an absolute redirect URI by joining the relative redirect URI with the client's redirect URL + pub async fn build_absolute_redirect_uri( + &self, + relative_redirect_uri: &str, + ) -> anyhow::Result { + let client_guard = self.get_client().await; + let client_redirect_url = client_guard + .redirect_uri() + .ok_or_else(|| anyhow!("OIDC client has no redirect URL configured"))?; + let absolute_redirect_uri = client_redirect_url + .url() + .join(relative_redirect_uri) + .with_context(|| { + format!( + "Failed to join redirect URI {} with client redirect URL {}", + relative_redirect_uri, + client_redirect_url.url() + ) + })? + .to_string(); + Ok(absolute_redirect_uri) + } } pub async fn initialize_oidc_state( @@ -494,11 +507,12 @@ async fn process_oidc_logout( .ok() .flatten(); - let scheme = request.connection_info().scheme().to_string(); let mut response = if let Some(end_session_endpoint) = oidc_state.get_end_session_endpoint().await { - let absolute_redirect_uri = - build_absolute_uri(&oidc_state.config.app_host, ¶ms.redirect_uri, &scheme)?; + let absolute_redirect_uri = oidc_state + .build_absolute_redirect_uri(¶ms.redirect_uri) + .await?; + let post_logout_redirect_uri = PostLogoutRedirectUrl::new(absolute_redirect_uri.clone()).with_context(|| { format!("Invalid post_logout_redirect_uri: {absolute_redirect_uri}") diff --git a/tests/oidc/mod.rs b/tests/oidc/mod.rs index 9db2532c..2972c126 100644 --- a/tests/oidc/mod.rs +++ b/tests/oidc/mod.rs @@ -64,6 +64,7 @@ struct DiscoveryResponse { response_types_supported: Vec, subject_types_supported: Vec, id_token_signing_alg_values_supported: Vec, + end_session_endpoint: String, } #[derive(Serialize)] @@ -89,6 +90,7 @@ async fn discovery_endpoint(state: Data) -> impl Responder response_types_supported: vec!["code".to_string()], subject_types_supported: vec!["public".to_string()], id_token_signing_alg_values_supported: vec!["HS256".to_string()], + end_session_endpoint: format!("{}/logout", state.issuer_url), }; HttpResponse::Ok() .insert_header((header::CONTENT_TYPE, "application/json")) @@ -435,3 +437,46 @@ async fn test_oidc_expired_token_is_rejected() { }) .await; } + +#[actix_web::test] +async fn test_oidc_logout_uses_correct_scheme() { + use sqlpage::{ + app_config::{test_database_url, AppConfig}, + webserver::oidc::create_logout_url, + AppState, + }; + + crate::common::init_log(); + let provider = FakeOidcProvider::new().await; + + let db_url = test_database_url(); + let config_json = format!( + r#"{{ + "database_url": "{db_url}", + "oidc_issuer_url": "{}", + "oidc_client_id": "{}", + "oidc_client_secret": "{}", + "https_domain": "example.com" + }}"#, + provider.issuer_url, provider.client_id, provider.client_secret + ); + + let config: AppConfig = serde_json::from_str(&config_json).unwrap(); + let app_state = AppState::init(&config).await.unwrap(); + let app = test::init_service(create_app(Data::new(app_state))).await; + + let logout_path = create_logout_url("/logged_out", "", &provider.client_secret); + // make sure the logout path includes the configured domain + assert!(logout_path.starts_with("/sqlpage/oidc_logout")); + + let req = test::TestRequest::get().uri(&logout_path).to_request(); + let resp = test::call_service(&app, req).await; + + assert_eq!(resp.status(), StatusCode::SEE_OTHER); + let location = resp.headers().get("location").unwrap().to_str().unwrap(); + let location_url = Url::parse(location).unwrap(); + assert_eq!(location_url.path(), "/logout"); + let params: HashMap = location_url.query_pairs().into_owned().collect(); + let post_logout = params.get("post_logout_redirect_uri").unwrap(); + assert_eq!(post_logout, "https://example.com/logged_out"); +}