Skip to content

Commit

Permalink
Added Bedrock as a provider (#211)
Browse files Browse the repository at this point in the history
* Bedrock client

* Bedrock entry point and installer
  • Loading branch information
machinewrapped authored Jan 26, 2025
1 parent 8ef080c commit c30c8fe
Show file tree
Hide file tree
Showing 5 changed files with 507 additions and 4 deletions.
142 changes: 142 additions & 0 deletions PySubtitle/Providers/Bedrock/BedrockClient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging

def _structure_messages(messages : list[str]) -> list[dict]:
"""
Structure the messages to be sent to the API
"""
return [
{
'role' : message['role'],
'content' : [{ 'text': message['content'] }]
}
for message in messages]

try:
import boto3

from PySubtitle.Helpers import FormatMessages
from PySubtitle.Translation import Translation
from PySubtitle.TranslationClient import TranslationClient
from PySubtitle.TranslationPrompt import TranslationPrompt
from PySubtitle.SubtitleError import TranslationImpossibleError, TranslationResponseError

class BedrockClient(TranslationClient):
"""
Handles communication with Amazon Bedrock to request translations
"""
def __init__(self, settings : dict):
super().__init__(settings)

logging.info(f"Translating with Bedrock model {self.model_id}, using region: {self.aws_region}")

self.client = boto3.client(
'bedrock-runtime',
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_access_key,
region_name=self.aws_region
)

@property
def access_key(self):
return self.settings.get('access_key')

@property
def secret_access_key(self):
return self.settings.get('secret_access_key')

@property
def aws_region(self):
return self.settings.get('aws_region')

@property
def model_id(self):
return self.settings.get('model')

@property
def max_tokens(self):
return self.settings.get('max_tokens', 4096)

def _request_translation(self, prompt : TranslationPrompt, temperature : float = None) -> Translation:
"""
Request a translation based on the provided prompt
"""
if not self.access_key:
raise TranslationImpossibleError('Access key must be set in .env or provided as an argument')

if not self.secret_access_key:
raise TranslationImpossibleError('Secret access key must be set in .env or provided as an argument')

if not self.aws_region:
raise TranslationImpossibleError('AWS region must be set in .env or provided as an argument')

if not self.model_id:
raise TranslationImpossibleError('Model ID must be provided as an argument')

logging.debug(f"Messages:\n{FormatMessages(prompt.messages)}")

content = _structure_messages(prompt.messages)

reponse = self._send_messages(prompt.system_prompt, content, temperature=temperature)

translation = Translation(reponse) if reponse else None

return translation

def _send_messages(self, system_prompt : str, messages : list[str], temperature : float = None) -> dict:
"""
Make a request to the Amazon Bedrock API to provide a translation
"""
if self.aborted:
return None

try:
inference_config = {
'temperature' : temperature or 0.0,
'maxTokens' : self.max_tokens
}

if self.supports_system_prompt and system_prompt:
result = self.client.converse(
modelId=self.model_id,
messages=messages,
system = [{ 'text' : system_prompt }],
inferenceConfig = inference_config
)
else:
result = self.client.converse(
modelId=self.model_id,
messages=messages,
inferenceConfig = inference_config
)

if self.aborted:
return None

output = result.get('output')

if not output:
raise TranslationResponseError("No output returned in the response", response=result)

response = {}

if 'stopReason' in result:
response['finish_reason'] = result['stopReason']

if 'usage' in result:
response['prompt_tokens'] = result['usage'].get('inputTokens')
response['output_tokens'] = result['usage'].get('outputTokens')
response['total_tokens'] = result['usage'].get('totalTokens')

message = output.get('message')
if message and message.get('role') == 'assistant':
text = [ content.get('text') for content in message.get('content',[]) ]
response['text'] = '\n'.join(text)

# Return the response if the API call succeeds
return response

except Exception as e:
raise TranslationImpossibleError(f"Error communicating with Bedrock: {str(e)}", error=e)

except ImportError:
logging.debug("AWS Boto3 SDK not installed.")
171 changes: 171 additions & 0 deletions PySubtitle/Providers/Provider_Bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
import os

try:
import boto3

from PySubtitle.Providers.Bedrock.BedrockClient import BedrockClient
from PySubtitle.TranslationClient import TranslationClient
from PySubtitle.TranslationProvider import TranslationProvider

class BedrockProvider(TranslationProvider):
name = "Bedrock"

information = """
<p>Bedrock API provider.</p>
<p>NOTE: Amazon Bedrock is not recommended for most users. The setup is complex, and model capabilities can be unpredictable - some models do not fulfil translation requests.</p>
<p>To use Bedrock as a provider you need to provide an access key and secret access key. This involves setting up an IAM user in the AWS console and <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html">enabling model access</a> for them.</p>
<p>You must also specify an AWS region to use for requests - this will affect the available models.</p>
"""

def __init__(self, settings : dict):
super().__init__(self.name, {
"access_key": settings.get('access_key', os.getenv('AWS_ACCESS_KEY_ID')),
"secret_access_key": settings.get('secret_access_key', os.getenv('AWS_SECRET_ACCESS_KEY')),
"aws_region": settings.get('aws_region', os.getenv('AWS_REGION', 'eu-west-1')),
"model": settings.get('model', 'Amazon-Titan-Text-G1'),
"max_tokens": settings.get('max_tokens', 8192),
#TODO: add options for supports system messages and prompt?
'temperature': settings.get('temperature', 0.0),
"rate_limit": settings.get('rate_limit', None)
})

self.refresh_when_changed = ['access_key', 'secret_access_key', 'aws_region']
self._regions = None

@property
def access_key(self):
return self.settings.get('access_key')

@property
def secret_access_key(self):
return self.settings.get('secret_access_key')

@property
def aws_region(self):
return self.settings.get('aws_region')

@property
def regions(self):
if not self._regions:
self._regions = self.get_aws_regions()
return self._regions

def GetTranslationClient(self, settings : dict) -> TranslationClient:
client_settings = self.settings.copy()
client_settings.update(settings)
client_settings.update({
'supports_conversation': True,
'supports_system_messages': False,
'supports_system_prompt': False # Apparently some models do?
})
return BedrockClient(client_settings)

def GetOptions(self) -> dict:
options = {
'access_key': (str, "An AWS access key is required"),
'secret_access_key': (str, "An AWS secret access key is required"),
}

regions = self.regions
if not regions:
options['aws_region'] = (str, "The AWS region to use for requests must be specified.")
else:
options['aws_region'] = (regions, "The AWS region to use for requests.")

if self.access_key and self.secret_access_key and self.aws_region:
models = self.available_models or ["Unable to retrieve model list"]
options.update({
'model': (models, "AI model to use as the translator. Model access must be enabled in the AWS Console. Some models may not translate the subtitles."),
'max_tokens': (int, "The maximum number of tokens to generate in a single request"),
'rate_limit': (float, "The maximum number of requests to make per minute")
})
return options

def GetInformation(self) -> str:
information = self.information
if not self.ValidateSettings():
information = information + f"<p>{self.validation_message}</p>"
return information

def GetAvailableModels(self) -> list[str]:
"""
Returns a list of possible values for the model
"""
try:
if not self.access_key or not self.secret_access_key:
logging.debug("AWS access keys not provided")
return []

client = boto3.client(
'bedrock',
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_access_key,
region_name=self.aws_region
)

response = client.list_foundation_models()

if not response or 'modelSummaries' not in response:
return []

model_details = response['modelSummaries']

# Define valid statuses for filtering
valid_status = ['ACTIVE','AVAILABLE']

# Filter for translation models that are in the valid statuses
translation_models = [
model['modelId']
for model in model_details
if 'TEXT' in model.get('inputModalities', []) and model.get('modelLifecycle', []).get('status') in valid_status
]

# If no translation-specific models are available, fall back to all available models
model_list = translation_models or [ model['modelId'] for model in model_details]

# Return sorted list of model IDs
return sorted(model_list)

except Exception as e:
logging.error(f"Unable to retrieve available AI models: {str(e)}")
return []

def ValidateSettings(self) -> bool:
"""
Validate the settings for the provider
"""
if not self.access_key:
self.validation_message = "AWS access key is required"
return False

if not self.secret_access_key:
self.validation_message = "AWS secret access key is required"
return False

if not self.aws_region:
self.validation_message = "AWS region is required"
return False

return True

def _allow_multithreaded_translation(self) -> bool:
"""
Assume the Bedrock provider can handle multiple requests
"""
return True

def get_aws_regions(self) -> list[str]:
"""
Fetches a list of AWS regions that support Bedrock from the boto3 SDK (may become out of date)
"""
try:
session = boto3.session.Session()
bedrock_regions = session.get_available_regions("bedrock")
return sorted(bedrock_regions)
except Exception as e:
print(f"Error fetching AWS regions: {e}")
return []

except ImportError:
logging.info("Amazon Boto3 SDK is not installed. Bedrock provider will not be available")
Loading

0 comments on commit c30c8fe

Please sign in to comment.