Skip to content

Commit

Permalink
General: Track token usage of LLM service requests (#9455)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjoham authored Oct 23, 2024
1 parent c17b2c4 commit dd96df5
Show file tree
Hide file tree
Showing 32 changed files with 976 additions and 112 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.tum.cit.aet.artemis.athena.dto;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.core.domain.LLMRequest;

/**
* DTO representing the meta information in the Athena response.
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record ResponseMetaDTO(TotalUsage totalUsage, List<LLMRequest> llmRequests) {

public record TotalUsage(Integer numInputTokens, Integer numOutputTokens, Integer numTotalTokens, Float cost) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@
import de.tum.cit.aet.artemis.athena.dto.ExerciseBaseDTO;
import de.tum.cit.aet.artemis.athena.dto.ModelingFeedbackDTO;
import de.tum.cit.aet.artemis.athena.dto.ProgrammingFeedbackDTO;
import de.tum.cit.aet.artemis.athena.dto.ResponseMetaDTO;
import de.tum.cit.aet.artemis.athena.dto.SubmissionBaseDTO;
import de.tum.cit.aet.artemis.athena.dto.TextFeedbackDTO;
import de.tum.cit.aet.artemis.core.domain.LLMRequest;
import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.exception.ConflictException;
import de.tum.cit.aet.artemis.core.exception.NetworkingException;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.exercise.domain.Submission;
import de.tum.cit.aet.artemis.exercise.domain.participation.StudentParticipation;
import de.tum.cit.aet.artemis.modeling.domain.ModelingExercise;
import de.tum.cit.aet.artemis.modeling.domain.ModelingSubmission;
import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise;
Expand Down Expand Up @@ -48,36 +56,40 @@ public class AthenaFeedbackSuggestionsService {

private final AthenaDTOConverterService athenaDTOConverterService;

private final LLMTokenUsageService llmTokenUsageService;

/**
* Create a new AthenaFeedbackSuggestionsService to receive feedback suggestions from the Athena service.
*
* @param athenaRestTemplate REST template used for the communication with Athena
* @param athenaModuleService Athena module serviced used to determine the urls for different modules
* @param athenaDTOConverterService Service to convert exr
* @param athenaDTOConverterService Service to convert exrcises and submissions to DTOs
* @param llmTokenUsageService Service to store the usage of LLM tokens
*/
public AthenaFeedbackSuggestionsService(@Qualifier("athenaRestTemplate") RestTemplate athenaRestTemplate, AthenaModuleService athenaModuleService,
AthenaDTOConverterService athenaDTOConverterService) {
AthenaDTOConverterService athenaDTOConverterService, LLMTokenUsageService llmTokenUsageService) {
textAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOText.class);
programmingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOProgramming.class);
modelingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOModeling.class);
this.athenaDTOConverterService = athenaDTOConverterService;
this.athenaModuleService = athenaModuleService;
this.llmTokenUsageService = llmTokenUsageService;
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record RequestDTO(ExerciseBaseDTO exercise, SubmissionBaseDTO submission, boolean isGraded) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOText(List<TextFeedbackDTO> data) {
private record ResponseDTOText(List<TextFeedbackDTO> data, ResponseMetaDTO meta) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOProgramming(List<ProgrammingFeedbackDTO> data) {
private record ResponseDTOProgramming(List<ProgrammingFeedbackDTO> data, ResponseMetaDTO meta) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOModeling(List<ModelingFeedbackDTO> data) {
private record ResponseDTOModeling(List<ModelingFeedbackDTO> data, ResponseMetaDTO meta) {
}

/**
Expand All @@ -100,6 +112,7 @@ public List<TextFeedbackDTO> getTextFeedbackSuggestions(TextExercise exercise, T
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOText response = textAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data.stream().toList();
}

Expand All @@ -117,6 +130,7 @@ public List<ProgrammingFeedbackDTO> getProgrammingFeedbackSuggestions(Programmin
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOProgramming response = programmingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data.stream().toList();
}

Expand All @@ -139,6 +153,36 @@ public List<ModelingFeedbackDTO> getModelingFeedbackSuggestions(ModelingExercise
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOModeling response = modelingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data;
}

/**
* Store the usage of LLM tokens for a given submission
*
* @param exercise the exercise the submission belongs to
* @param submission the submission for which the tokens were used
* @param meta the meta information of the response from Athena
* @param isPreliminaryFeedback whether the feedback is preliminary or not
*/
private void storeTokenUsage(Exercise exercise, Submission submission, ResponseMetaDTO meta, Boolean isPreliminaryFeedback) {
if (meta == null) {
return;
}
Long courseId = exercise.getCourseViaExerciseGroupOrCourseMember().getId();
Long userId;
if (submission.getParticipation() instanceof StudentParticipation studentParticipation) {
userId = studentParticipation.getStudent().map(User::getId).orElse(null);
}
else {
userId = null;
}
List<LLMRequest> llmRequests = meta.llmRequests();
if (llmRequests == null) {
return;
}

llmTokenUsageService.saveLLMTokenUsage(llmRequests, LLMServiceType.ATHENA,
(llmTokenUsageBuilder -> llmTokenUsageBuilder.withCourse(courseId).withExercise(exercise.getId()).withUser(userId)));
}
}
14 changes: 14 additions & 0 deletions src/main/java/de/tum/cit/aet/artemis/core/domain/LLMRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.domain;

/**
* This record is used for the LLMTokenUsageService to provide relevant information about LLM Token usage
*
* @param model LLM model (e.g. gpt-4o)
* @param numInputTokens number of tokens of the LLM call
* @param costPerMillionInputToken cost in Euro per million input tokens
* @param numOutputTokens number of tokens of the LLM answer
* @param costPerMillionOutputToken cost in Euro per million output tokens
* @param pipelineId String with the pipeline name (e.g. IRIS_COURSE_CHAT_PIPELINE)
*/
public record LLMRequest(String model, int numInputTokens, float costPerMillionInputToken, int numOutputTokens, float costPerMillionOutputToken, String pipelineId) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package de.tum.cit.aet.artemis.core.domain;

/**
* Enum representing different types of LLM (Large Language Model) services used in the system.
*/
public enum LLMServiceType {
IRIS, ATHENA
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package de.tum.cit.aet.artemis.core.domain;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonInclude;

/**
* Represents the token usage details of a single LLM request, including model, service pipeline, token counts, and costs.
*/
@Entity
@Table(name = "llm_token_usage_request")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsageRequest extends DomainObject {

/**
* LLM model (e.g. gpt-4o)
*/
@Column(name = "model")
private String model;

/**
* pipeline that was called (e.g. IRIS_COURSE_CHAT_PIPELINE)
*/
@Column(name = "service_pipeline_id")
private String servicePipelineId;

@Column(name = "num_input_tokens")
private int numInputTokens;

@Column(name = "cost_per_million_input_tokens")
private float costPerMillionInputTokens;

@Column(name = "num_output_tokens")
private int numOutputTokens;

@Column(name = "cost_per_million_output_tokens")
private float costPerMillionOutputTokens;

@ManyToOne
private LLMTokenUsageTrace trace;

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public String getServicePipelineId() {
return servicePipelineId;
}

public void setServicePipelineId(String servicePipelineId) {
this.servicePipelineId = servicePipelineId;
}

public float getCostPerMillionInputTokens() {
return costPerMillionInputTokens;
}

public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
this.costPerMillionInputTokens = costPerMillionInputToken;
}

public float getCostPerMillionOutputTokens() {
return costPerMillionOutputTokens;
}

public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) {
this.costPerMillionOutputTokens = costPerMillionOutputToken;
}

public int getNumInputTokens() {
return numInputTokens;
}

public void setNumInputTokens(int numInputTokens) {
this.numInputTokens = numInputTokens;
}

public int getNumOutputTokens() {
return numOutputTokens;
}

public void setNumOutputTokens(int numOutputTokens) {
this.numOutputTokens = numOutputTokens;
}

public LLMTokenUsageTrace getTrace() {
return trace;
}

public void setTrace(LLMTokenUsageTrace trace) {
this.trace = trace;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package de.tum.cit.aet.artemis.core.domain;

import java.time.ZonedDateTime;
import java.util.HashSet;
import java.util.Set;

import jakarta.annotation.Nullable;
import jakarta.persistence.CascadeType;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.FetchType;
import jakarta.persistence.OneToMany;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonInclude;

/**
* This represents a trace that contains one or more requests of type {@link LLMTokenUsageRequest}
*/
@Entity
@Table(name = "llm_token_usage_trace")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsageTrace extends DomainObject {

@Column(name = "service")
@Enumerated(EnumType.STRING)
private LLMServiceType serviceType;

@Nullable
@Column(name = "course_id")
private Long courseId;

@Nullable
@Column(name = "exercise_id")
private Long exerciseId;

@Column(name = "user_id")
private Long userId;

@Column(name = "time")
private ZonedDateTime time = ZonedDateTime.now();

@Nullable
@Column(name = "iris_message_id")
private Long irisMessageId;

@OneToMany(mappedBy = "trace", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
private Set<LLMTokenUsageRequest> llmRequests = new HashSet<>();

public LLMServiceType getServiceType() {
return serviceType;
}

public void setServiceType(LLMServiceType serviceType) {
this.serviceType = serviceType;
}

public Long getCourseId() {
return courseId;
}

public void setCourseId(Long courseId) {
this.courseId = courseId;
}

public Long getExerciseId() {
return exerciseId;
}

public void setExerciseId(Long exerciseId) {
this.exerciseId = exerciseId;
}

public Long getUserId() {
return userId;
}

public void setUserId(Long userId) {
this.userId = userId;
}

public ZonedDateTime getTime() {
return time;
}

public void setTime(ZonedDateTime time) {
this.time = time;
}

public Set<LLMTokenUsageRequest> getLLMRequests() {
return llmRequests;
}

public void setLlmRequests(Set<LLMTokenUsageRequest> llmRequests) {
this.llmRequests = llmRequests;
}

public Long getIrisMessageId() {
return irisMessageId;
}

public void setIrisMessageId(Long messageId) {
this.irisMessageId = messageId;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Profile(PROFILE_CORE)
@Repository
public interface LLMTokenUsageRequestRepository extends ArtemisJpaRepository<LLMTokenUsageRequest, Long> {
}
Loading

0 comments on commit dd96df5

Please sign in to comment.