From dd96df50f711e1bf654198eafa46d94c50c71a3d Mon Sep 17 00:00:00 2001 From: Alexander Joham <73483450+alexjoham@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:57:21 +0200 Subject: [PATCH] General: Track token usage of LLM service requests (#9455) --- .../artemis/athena/dto/ResponseMetaDTO.java | 17 ++ .../AthenaFeedbackSuggestionsService.java | 54 +++- .../aet/artemis/core/domain/LLMRequest.java | 14 ++ .../artemis/core/domain/LLMServiceType.java | 8 + .../core/domain/LLMTokenUsageRequest.java | 104 ++++++++ .../core/domain/LLMTokenUsageTrace.java | 111 +++++++++ .../LLMTokenUsageRequestRepository.java | 14 ++ .../LLMTokenUsageTraceRepository.java | 14 ++ .../core/service/LLMTokenUsageService.java | 143 +++++++++++ .../iris/dto/IrisChatWebsocketDTO.java | 8 +- .../IrisCompetencyGenerationService.java | 36 ++- .../iris/service/pyris/PyrisJobService.java | 19 +- .../pyris/PyrisStatusUpdateService.java | 38 +-- .../dto/chat/PyrisChatStatusUpdateDTO.java | 3 +- .../PyrisCompetencyStatusUpdateDTO.java | 4 +- .../pyris/dto/data/PyrisLLMCostDTO.java | 4 + .../pyris/job/CompetencyExtractionJob.java | 8 +- .../iris/service/pyris/job/CourseChatJob.java | 7 +- .../service/pyris/job/ExerciseChatJob.java | 7 +- .../job/TrackedSessionBasedPyrisJob.java | 14 ++ .../AbstractIrisChatSessionService.java | 78 +++++- .../session/IrisCourseChatSessionService.java | 38 +-- .../IrisExerciseChatSessionService.java | 37 +-- .../IrisTextExerciseChatSessionService.java | 6 +- .../websocket/IrisChatWebsocketService.java | 10 +- .../changelog/20241018053210_changelog.xml | 49 ++++ .../resources/config/liquibase/master.xml | 1 + .../iris/IrisChatMessageIntegrationTest.java | 2 +- .../IrisChatTokenTrackingIntegrationTest.java | 230 ++++++++++++++++++ .../artemis/iris/IrisChatWebsocketTest.java | 2 +- ...isCompetencyGenerationIntegrationTest.java | 6 +- ...extExerciseChatMessageIntegrationTest.java | 2 +- 32 files changed, 976 insertions(+), 112 deletions(-) create mode 100644 src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java create mode 100644 src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java create mode 100644 src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml create mode 100644 src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java diff --git a/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java b/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java new file mode 100644 index 000000000000..44d36a033552 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/athena/dto/ResponseMetaDTO.java @@ -0,0 +1,17 @@ +package de.tum.cit.aet.artemis.athena.dto; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; + +import de.tum.cit.aet.artemis.core.domain.LLMRequest; + +/** + * DTO representing the meta information in the Athena response. + */ +@JsonInclude(JsonInclude.Include.NON_EMPTY) +public record ResponseMetaDTO(TotalUsage totalUsage, List llmRequests) { + + public record TotalUsage(Integer numInputTokens, Integer numOutputTokens, Integer numTotalTokens, Float cost) { + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java b/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java index d9c81849b396..210b3c7ba859 100644 --- a/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java +++ b/src/main/java/de/tum/cit/aet/artemis/athena/service/AthenaFeedbackSuggestionsService.java @@ -17,10 +17,18 @@ import de.tum.cit.aet.artemis.athena.dto.ExerciseBaseDTO; import de.tum.cit.aet.artemis.athena.dto.ModelingFeedbackDTO; import de.tum.cit.aet.artemis.athena.dto.ProgrammingFeedbackDTO; +import de.tum.cit.aet.artemis.athena.dto.ResponseMetaDTO; import de.tum.cit.aet.artemis.athena.dto.SubmissionBaseDTO; import de.tum.cit.aet.artemis.athena.dto.TextFeedbackDTO; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.User; import de.tum.cit.aet.artemis.core.exception.ConflictException; import de.tum.cit.aet.artemis.core.exception.NetworkingException; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.exercise.domain.Exercise; +import de.tum.cit.aet.artemis.exercise.domain.Submission; +import de.tum.cit.aet.artemis.exercise.domain.participation.StudentParticipation; import de.tum.cit.aet.artemis.modeling.domain.ModelingExercise; import de.tum.cit.aet.artemis.modeling.domain.ModelingSubmission; import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; @@ -48,20 +56,24 @@ public class AthenaFeedbackSuggestionsService { private final AthenaDTOConverterService athenaDTOConverterService; + private final LLMTokenUsageService llmTokenUsageService; + /** * Create a new AthenaFeedbackSuggestionsService to receive feedback suggestions from the Athena service. * * @param athenaRestTemplate REST template used for the communication with Athena * @param athenaModuleService Athena module serviced used to determine the urls for different modules - * @param athenaDTOConverterService Service to convert exr + * @param athenaDTOConverterService Service to convert exrcises and submissions to DTOs + * @param llmTokenUsageService Service to store the usage of LLM tokens */ public AthenaFeedbackSuggestionsService(@Qualifier("athenaRestTemplate") RestTemplate athenaRestTemplate, AthenaModuleService athenaModuleService, - AthenaDTOConverterService athenaDTOConverterService) { + AthenaDTOConverterService athenaDTOConverterService, LLMTokenUsageService llmTokenUsageService) { textAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOText.class); programmingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOProgramming.class); modelingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOModeling.class); this.athenaDTOConverterService = athenaDTOConverterService; this.athenaModuleService = athenaModuleService; + this.llmTokenUsageService = llmTokenUsageService; } @JsonInclude(JsonInclude.Include.NON_EMPTY) @@ -69,15 +81,15 @@ private record RequestDTO(ExerciseBaseDTO exercise, SubmissionBaseDTO submission } @JsonInclude(JsonInclude.Include.NON_EMPTY) - private record ResponseDTOText(List data) { + private record ResponseDTOText(List data, ResponseMetaDTO meta) { } @JsonInclude(JsonInclude.Include.NON_EMPTY) - private record ResponseDTOProgramming(List data) { + private record ResponseDTOProgramming(List data, ResponseMetaDTO meta) { } @JsonInclude(JsonInclude.Include.NON_EMPTY) - private record ResponseDTOModeling(List data) { + private record ResponseDTOModeling(List data, ResponseMetaDTO meta) { } /** @@ -100,6 +112,7 @@ public List getTextFeedbackSuggestions(TextExercise exercise, T final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded); ResponseDTOText response = textAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0); log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data); + storeTokenUsage(exercise, submission, response.meta, !isGraded); return response.data.stream().toList(); } @@ -117,6 +130,7 @@ public List getProgrammingFeedbackSuggestions(Programmin final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded); ResponseDTOProgramming response = programmingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0); log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data); + storeTokenUsage(exercise, submission, response.meta, !isGraded); return response.data.stream().toList(); } @@ -139,6 +153,36 @@ public List getModelingFeedbackSuggestions(ModelingExercise final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded); ResponseDTOModeling response = modelingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0); log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data); + storeTokenUsage(exercise, submission, response.meta, !isGraded); return response.data; } + + /** + * Store the usage of LLM tokens for a given submission + * + * @param exercise the exercise the submission belongs to + * @param submission the submission for which the tokens were used + * @param meta the meta information of the response from Athena + * @param isPreliminaryFeedback whether the feedback is preliminary or not + */ + private void storeTokenUsage(Exercise exercise, Submission submission, ResponseMetaDTO meta, Boolean isPreliminaryFeedback) { + if (meta == null) { + return; + } + Long courseId = exercise.getCourseViaExerciseGroupOrCourseMember().getId(); + Long userId; + if (submission.getParticipation() instanceof StudentParticipation studentParticipation) { + userId = studentParticipation.getStudent().map(User::getId).orElse(null); + } + else { + userId = null; + } + List llmRequests = meta.llmRequests(); + if (llmRequests == null) { + return; + } + + llmTokenUsageService.saveLLMTokenUsage(llmRequests, LLMServiceType.ATHENA, + (llmTokenUsageBuilder -> llmTokenUsageBuilder.withCourse(courseId).withExercise(exercise.getId()).withUser(userId))); + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java new file mode 100644 index 000000000000..040b6ad88893 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java @@ -0,0 +1,14 @@ +package de.tum.cit.aet.artemis.core.domain; + +/** + * This record is used for the LLMTokenUsageService to provide relevant information about LLM Token usage + * + * @param model LLM model (e.g. gpt-4o) + * @param numInputTokens number of tokens of the LLM call + * @param costPerMillionInputToken cost in Euro per million input tokens + * @param numOutputTokens number of tokens of the LLM answer + * @param costPerMillionOutputToken cost in Euro per million output tokens + * @param pipelineId String with the pipeline name (e.g. IRIS_COURSE_CHAT_PIPELINE) + */ +public record LLMRequest(String model, int numInputTokens, float costPerMillionInputToken, int numOutputTokens, float costPerMillionOutputToken, String pipelineId) { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java new file mode 100644 index 000000000000..22465bc57b5f --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -0,0 +1,8 @@ +package de.tum.cit.aet.artemis.core.domain; + +/** + * Enum representing different types of LLM (Large Language Model) services used in the system. + */ +public enum LLMServiceType { + IRIS, ATHENA +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java new file mode 100644 index 000000000000..81d7ca8f21a8 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageRequest.java @@ -0,0 +1,104 @@ +package de.tum.cit.aet.artemis.core.domain; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.Table; + +import org.hibernate.annotations.Cache; +import org.hibernate.annotations.CacheConcurrencyStrategy; + +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * Represents the token usage details of a single LLM request, including model, service pipeline, token counts, and costs. + */ +@Entity +@Table(name = "llm_token_usage_request") +@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) +@JsonInclude(JsonInclude.Include.NON_EMPTY) +public class LLMTokenUsageRequest extends DomainObject { + + /** + * LLM model (e.g. gpt-4o) + */ + @Column(name = "model") + private String model; + + /** + * pipeline that was called (e.g. IRIS_COURSE_CHAT_PIPELINE) + */ + @Column(name = "service_pipeline_id") + private String servicePipelineId; + + @Column(name = "num_input_tokens") + private int numInputTokens; + + @Column(name = "cost_per_million_input_tokens") + private float costPerMillionInputTokens; + + @Column(name = "num_output_tokens") + private int numOutputTokens; + + @Column(name = "cost_per_million_output_tokens") + private float costPerMillionOutputTokens; + + @ManyToOne + private LLMTokenUsageTrace trace; + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getServicePipelineId() { + return servicePipelineId; + } + + public void setServicePipelineId(String servicePipelineId) { + this.servicePipelineId = servicePipelineId; + } + + public float getCostPerMillionInputTokens() { + return costPerMillionInputTokens; + } + + public void setCostPerMillionInputTokens(float costPerMillionInputToken) { + this.costPerMillionInputTokens = costPerMillionInputToken; + } + + public float getCostPerMillionOutputTokens() { + return costPerMillionOutputTokens; + } + + public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) { + this.costPerMillionOutputTokens = costPerMillionOutputToken; + } + + public int getNumInputTokens() { + return numInputTokens; + } + + public void setNumInputTokens(int numInputTokens) { + this.numInputTokens = numInputTokens; + } + + public int getNumOutputTokens() { + return numOutputTokens; + } + + public void setNumOutputTokens(int numOutputTokens) { + this.numOutputTokens = numOutputTokens; + } + + public LLMTokenUsageTrace getTrace() { + return trace; + } + + public void setTrace(LLMTokenUsageTrace trace) { + this.trace = trace; + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java new file mode 100644 index 000000000000..1773a0c507da --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsageTrace.java @@ -0,0 +1,111 @@ +package de.tum.cit.aet.artemis.core.domain; + +import java.time.ZonedDateTime; +import java.util.HashSet; +import java.util.Set; + +import jakarta.annotation.Nullable; +import jakarta.persistence.CascadeType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.EnumType; +import jakarta.persistence.Enumerated; +import jakarta.persistence.FetchType; +import jakarta.persistence.OneToMany; +import jakarta.persistence.Table; + +import org.hibernate.annotations.Cache; +import org.hibernate.annotations.CacheConcurrencyStrategy; + +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * This represents a trace that contains one or more requests of type {@link LLMTokenUsageRequest} + */ +@Entity +@Table(name = "llm_token_usage_trace") +@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) +@JsonInclude(JsonInclude.Include.NON_EMPTY) +public class LLMTokenUsageTrace extends DomainObject { + + @Column(name = "service") + @Enumerated(EnumType.STRING) + private LLMServiceType serviceType; + + @Nullable + @Column(name = "course_id") + private Long courseId; + + @Nullable + @Column(name = "exercise_id") + private Long exerciseId; + + @Column(name = "user_id") + private Long userId; + + @Column(name = "time") + private ZonedDateTime time = ZonedDateTime.now(); + + @Nullable + @Column(name = "iris_message_id") + private Long irisMessageId; + + @OneToMany(mappedBy = "trace", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true) + private Set llmRequests = new HashSet<>(); + + public LLMServiceType getServiceType() { + return serviceType; + } + + public void setServiceType(LLMServiceType serviceType) { + this.serviceType = serviceType; + } + + public Long getCourseId() { + return courseId; + } + + public void setCourseId(Long courseId) { + this.courseId = courseId; + } + + public Long getExerciseId() { + return exerciseId; + } + + public void setExerciseId(Long exerciseId) { + this.exerciseId = exerciseId; + } + + public Long getUserId() { + return userId; + } + + public void setUserId(Long userId) { + this.userId = userId; + } + + public ZonedDateTime getTime() { + return time; + } + + public void setTime(ZonedDateTime time) { + this.time = time; + } + + public Set getLLMRequests() { + return llmRequests; + } + + public void setLlmRequests(Set llmRequests) { + this.llmRequests = llmRequests; + } + + public Long getIrisMessageId() { + return irisMessageId; + } + + public void setIrisMessageId(Long messageId) { + this.irisMessageId = messageId; + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java new file mode 100644 index 000000000000..145383bf124a --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRequestRepository.java @@ -0,0 +1,14 @@ +package de.tum.cit.aet.artemis.core.repository; + +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; + +import org.springframework.context.annotation.Profile; +import org.springframework.stereotype.Repository; + +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; +import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; + +@Profile(PROFILE_CORE) +@Repository +public interface LLMTokenUsageRequestRepository extends ArtemisJpaRepository { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java new file mode 100644 index 000000000000..cc1b0e588c4e --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageTraceRepository.java @@ -0,0 +1,14 @@ +package de.tum.cit.aet.artemis.core.repository; + +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; + +import org.springframework.context.annotation.Profile; +import org.springframework.stereotype.Repository; + +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; +import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; + +@Profile(PROFILE_CORE) +@Repository +public interface LLMTokenUsageTraceRepository extends ArtemisJpaRepository { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java new file mode 100644 index 000000000000..c3dc2af1e519 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -0,0 +1,143 @@ +package de.tum.cit.aet.artemis.core.service; + +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE; + +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.springframework.context.annotation.Profile; +import org.springframework.stereotype.Service; + +import de.tum.cit.aet.artemis.core.domain.LLMRequest; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; + +/** + * Service for managing the LLMTokenUsage by all LLMs in Artemis + */ +@Profile(PROFILE_CORE) +@Service +public class LLMTokenUsageService { + + private final LLMTokenUsageTraceRepository llmTokenUsageTraceRepository; + + private final LLMTokenUsageRequestRepository llmTokenUsageRequestRepository; + + public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepository, LLMTokenUsageRequestRepository llmTokenUsageRequestRepository) { + this.llmTokenUsageTraceRepository = llmTokenUsageTraceRepository; + this.llmTokenUsageRequestRepository = llmTokenUsageRequestRepository; + } + + /** + * Saves the token usage to the database. + * This method records the usage of tokens by various LLM services in the system. + * + * @param llmRequests List of LLM requests containing details about the token usage. + * @param serviceType Type of the LLM service (e.g., IRIS, GPT-3). + * @param builderFunction A function that takes an LLMTokenUsageBuilder and returns a modified LLMTokenUsageBuilder. + * This function is used to set additional properties on the LLMTokenUsageTrace object, such as + * the course ID, user ID, exercise ID, and Iris message ID. + * Example usage: + * builder -> builder.withCourse(courseId).withUser(userId) + * @return The saved LLMTokenUsageTrace object, which includes the details of the token usage. + */ + // TODO: this should ideally be done Async + public LLMTokenUsageTrace saveLLMTokenUsage(List llmRequests, LLMServiceType serviceType, Function builderFunction) { + LLMTokenUsageTrace llmTokenUsageTrace = new LLMTokenUsageTrace(); + llmTokenUsageTrace.setServiceType(serviceType); + + LLMTokenUsageBuilder builder = builderFunction.apply(new LLMTokenUsageBuilder()); + builder.getIrisMessageID().ifPresent(llmTokenUsageTrace::setIrisMessageId); + builder.getCourseID().ifPresent(llmTokenUsageTrace::setCourseId); + builder.getExerciseID().ifPresent(llmTokenUsageTrace::setExerciseId); + builder.getUserID().ifPresent(llmTokenUsageTrace::setUserId); + + llmTokenUsageTrace.setLlmRequests(llmRequests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest) + .peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(llmTokenUsageTrace)).collect(Collectors.toSet())); + + return llmTokenUsageTraceRepository.save(llmTokenUsageTrace); + } + + private static LLMTokenUsageRequest convertLLMRequestToLLMTokenUsageRequest(LLMRequest llmRequest) { + LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest(); + llmTokenUsageRequest.setModel(llmRequest.model()); + llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens()); + llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens()); + llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken()); + llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken()); + llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId()); + return llmTokenUsageRequest; + } + + // TODO: this should ideally be done Async + public void appendRequestsToTrace(List requests, LLMTokenUsageTrace trace) { + var requestSet = requests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest).peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(trace)) + .collect(Collectors.toSet()); + llmTokenUsageRequestRepository.saveAll(requestSet); + } + + /** + * Finds an LLMTokenUsageTrace by its ID. + * + * @param id The ID of the LLMTokenUsageTrace to find. + * @return An Optional containing the LLMTokenUsageTrace if found, or an empty Optional otherwise. + */ + public Optional findLLMTokenUsageTraceById(Long id) { + return llmTokenUsageTraceRepository.findById(id); + } + + /** + * Class LLMTokenUsageBuilder to be used for saveLLMTokenUsage() + */ + public static class LLMTokenUsageBuilder { + + private Optional courseID = Optional.empty(); + + private Optional irisMessageID = Optional.empty(); + + private Optional exerciseID = Optional.empty(); + + private Optional userID = Optional.empty(); + + public LLMTokenUsageBuilder withCourse(Long courseID) { + this.courseID = Optional.ofNullable(courseID); + return this; + } + + public LLMTokenUsageBuilder withIrisMessageID(Long irisMessageID) { + this.irisMessageID = Optional.ofNullable(irisMessageID); + return this; + } + + public LLMTokenUsageBuilder withExercise(Long exerciseID) { + this.exerciseID = Optional.ofNullable(exerciseID); + return this; + } + + public LLMTokenUsageBuilder withUser(Long userID) { + this.userID = Optional.ofNullable(userID); + return this; + } + + public Optional getCourseID() { + return courseID; + } + + public Optional getIrisMessageID() { + return irisMessageID; + } + + public Optional getExerciseID() { + return exerciseID; + } + + public Optional getUserID() { + return userID; + } + } +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java index 75b56488e513..9057b8229fb5 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/dto/IrisChatWebsocketDTO.java @@ -7,6 +7,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; @@ -21,7 +22,7 @@ */ @JsonInclude(JsonInclude.Include.NON_EMPTY) public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, - List suggestions) { + List suggestions, List tokens) { /** * Creates a new IrisWebsocketDTO instance with the given parameters @@ -31,8 +32,9 @@ public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage me * @param rateLimitInfo the rate limit information * @param stages the stages of the Pyris pipeline */ - public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, List suggestions) { - this(determineType(message), message, rateLimitInfo, stages, suggestions); + public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List stages, List suggestions, + List tokens) { + this(determineType(message), message, rateLimitInfo, stages, suggestions, tokens); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 98182ae92b06..88906ff80628 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -7,7 +7,11 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy; import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.repository.CourseRepository; +import de.tum.cit.aet.artemis.core.repository.UserRepository; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.competency.PyrisCompetencyExtractionPipelineExecutionDTO; @@ -25,14 +29,24 @@ public class IrisCompetencyGenerationService { private final PyrisPipelineService pyrisPipelineService; + private final LLMTokenUsageService llmTokenUsageService; + + private final CourseRepository courseRepository; + private final IrisWebsocketService websocketService; private final PyrisJobService pyrisJobService; - public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { + private final UserRepository userRepository; + + public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository, + IrisWebsocketService websocketService, PyrisJobService pyrisJobService, UserRepository userRepository) { this.pyrisPipelineService = pyrisPipelineService; + this.llmTokenUsageService = llmTokenUsageService; + this.courseRepository = courseRepository; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; + this.userRepository = userRepository; } /** @@ -48,9 +62,9 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String pyrisPipelineService.executePipeline( "competency-extraction", "default", - pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getLogin())), + pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getId())), executionDto -> new PyrisCompetencyExtractionPipelineExecutionDTO(executionDto, courseDescription, currentCompetencies, CompetencyTaxonomy.values(), 5), - stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null)) + stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null, null)) ); // @formatter:on } @@ -58,12 +72,20 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String /** * Takes a status update from Pyris containing a new competency extraction result and sends it to the client via websocket * - * @param userLogin the login of the user - * @param courseId the id of the course + * @param job Job related to the status update * @param statusUpdate the status update containing the new competency recommendations + * @return the same job that was passed in */ - public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) { - websocketService.send(userLogin, websocketTopic(courseId), statusUpdate); + public CompetencyExtractionJob handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { + Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); + if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { + llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course.getId()).withUser(job.userId())); + } + + var user = userRepository.findById(job.userId()).orElseThrow(); + websocketService.send(user.getLogin(), websocketTopic(job.courseId()), statusUpdate); + + return job; } private static String websocketTopic(long courseId) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java index 7933e9e20920..16e8969bc463 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisJobService.java @@ -78,14 +78,14 @@ public String createTokenForJob(Function tokenToJobFunction) { public String addExerciseChatJob(Long courseId, Long exerciseId, Long sessionId) { var token = generateJobIdToken(); - var job = new ExerciseChatJob(token, courseId, exerciseId, sessionId); + var job = new ExerciseChatJob(token, courseId, exerciseId, sessionId, null); jobMap.put(token, job); return token; } public String addCourseChatJob(Long courseId, Long sessionId) { var token = generateJobIdToken(); - var job = new CourseChatJob(token, courseId, sessionId); + var job = new CourseChatJob(token, courseId, sessionId, null); jobMap.put(token, job); return token; } @@ -107,10 +107,19 @@ public String addIngestionWebhookJob() { /** * Remove a job from the job map. * - * @param token the token + * @param job the job to remove + */ + public void removeJob(PyrisJob job) { + jobMap.remove(job.jobId()); + } + + /** + * Store a job in the job map. + * + * @param job the job to store */ - public void removeJob(String token) { - jobMap.remove(token); + public void updateJob(PyrisJob job) { + jobMap.put(job.jobId(), job); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java index 9403da9beb56..cdd398e5c683 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java @@ -20,7 +20,9 @@ import de.tum.cit.aet.artemis.iris.service.pyris.job.CourseChatJob; import de.tum.cit.aet.artemis.iris.service.pyris.job.ExerciseChatJob; import de.tum.cit.aet.artemis.iris.service.pyris.job.IngestionWebhookJob; +import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; import de.tum.cit.aet.artemis.iris.service.pyris.job.TextExerciseChatJob; +import de.tum.cit.aet.artemis.iris.service.pyris.job.TrackedSessionBasedPyrisJob; import de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService; import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; import de.tum.cit.aet.artemis.iris.service.session.IrisTextExerciseChatSessionService; @@ -52,15 +54,16 @@ public PyrisStatusUpdateService(PyrisJobService pyrisJobService, IrisExerciseCha } /** - * Handles the status update of a exercise chat job and forwards it to {@link IrisExerciseChatSessionService#handleStatusUpdate(ExerciseChatJob, PyrisChatStatusUpdateDTO)} + * Handles the status update of a exercise chat job and forwards it to + * {@link IrisExerciseChatSessionService#handleStatusUpdate(TrackedSessionBasedPyrisJob, PyrisChatStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - irisExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); + var updatedJob = irisExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** @@ -71,52 +74,55 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta * @param statusUpdate the status update */ public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { - irisTextExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); + var updatedJob = irisTextExerciseChatSessionService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** * Handles the status update of a course chat job and forwards it to - * {@link de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService#handleStatusUpdate(CourseChatJob, PyrisChatStatusUpdateDTO)} + * {@link de.tum.cit.aet.artemis.iris.service.session.IrisCourseChatSessionService#handleStatusUpdate(TrackedSessionBasedPyrisJob, PyrisChatStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - courseChatSessionService.handleStatusUpdate(job, statusUpdate); + var updatedJob = courseChatSessionService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** * Handles the status update of a competency extraction job and forwards it to - * {@link IrisCompetencyGenerationService#handleStatusUpdate(String, long, PyrisCompetencyStatusUpdateDTO)} + * {@link IrisCompetencyGenerationService#handleStatusUpdate(CompetencyExtractionJob, PyrisCompetencyStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { - competencyGenerationService.handleStatusUpdate(job.userLogin(), job.courseId(), statusUpdate); + var updatedJob = competencyGenerationService.handleStatusUpdate(job, statusUpdate); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), updatedJob); } /** - * Removes the job from the job service if the status update indicates that the job is terminated. - * This is the case if all stages are in a terminal state. + * Removes the job from the job service if the status update indicates that the job is terminated; updates it to distribute changes otherwise. + * A job is terminated if all stages are in a terminal state. *

* * @see PyrisStageState#isTerminal() * * @param stages the stages of the status update - * @param job the job to remove + * @param job the job to remove or to update */ - private void removeJobIfTerminated(List stages, String job) { + private void removeJobIfTerminatedElseUpdate(List stages, PyrisJob job) { var isDone = stages.stream().map(PyrisStageDTO::state).allMatch(PyrisStageState::isTerminal); if (isDone) { pyrisJobService.removeJob(job); } + else { + pyrisJobService.updateJob(job); + } } /** @@ -128,6 +134,6 @@ private void removeJobIfTerminated(List stages, String job) { */ public void handleStatusUpdate(IngestionWebhookJob job, PyrisLectureIngestionStatusUpdateDTO statusUpdate) { statusUpdate.stages().forEach(stage -> log.info(stage.name() + ":" + stage.message())); - removeJobIfTerminated(statusUpdate.stages(), job.jobId()); + removeJobIfTerminatedElseUpdate(statusUpdate.stages(), job); } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java index cbfa0b2d98dd..5a1024c6315b 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/chat/PyrisChatStatusUpdateDTO.java @@ -4,8 +4,9 @@ import com.fasterxml.jackson.annotation.JsonInclude; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record PyrisChatStatusUpdateDTO(String result, List stages, List suggestions) { +public record PyrisChatStatusUpdateDTO(String result, List stages, List suggestions, List tokens) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java index 0956a52f26e8..465c8e5edb65 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/competency/PyrisCompetencyStatusUpdateDTO.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; /** @@ -13,7 +14,8 @@ * * @param stages List of stages of the generation process * @param result List of competencies recommendations that have been generated so far + * @param tokens List of token usages send by Pyris for tracking the token usage and cost */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record PyrisCompetencyStatusUpdateDTO(List stages, List result) { +public record PyrisCompetencyStatusUpdateDTO(List stages, List result, List tokens) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java new file mode 100644 index 000000000000..43c000a879ae --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -0,0 +1,4 @@ +package de.tum.cit.aet.artemis.iris.service.pyris.dto.data; + +public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, String pipeline) { +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java index 26ab6427a020..b50d8e70b8c9 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CompetencyExtractionJob.java @@ -7,12 +7,12 @@ /** * A pyris job that extracts competencies from a course description. * - * @param jobId the job id - * @param courseId the course in which the competencies are being extracted - * @param userLogin the user login of the user who started the job + * @param jobId the job id + * @param courseId the course in which the competencies are being extracted + * @param userId the user who started the job */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record CompetencyExtractionJob(String jobId, long courseId, String userLogin) implements PyrisJob { +public record CompetencyExtractionJob(String jobId, long courseId, long userId) implements PyrisJob { @Override public boolean canAccess(Course course) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java index fb4b93a28854..2f389e22ed96 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/CourseChatJob.java @@ -9,10 +9,15 @@ * This job is used to reference the details of a course chat session when Pyris sends a status update. */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record CourseChatJob(String jobId, long courseId, long sessionId) implements PyrisJob { +public record CourseChatJob(String jobId, long courseId, long sessionId, Long traceId) implements TrackedSessionBasedPyrisJob { @Override public boolean canAccess(Course course) { return courseId == course.getId(); } + + @Override + public TrackedSessionBasedPyrisJob withTraceId(long traceId) { + return new CourseChatJob(jobId, courseId, sessionId, traceId); + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java index 302ae274d8e2..f74e7360be82 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/ExerciseChatJob.java @@ -10,7 +10,7 @@ * This job is used to reference the details of a exercise chat session when Pyris sends a status update. */ @JsonInclude(JsonInclude.Include.NON_EMPTY) -public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId) implements PyrisJob { +public record ExerciseChatJob(String jobId, long courseId, long exerciseId, long sessionId, Long traceId) implements TrackedSessionBasedPyrisJob { @Override public boolean canAccess(Course course) { @@ -21,4 +21,9 @@ public boolean canAccess(Course course) { public boolean canAccess(Exercise exercise) { return exercise.getId().equals(exerciseId); } + + @Override + public TrackedSessionBasedPyrisJob withTraceId(long traceId) { + return new ExerciseChatJob(jobId, courseId, exerciseId, sessionId, traceId); + } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java new file mode 100644 index 000000000000..bdd180103840 --- /dev/null +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/job/TrackedSessionBasedPyrisJob.java @@ -0,0 +1,14 @@ +package de.tum.cit.aet.artemis.iris.service.pyris.job; + +/** + * A Pyris job that has a session id and stored its own LLM usage tracing ID. + * This is used for chat jobs where we need to reference the trace ID later after chat suggestions have been generated. + */ +public interface TrackedSessionBasedPyrisJob extends PyrisJob { + + long sessionId(); + + Long traceId(); + + TrackedSessionBasedPyrisJob withTraceId(long traceId); +} diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java index f732529aae72..6f0b5a9f411a 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/AbstractIrisChatSessionService.java @@ -1,22 +1,43 @@ package de.tum.cit.aet.artemis.iris.service.session; import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; +import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; import de.tum.cit.aet.artemis.iris.domain.session.IrisChatSession; import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; +import de.tum.cit.aet.artemis.iris.service.IrisMessageService; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.job.TrackedSessionBasedPyrisJob; +import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; public abstract class AbstractIrisChatSessionService implements IrisChatBasedFeatureInterface, IrisRateLimitedFeatureInterface { private final IrisSessionRepository irisSessionRepository; + private final IrisMessageService irisMessageService; + + private final IrisChatWebsocketService irisChatWebsocketService; + + private final LLMTokenUsageService llmTokenUsageService; + private final ObjectMapper objectMapper; - public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper) { + public AbstractIrisChatSessionService(IrisSessionRepository irisSessionRepository, ObjectMapper objectMapper, IrisMessageService irisMessageService, + IrisChatWebsocketService irisChatWebsocketService, LLMTokenUsageService llmTokenUsageService) { this.irisSessionRepository = irisSessionRepository; this.objectMapper = objectMapper; + this.irisMessageService = irisMessageService; + this.irisChatWebsocketService = irisChatWebsocketService; + this.llmTokenUsageService = llmTokenUsageService; } /** @@ -40,4 +61,59 @@ protected void updateLatestSuggestions(S session, List latestSuggestions throw new RuntimeException("Could not update latest suggestions for session " + session.getId(), e); } } + + /** + * Handles the status update of a ExerciseChatJob by sending the result to the student via the Websocket. + * + * @param job The job that was executed + * @param statusUpdate The status update of the job + * @return the same job record or a new job record with the same job id if changes were made + */ + public TrackedSessionBasedPyrisJob handleStatusUpdate(TrackedSessionBasedPyrisJob job, PyrisChatStatusUpdateDTO statusUpdate) { + var session = (S) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); + IrisMessage savedMessage; + if (statusUpdate.result() != null) { + var message = new IrisMessage(); + message.addContent(new IrisTextMessageContent(statusUpdate.result())); + savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); + irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); + } + else { + savedMessage = null; + irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); + } + + AtomicReference updatedJob = new AtomicReference<>(job); + if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) { + if (savedMessage != null) { + // generated message is first sent and generated trace is saved + var llmTokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> { + builder.withIrisMessageID(savedMessage.getId()).withUser(session.getUser().getId()); + this.setLLMTokenUsageParameters(builder, session); + return builder; + }); + + updatedJob.set(job.withTraceId(llmTokenUsageTrace.getId())); + } + else { + // interaction suggestion is sent and appended to the generated trace if it exists + Optional.ofNullable(job.traceId()).flatMap(llmTokenUsageService::findLLMTokenUsageTraceById) + .ifPresentOrElse(trace -> llmTokenUsageService.appendRequestsToTrace(statusUpdate.tokens(), trace), () -> { + var llmTokenUsage = llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> { + builder.withUser(session.getUser().getId()); + this.setLLMTokenUsageParameters(builder, session); + return builder; + }); + + updatedJob.set(job.withTraceId(llmTokenUsage.getId())); + }); + } + } + + updateLatestSuggestions(session, statusUpdate.suggestions()); + + return updatedJob.get(); + } + + protected abstract void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, S session); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index 6dea7a728ca6..d2743c2e71a5 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -19,9 +19,8 @@ import de.tum.cit.aet.artemis.core.exception.AccessForbiddenException; import de.tum.cit.aet.artemis.core.security.Role; import de.tum.cit.aet.artemis.core.service.AuthorizationCheckService; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; -import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; import de.tum.cit.aet.artemis.iris.domain.session.IrisCourseChatSession; import de.tum.cit.aet.artemis.iris.domain.settings.IrisSubSettingsType; import de.tum.cit.aet.artemis.iris.repository.IrisCourseChatSessionRepository; @@ -29,8 +28,6 @@ import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.job.CourseChatJob; import de.tum.cit.aet.artemis.iris.service.settings.IrisSettingsService; import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; @@ -41,8 +38,6 @@ @Profile(PROFILE_IRIS) public class IrisCourseChatSessionService extends AbstractIrisChatSessionService { - private final IrisMessageService irisMessageService; - private final IrisSettingsService irisSettingsService; private final IrisChatWebsocketService irisChatWebsocketService; @@ -57,11 +52,11 @@ public class IrisCourseChatSessionService extends AbstractIrisChatSessionService private final PyrisPipelineService pyrisPipelineService; - public IrisCourseChatSessionService(IrisMessageService irisMessageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, - AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, IrisRateLimitService rateLimitService, - IrisCourseChatSessionRepository irisCourseChatSessionRepository, PyrisPipelineService pyrisPipelineService, ObjectMapper objectMapper) { - super(irisSessionRepository, objectMapper); - this.irisMessageService = irisMessageService; + public IrisCourseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService, + IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, + IrisRateLimitService rateLimitService, IrisCourseChatSessionRepository irisCourseChatSessionRepository, PyrisPipelineService pyrisPipelineService, + ObjectMapper objectMapper) { + super(irisSessionRepository, objectMapper, irisMessageService, irisChatWebsocketService, llmTokenUsageService); this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -126,24 +121,9 @@ private void requestAndHandleResponse(IrisCourseChatSession session, String vari pyrisPipelineService.executeCourseChatPipeline(variant, chatSession, competencyJol); } - /** - * Handles the status update of a CourseChatJob by sending the result to the student via the Websocket. - * - * @param job The job that was executed - * @param statusUpdate The status update of the job - */ - public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - var session = (IrisCourseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); - if (statusUpdate.result() != null) { - var message = new IrisMessage(); - message.addContent(new IrisTextMessageContent(statusUpdate.result())); - var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); - } - else { - irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions()); - } - updateLatestSuggestions(session, statusUpdate.suggestions()); + @Override + protected void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, IrisCourseChatSession session) { + builder.withCourse(session.getCourse().getId()); } /** diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index d520540a2db4..a51f1730e98c 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -15,18 +15,15 @@ import de.tum.cit.aet.artemis.core.exception.ConflictException; import de.tum.cit.aet.artemis.core.security.Role; import de.tum.cit.aet.artemis.core.service.AuthorizationCheckService; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.exercise.domain.Submission; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; -import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageSender; -import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; import de.tum.cit.aet.artemis.iris.domain.session.IrisExerciseChatSession; import de.tum.cit.aet.artemis.iris.domain.settings.IrisSubSettingsType; import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.job.ExerciseChatJob; import de.tum.cit.aet.artemis.iris.service.settings.IrisSettingsService; import de.tum.cit.aet.artemis.iris.service.websocket.IrisChatWebsocketService; import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; @@ -42,8 +39,6 @@ @Profile(PROFILE_IRIS) public class IrisExerciseChatSessionService extends AbstractIrisChatSessionService implements IrisRateLimitedFeatureInterface { - private final IrisMessageService irisMessageService; - private final IrisSettingsService irisSettingsService; private final IrisChatWebsocketService irisChatWebsocketService; @@ -62,13 +57,12 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi private final ProgrammingExerciseRepository programmingExerciseRepository; - public IrisExerciseChatSessionService(IrisMessageService irisMessageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, - AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, + public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService, + IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, ProgrammingExerciseStudentParticipationRepository programmingExerciseStudentParticipationRepository, ProgrammingSubmissionRepository programmingSubmissionRepository, IrisRateLimitService rateLimitService, PyrisPipelineService pyrisPipelineService, ProgrammingExerciseRepository programmingExerciseRepository, ObjectMapper objectMapper) { - super(irisSessionRepository, objectMapper); - this.irisMessageService = irisMessageService; + super(irisSessionRepository, objectMapper, irisMessageService, irisChatWebsocketService, llmTokenUsageService); this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -158,24 +152,9 @@ private Optional getLatestSubmissionIfExists(ProgrammingE .flatMap(sub -> programmingSubmissionRepository.findWithEagerResultsAndFeedbacksAndBuildLogsById(sub.getId())); } - /** - * Handles the status update of a ExerciseChatJob by sending the result to the student via the Websocket. - * - * @param job The job that was executed - * @param statusUpdate The status update of the job - */ - public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO statusUpdate) { - var session = (IrisExerciseChatSession) irisSessionRepository.findByIdWithMessagesAndContents(job.sessionId()); - if (statusUpdate.result() != null) { - var message = new IrisMessage(); - message.addContent(new IrisTextMessageContent(statusUpdate.result())); - var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); - } - else { - irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions()); - } - - updateLatestSuggestions(session, statusUpdate.suggestions()); + @Override + protected void setLLMTokenUsageParameters(LLMTokenUsageService.LLMTokenUsageBuilder builder, IrisExerciseChatSession session) { + var exercise = session.getExercise(); + builder.withCourse(exercise.getCourseViaExerciseGroupOrCourseMember().getId()).withExercise(exercise.getId()); } } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java index 4520417aad48..8702db7bdf54 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisTextExerciseChatSessionService.java @@ -115,8 +115,10 @@ public void requestAndHandleResponse(IrisTextExerciseChatSession irisSession) { * * @param job The job that is updated * @param statusUpdate The status update + * @return The same job that was passed in */ - public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { + public TextExerciseChatJob handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatStatusUpdateDTO statusUpdate) { + // TODO: LLM Token Tracking - or better, make this class a subclass of AbstractIrisChatSessionService var session = (IrisTextExerciseChatSession) irisSessionRepository.findByIdElseThrow(job.sessionId()); if (statusUpdate.result() != null) { var message = session.newMessage(); @@ -127,6 +129,8 @@ public void handleStatusUpdate(TextExerciseChatJob job, PyrisTextExerciseChatSta else { irisChatWebsocketService.sendMessage(session, null, statusUpdate.stages()); } + + return job; } @Override diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java index 320a3103fe99..d6625dcc6f40 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/websocket/IrisChatWebsocketService.java @@ -7,6 +7,7 @@ import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.domain.session.IrisChatSession; import de.tum.cit.aet.artemis.iris.dto.IrisChatWebsocketDTO; @@ -41,7 +42,7 @@ public void sendMessage(IrisChatSession session, IrisMessage irisMessage, List

stages) { - this.sendStatusUpdate(session, stages, null); + this.sendStatusUpdate(session, stages, null, null); } /** @@ -61,12 +62,13 @@ public void sendStatusUpdate(IrisChatSession session, List stages * @param session the session to send the status update to * @param stages the stages to send * @param suggestions the suggestions to send + * @param tokens token usage and cost send by Pyris */ - public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions) { + public void sendStatusUpdate(IrisChatSession session, List stages, List suggestions, List tokens) { var user = session.getUser(); var rateLimitInfo = rateLimitService.getRateLimitInformation(user); var topic = "" + session.getId(); // Todo: add more specific topic - var payload = new IrisChatWebsocketDTO(null, rateLimitInfo, stages, suggestions); + var payload = new IrisChatWebsocketDTO(null, rateLimitInfo, stages, suggestions, tokens); websocketService.send(user.getLogin(), topic, payload); } } diff --git a/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml new file mode 100644 index 000000000000..e514ec8e5f58 --- /dev/null +++ b/src/main/resources/config/liquibase/changelog/20241018053210_changelog.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/resources/config/liquibase/master.xml b/src/main/resources/config/liquibase/master.xml index d496528a13ec..109eefaa1bbf 100644 --- a/src/main/resources/config/liquibase/master.xml +++ b/src/main/resources/config/liquibase/master.xml @@ -28,6 +28,7 @@ + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java index 8cda014838a1..96c047ad7345 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java @@ -446,7 +446,7 @@ public String toString() { private void sendStatus(String jobId, String result, List stages, List suggestions) throws Exception { var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions), + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, null), HttpStatus.OK, headers); } } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java new file mode 100644 index 000000000000..adb5b009809f --- /dev/null +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatTokenTrackingIntegrationTest.java @@ -0,0 +1,230 @@ +package de.tum.cit.aet.artemis.iris; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.awaitility.Awaitility.await; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.eclipse.jgit.api.errors.GitAPIException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.util.LinkedMultiValueMap; + +import de.tum.cit.aet.artemis.core.connector.IrisRequestMockProvider; +import de.tum.cit.aet.artemis.core.domain.Course; +import de.tum.cit.aet.artemis.core.domain.LLMRequest; +import de.tum.cit.aet.artemis.core.domain.LLMServiceType; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest; +import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository; +import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository; +import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; +import de.tum.cit.aet.artemis.exercise.participation.util.ParticipationUtilService; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; +import de.tum.cit.aet.artemis.iris.domain.message.IrisMessageContent; +import de.tum.cit.aet.artemis.iris.domain.message.IrisTextMessageContent; +import de.tum.cit.aet.artemis.iris.domain.session.IrisSession; +import de.tum.cit.aet.artemis.iris.repository.IrisMessageRepository; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; +import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; +import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise; +import de.tum.cit.aet.artemis.programming.domain.ProgrammingExerciseStudentParticipation; +import de.tum.cit.aet.artemis.programming.domain.ProjectType; +import de.tum.cit.aet.artemis.programming.domain.SolutionProgrammingExerciseParticipation; +import de.tum.cit.aet.artemis.programming.domain.TemplateProgrammingExerciseParticipation; + +class IrisChatTokenTrackingIntegrationTest extends AbstractIrisIntegrationTest { + + private static final String TEST_PREFIX = "irischattokentrackingintegration"; + + @Autowired + private IrisExerciseChatSessionService irisExerciseChatSessionService; + + @Autowired + private IrisMessageRepository irisMessageRepository; + + @Autowired + private LLMTokenUsageService llmTokenUsageService; + + @Autowired + private LLMTokenUsageTraceRepository irisLLMTokenUsageTraceRepository; + + @Autowired + private LLMTokenUsageRequestRepository irisLLMTokenUsageRequestRepository; + + @Autowired + private IrisRequestMockProvider irisRequestMockProvider; + + @Autowired + private ParticipationUtilService participationUtilService; + + private ProgrammingExercise exercise; + + private Course course; + + private AtomicBoolean pipelineDone; + + @BeforeEach + void initTestCase() throws GitAPIException, IOException, URISyntaxException { + userUtilService.addUsers(TEST_PREFIX, 2, 0, 0, 0); + course = programmingExerciseUtilService.addCourseWithOneProgrammingExercise(); + exercise = exerciseUtilService.getFirstExerciseWithType(course, ProgrammingExercise.class); + String projectKey = exercise.getProjectKey(); + exercise.setProjectType(ProjectType.PLAIN_GRADLE); + exercise.setTestRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + projectKey.toLowerCase() + "-tests.git"); + programmingExerciseBuildConfigRepository.save(exercise.getBuildConfig()); + programmingExerciseRepository.save(exercise); + exercise = programmingExerciseRepository.findWithAllParticipationsAndBuildConfigById(exercise.getId()).orElseThrow(); + // Set the correct repository URIs for the template and the solution participation. + String templateRepositorySlug = projectKey.toLowerCase() + "-exercise"; + TemplateProgrammingExerciseParticipation templateParticipation = exercise.getTemplateParticipation(); + templateParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + templateRepositorySlug + ".git"); + templateProgrammingExerciseParticipationRepository.save(templateParticipation); + String solutionRepositorySlug = projectKey.toLowerCase() + "-solution"; + SolutionProgrammingExerciseParticipation solutionParticipation = exercise.getSolutionParticipation(); + solutionParticipation.setRepositoryUri(localVCBaseUrl + "/git/" + projectKey + "/" + solutionRepositorySlug + ".git"); + solutionProgrammingExerciseParticipationRepository.save(solutionParticipation); + String assignmentRepositorySlug = projectKey.toLowerCase() + "-" + TEST_PREFIX + "student1"; + // Add a participation for student1. + ProgrammingExerciseStudentParticipation studentParticipation = participationUtilService.addStudentParticipationForProgrammingExercise(exercise, TEST_PREFIX + "student1"); + studentParticipation.setRepositoryUri(String.format(localVCBaseUrl + "/git/%s/%s.git", projectKey, assignmentRepositorySlug)); + studentParticipation.setBranch(defaultBranch); + programmingExerciseStudentParticipationRepository.save(studentParticipation); + // Prepare the repositories. + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, templateRepositorySlug); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, projectKey.toLowerCase() + "-tests"); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, solutionRepositorySlug); + localVCLocalCITestService.createAndConfigureLocalRepository(projectKey, assignmentRepositorySlug); + // Check that the repository folders were created in the file system for all base repositories. + localVCLocalCITestService.verifyRepositoryFoldersExist(exercise, localVCBasePath); + activateIrisGlobally(); + activateIrisFor(course); + activateIrisFor(exercise); + // Clean up the database + irisLLMTokenUsageRequestRepository.deleteAll(); + irisLLMTokenUsageTraceRepository.deleteAll(); + pipelineDone = new AtomicBoolean(false); + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingHandledExerciseChat() throws Exception { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var messageToSend = createDefaultMockMessage(irisSession); + var tokens = getMockLLMCosts(); + List doneStage = new ArrayList<>(); + doneStage.add(new PyrisStageDTO("DoneTest", 10, PyrisStageState.DONE, "Done")); + irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { + assertThat(dto.settings().authenticationToken()).isNotNull(); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", doneStage, tokens)); + pipelineDone.set(true); + }); + request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + await().until(pipelineDone::get); + List savedTokenUsageTraces = irisLLMTokenUsageTraceRepository.findAll(); + List savedTokenUsageRequests = irisLLMTokenUsageRequestRepository.findAll(); + assertThat(savedTokenUsageTraces).hasSize(1); + assertThat(savedTokenUsageTraces.getFirst().getServiceType()).isEqualTo(LLMServiceType.IRIS); + assertThat(savedTokenUsageTraces.getFirst().getExerciseId()).isEqualTo(exercise.getId()); + assertThat(savedTokenUsageTraces.getFirst().getCourseId()).isEqualTo(course.getId()); + assertThat(savedTokenUsageRequests).hasSize(5); + for (int i = 0; i < savedTokenUsageRequests.size(); i++) { + LLMTokenUsageRequest usage = savedTokenUsageRequests.get(i); + LLMRequest expectedCost = tokens.get(i); + assertThat(usage.getModel()).isEqualTo(expectedCost.model()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerMillionInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerMillionOutputToken()); + assertThat(usage.getServicePipelineId()).isEqualTo(expectedCost.pipelineId()); + } + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingSavedExerciseChat() { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var irisMessage = createDefaultMockMessage(irisSession); + irisMessageRepository.save(irisMessage); + var tokens = getMockLLMCosts(); + LLMTokenUsageTrace tokenUsageTrace = llmTokenUsageService.saveLLMTokenUsage(tokens, LLMServiceType.IRIS, + builder -> builder.withIrisMessageID(irisMessage.getId()).withExercise(exercise.getId()).withUser(irisSession.getUser().getId()).withCourse(course.getId())); + assertThat(tokenUsageTrace.getServiceType()).isEqualTo(LLMServiceType.IRIS); + assertThat(tokenUsageTrace.getIrisMessageId()).isEqualTo(irisMessage.getId()); + assertThat(tokenUsageTrace.getExerciseId()).isEqualTo(exercise.getId()); + assertThat(tokenUsageTrace.getUserId()).isEqualTo(irisSession.getUser().getId()); + assertThat(tokenUsageTrace.getCourseId()).isEqualTo(course.getId()); + } + + @Test + @WithMockUser(username = TEST_PREFIX + "student1", roles = "USER") + void testTokenTrackingExerciseChatWithPipelineFail() throws Exception { + var irisSession = irisExerciseChatSessionService.createChatSessionForProgrammingExercise(exercise, userUtilService.getUserByLogin(TEST_PREFIX + "student1")); + var messageToSend = createDefaultMockMessage(irisSession); + var tokens = getMockLLMCosts(); + List failedStages = new ArrayList<>(); + failedStages.add(new PyrisStageDTO("TestTokenFail", 10, PyrisStageState.ERROR, "Failed running pipeline")); + irisRequestMockProvider.mockProgrammingExerciseChatResponse(dto -> { + assertThat(dto.settings().authenticationToken()).isNotNull(); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, failedStages, tokens)); + pipelineDone.set(true); + }); + request.postWithoutResponseBody("/api/iris/sessions/" + irisSession.getId() + "/messages", messageToSend, HttpStatus.CREATED); + await().until(pipelineDone::get); + List savedTokenUsageTraces = irisLLMTokenUsageTraceRepository.findAll(); + List savedTokenUsageRequests = irisLLMTokenUsageRequestRepository.findAll(); + assertThat(savedTokenUsageTraces).hasSize(1); + assertThat(savedTokenUsageTraces.getFirst().getServiceType()).isEqualTo(LLMServiceType.IRIS); + assertThat(savedTokenUsageTraces.getFirst().getExerciseId()).isEqualTo(exercise.getId()); + assertThat(savedTokenUsageTraces.getFirst().getIrisMessageId()).isEqualTo(messageToSend.getId()); + assertThat(savedTokenUsageTraces.getFirst().getCourseId()).isEqualTo(course.getId()); + assertThat(savedTokenUsageRequests).hasSize(5); + for (int i = 0; i < savedTokenUsageRequests.size(); i++) { + LLMTokenUsageRequest usage = savedTokenUsageRequests.get(i); + LLMRequest expectedCost = tokens.get(i); + assertThat(usage.getModel()).isEqualTo(expectedCost.model()); + assertThat(usage.getNumInputTokens()).isEqualTo(expectedCost.numInputTokens()); + assertThat(usage.getNumOutputTokens()).isEqualTo(expectedCost.numOutputTokens()); + assertThat(usage.getCostPerMillionInputTokens()).isEqualTo(expectedCost.costPerMillionInputToken()); + assertThat(usage.getCostPerMillionOutputTokens()).isEqualTo(expectedCost.costPerMillionOutputToken()); + assertThat(usage.getServicePipelineId()).isEqualTo(expectedCost.pipelineId()); + } + } + + private List getMockLLMCosts() { + List costs = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + costs.add(new LLMRequest("test-llm", i * 10 + 5, i * 0.5f, i * 3 + 5, i * 0.12f, "IRIS_CHAT_EXERCISE_MESSAGE")); + } + return costs; + } + + private IrisMessage createDefaultMockMessage(IrisSession irisSession) { + var messageToSend = irisSession.newMessage(); + messageToSend.addContent(createMockTextContent(), createMockTextContent(), createMockTextContent()); + return messageToSend; + } + + private IrisMessageContent createMockTextContent() { + var text = "The happy dog jumped over the lazy dog."; + return new IrisTextMessageContent(text); + } + + private void sendStatus(String jobId, String result, List stages, List tokens) throws Exception { + var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, null, tokens), + HttpStatus.OK, headers); + } +} diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java index 03845b59efb7..03afd1453235 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatWebsocketTest.java @@ -53,7 +53,7 @@ void sendMessage() { message.setMessageDifferentiator(101010); irisChatWebsocketService.sendMessage(irisSession, message, List.of()); verify(websocketMessagingService, times(1)).sendMessageToUser(eq(TEST_PREFIX + "student1"), eq("/topic/iris/" + irisSession.getId()), - eq(new IrisChatWebsocketDTO(message, new IrisRateLimitService.IrisRateLimitInformation(0, -1, 0), List.of(), List.of()))); + eq(new IrisChatWebsocketDTO(message, new IrisRateLimitService.IrisRateLimitInformation(0, -1, 0), List.of(), List.of(), List.of()))); } private IrisTextMessageContent createMockContent() { diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index b4fef850f439..7b7279a25053 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -22,6 +22,7 @@ import de.tum.cit.aet.artemis.iris.service.pyris.dto.competency.PyrisCompetencyStatusUpdateDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; +import de.tum.cit.aet.artemis.iris.service.pyris.job.CompetencyExtractionJob; class IrisCompetencyGenerationIntegrationTest extends AbstractIrisIntegrationTest { @@ -66,7 +67,10 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { List stages = List.of(new PyrisStageDTO("Generating Competencies", 10, PyrisStageState.DONE, null)); // In the real system, this would be triggered by Pyris via a REST call to the Artemis server - irisCompetencyGenerationService.handleStatusUpdate(TEST_PREFIX + "editor1", course.getId(), new PyrisCompetencyStatusUpdateDTO(stages, recommendations)); + String jobId = "testJobId"; + String userLogin = TEST_PREFIX + "editor1"; + CompetencyExtractionJob job = new CompetencyExtractionJob(jobId, course.getId(), userUtilService.getUserByLogin(userLogin).getId()); + irisCompetencyGenerationService.handleStatusUpdate(job, new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); verify(websocketMessagingService, timeout(200).times(3)).sendMessageToUser(eq(TEST_PREFIX + "editor1"), eq("/topic/iris/competencies/" + course.getId()), diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java index 7be2d0e8abc9..0366317fd557 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisTextExerciseChatMessageIntegrationTest.java @@ -398,7 +398,7 @@ public String toString() { private void sendStatus(String jobId, String result, List stages, List suggestions) throws Exception { var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/text-exercise-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions), + request.postWithoutResponseBody("/api/public/pyris/pipelines/text-exercise-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, null), HttpStatus.OK, headers); } }