Skip to content

Commit

Permalink
chore: fix linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jtsextonMITRE committed Jan 24, 2025
1 parent 3c93e37 commit ac92dbb
Show file tree
Hide file tree
Showing 21 changed files with 885 additions and 565 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ select = B,C,E,F,W,B9
ignore = E203,E302,E501,W503,B905,B907,B909
per-file-ignores =
examples/*/src/*.py:E402
tests/unit/restapi/v1/signature_analysis/test*.py:F821,F401,B006,B950,E402
extend-exclude =
.ipynb_checkpoints
alembic
7 changes: 6 additions & 1 deletion src/dioptra/restapi/v1/workflows/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from injector import inject
from structlog.stdlib import BoundLogger

from .schema import FileTypes, JobFilesDownloadQueryParametersSchema, SignatureAnalysisSchema, SignatureAnalysisOutputSchema
from .schema import (
FileTypes,
JobFilesDownloadQueryParametersSchema,
SignatureAnalysisOutputSchema,
SignatureAnalysisSchema,
)
from .service import JobFilesDownloadService, SignatureAnalysisService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand Down
67 changes: 27 additions & 40 deletions src/dioptra/restapi/v1/workflows/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,96 +42,83 @@ class JobFilesDownloadQueryParametersSchema(Schema):
default=FileTypes.TAR_GZ.value,
)


class SignatureAnalysisSchema(Schema):

fileContents = fields.String(
attribute="file_contents",
metadata=dict(
description="The contents of the file"
)
attribute="file_contents", metadata=dict(description="The contents of the file")
)

filename = fields.String(
attribute="filename",
metadata=dict(
description="The name of the file"
)
attribute="filename", metadata=dict(description="The name of the file")
)


class SignatureAnalysisSignatureParamSchema(Schema):
name = fields.String(
attribute="name",
metadata=dict(
description="The name of the parameter"
)
attribute="name", metadata=dict(description="The name of the parameter")
)
type = fields.String(
attribute="type",
metadata=dict(
description="The type of the parameter"
)
attribute="type", metadata=dict(description="The type of the parameter")
)


class SignatureAnalysisSignatureInputSchema(SignatureAnalysisSignatureParamSchema):
required = fields.Boolean(
attribute="required",
metadata=dict(
description="Whether this is a required parameter"
)
metadata=dict(description="Whether this is a required parameter"),
)


class SignatureAnalysisSignatureOutputSchema(SignatureAnalysisSignatureParamSchema):
''' No additional fields. '''
"""No additional fields."""


class SignatureAnalysisSuggestedTypes(Schema):
# this should be an integer or a list of integer resource ids on the next iteration
proposed_type = fields.String(
proposed_type = fields.String(
attribute="proposed_type",
metadata=dict(
description="A suggestion for the name of the type"
)
metadata=dict(description="A suggestion for the name of the type"),
)

missing_type = fields.String(
attribute="missing_type",
metadata=dict(
description="The annotation the suggestion is attempting to represent"
)
),
)


class SignatureAnalysisSignatureSchema(Schema):
name = fields.String(
attribute="name",
metadata=dict(
description="The name of the function"
)
attribute="name", metadata=dict(description="The name of the function")
)
inputs = fields.Nested(
SignatureAnalysisSignatureInputSchema,
metadata=dict(
description="A list of objects describing the input parameters."
),
many=True
metadata=dict(description="A list of objects describing the input parameters."),
many=True,
)
outputs = fields.Nested(
SignatureAnalysisSignatureOutputSchema,
metadata=dict(
description="A list of objects describing the output parameters."
),
many=True
many=True,
)
missing_types = fields.Nested(
SignatureAnalysisSuggestedTypes,
metadata=dict(
description="A list of suggested types for non-primitives defined by the file"
description="A list of missing types for non-primitives defined by the file"
),
many=True
many=True,
)


class SignatureAnalysisOutputSchema(Schema):
plugins = fields.Nested(
SignatureAnalysisSignatureSchema,
metadata=dict(
description="A list of signature analyses for the plugins in the input file"
),
many=True
)
),
many=True,
)
54 changes: 36 additions & 18 deletions src/dioptra/restapi/v1/workflows/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""The server-side functions that perform workflows endpoint operations."""
from typing import IO, Any, Final, Iterator, List
from typing import IO, Any, Final, List

import structlog
from structlog.stdlib import BoundLogger

from dioptra.restapi.v1.lib.signature_analysis import get_plugin_signatures

from .lib import views
from .lib.package_job_files import package_job_files
from .schema import FileTypes
Expand Down Expand Up @@ -67,10 +68,13 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]:
logger=log,
)


class SignatureAnalysisService(object):
"""The service methods for performing signature analysis on a file."""

def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dict[str, Any]]]:
def post(
self, filename: str, fileContents: str, **kwargs
) -> dict[str, List[dict[str, Any]]]:
"""Perform signature analysis on a file.
Args:
Expand All @@ -81,14 +85,19 @@ def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dic
A dictionary containing the signature analysis.
"""
log: BoundLogger = kwargs.get("log", LOGGER.new())
log.debug("Performing signature analysis", filename=filename, python_source=fileContents)
log.debug(
"Performing signature analysis",
filename=filename,
python_source=fileContents,
)

signatures = list(
get_plugin_signatures(
python_source=fileContents,
filepath=filename,
)
)

signatures = list(get_plugin_signatures(
python_source=fileContents,
filepath=filename,
))

print(signatures)
endpoint_analyses = []
for signature in signatures:
Expand All @@ -98,19 +107,28 @@ def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dic
inferences = signature["suggested_types"]
endpoint_analysis = {}

endpoint_analysis['name'] = function_name
endpoint_analysis['inputs'] = function_inputs
endpoint_analysis['outputs'] = function_outputs
endpoint_analysis["name"] = function_name
endpoint_analysis["inputs"] = function_inputs
endpoint_analysis["outputs"] = function_outputs

# Compute the suggestions for the unknown types

missing_types = []

for inference in inferences:
suggested_type = inference['suggestion'] # replace this with resource id's for suggestions
original_annotation = inference['type_annotation'] # do a database lookup with this
missing_types += [{ 'missing_type': original_annotation, 'proposed_type': suggested_type}]

endpoint_analysis['missing_types'] = missing_types
suggested_type = inference[
"suggestion"
] # replace this with resource id's for suggestions
original_annotation = inference[
"type_annotation"
] # do a database lookup with this
missing_types += [
{
"missing_type": original_annotation,
"proposed_type": suggested_type,
}
]

endpoint_analysis["missing_types"] = missing_types
endpoint_analyses += [endpoint_analysis]
return {"plugins": endpoint_analyses}
return {"plugins": endpoint_analyses}
3 changes: 2 additions & 1 deletion tests/unit/restapi/lib/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ def remove_tag(
follow_redirects=True,
)


def post_metrics(
client: FlaskClient, job_id: int, metric_name: str, metric_value: float
) -> TestResponse:
Expand Down Expand Up @@ -835,4 +836,4 @@ def post_mlflowruns(
).get_json()
responses[key] = mlflowrun_response

return responses
return responses
4 changes: 1 addition & 3 deletions tests/unit/restapi/lib/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,7 @@ def assert_creating_another_existing_draft_fails(
Raises:
AssertionError: If the response status code is not 400.
"""
response = drafts_client.create(
*resource_ids, **payload
)
response = drafts_client.create(*resource_ids, **payload)
assert response.status_code == HTTPStatus.BAD_REQUEST


Expand Down
7 changes: 6 additions & 1 deletion tests/unit/restapi/lib/mock_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def get_run(self, id: str) -> MockMlflowRun:
return run

def log_metric(
self, id: str, key: str, value: float, step: Optional[int] = None, timestamp: Optional[int] = None
self,
id: str,
key: str,
value: float,
step: Optional[int] = None,
timestamp: Optional[int] = None,
):
if id not in active_runs:
active_runs[id] = []
Expand Down
Loading

0 comments on commit ac92dbb

Please sign in to comment.