diff --git a/docs/features.md b/docs/features.md index 18f7cdd51..195095a48 100644 --- a/docs/features.md +++ b/docs/features.md @@ -171,7 +171,7 @@ async def hello(request): return "Hello World" ``` -## Global Headers +## Global Request Headers You can also add global headers for every request. @@ -179,9 +179,17 @@ You can also add global headers for every request. app.add_request_header("server", "robyn") ``` +## Global Response Headers + +You can also add global response headers for every request. + +```python +app.add_response_header("content-type", "application/json") +``` + ## Per route headers -You can also add headers for every route. +You can also add request and response headers for every route. ```python @app.get("/request_headers") @@ -194,6 +202,14 @@ async def request_headers(): } ``` +```python +@app.get("/response_headers") +async def response_headers(): + return { + "headers": {"Header": "header_value"}, + } +``` + ## Query Params You can access query params from every HTTP method. diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 6bd656812..da0107731 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -511,7 +511,7 @@ async def post(request): if __name__ == "__main__": - app.add_request_header("server", "robyn") + app.add_response_header("server", "robyn") app.add_directory( route="/test_dir", directory_path=os.path.join(current_file_path, "build"), diff --git a/integration_tests/test_app.py b/integration_tests/test_app.py index 5081cc7fe..d8ffbb19f 100644 --- a/integration_tests/test_app.py +++ b/integration_tests/test_app.py @@ -9,6 +9,12 @@ def test_add_request_header(): assert app.request_headers == [Header(key="server", val="robyn")] +def test_add_response_header(): + app = Robyn(__file__) + app.add_response_header("content-type", "application/json") + assert app.response_headers == [Header(key="content-type", val="application/json")] + + def test_lifecycle_handlers(): def mock_startup_handler(): pass diff --git a/robyn/__init__.py b/robyn/__init__.py index 32a84da71..92cb24dc8 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -33,6 +33,7 @@ def __init__(self, file_object: str) -> None: self.middleware_router = MiddlewareRouter() self.web_socket_router = WebSocketRouter() self.request_headers: List[Header] = [] # This needs a better type + self.response_headers: List[Header] = [] # This needs a better type self.directories: List[Directory] = [] self.event_handlers = {} load_vars(project_root=directory_path) @@ -83,6 +84,9 @@ def add_directory( def add_request_header(self, key: str, value: str) -> None: self.request_headers.append(Header(key, value)) + def add_response_header(self, key: str, value: str) -> None: + self.response_headers.append(Header(key, value)) + def add_web_socket(self, endpoint: str, ws: WS) -> None: self.web_socket_router.add_route(endpoint, ws) @@ -126,6 +130,7 @@ def start(self, url: str = "127.0.0.1", port: int = 8080): self.event_handlers, self.config.workers, self.config.processes, + self.response_headers, ) else: event_handler = EventHandler( @@ -139,6 +144,7 @@ def start(self, url: str = "127.0.0.1", port: int = 8080): self.event_handlers, self.config.workers, self.config.processes, + self.response_headers, ) event_handler.start_server() logger.info( diff --git a/robyn/dev_event_handler.py b/robyn/dev_event_handler.py index ccbcfb9cd..e2ebc932c 100644 --- a/robyn/dev_event_handler.py +++ b/robyn/dev_event_handler.py @@ -23,11 +23,13 @@ def __init__( event_handlers: Dict[Events, FunctionInfo], workers: int, processes: int, + response_headers: List[Header], ) -> None: self.url = url self.port = port self.directories = directories self.request_headers = request_headers + self.response_headers = response_headers self.routes = routes self.middlewares = middlewares self.web_sockets = web_sockets @@ -48,6 +50,7 @@ def start_server(self): self.event_handlers, self.n_workers, self.n_processes, + self.response_headers, True, ) diff --git a/robyn/processpool.py b/robyn/processpool.py index 9a9666d32..0d5ea1ef1 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -23,6 +23,7 @@ def run_processes( event_handlers: Dict[Events, FunctionInfo], workers: int, processes: int, + response_headers: List[Header], from_event_handler: bool = False, ) -> List[Process]: socket = SocketHeld(url, port) @@ -37,6 +38,7 @@ def run_processes( socket, workers, processes, + response_headers, ) if not from_event_handler: @@ -66,6 +68,7 @@ def init_processpool( socket: SocketHeld, workers: int, processes: int, + response_headers: List[Header], ) -> List[Process]: process_pool = [] if sys.platform.startswith("win32"): @@ -78,6 +81,7 @@ def init_processpool( event_handlers, socket, workers, + response_headers, ) return process_pool @@ -95,6 +99,7 @@ def init_processpool( event_handlers, copied_socket, workers, + response_headers, ), ) process.start() @@ -128,6 +133,7 @@ def spawn_process( event_handlers: Dict[Events, FunctionInfo], socket: SocketHeld, workers: int, + response_headers: List[Header], ): """ This function is called by the main process handler to create a server runtime. @@ -156,6 +162,9 @@ def spawn_process( for header in request_headers: server.add_request_header(*header.as_list()) + for header in response_headers: + server.add_response_header(*header.as_list()) + for route in routes: route_type, endpoint, function, is_const = route server.add_route(route_type, endpoint, function, is_const) diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index e0ec68658..b3d5845fa 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -37,6 +37,8 @@ class Server: pass def add_request_header(self, key: str, value: str) -> None: pass + def add_response_header(self, key: str, value: str) -> None: + pass def add_route( self, route_type: str, diff --git a/src/server.rs b/src/server.rs index cf8137e30..adc085dcb 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,6 +48,7 @@ pub struct Server { websocket_router: Arc, middleware_router: Arc, global_request_headers: Arc>, + global_response_headers: Arc>, directories: Arc>>, startup_handler: Option>, shutdown_handler: Option>, @@ -63,6 +64,7 @@ impl Server { websocket_router: Arc::new(WebSocketRouter::new()), middleware_router: Arc::new(MiddlewareRouter::new()), global_request_headers: Arc::new(DashMap::new()), + global_response_headers: Arc::new(DashMap::new()), directories: Arc::new(RwLock::new(Vec::new())), startup_handler: None, shutdown_handler: None, @@ -92,6 +94,7 @@ impl Server { let middleware_router = self.middleware_router.clone(); let web_socket_router = self.websocket_router.clone(); let global_request_headers = self.global_request_headers.clone(); + let global_response_headers = self.global_response_headers.clone(); let directories = self.directories.clone(); let workers = Arc::new(workers); @@ -145,7 +148,8 @@ impl Server { .app_data(web::Data::new(router.clone())) .app_data(web::Data::new(const_router.clone())) .app_data(web::Data::new(middleware_router.clone())) - .app_data(web::Data::new(global_request_headers.clone())); + .app_data(web::Data::new(global_request_headers.clone())) + .app_data(web::Data::new(global_response_headers.clone())); let web_socket_map = web_socket_router.get_web_socket_map(); for (elem, value) in (web_socket_map.read().unwrap()).iter() { @@ -165,6 +169,7 @@ impl Server { const_router: web::Data>, middleware_router: web::Data>, global_request_headers, + global_response_headers, body, req| { pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move { @@ -173,6 +178,7 @@ impl Server { const_router, middleware_router, global_request_headers, + global_response_headers, body, req, ) @@ -223,19 +229,32 @@ impl Server { }); } - /// Adds a new header to our concurrent hashmap + /// Adds a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn add_request_header(&self, key: &str, value: &str) { self.global_request_headers .insert(key.to_string(), value.to_string()); } - /// Removes a new header to our concurrent hashmap + /// Adds a new response header to our concurrent hashmap + /// this can be called after the server has started. + pub fn add_response_header(&self, key: &str, value: &str) { + self.global_response_headers + .insert(key.to_string(), value.to_string()); + } + + /// Removes a new request header to our concurrent hashmap /// this can be called after the server has started. pub fn remove_header(&self, key: &str) { self.global_request_headers.remove(key); } + /// Removes a new response header to our concurrent hashmap + /// this can be called after the server has started. + pub fn remove_response_header(&self, key: &str) { + self.global_response_headers.remove(key); + } + /// Add a new route to the routing tables /// can be called after the server has been started pub fn add_route( @@ -345,6 +364,7 @@ async fn index( const_router: web::Data>, middleware_router: web::Data>, global_request_headers: web::Data>, + global_response_headers: web::Data>, body: Bytes, req: HttpRequest, ) -> impl Responder { @@ -360,6 +380,7 @@ async fn index( let mut response_builder = HttpResponse::Ok(); apply_dashmap_headers(&mut response_builder, &global_request_headers); + apply_dashmap_headers(&mut response_builder, &global_response_headers); apply_hashmap_headers(&mut response_builder, &request.headers); let response = if let Some(r) = const_router.get_route(req.method(), req.uri().path()) { diff --git a/src/types.rs b/src/types.rs index e0874ed34..301d8307e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -164,6 +164,7 @@ pub struct Response { #[pymethods] impl Response { + // To do: Add check for content-type in header and change response_type accordingly #[new] pub fn new(status_code: u16, headers: HashMap, body: &PyAny) -> PyResult { Ok(Self {