Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: don't use 2 threads to create FastAPI server #2592

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions bittensor/core/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,6 @@ def run_in_thread(self):
self.should_exit = True
thread.join()

def _wrapper_run(self):
"""
A wrapper method for the :func:`run_in_thread` context manager. This method is used internally by the ``start`` method to initiate the server's execution in a separate thread.
"""
with self.run_in_thread():
while not self.should_exit:
time.sleep(1e-3)

def start(self):
"""
Starts the FastAPI server in a separate thread if it is not already running. This method sets up the server to handle HTTP requests concurrently, enabling the Axon server to efficiently manage incoming network requests.
Expand All @@ -158,7 +150,7 @@ def start(self):
"""
if not self.is_running:
self.should_exit = False
thread = threading.Thread(target=self._wrapper_run, daemon=True)
thread = threading.Thread(target=self.run, daemon=True)
thread.start()
self.is_running = True

Expand Down
47 changes: 46 additions & 1 deletion tests/unit_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@


import re
import threading
import time
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

import aiohttp
import fastapi
import netaddr
import pydantic
import pytest
import uvicorn
from fastapi.testclient import TestClient
from starlette.requests import Request

from bittensor.core.axon import AxonMiddleware, Axon
from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer
from bittensor.core.errors import RunException
from bittensor.core.settings import version_as_int
from bittensor.core.stream import StreamingSynapse
Expand Down Expand Up @@ -785,3 +788,45 @@ async def forward_fn(synapse: streaming_synapse_cls):
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
},
)


@pytest.mark.asyncio
async def test_threaded_fastapi():
server_started = threading.Event()
server_stopped = threading.Event()

async def lifespan(app):
server_started.set()
yield
server_stopped.set()

app = fastapi.FastAPI(
lifespan=lifespan,
)
app.get("/")(lambda: "Hello World")

server = FastAPIThreadedServer(
uvicorn.Config(
app,
),
)
server.start()

server_started.wait()

assert server.is_running is True

async with aiohttp.ClientSession(
base_url="http://127.0.0.1:8000",
) as session:
async with session.get("/") as response:
assert await response.text() == '"Hello World"'

server.stop()

assert server.should_exit is True

server_stopped.wait()

with pytest.raises(aiohttp.ClientConnectorError):
await session.get("/")
Loading