Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using a ThreadPoolExecutor to match socket pool #14

Merged
merged 8 commits into from
May 24, 2023
116 changes: 35 additions & 81 deletions pybase/region/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

import logging
import socket
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from io import BytesIO
from struct import pack, unpack
from threading import Condition, Lock
from threading import current_thread, Condition, Lock

from ..exceptions import (NoSuchColumnFamilyException, NotServingRegionException, PyBaseException,
RegionMovedException, RegionOpeningException, RegionServerException)
Expand Down Expand Up @@ -67,11 +68,10 @@ def __init__(self, host, port, secondary):
self.host = host.decode('utf8') if isinstance(host, bytes) else host
self.port = port.decode('utf8') if isinstance(port, bytes) else port
self.pool_size = 0
# We support connection pools so have lists of sockets and read/write
# mutexes on them.

self.thread_pool = None
self.sock_pool = []
self.write_lock_pool = []
self.read_lock_pool = []

# Why yes, we do have a mutex protecting a single variable.
self.call_lock = Lock()
self.call_id = 0
Expand Down Expand Up @@ -142,26 +142,11 @@ def _send_request(self, rq, lock_timeout=10):
to_send = pack(">IB", total_length - 4, len(serialized_header))
to_send += serialized_header + rpc_length_bytes + serialized_rpc

pool_id = my_id % self.pool_size
try:
# todo: quick hack to patch a deadlock happening here. Needs revisiting.
with acquire_timeout(self.write_lock_pool[pool_id], lock_timeout) as acquired:
if acquired:
logger.debug('Sending %s RPC to %s:%s on pool port %s',
rq.type, self.host, self.port, pool_id)
self.sock_pool[pool_id].send(to_send)
else:
logger.warning('Lock timeout sending %s RPC to %s:%s on pool port %s',
rq.type, self.host, self.port, pool_id)
raise RegionServerException(region_client=self)
except socket.error:
# RegionServer dead?
raise RegionServerException(region_client=self)
# Message is sent! Now go listen for the results.
return self._receive_rpc(my_id, rq)
# send and receive the request
future = self.thread_pool.submit(Client.send_and_receive_rpc, [self, my_id, rq, to_send])
return future.result()

# Called after sending an RPC, listens for the response and builds the
# correct pbResponse object.
# Sending an RPC, listens for the response and builds the correct pbResponse object.
#
# The raw bytes we receive are composed (in order) -
#
Expand All @@ -171,32 +156,30 @@ def _send_request(self, rq, lock_timeout=10):
# 4. A varint representing the length of the serialized ResponseMessage.
# 5. The ResponseMessage.
#
def _receive_rpc(self, call_id, rq, data=None, lock_timeout=10):
@staticmethod
def send_and_receive_rpc(client, call_id, rq, to_send):
thread_name = current_thread().name
sp = thread_name.split("_") # i.e. splitting "ThreadPoolExecutor-1_0"
pool_id = int(sp[1]) # thread number is now responsible for only using its matching socket

client.sock_pool[pool_id].send(to_send)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the send into the try below to be able use the same exception handling?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


# If the field data is populated that means we should process from that
# instead of the socket.
full_data = data
if data is None:
pool_id = call_id % self.pool_size
# Total message length is going to be the first four bytes
# (little-endian uint32)
with acquire_timeout(self.read_lock_pool[pool_id], lock_timeout) as acquired:
if acquired:
try:
msg_length = self._recv_n(self.sock_pool[pool_id], 4)
if msg_length is None:
raise
msg_length = unpack(">I", msg_length)[0]
# The message is then going to be however many bytes the first four
# bytes specified. We don't want to overread or underread as that'll
# cause havoc.
full_data = self._recv_n(
self.sock_pool[pool_id], msg_length)
except socket.error:
raise RegionServerException(region_client=self)
else:
logger.warning('Lock timeout receive %s RPC to %s:%s on pool port %s',
rq.type, self.host, self.port, pool_id)
raise RegionServerException(region_client=self)
full_data = None
# Total message length is going to be the first four bytes
# (little-endian uint32)
try:
msg_length = Client._recv_n(self.sock_pool[pool_id], 4)
if msg_length is None:
raise
msg_length = unpack(">I", msg_length)[0]
# The message is then going to be however many bytes the first four
# bytes specified. We don't want to overread or underread as that'll
# cause havoc.
full_data = Client._recv_n(self.sock_pool[pool_id], msg_length)
except socket.error:
raise RegionServerException(region_client=self)
# Pass in the full data as well as your current position to the
# decoder. It'll then return two variables:
# - next_pos: The number of bytes of data specified by the varint
Expand All @@ -205,11 +188,7 @@ def _receive_rpc(self, call_id, rq, data=None, lock_timeout=10):
header = ResponseHeader()
header.ParseFromString(full_data[pos: pos + next_pos])
pos += next_pos
if header.call_id != call_id:
# call_ids don't match? Looks like a different thread nabbed our
# response.
return self._bad_call_id(call_id, rq, header.call_id, full_data)
elif header.exception.exception_class_name != '':
if header.exception.exception_class_name != '':
# If we're in here it means a remote exception has happened.
exception_class = header.exception.exception_class_name
if exception_class in \
Expand All @@ -234,35 +213,11 @@ def _receive_rpc(self, call_id, rq, data=None, lock_timeout=10):
# The rpc is fully built!
return rpc

# Receive an RPC with incorrect call_id?
# 1. Acquire lock
# 2. Place raw data into missed_rpcs with key call_id
# 3. Notify all other threads to wake up (nothing will happen until you release the lock)
# 4. WHILE: Your call_id is not in the dictionary
# 4.5 Call wait() on the conditional and get comfy.
# 5. Pop your data out
# 6. Release the lock
def _bad_call_id(self, my_id, my_request, msg_id, data, lock_timeout=10):
with acquire_timeout(self.missed_rpcs_lock, lock_timeout) as acquired:
if acquired:
logger.debug("Received invalid RPC ID. Got: %s, Expected: %s.", msg_id, my_id)
self.missed_rpcs[msg_id] = data
self.missed_rpcs_condition.notifyAll()
while my_id not in self.missed_rpcs:
if self.shutting_down:
raise RegionServerException(region_client=self)
self.missed_rpcs_condition.wait(lock_timeout)
new_data = self.missed_rpcs.pop(my_id)
logger.debug("Another thread found my RPC! RPC ID: %s", my_id)
else:
logger.warning('Lock timeout bad_call to %s:%s', self.host, self.port)
raise RegionServerException(region_client=self)
return self._receive_rpc(my_id, my_request, data=new_data)

# Receives exactly n bytes from the socket. Will block until n bytes are
# received. If a socket is closed (RegionServer died) then raise an
# exception that goes all the way back to the main client
def _recv_n(self, sock, n):
@staticmethod
def _recv_n(sock, n):
partial_str = BytesIO()
partial_len = 0
while partial_len < n:
Expand Down Expand Up @@ -291,14 +246,13 @@ def NewClient(host, port, pool_size, secondary=False):
c = Client(host, port, secondary)
try:
c.pool_size = pool_size
c.thread_pool = ThreadPoolExecutor(pool_size)
for x in range(pool_size):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((c.host, int(port)))
_send_hello(s)
s.settimeout(2)
c.sock_pool.append(s)
c.read_lock_pool.append(Lock())
c.write_lock_pool.append(Lock())
except (socket.error, socket.timeout):
return None
return c
Expand Down