Skip to content

Commit

Permalink
Add observation location to gemini identification request
Browse files Browse the repository at this point in the history
  • Loading branch information
eleurent committed Sep 8, 2024
1 parent 711de99 commit 8cb525d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
33 changes: 24 additions & 9 deletions backend/nature_go/identification/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@
from django.db.models import Q

CONFIGURED = False
PROMPT_PREFIX = """Identify the species in the picture.
PROMPT_PREFIX = """Identify the species in the picture, taking metadata into account.
Use this JSON schema:
Result = {"commonName": str, "scientificName": str, "confidence": float}
Return: list[Result].
"""
BIRD_ID_FEW_SHOTS = {
'[{"commonName": "Common starling", "scientificName": "Sturnus vulgaris", "confidence": 0.90}, {"commonName": "Spotless starling", "scientificName": "Sturnus unicolor", "confidence": 0.05}]': 'https://preview.redd.it/whats-this-bird-v0-sx6m25i5pm0d1.jpeg?width=1080&crop=smart&auto=webp&s=004ec0c9d413ca38001a069172b4e35f729aabec',
'[{"commonName": "Eurasian jay", "scientificName": "Garrulus glandarius", confidence: 1.0}]': 'https://preview.redd.it/what-kind-of-bird-is-this-sighted-in-rome-italy-v0-wzm5jvwva20d1.jpg?width=1080&crop=smart&auto=webp&s=47eb746fd839673f4cceafa4896dd118d21b897d'
}
# Metadata, url, response
BIRD_ID_FEW_SHOTS = [
{
'metadata': '{"latitude": 51.495780, "longitude": -0.176399}'
'url': 'https://preview.redd.it/whats-this-bird-v0-sx6m25i5pm0d1.jpeg?width=1080&crop=smart&auto=webp&s=004ec0c9d413ca38001a069172b4e35f729aabec',
'response': '[{"commonName": "Common starling", "scientificName": "Sturnus vulgaris", "confidence": 0.60}, {"commonName": "Spotless starling", "scientificName": "Sturnus unicolor", "confidence": 0.40}]'
},
{
'metadata': '{"latitude": 41.909442, "longitude": 12.503025}'
'url': 'https://preview.redd.it/what-kind-of-bird-is-this-sighted-in-rome-italy-v0-wzm5jvwva20d1.jpg?width=1080&crop=smart&auto=webp&s=47eb746fd839673f4cceafa4896dd118d21b897d'
'response': '[{"commonName": "Eurasian jay", "scientificName": "Garrulus glandarius", confidence: 0.9}]'
}
]


def configure():
Expand All @@ -27,12 +36,17 @@ def configure():
CONFIGURED = True


def gemini_identify_few_shot(image_path: str, few_shots: list[tuple[str, str]], model_id: str = 'models/gemini-1.5-flash-latest'):
def gemini_identify_few_shot(
image_path: str,
location: str,
few_shots: list[dict[str, str]],
model_id: str = 'models/gemini-1.5-flash-latest'
):
"""Identify a species though the Gemini API
Args:
image_path (str): path to an image file
few_shots (list): list of (expected_response, image_url) pairs
few_shots (list): list of {metadata, image_url, expected_response} dicts
Returns:
str: response for the input image
Expand All @@ -44,9 +58,10 @@ def load_image_from_url(image_url: str) -> PIL.Image:
return PIL.Image.open(io.BytesIO(response.content))

multimodal_model = genai.GenerativeModel(model_id, generation_config={"response_mime_type": "application/json"})
examples = [(load_image_from_url(image_url), prompt) for prompt, image_url in few_shots.items()]
examples = [(few_shot['metadata'], load_image_from_url(few_shot['url']), few_shot['response']) for few_shot in few_shots]
image = PIL.Image.open(image_path)
contents = (PROMPT_PREFIX,) + sum(examples, ()) + (image,)
metadata = location
contents = (PROMPT_PREFIX,) + sum(examples, ()) + (metadata, image,)
response = multimodal_model.generate_content(contents)

# Parse response
Expand Down
4 changes: 3 additions & 1 deletion backend/nature_go/observation/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def create(self, request, *args, **kwargs):
observation.save()
elif observation.type == Species.BIRD_TYPE:
response = gemini.gemini_identify_few_shot(
image_path=observation.image.path, few_shots=gemini.BIRD_ID_FEW_SHOTS)
image_path=observation.image.path,
location=observation.location,
few_shots=gemini.BIRD_ID_FEW_SHOTS)
observation.identification_response = serialize_identification_response(response)
observation.save()

Expand Down

0 comments on commit 8cb525d

Please sign in to comment.