Skip to content

Commit

Permalink
feat: Allow global level Response headers (#410)
Browse files Browse the repository at this point in the history
* feat: Allow global level response headers

* fix(base_routes): Replace add_request_headers with add_response_header

---------

Co-authored-by: Sanskar Jethi <[email protected]>
  • Loading branch information
ParthS007 and sansyrox authored Feb 25, 2023
1 parent afcf057 commit 91ffd1c
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 6 deletions.
20 changes: 18 additions & 2 deletions docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,25 @@ async def hello(request):
return "Hello World"
```

## Global Headers
## Global Request Headers

You can also add global headers for every request.

```python
app.add_request_header("server", "robyn")
```

## Global Response Headers

You can also add global response headers for every request.

```python
app.add_response_header("content-type", "application/json")
```

## Per route headers

You can also add headers for every route.
You can also add request and response headers for every route.

```python
@app.get("/request_headers")
Expand All @@ -194,6 +202,14 @@ async def request_headers():
}
```

```python
@app.get("/response_headers")
async def response_headers():
return {
"headers": {"Header": "header_value"},
}
```

## Query Params

You can access query params from every HTTP method.
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ async def post(request):


if __name__ == "__main__":
app.add_request_header("server", "robyn")
app.add_response_header("server", "robyn")
app.add_directory(
route="/test_dir",
directory_path=os.path.join(current_file_path, "build"),
Expand Down
6 changes: 6 additions & 0 deletions integration_tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def test_add_request_header():
assert app.request_headers == [Header(key="server", val="robyn")]


def test_add_response_header():
app = Robyn(__file__)
app.add_response_header("content-type", "application/json")
assert app.response_headers == [Header(key="content-type", val="application/json")]


def test_lifecycle_handlers():
def mock_startup_handler():
pass
Expand Down
6 changes: 6 additions & 0 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, file_object: str) -> None:
self.middleware_router = MiddlewareRouter()
self.web_socket_router = WebSocketRouter()
self.request_headers: List[Header] = [] # This needs a better type
self.response_headers: List[Header] = [] # This needs a better type
self.directories: List[Directory] = []
self.event_handlers = {}
load_vars(project_root=directory_path)
Expand Down Expand Up @@ -83,6 +84,9 @@ def add_directory(
def add_request_header(self, key: str, value: str) -> None:
self.request_headers.append(Header(key, value))

def add_response_header(self, key: str, value: str) -> None:
self.response_headers.append(Header(key, value))

def add_web_socket(self, endpoint: str, ws: WS) -> None:
self.web_socket_router.add_route(endpoint, ws)

Expand Down Expand Up @@ -126,6 +130,7 @@ def start(self, url: str = "127.0.0.1", port: int = 8080):
self.event_handlers,
self.config.workers,
self.config.processes,
self.response_headers,
)
else:
event_handler = EventHandler(
Expand All @@ -139,6 +144,7 @@ def start(self, url: str = "127.0.0.1", port: int = 8080):
self.event_handlers,
self.config.workers,
self.config.processes,
self.response_headers,
)
event_handler.start_server()
logger.info(
Expand Down
3 changes: 3 additions & 0 deletions robyn/dev_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(
event_handlers: Dict[Events, FunctionInfo],
workers: int,
processes: int,
response_headers: List[Header],
) -> None:
self.url = url
self.port = port
self.directories = directories
self.request_headers = request_headers
self.response_headers = response_headers
self.routes = routes
self.middlewares = middlewares
self.web_sockets = web_sockets
Expand All @@ -48,6 +50,7 @@ def start_server(self):
self.event_handlers,
self.n_workers,
self.n_processes,
self.response_headers,
True,
)

Expand Down
9 changes: 9 additions & 0 deletions robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def run_processes(
event_handlers: Dict[Events, FunctionInfo],
workers: int,
processes: int,
response_headers: List[Header],
from_event_handler: bool = False,
) -> List[Process]:
socket = SocketHeld(url, port)
Expand All @@ -37,6 +38,7 @@ def run_processes(
socket,
workers,
processes,
response_headers,
)

if not from_event_handler:
Expand Down Expand Up @@ -66,6 +68,7 @@ def init_processpool(
socket: SocketHeld,
workers: int,
processes: int,
response_headers: List[Header],
) -> List[Process]:
process_pool = []
if sys.platform.startswith("win32"):
Expand All @@ -78,6 +81,7 @@ def init_processpool(
event_handlers,
socket,
workers,
response_headers,
)

return process_pool
Expand All @@ -95,6 +99,7 @@ def init_processpool(
event_handlers,
copied_socket,
workers,
response_headers,
),
)
process.start()
Expand Down Expand Up @@ -128,6 +133,7 @@ def spawn_process(
event_handlers: Dict[Events, FunctionInfo],
socket: SocketHeld,
workers: int,
response_headers: List[Header],
):
"""
This function is called by the main process handler to create a server runtime.
Expand Down Expand Up @@ -156,6 +162,9 @@ def spawn_process(
for header in request_headers:
server.add_request_header(*header.as_list())

for header in response_headers:
server.add_response_header(*header.as_list())

for route in routes:
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)
Expand Down
2 changes: 2 additions & 0 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Server:
pass
def add_request_header(self, key: str, value: str) -> None:
pass
def add_response_header(self, key: str, value: str) -> None:
pass
def add_route(
self,
route_type: str,
Expand Down
27 changes: 24 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct Server {
websocket_router: Arc<WebSocketRouter>,
middleware_router: Arc<MiddlewareRouter>,
global_request_headers: Arc<DashMap<String, String>>,
global_response_headers: Arc<DashMap<String, String>>,
directories: Arc<RwLock<Vec<Directory>>>,
startup_handler: Option<Arc<FunctionInfo>>,
shutdown_handler: Option<Arc<FunctionInfo>>,
Expand All @@ -63,6 +64,7 @@ impl Server {
websocket_router: Arc::new(WebSocketRouter::new()),
middleware_router: Arc::new(MiddlewareRouter::new()),
global_request_headers: Arc::new(DashMap::new()),
global_response_headers: Arc::new(DashMap::new()),
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
Expand Down Expand Up @@ -92,6 +94,7 @@ impl Server {
let middleware_router = self.middleware_router.clone();
let web_socket_router = self.websocket_router.clone();
let global_request_headers = self.global_request_headers.clone();
let global_response_headers = self.global_response_headers.clone();
let directories = self.directories.clone();
let workers = Arc::new(workers);

Expand Down Expand Up @@ -145,7 +148,8 @@ impl Server {
.app_data(web::Data::new(router.clone()))
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(global_request_headers.clone()));
.app_data(web::Data::new(global_request_headers.clone()))
.app_data(web::Data::new(global_response_headers.clone()));

let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read().unwrap()).iter() {
Expand All @@ -165,6 +169,7 @@ impl Server {
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers,
global_response_headers,
body,
req| {
pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move {
Expand All @@ -173,6 +178,7 @@ impl Server {
const_router,
middleware_router,
global_request_headers,
global_response_headers,
body,
req,
)
Expand Down Expand Up @@ -223,19 +229,32 @@ impl Server {
});
}

/// Adds a new header to our concurrent hashmap
/// Adds a new request header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_request_header(&self, key: &str, value: &str) {
self.global_request_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new header to our concurrent hashmap
/// Adds a new response header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_response_header(&self, key: &str, value: &str) {
self.global_response_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new request header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_header(&self, key: &str) {
self.global_request_headers.remove(key);
}

/// Removes a new response header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_response_header(&self, key: &str) {
self.global_response_headers.remove(key);
}

/// Add a new route to the routing tables
/// can be called after the server has been started
pub fn add_route(
Expand Down Expand Up @@ -345,6 +364,7 @@ async fn index(
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers: web::Data<Arc<Headers>>,
global_response_headers: web::Data<Arc<Headers>>,
body: Bytes,
req: HttpRequest,
) -> impl Responder {
Expand All @@ -360,6 +380,7 @@ async fn index(

let mut response_builder = HttpResponse::Ok();
apply_dashmap_headers(&mut response_builder, &global_request_headers);
apply_dashmap_headers(&mut response_builder, &global_response_headers);
apply_hashmap_headers(&mut response_builder, &request.headers);

let response = if let Some(r) = const_router.get_route(req.method(), req.uri().path()) {
Expand Down
1 change: 1 addition & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ pub struct Response {

#[pymethods]
impl Response {
// To do: Add check for content-type in header and change response_type accordingly
#[new]
pub fn new(status_code: u16, headers: HashMap<String, String>, body: &PyAny) -> PyResult<Self> {
Ok(Self {
Expand Down

0 comments on commit 91ffd1c

Please sign in to comment.