Skip to content

Commit

Permalink
Merge pull request #6 from tomvanderlee/feature/websockets
Browse files Browse the repository at this point in the history
Added websocket support
  • Loading branch information
tomvanderlee authored Aug 30, 2024
2 parents 53a8f30 + a72a048 commit 0f7c975
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 87 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/docker-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v4
uses: docker/metadata-action@v5
with:
images: ghcr.io/tomvanderlee/ttun-server
tags: |
Expand All @@ -25,13 +25,13 @@ jobs:
- name: Login to DockerHub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v4
uses: docker/build-push-action@v6
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
starlette ~= 0.17
starlette ~= 0.37
uvicorn[standard] ~= 0.16
aioredis ~= 2.0
redis
setuptools
4 changes: 3 additions & 1 deletion ttun_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from starlette.applications import Starlette
from starlette.routing import Route, WebSocketRoute, Host, Router

from ttun_server.endpoints import Proxy, Tunnel, Health
from ttun_server.endpoints import Proxy, Health
from .websockets import WebsocketProxy, Tunnel

logging.basicConfig(level=getattr(logging, os.environ.get('LOG_LEVEL', 'INFO')))

Expand All @@ -18,6 +19,7 @@
routes=[
Host(os.environ['TUNNEL_DOMAIN'], base_router, 'base'),
Route('/{path:path}', Proxy),
WebSocketRoute('/{path:path}', WebsocketProxy)
]
)

Expand Down
75 changes: 5 additions & 70 deletions ttun_server/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import asyncio
import logging
import os
from asyncio import create_task
from base64 import b64decode, b64encode
from typing import Optional, Any
from uuid import uuid4

from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket

import ttun_server
from ttun_server.proxy_queue import ProxyQueue
from ttun_server.types import RequestData, Config, Message, MessageType
from ttun_server.types import HttpRequestData, Message, HttpMessageType, HttpMessage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,11 +37,11 @@ async def dispatch(self) -> None:

logger.debug('PROXY %s%s ', subdomain, request.url)
await request_queue.enqueue(
Message(
type=MessageType.request.value,
HttpMessage(
type=HttpMessageType.request.value,
identifier=identifier,
payload=
RequestData(
HttpRequestData(
method=request.method,
path=str(request.url).replace(str(request.base_url), '/'),
headers=list(request.headers.items()),
Expand Down Expand Up @@ -78,61 +71,3 @@ async def get(self, _) -> None:
await response(self.scope, self.receive, self.send)


class Tunnel(WebSocketEndpoint):
encoding = 'json'

def __init__(self, scope: Scope, receive: Receive, send: Send):
super().__init__(scope, receive, send)
self.request_task = None
self.config: Optional[Config] = None

async def handle_requests(self, websocket: WebSocket):
while request := await self.proxy_queue.dequeue():
create_task(websocket.send_json(request))

async def on_connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.config = await websocket.receive_json()

client_version = self.config.get('version', '1.0.0')
logger.debug('client_version %s', client_version)

if 'git' not in client_version and ttun_server.__version__ != 'development':
[client_major, *_] = [int(i) for i in client_version.split('.')[:3]]
[server_major, *_] = [int(i) for i in ttun_server.__version__.split('.')]

if client_major < server_major:
await websocket.close(4000, 'Your client is too old')

if client_major > server_major:
await websocket.close(4001, 'Your client is too new')


if self.config['subdomain'] is None \
or await ProxyQueue.has_connection(self.config['subdomain']):
self.config['subdomain'] = uuid4().hex


self.proxy_queue = await ProxyQueue.create_for_identifier(self.config['subdomain'])

hostname = os.environ.get("TUNNEL_DOMAIN")
protocol = "https" if os.environ.get("SECURE", False) else "http"

await websocket.send_json({
'url': f'{protocol}://{self.config["subdomain"]}.{hostname}'
})

self.request_task = asyncio.create_task(self.handle_requests(websocket))

async def on_receive(self, websocket: WebSocket, data: Message):
try:
response_queue = await ProxyQueue.get_for_identifier(f"{self.config['subdomain']}_{data['identifier']}")
await response_queue.enqueue(data)
except AssertionError:
pass

async def on_disconnect(self, websocket: WebSocket, close_code: int):
await self.proxy_queue.delete()

if self.request_task is not None:
self.request_task.cancel()
1 change: 1 addition & 0 deletions ttun_server/proxy_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import traceback
from typing import Type

from ttun_server.redis import RedisConnectionPool
Expand Down
7 changes: 3 additions & 4 deletions ttun_server/redis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import os
from asyncio import get_running_loop

from aioredis import ConnectionPool, Redis
from redis.asyncio import ConnectionPool, Redis


class RedisConnectionPool:
Expand All @@ -9,9 +11,6 @@ class RedisConnectionPool:
def __init__(self):
self.pool = ConnectionPool.from_url(os.environ.get('REDIS_URL'))

def __del__(self):
self.pool.disconnect()

@classmethod
def get_connection(cls) -> Redis:
if cls.instance is None:
Expand Down
51 changes: 45 additions & 6 deletions ttun_server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TypedDict, Optional


class MessageType(Enum):
class HttpMessageType(Enum):
request = 'request'
response = 'response'

Expand All @@ -13,23 +13,62 @@ class Config(TypedDict):
client_version: str


class RequestData(TypedDict):
class HttpRequestData(TypedDict):
method: str
path: str
headers: list[tuple[str, str]]
body: Optional[str]


class ResponseData(TypedDict):
class HttpResponseData(TypedDict):
status: int
headers: list[tuple[str, str]]
body: Optional[str]


class Message(TypedDict):
type: MessageType
class HttpMessage(TypedDict):
type: HttpMessageType
identifier: str
payload: Config | RequestData | ResponseData
payload: Config | HttpRequestData | HttpResponseData


class WebsocketMessageType(Enum):
connect = 'connect'
disconnect = 'disconnect'
message = 'message'
ack = 'ack'


class WebsocketConnectData(TypedDict):
path: str
headers: list[tuple[str, str]]


class WebsocketDisconnectData(TypedDict):
close_code: int


class WebsocketMessageData(TypedDict):
body: Optional[str]


class WebsocketMessage(TypedDict):
type: WebsocketMessageType
identifier: str
payload: WebsocketConnectData | WebsocketDisconnectData | WebsocketMessageData


class MessageType(Enum):
request = 'request'
response = 'response'

ws_connect = 'connect'
ws_disconnect = 'disconnect'
ws_message = 'message'
ws_ack = 'ack'


Message = HttpMessage | WebsocketMessage


class MemoryConnection(TypedDict):
Expand Down
Loading

0 comments on commit 0f7c975

Please sign in to comment.