commit 0a978fdb5dae84d0f98c4352b144a2a50b1e3b56
parent aa154be51db12163ffba7973211573d8812a15ac
Author: lu4nm3 <lmedina@lyft.com>
Date: Wed, 8 May 2019 14:51:19 -0700
Set default 'Server' header only if it isn't set.
Closes #996.
Diffstat:
2 files changed, 48 insertions(+), 2 deletions(-)
diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs
@@ -216,9 +216,12 @@ impl Rocket {
// Route the request and run the user's handlers.
let mut response = self.route_and_process(request, data);
- // Add the 'rocket' server header to the response and run fairings.
+ // Add the 'rocket' server header to the response and run fairings only if the header
+ // doesn't already exist.
// TODO: If removing Hyper, write out `Date` header too.
- response.set_header(Header::new("Server", "Rocket"));
+ if !response.headers().contains("Server") {
+ response.set_header(Header::new("Server", "Rocket"));
+ }
self.fairings.handle_response(request, &mut response);
// Strip the body if this is a `HEAD` request.
diff --git a/core/lib/tests/conditionally-set-server-header-997.rs b/core/lib/tests/conditionally-set-server-header-997.rs
@@ -0,0 +1,43 @@
+#![feature(proc_macro_hygiene, decl_macro)]
+
+#[macro_use]
+extern crate rocket;
+extern crate rocket_http;
+
+use rocket::response::Redirect;
+use rocket::Response;
+use rocket_http::{Header, Status};
+
+#[get("/do_not_overwrite")]
+fn do_not_overwrite<'r>() -> Result<Response<'r>, ()> {
+ let header = Header::new("Server", "Test");
+
+ Response::build()
+ .header(header)
+ .ok()
+}
+
+#[get("/use_default")]
+fn use_default<'r>() -> Result<Response<'r>, ()> {
+ Response::build()
+ .ok()
+}
+
+mod conditionally_set_server_header {
+ use super::*;
+ use rocket::local::Client;
+
+ #[test]
+ fn do_not_overwrite_server_header() {
+ let rocket = rocket::ignite().mount("/", routes![do_not_overwrite, use_default]);
+ let client = Client::new(rocket).unwrap();
+
+ let response = client.get("/do_not_overwrite").dispatch();
+ let server = response.headers().get_one("Server");
+ assert_eq!(server, Some("Test"));
+
+ let response = client.get("/use_default").dispatch();
+ let server = response.headers().get_one("Server");
+ assert_eq!(server, Some("Rocket"));
+ }
+}