Skip to content

Commit

Permalink
Trial mode added for Qiskit Functions (#1571)
Browse files Browse the repository at this point in the history
* modified the model to support the new field

* created trial and fixed retrieving a function

* move the provider to the serializer

* remove unneeded change

* updated test to check the environment variable

* added allow_null in the serializer

* change env_var name
  • Loading branch information
Tansito authored Jan 21, 2025
1 parent ccd159a commit 1ca0a18
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 102 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.16 on 2025-01-17 15:59

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("auth", "0012_alter_user_first_name_max_length"),
("api", "0032_computeresource_gpu_job_gpu"),
]

operations = [
migrations.AddField(
model_name="job",
name="trial",
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name="program",
name="trial_instances",
field=models.ManyToManyField(
blank=True, related_name="program_trial_instances", to="auth.group"
),
),
migrations.AlterField(
model_name="program",
name="instances",
field=models.ManyToManyField(
blank=True, related_name="program_instances", to="auth.group"
),
),
]
30 changes: 17 additions & 13 deletions gateway/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ class Program(ExportModelOperationsMixin("program"), models.Model):
env_vars = models.TextField(null=False, blank=True, default="{}")
dependencies = models.TextField(null=False, blank=True, default="[]")

instances = models.ManyToManyField(Group, blank=True)
instances = models.ManyToManyField(
Group, blank=True, related_name="program_instances"
)
trial_instances = models.ManyToManyField(
Group, blank=True, related_name="program_trial_instances"
)
author = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
Expand Down Expand Up @@ -174,36 +179,35 @@ class Job(models.Model):
created = models.DateTimeField(auto_now_add=True, editable=False)
updated = models.DateTimeField(auto_now=True, null=True)

program = models.ForeignKey(to=Program, on_delete=models.SET_NULL, null=True)
arguments = models.TextField(null=False, blank=True, default="{}")
env_vars = models.TextField(null=False, blank=True, default="{}")
gpu = models.BooleanField(default=False, null=False)
logs = models.TextField(default="No logs yet.")
ray_job_id = models.CharField(max_length=255, null=True, blank=True)
result = models.TextField(null=True, blank=True)
author = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
)
status = models.CharField(
max_length=10,
choices=JOB_STATUSES,
default=QUEUED,
)
trial = models.BooleanField(default=False, null=False)
version = IntegerVersionField()

author = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
)
compute_resource = models.ForeignKey(
ComputeResource, on_delete=models.SET_NULL, null=True, blank=True
)
ray_job_id = models.CharField(max_length=255, null=True, blank=True)
logs = models.TextField(default="No logs yet.")

version = IntegerVersionField()

config = models.ForeignKey(
to=JobConfig,
on_delete=models.CASCADE,
default=None,
null=True,
blank=True,
)

gpu = models.BooleanField(default=False, null=False)
program = models.ForeignKey(to=Program, on_delete=models.SET_NULL, null=True)

def __str__(self):
return f"<Job {self.id} | {self.status}>"
Expand Down
17 changes: 14 additions & 3 deletions gateway/api/repositories/functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""
Repository implementation for Programs model
"""
import logging

import logging
from typing import List

from django.db.models import Q
from django.contrib.auth.models import Group

from api.models import Program as Function

from api.repositories.users import UserRepository


Expand Down Expand Up @@ -222,3 +221,15 @@ def get_function_by_permission(
)

return self.get_user_function(author=user, title=function_title)

def get_trial_instances(self, function: Function) -> List[Group]:
"""
Returns the details of the function groups from trial_instances field
Args:
function: the instance of the Function
Returns:
[Group]: list of available groups
"""
return function.trial_instances.all()
25 changes: 23 additions & 2 deletions gateway/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from django.db.models import Q
from rest_framework import serializers

from api.repositories.functions import FunctionRepository
from api.repositories.users import UserRepository
from api.utils import build_env_variables, encrypt_env_vars, sanitize_name
from .models import (
Provider,
Expand Down Expand Up @@ -182,6 +184,7 @@ class RunProgramSerializer(serializers.Serializer):
title = serializers.CharField(max_length=255)
arguments = serializers.CharField()
config = serializers.JSONField()
provider = serializers.CharField(required=False, allow_null=True)

def retrieve_one_by_title(self, title, author):
"""
Expand All @@ -202,12 +205,28 @@ def create(self, validated_data):

class RunJobSerializer(serializers.ModelSerializer):
"""
Job serializer for the /run and /run end-point
Job serializer for the /run and end-point
"""

class Meta:
model = Job

def is_trial(self, function: Program, author) -> bool:
"""
This method checks if a group with run permissions from the author
is assigned to a trial instance in a function
"""

function_repository = FunctionRepository()
user_repository = UserRepository()

trial_groups = function_repository.get_trial_instances(function=function)
user_run_groups = user_repository.get_groups_by_permissions(
user=author, permission_name=RUN_PROGRAM_PERMISSION
)

return any(group in trial_groups for group in user_run_groups)

def create(self, validated_data):
logger.info("Creating Job with RunExistingJobSerializer")
status = Job.QUEUED
Expand All @@ -219,15 +238,17 @@ def create(self, validated_data):
token = validated_data.pop("token")
carrier = validated_data.pop("carrier")

trial = self.is_trial(program, author)
job = Job(
trial=trial,
status=status,
program=program,
arguments=arguments,
author=author,
config=config,
)

env = encrypt_env_vars(build_env_variables(token, job, arguments))
env = encrypt_env_vars(build_env_variables(token, job, trial, arguments))
try:
env["traceparent"] = carrier["traceparent"]
except KeyError:
Expand Down
5 changes: 4 additions & 1 deletion gateway/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def decrypt_string(string: str) -> str:
return fernet.decrypt(string.encode("utf-8")).decode("utf-8")


def build_env_variables(token, job: Job, args: str = None) -> Dict[str, str]:
def build_env_variables(
token, job: Job, trial_mode: bool, args: str = None
) -> Dict[str, str]:
"""Builds env variables for job.
Args:
Expand Down Expand Up @@ -149,6 +151,7 @@ def build_env_variables(token, job: Job, args: str = None) -> Dict[str, str]:
"ENV_JOB_GATEWAY_HOST": str(settings.SITE_HOST),
"ENV_JOB_ID_GATEWAY": str(job.id),
"ENV_JOB_ARGUMENTS": arguments,
"ENV_ACCESS_TRIAL": str(trial_mode),
},
**extra,
}
Expand Down
79 changes: 22 additions & 57 deletions gateway/api/views/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import logging
import os

from django.db.models import Q
from django.contrib.auth.models import Group, Permission

# pylint: disable=duplicate-code
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
Expand Down Expand Up @@ -56,7 +53,7 @@ class ProgramViewSet(viewsets.GenericViewSet):

BASE_NAME = "programs"

program_repository = FunctionRepository()
function_repository = FunctionRepository()

@staticmethod
def get_serializer_job_config(*args, **kwargs):
Expand Down Expand Up @@ -104,45 +101,6 @@ def get_serializer_class(self):
def get_object(self):
logger.warning("ProgramViewSet.get_object not implemented")

def get_run_queryset(self):
"""get run queryset"""
author = self.request.user

logger.info("ProgramViewSet get run_program permission")
run_program_permission = Permission.objects.get(codename=RUN_PROGRAM_PERMISSION)

# Groups logic
user_criteria = Q(user=author)
run_permission_criteria = Q(permissions=run_program_permission)
author_groups_with_run_permissions = Group.objects.filter(
user_criteria & run_permission_criteria
)
author_groups_with_run_permissions_count = (
author_groups_with_run_permissions.count()
)
logger.info(
"ProgramViewSet get author [%s] groups [%s]",
author.id,
author_groups_with_run_permissions_count,
)

# Programs logic
author_criteria = Q(author=author)
author_groups_with_run_permissions_criteria = Q(
instances__in=author_groups_with_run_permissions
)
author_programs = Program.objects.filter(
author_criteria | author_groups_with_run_permissions_criteria
).distinct()
author_programs_count = author_programs.count()
logger.info(
"ProgramViewSet get author [%s] programs [%s]",
author.id,
author_programs_count,
)

return author_programs

def list(self, request):
"""List programs:"""
tracer = trace.get_tracer("gateway.tracer")
Expand All @@ -156,18 +114,18 @@ def list(self, request):
# Serverless filter only returns functions created by the author
# with the next criterias:
# - user is the author of the function and there is no provider
functions = self.program_repository.get_user_functions(author)
functions = self.function_repository.get_user_functions(author)
elif type_filter == TypeFilter.CATALOG:
# Catalog filter only returns providers functions that user has access:
# author has view permissions and the function has a provider assigned
functions = (
self.program_repository.get_provider_functions_by_permission(
self.function_repository.get_provider_functions_by_permission(
author, permission_name=RUN_PROGRAM_PERMISSION
)
)
else:
# If filter is not applied we return author and providers functions together
functions = self.program_repository.get_functions_by_permission(
functions = self.function_repository.get_functions_by_permission(
author, permission_name=VIEW_PROGRAM_PERMISSION
)

Expand Down Expand Up @@ -241,26 +199,33 @@ def run(self, request):
serializer = self.get_serializer_run_program(data=request.data)
if not serializer.is_valid():
logger.error(
"RunExistingProgramSerializer validation failed:\n %s",
"RunProgramSerializer validation failed:\n %s",
serializer.errors,
)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

author_program = self.get_run_queryset()
author = request.user
title = sanitize_name(serializer.data.get("title"))
program = author_program.filter(title=title).first()
if program is None:
logger.error("Qiskit Pattern [%s] was not found.", title)
# The sanitization should happen in the serializer
# but it's here until we can refactor the /run end-point
provider_name = sanitize_name(serializer.data.get("provider"))
function_title = sanitize_name(serializer.data.get("title"))
function = self.function_repository.get_function_by_permission(
user=author,
permission_name=RUN_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if function is None:
logger.error("Qiskit Pattern [%s] was not found.", function_title)
return Response(
{"message": f"Qiskit Pattern [{title}] was not found."},
{"message": f"Qiskit Pattern [{function_title}] was not found."},
status=status.HTTP_404_NOT_FOUND,
)

jobconfig = None
config_json = serializer.data.get("config")
if config_json:
logger.info("Configuration for [%s] was found.", title)
logger.info("Configuration for [%s] was found.", function_title)
job_config_serializer = self.get_serializer_job_config(data=config_json)
if not job_config_serializer.is_valid():
logger.error(
Expand All @@ -279,7 +244,7 @@ def run(self, request):
token = ""
if request.auth:
token = request.auth.token.decode()
job_data = {"arguments": arguments, "program": program.id}
job_data = {"arguments": arguments, "program": function.id}
job_serializer = self.get_serializer_run_job(data=job_data)
if not job_serializer.is_valid():
logger.error(
Expand Down Expand Up @@ -309,14 +274,14 @@ def get_by_title(self, request, title):
)

if provider_name:
function = self.program_repository.get_provider_function_by_permission(
function = self.function_repository.get_provider_function_by_permission(
author=author,
permission_name=VIEW_PROGRAM_PERMISSION,
title=function_title,
provider_name=provider_name,
)
else:
function = self.program_repository.get_user_function(
function = self.function_repository.get_user_function(
author=author, title=function_title
)

Expand Down
Loading

0 comments on commit 1ca0a18

Please sign in to comment.