Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini Support #469

Merged
merged 8 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions refact_known_models/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,55 @@
"pp1000t_prompt": 150,
"pp1000t_generated": 600, # TODO: don't know the price
"filter_caps": ["chat", "completion"],
}
}
},

# gemini and gemma bear the same tokenizer
# according to https://medium.com/google-cloud/a-gemini-and-gemma-tokenizer-in-java-e18831ac9677
# downloadable tokenizer.json does not exist for gemini, proposed solution to use vertexai lib in python uses web requests
# for pricing consult: https://ai.google.dev/pricing
# pricing below is assumed for <= 128_000 context is used

"gemini-2.0-flash-exp": {
"backend": "litellm",
"provider": "gemini",
"tokenizer_path": "google/gemma-7b",
valaises marked this conversation as resolved.
Show resolved Hide resolved
"resolve_as": "gemini-2.0-flash-exp",
"T": 1_048_576,
"T_out": 8_192,
"pp1000t_prompt": 75, # $0.075 / 1M tokens
"pp1000t_generated": 300, # $0.30 / 1M tokens
"filter_caps": ["chat", "tools", "completion", "multimodal"],
},
"gemini-1.5-flash": {
"backend": "litellm",
"provider": "gemini",
"tokenizer_path": "google/gemma-7b",
"resolve_as": "gemini-1.5-flash",
"T": 1_048_576,
"T_out": 8_192,
"pp1000t_prompt": 75, # $0.075 / 1M tokens
"pp1000t_generated": 300, # $0.30 / 1M tokens
"filter_caps": ["chat", "tools", "completion", "multimodal"],
},
"gemini-1.5-flash-8b": {
"backend": "litellm",
"provider": "gemini",
"tokenizer_path": "google/gemma-7b",
"resolve_as": "gemini-1.5-flash-8b",
"T": 1_048_576,
"T_out": 8_192,
"pp1000t_prompt": 37.5, # $0.0375 / 1M tokens
"pp1000t_generated": 150, # $0.15 / 1M tokens
"filter_caps": ["chat", "tools", "completion", "multimodal"],
},
"gemini-1.5-pro": {
"backend": "litellm",
"provider": "gemini",
"tokenizer_path": "google/gemma-7b",
"resolve_as": "gemini-1.5-pro",
"T": 2_097_152,
"T_out": 8_192,
"pp1000t_prompt": 1250, # $1.25 / 1M tokens
"pp1000t_generated": 5000, # $5.00 / 1M tokens
"filter_caps": ["chat", "tools", "completion", "multimodal"],
}}
3 changes: 3 additions & 0 deletions refact_utils/finetune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def _add_results_for_passthrough_provider(provider: str) -> None:
if data.get('cerebras_api_enable'):
_add_results_for_passthrough_provider('cerebras')

if data.get('gemini_api_enable'):
_add_results_for_passthrough_provider('gemini')

if data.get('groq_api_enable'):
_add_results_for_passthrough_provider('groq')

Expand Down
7 changes: 6 additions & 1 deletion refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fastapi import APIRouter, HTTPException, Query, Header
from fastapi.responses import Response, StreamingResponse

from refact_utils.huggingface.utils import huggingface_hub_token
valaises marked this conversation as resolved.
Show resolved Hide resolved
from refact_utils.scripts import env
from refact_utils.finetune.utils import running_models_and_loras
from refact_webgui.webgui.selfhost_model_resolve import resolve_model_context_size
Expand Down Expand Up @@ -233,6 +234,7 @@ def _integrations_env_setup(env_var_name: str, api_key_name: str, api_enable_nam
_integrations_env_setup("ANTHROPIC_API_KEY", "anthropic_api_key", "anthropic_api_enable")
_integrations_env_setup("GROQ_API_KEY", "groq_api_key", "groq_api_enable")
_integrations_env_setup("CEREBRAS_API_KEY", "cerebras_api_key", "cerebras_api_enable")
_integrations_env_setup("GEMINI_API_KEY", "gemini_api_key", "gemini_api_enable")

def _models_available_dict_rewrite(self, models_available: List[str]) -> Dict[str, Any]:
rewrite_dict = {}
Expand Down Expand Up @@ -337,7 +339,10 @@ async def _passthrough_tokenizer(self, model_path: str) -> str:
try:
async with aiohttp.ClientSession() as session:
tokenizer_url = f"https://huggingface.co/{model_path}/resolve/main/tokenizer.json"
async with session.get(tokenizer_url) as resp:
headers = {}
if hf_token := huggingface_hub_token():
headers["Authorization"] = f"Bearer {hf_token}"
async with session.get(tokenizer_url, headers=headers) as resp:
valaises marked this conversation as resolved.
Show resolved Hide resolved
return await resp.text()
except:
raise HTTPException(404, detail=f"can't load tokenizer.json for passthrough {model_path}")
Expand Down
2 changes: 2 additions & 0 deletions refact_webgui/webgui/selfhost_model_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def first_run(self):
"anthropic_api_enable": False,
"groq_api_enable": False,
"cerebras_api_enable": False,
"gemini_api_enable": False,
}
self.models_to_watchdog_configs(default_config)

Expand Down Expand Up @@ -259,6 +260,7 @@ def model_assignment(self):
j = json.load(open(env.CONFIG_INFERENCE, "r"))
j["groq_api_enable"] = j.get("groq_api_enable", False)
j["cerebras_api_enable"] = j.get("cerebras_api_enable", False)
j["gemini_api_enable"] = j.get("gemini_api_enable", False)
else:
j = {"model_assign": {}}

Expand Down
2 changes: 2 additions & 0 deletions refact_webgui/webgui/selfhost_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def _add_models_for_passthrough_provider(provider):
_add_models_for_passthrough_provider('groq')
if j.get("cerebras_api_enable"):
_add_models_for_passthrough_provider('cerebras')
if j.get("gemini_api_enable"):
_add_models_for_passthrough_provider('gemini')

return self._models_available

Expand Down
5 changes: 5 additions & 0 deletions refact_webgui/webgui/static/tab-model-hosting.html
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ <h3>3rd Party APIs</h3>
<input class="form-check-input" type="checkbox" role="switch" id="enable_cerebras">
<label class="form-check-label" for="enable_cerebras">Enable Cerebras API</label>
</div>
<div class="form-check form-switch">
<input class="form-check-input" type="checkbox" role="switch" id="enable_gemini">
<label class="form-check-label" for="enable_gemini">Enable Gemini API</label>
</div>

<div class="chat-enabler-status">
To enable Chat GPT add your API key in the <span id="redirect2credentials" class="main-tab-button fake-link" data-tab="settings">API Keys tab</span>.
</div>
Expand Down
5 changes: 5 additions & 0 deletions refact_webgui/webgui/static/tab-model-hosting.js
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ function get_models()
integration_switch_init('enable_anthropic', models_data['anthropic_api_enable']);
integration_switch_init('enable_groq', models_data['groq_api_enable']);
integration_switch_init('enable_cerebras', models_data['cerebras_api_enable']);
integration_switch_init('enable_gemini', models_data['gemini_api_enable']);


const more_gpus_notification = document.querySelector('.model-hosting-error');
if(data.hasOwnProperty('more_models_than_gpus') && data.more_models_than_gpus) {
Expand All @@ -144,6 +146,8 @@ function save_model_assigned() {
const anthropic_enable = document.querySelector('#enable_anthropic');
const groq_enable = document.querySelector('#enable_groq');
const cerebras_enable = document.querySelector('#enable_cerebras');
const gemini_enable = document.querySelector('#enable_gemini');

const data = {
model_assign: {
...models_data.model_assign,
Expand All @@ -152,6 +156,7 @@ function save_model_assigned() {
anthropic_api_enable: anthropic_enable.checked,
groq_api_enable: groq_enable.checked,
cerebras_api_enable: cerebras_enable.checked,
gemini_api_enable: gemini_enable.checked,
};
console.log(data);
fetch("/tab-host-models-assign", {
Expand Down
3 changes: 3 additions & 0 deletions refact_webgui/webgui/static/tab-settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ <h2>API Integrations</h2>
<input type="text" name="groq_api_key" value="" class="form-control" id="groq_api_key">
<label for="cerebras_api_key" class="form-label mt-4">Cerebras API Key</label>
<input type="text" name="cerebras_api_key" value="" class="form-control" id="cerebras_api_key">
<label for="gemini_api_key" class="form-label mt-4">Gemini API Key</label>
<input type="text" name="gemini_api_key" value="" class="form-control" id="gemini_api_key">

<!-- <div class="d-flex flex-row-reverse mt-3"><button type="button" class="btn btn-primary" id="integrations-save">Save</button></div>-->
</div>
</div>
Expand Down
8 changes: 8 additions & 0 deletions refact_webgui/webgui/static/tab-settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ function save_integration_api_keys() {
const anthropic_api_key = document.getElementById('anthropic_api_key');
const groq_api_key = document.getElementById('groq_api_key');
const cerebras_api_key = document.getElementById('cerebras_api_key');
const gemini_api_key = document.getElementById("gemini_api_key");

const huggingface_api_key = document.getElementById('huggingface_api_key');
fetch("/tab-settings-integrations-save", {
method: "POST",
Expand All @@ -185,6 +187,8 @@ function save_integration_api_keys() {
anthropic_api_key: anthropic_api_key.getAttribute('data-value'),
groq_api_key: groq_api_key.getAttribute('data-value'),
cerebras_api_key: cerebras_api_key.getAttribute('data-value'),
gemini_api_key: gemini_api_key.getAttribute("data-value"),

huggingface_api_key: huggingface_api_key.getAttribute('data-value'),
})
})
Expand All @@ -195,6 +199,8 @@ function save_integration_api_keys() {
anthropic_api_key.setAttribute('data-saved-value', anthropic_api_key.getAttribute('data-value'))
groq_api_key.setAttribute('data-saved-value', groq_api_key.getAttribute('data-value'))
cerebras_api_key.setAttribute('data-saved-value', cerebras_api_key.getAttribute('data-value'))
gemini_api_key.setAttribute('data-saved-value', gemini_api_key.getAttribute('data-value'))

huggingface_api_key.setAttribute('data-saved-value', huggingface_api_key.getAttribute('data-value'))
});
}
Expand Down Expand Up @@ -230,6 +236,8 @@ export function tab_settings_integrations_get() {
integrations_input_init(document.getElementById('anthropic_api_key'), data['anthropic_api_key']);
integrations_input_init(document.getElementById('groq_api_key'), data['groq_api_key']);
integrations_input_init(document.getElementById('cerebras_api_key'), data['cerebras_api_key']);
integrations_input_init(document.getElementById('gemini_api_key'), data['gemini_api_key']);

integrations_input_init(document.getElementById('huggingface_api_key'), data['huggingface_api_key']);
});
}
Expand Down
1 change: 1 addition & 0 deletions refact_webgui/webgui/tab_models_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class TabHostModelsAssign(BaseModel):
anthropic_api_enable: bool = False
groq_api_enable: bool = False
cerebras_api_enable: bool = False
gemini_api_enable: bool = False

model_config = ConfigDict(protected_namespaces=()) # avoiding model_ namespace protection

Expand Down
2 changes: 2 additions & 0 deletions refact_webgui/webgui/tab_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Integrations(BaseModel):
anthropic_api_key: Optional[str] = None
groq_api_key: Optional[str] = None
cerebras_api_key: Optional[str] = None
gemini_api_key: Optional[str] = None

huggingface_api_key: Optional[str] = None

def __init__(self, models_assigner: ModelAssigner, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PyPackage:
"refact_webgui": PyPackage(
requires=["aiohttp", "aiofiles", "cryptography", "fastapi==0.100.0", "giturlparse", "pydantic>=2",
"starlette==0.27.0", "uvicorn", "uvloop", "termcolor", "python-multipart", "more_itertools",
"scyllapy==1.3.0", "pandas>=2.0.3", "litellm>=1.49.5"],
"scyllapy==1.3.0", "pandas>=2.0.3", "litellm>=1.55.3"],
requires_packages=["refact_known_models", "refact_utils"],
data=["webgui/static/*", "webgui/static/components/modals/*",
"webgui/static/dashboards/*", "webgui/static/assets/*", "webgui/static/utils/*",]),
Expand All @@ -45,7 +45,7 @@ class PyPackage:
"bitsandbytes", "safetensors", "peft", "triton",
"torchinfo", "mpi4py", "deepspeed>=0.15.3",
"sentence-transformers", "huggingface-hub>=0.26.2",
"aiohttp", "setproctitle"],
"aiohttp", "setproctitle", "google-auth>=2.37.0"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

google-auth? for what purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

litellm requires it to refer to google api

optional=["ninja", "flash-attn"],
requires_packages=["refact_known_models", "refact_data_pipeline",
"refact_webgui", "refact_utils"],
Expand Down
Loading