Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: python hostcall #860

Merged
merged 6 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/metagen/src/fdk_python/static/main.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from .{{ mod_name }}_types import {{ imports }}

{% for func in funcs %}
@typed_{{ func.name }}
def {{ func.name }}(inp: {{ func.input_name }}) -> {{ func.output_name }}:
def {{ func.name }}(inp: {{ func.input_name }}, ctx: Ctx) -> {{ func.output_name }}:
# TODO: write your logic here
raise Exception("{{ func.name }} not implemented")
{% endfor %}
11 changes: 8 additions & 3 deletions src/metagen/src/fdk_python/static/types.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ from dataclasses import dataclass, asdict, fields

FORWARD_REFS = {}


class Ctx:
def gql(self, query: str, variables: str) -> Any:
pass

class Struct:
def repr(self):
return asdict(self)
Expand Down Expand Up @@ -93,10 +98,10 @@ def __repr(value: Any):


{%for func in funcs %}
def typed_{{ func.name }}(user_fn: Callable[[{{ func.input_name }}], {{ func.output_name }}]):
def exported_wrapper(raw_inp):
def typed_{{ func.name }}(user_fn: Callable[[{{ func.input_name }}, Ctx], {{ func.output_name }}]):
def exported_wrapper(raw_inp, ctx):
inp: {{ func.input_name }} = Struct.new({{ func.input_name }}, raw_inp)
out: {{ func.output_name }} = user_fn(inp)
out: {{ func.output_name }} = user_fn(inp, ctx)
if isinstance(out, list):
return [__repr(v) for v in out]
return __repr(out)
Expand Down
68 changes: 43 additions & 25 deletions src/pyrt_wit_wire/main.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,52 @@
import wit_wire.exports
import importlib
import importlib.abc
import importlib.machinery
import importlib.util
import json
import os
import inspect
import sys
import traceback
import types
from typing import Any, Callable, Dict, TypeVar

# NOTE: all imports must be toplevel as constrained by `componentize-py`
# https://github.com/bytecodealliance/componentize-py/issues/23
# from pyrt.imports.typegate_wire import hostcall
import wit_wire.exports
from wit_wire.exports.mat_wire import (
Err,
HandleErr_HandlerErr,
HandleErr_InJsonErr,
HandleErr_NoHandler,
HandleReq,
InitArgs,
InitResponse,
InitError_UnexpectedMat,
InitError_Other,
InitError_UnexpectedMat,
InitResponse,
MatInfo,
HandleReq,
HandleErr_NoHandler,
HandleErr_InJsonErr,
HandleErr_HandlerErr,
Err,
)

import json
import types
from typing import Callable, Any, Dict
import importlib
import importlib.util
import importlib.abc
import importlib.machinery
import os
import sys
import traceback
# NOTE: all imports must be toplevel as constrained by `componentize-py`
# https://github.com/bytecodealliance/componentize-py/issues/23
# from pyrt.imports.typegate_wire import hostcall
from wit_wire.imports.typegate_wire import hostcall

# the `MatWire` class is instantiated for each
# external call. We have to put any persisted
# state here.
handlers = {}


T = TypeVar("T")
HandlerFn = Callable[..., T]


class Ctx:
def gql(self, query: str, variables: str) -> Any:
data = json.loads(
hostcall("gql", json=json.dumps({"query": query, "variables": variables}))
)
return data["data"]


class MatWire(wit_wire.exports.MatWire):
def init(self, args: InitArgs):
for op in args.expected_ops:
Expand Down Expand Up @@ -64,12 +78,16 @@ def handle(self, req: HandleReq):


class ErasedHandler:
def __init__(self, handler_fn: Callable[[Any], Any]) -> None:
def __init__(self, handler_fn: HandlerFn[T]) -> None:
self.handler_fn = handler_fn
self.param_count = len(inspect.signature(self.handler_fn).parameters)

def handle(self, req: HandleReq):
in_parsed = json.loads(req.in_json)
out = self.handler_fn(in_parsed)
if self.param_count == 1:
out = self.handler_fn(in_parsed)
else:
out = self.handler_fn(in_parsed, Ctx())
return json.dumps(out)


Expand All @@ -79,7 +97,7 @@ def op_to_handler(op: MatInfo) -> ErasedHandler:
module = types.ModuleType(op.op_name)
exec(data_parsed["source"], module.__dict__)
fn = module.__dict__[data_parsed["func_name"]]
return ErasedHandler(handler_fn=lambda inp: fn(inp))
return ErasedHandler(handler_fn=fn)
elif data_parsed["ty"] == "import_function":
prefix = data_parsed["func_name"]

Expand All @@ -101,7 +119,7 @@ def op_to_handler(op: MatInfo) -> ErasedHandler:
return ErasedHandler(handler_fn=getattr(module, data_parsed["func_name"]))
elif data_parsed["ty"] == "lambda":
fn = eval(data_parsed["source"])
return ErasedHandler(handler_fn=lambda inp: fn(inp))
return ErasedHandler(handler_fn=fn)
else:
raise Err(InitError_UnexpectedMat(op))

Expand Down
24 changes: 22 additions & 2 deletions src/typegate/src/runtimes/wit_wire/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

async handle(opName: string, args: ResolverArgs) {
const { _, ...inJson } = args;

let res;
try {
res = await Meta.wit_wire.handle(this.id, {
Expand Down Expand Up @@ -173,8 +174,12 @@
async function gql(cx: HostCallCtx, args: object) {
const argsValidator = zod.object({
query: zod.string(),
variables: zod.record(zod.string(), zod.unknown()),
variables: zod.union([
zod.string(),
zod.record(zod.string(), zod.unknown()),
]),
});

const parseRes = argsValidator.safeParse(args);
if (!parseRes.success) {
throw new Error("error validating gql args", {
Expand All @@ -184,6 +189,19 @@
});
}
const parsed = parseRes.data;

// Convert variables to an object if it's a string
let variables = parsed.variables;
if (typeof variables === "string") {
try {
variables = JSON.parse(variables);
} catch (error) {
throw new Error("Failed to parse variables string as JSON", {
cause: error,
});
}

Check warning on line 202 in src/typegate/src/runtimes/wit_wire/mod.ts

View check run for this annotation

Codecov / codecov/patch

src/typegate/src/runtimes/wit_wire/mod.ts#L199-L202

Added lines #L199 - L202 were not covered by tests
}

const request = new Request(cx.typegraphUrl, {
method: "POST",
headers: {
Expand All @@ -193,15 +211,17 @@
},
body: JSON.stringify({
query: parsed.query,
variables: parsed.variables,
variables: variables,
}),
});

//TODO: make `handle` more friendly to internal requests
const res = await cx.typegate.handle(request, {
port: 0,
hostname: "internal",
transport: "tcp",
});

if (!res.ok) {
const text = await res.text();
throw new Error(`gql fetch on ${cx.typegraphUrl} failed: ${text}`, {
Expand Down
11 changes: 10 additions & 1 deletion tests/internal/internal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typegraph.graph.typegraph import Graph
from typegraph.policy import Policy
from typegraph.runtimes.deno import DenoRuntime
from typegraph.runtimes.python import PythonRuntime

from typegraph import t, typegraph

Expand All @@ -11,6 +12,7 @@ def internal(g: Graph):
internal = Policy.internal()

deno = DenoRuntime()
python = PythonRuntime()

inp = t.struct({"first": t.float(), "second": t.float()})
out = t.float()
Expand All @@ -19,7 +21,14 @@ def internal(g: Graph):
sum=deno.import_(inp, out, module="ts/logic.ts", name="sum").with_policy(
internal
),
remoteSum=deno.import_(
remoteSumDeno=deno.import_(
inp, out, module="ts/logic.ts", name="remoteSum"
).with_policy(public),
remoteSumPy=python.import_(
inp,
out,
module="py/logic.py",
name="remote_sum",
deps=["./py/logic_types.py"],
).with_policy(public),
)
38 changes: 36 additions & 2 deletions tests/internal/internal_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
// SPDX-License-Identifier: Elastic-2.0

import { gql, Meta } from "../utils/mod.ts";
import { join } from "@std/path/join";
import { assertEquals } from "@std/assert";

Meta.test({
name: "client table suite",
}, async (_) => {
const scriptsPath = join(import.meta.dirname!, ".");

assertEquals(
(
await Meta.cli(
{
env: {
// RUST_BACKTRACE: "1",
},
},
...`-C ${scriptsPath} gen`.split(" "),
)
).code,
0,
);
});

Meta.test(
{
Expand All @@ -13,11 +35,23 @@ Meta.test(
await t.should("work on the default worker", async () => {
await gql`
query {
remoteSum(first: 1.2, second: 2.3)
remoteSumDeno(first: 1.2, second: 2.3)
}
`
.expectData({
remoteSumDeno: 3.5,
})
.on(e, `http://localhost:${t.port}`);
});

await t.should("hostcall python work", async () => {
await gql`
query {
remoteSumPy(first: 1.2, second: 2.3)
}
`
.expectData({
remoteSum: 3.5,
remoteSumPy: 3.5,
})
.on(e, `http://localhost:${t.port}`);
});
Expand Down
12 changes: 12 additions & 0 deletions tests/internal/metatype.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
typegates:
dev:
url: "http://localhost:7890"
username: admin
password: password

metagen:
targets:
main:
- generator: fdk_python
path: ./py/
typegraph_path: internal.py
18 changes: 18 additions & 0 deletions tests/internal/py/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright Metatype OÜ, licensed under the Elastic License 2.0.
# SPDX-License-Identifier: Elastic-2.0

from .logic_types import Ctx
import json


def remote_sum(inp: dict, ctx: Ctx) -> float:
data = ctx.gql(
query="""
query q($first: Float!, $second: Float!) {
sum(first: $first, second: $second)
}
""",
variables=json.dumps(inp),
)
sum = data["sum"]
return sum
Loading
Loading