diff --git a/fedn/network/api/v1/prediction_routes.py b/fedn/network/api/v1/prediction_routes.py index d5dd804cc..e5ce8edb7 100644 --- a/fedn/network/api/v1/prediction_routes.py +++ b/fedn/network/api/v1/prediction_routes.py @@ -4,7 +4,7 @@ from fedn.network.api.auth import jwt_auth_required from fedn.network.api.shared import control -from fedn.network.api.v1.shared import api_version, mdb +from fedn.network.api.v1.shared import api_version, mdb, get_typed_list_headers, get_post_data_to_kwargs from fedn.network.storage.statestore.stores.model_store import ModelStore from fedn.network.storage.statestore.stores.prediction_store import PredictionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -49,3 +49,234 @@ def start_session(): return jsonify({"message": "Prediction session started"}), 200 except Exception: return jsonify({"message": "Failed to start prediction session"}), 500 + + +@bp.route("/", methods=["GET"]) +@jwt_auth_required(role="admin") +def get_predictions(): + """Get predictions + Retrieves a list of predictions based on the provided parameters. + By specifying a parameter in the url, you can filter the predictions based on that parameter, + and the response will contain only the predictions that match the filter. + --- + tags: + - Predictions + parameters: + - name: sender.name + in: query + required: false + type: string + description: Name of the sender + - name: sender.role + in: query + required: false + type: string + description: Role of the sender + - name: receiver.name + in: query + required: false + type: string + description: Name of the receiver + - name: receiver.role + in: query + required: false + type: string + description: Role of the receiver + - name: prediction_id + in: query + required: false + type: string + description: Prediction id of the prediction + - name: model_id + in: query + required: false + type: string + - name: correlation_id + in: query + required: false + type: string + description: Correlation id of the prediction + - name: X-Limit + in: header + required: false + type: integer + description: The maximum number of predictions to retrieve + - name: X-Skip + in: header + required: false + type: integer + description: The number of predictions to skip + - name: X-Sort-Key + in: header + required: false + type: string + description: The key to sort the predictions by + - name: X-Sort-Order + in: header + required: false + type: string + description: The order to sort the predictions in ('asc' or 'desc') + definitions: + Prediction: + type: object + properties: + id: + type: string + correlation_id: + type: string + prediction_id: + type: string + model_id: + type: string + timestamp: + type: object + format: date-time + data: + type: string + meta: + type: string + sender: + type: object + properties: + name: + type: string + role: + type: string + receiver: + type: object + properties: + name: + type: string + role: + type: string + responses: + 200: + description: A list of predictions and the total count. + schema: + type: object + properties: + count: + type: integer + result: + type: array + items: + $ref: '#/definitions/Prediction' + 500: + description: An error occurred + schema: + type: object + properties: + message: + type: string + """ + try: + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + kwargs = request.args.to_dict() + + predictions = prediction_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) + + result = [prediction.__dict__ for prediction in predictions["result"]] if use_typing else predictions["result"] + + response = {"count": predictions["count"], "result": result} + + return jsonify(response), 200 + except Exception: + return jsonify({"message": "An unexpected error occurred"}), 500 + + +@bp.route("/list", methods=["POST"]) +@jwt_auth_required(role="admin") +def list_predictions(): + """List predictions + Retrieves a list of predictions based on the provided parameters. + Works much like the GET predictions method, but allows for a more complex query. + By specifying a parameter in the body, you can filter the predictions based on that parameter, + and the response will contain only the predictions that match the filter. If the parameter value contains a comma, + the filter will be an "in" query, meaning that the predictions will be returned if the specified field contains any of the values in the parameter. + --- + tags: + - Predictions + parameters: + - name: prediction + in: body + required: false + schema: + type: object + properties: + sender.name: + type: string + description: Name of the sender + sender.role: + required: false + type: string + description: Role of the sender + receiver.name: + type: string + description: Name of the receiver + receiver.role: + required: false + type: string + description: Role of the receiver + prediction_id: + required: false + type: string + model_id: + required: false + type: string + correlation_id: + required: false + type: string + description: Correlation id of the status + - name: X-Limit + in: header + required: false + type: integer + description: The maximum number of predictions to retrieve + - name: X-Skip + in: header + required: false + type: integer + description: The number of predictions to skip + - name: X-Sort-Key + in: header + required: false + type: string + description: The key to sort the predictions by + - name: X-Sort-Order + in: header + required: false + type: string + description: The order to sort the predictions in ('asc' or 'desc') + responses: + 200: + description: A list of predictions and the total count. + schema: + type: object + properties: + count: + type: integer + result: + type: array + items: + $ref: '#/definitions/Prediction' + 500: + description: An error occurred + schema: + type: object + properties: + message: + type: string + """ + try: + limit, skip, sort_key, sort_order, use_typing = get_typed_list_headers(request.headers) + kwargs = get_post_data_to_kwargs(request) + + predictions = prediction_store.list(limit, skip, sort_key, sort_order, use_typing=use_typing, **kwargs) + + result = [prediction.__dict__ for prediction in predictions["result"]] if use_typing else predictions["result"] + + response = {"count": predictions["count"], "result": result} + + return jsonify(response), 200 + except Exception: + return jsonify({"message": "An unexpected error occurred"}), 500 diff --git a/fedn/network/storage/statestore/stores/prediction_store.py b/fedn/network/storage/statestore/stores/prediction_store.py index 3a14ec8b9..1ae29b94c 100644 --- a/fedn/network/storage/statestore/stores/prediction_store.py +++ b/fedn/network/storage/statestore/stores/prediction_store.py @@ -27,7 +27,7 @@ def from_dict(data: dict) -> "Prediction": data=data["data"] if "data" in data else None, correlation_id=data["correlationId"] if "correlationId" in data else None, timestamp=data["timestamp"] if "timestamp" in data else None, - session_id=data["sessionId"] if "sessionId" in data else None, + prediction_id=data["predictionId"] if "predictionId" in data else None, meta=data["meta"] if "meta" in data else None, sender=data["sender"] if "sender" in data else None, receiver=data["receiver"] if "receiver" in data else None,