Skip to content

Commit

Permalink
feat: support set environment variables on flow.dag.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Feb 27, 2024
1 parent 82584dc commit 841505c
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 19 deletions.
4 changes: 4 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta):
class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema):
"""Schema for flow.dag.yaml file."""

environment_variables = fields.Dict(
fields.Str(),
fields.Str(),
)
additional_includes = fields.List(LocalPathField())


Expand Down
11 changes: 8 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def __init__(
self._column_mapping = column_mapping or {}
self._variant = variant
self._connections = connections or {}
self._environment_variables = environment_variables or {}

self._inputs = FlowComponentInputDict()
self._outputs = FlowComponentOutputDict()
Expand All @@ -266,8 +265,14 @@ def __init__(
# file existence has been checked in _get_flow_definition
# we don't need to rebase additional_includes as we have updated base_path
with open(Path(self.base_path, self._flow), "r", encoding="utf-8") as f:
flow_content = f.read()
additional_includes = yaml.safe_load(flow_content).get("additional_includes", None)
flow_content = yaml.safe_load(f.read())
additional_includes = flow_content.get("additional_includes", None)
environment_variables_from_flow = flow_content.get("environment_variables", {})
for key, value in (environment_variables or {}).items():
environment_variables_from_flow[key] = value
environment_variables = environment_variables_from_flow

self._environment_variables = environment_variables or {}
self._additional_includes = additional_includes or []

# unlike other Component, code is a private property in FlowComponent and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ def test_component_load_from_dag(self):
"is_deterministic": True,
"code": "/subscriptions/xxx/resourceGroups/xxx/workspaces/xxx/codes/xxx/versions/1",
"flow_file_name": "flow.dag.yaml",
"environment_variables": {
"AZURE_OPENAI_API_BASE": "${my_connection.api_base}",
"AZURE_OPENAI_API_KEY": "${my_connection.api_key}",
"AZURE_OPENAI_API_TYPE": "azure",
"AZURE_OPENAI_API_VERSION": "2023-03-15-preview",
},
},
"description": "test load component from flow",
"is_anonymous": False,
"is_archived": False,
"properties": {
"client_component_hash": "b503491e-be3a-de50-0413-30c8c8abb43a",
"client_component_hash": "19278001-3d52-0e43-dc43-4082128d8243",
},
"tags": {},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ nodes:
max_tokens: "120"
environment:
python_requirements_txt: requirements.txt
environment_variables:
AZURE_OPENAI_API_TYPE: azure
AZURE_OPENAI_API_VERSION: 2023-03-15-preview
AZURE_OPENAI_API_KEY: ${my_connection.api_key}
AZURE_OPENAI_API_BASE: ${my_connection.api_base}
38 changes: 25 additions & 13 deletions sdk/ml/azure-ai-ml/tests/test_configs/flows/basic/hello.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import openai
from dotenv import load_dotenv
from openai.version import VERSION as OPENAI_VERSION
from promptflow import tool

# The inputs section will change based on the arguments of the tool function, after you save the code
Expand All @@ -13,6 +13,27 @@ def to_bool(value) -> bool:
return str(value).lower() == "true"


def get_client():
if OPENAI_VERSION.startswith("0."):
raise Exception(
"Please upgrade your OpenAI package to version >= 1.0.0 or using the command: pip install --upgrade openai."
)
api_key = os.environ["AZURE_OPENAI_API_KEY"]
conn = dict(
api_key=os.environ["AZURE_OPENAI_API_KEY"],
)
if api_key.startswith("sk-"):
from openai import OpenAI as Client
else:
from openai import AzureOpenAI as Client

conn.update(
azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2023-07-01-preview"),
)
return Client(**conn)


@tool
def my_python_tool(
prompt: str,
Expand All @@ -38,21 +59,14 @@ def my_python_tool(
load_dotenv()

if "AZURE_OPENAI_API_KEY" not in os.environ:
raise Exception("Please sepecify environment variables: AZURE_OPENAI_API_KEY")

conn = dict(
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_base=os.environ["AZURE_OPENAI_API_BASE"],
api_type=os.environ.get("AZURE_OPENAI_API_TYPE", "azure"),
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
)
raise Exception("Please specify environment variables: AZURE_OPENAI_API_KEY")

# TODO: remove below type conversion after client can pass json rather than string.
echo = to_bool(echo)

response = openai.Completion.create(
response = get_client().completions.create(
prompt=prompt,
engine=deployment_name,
model=deployment_name,
# empty string suffix should be treated as None.
suffix=suffix if suffix else None,
max_tokens=int(max_tokens),
Expand All @@ -69,8 +83,6 @@ def my_python_tool(
# Logit bias must be a dict if we passed it to openai api.
logit_bias=logit_bias if logit_bias else {},
user=user,
request_timeout=30,
**conn,
)

# get first element because prompt is single.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ environment_variables:
# environment variables from connection
AZURE_OPENAI_API_KEY: ${azure_open_ai_connection.api_key}
AZURE_OPENAI_API_BASE: ${azure_open_ai_connection.api_base}
AZURE_OPENAI_API_TYPE: azure
AZURE_OPENAI_API_VERSION: 2023-03-15-preview
connections:
llm:
connection: azure_open_ai_connection
Expand Down

0 comments on commit 841505c

Please sign in to comment.