Skip to content

Commit

Permalink
python driver psycopg3 (apache#1793)
Browse files Browse the repository at this point in the history
* update for psycopg3

* set a default argparse namespace  in case tests are run in such a way that argparse is bypassed.
  • Loading branch information
mhmaguire authored and jrgemignani committed May 2, 2024
1 parent 320573d commit f3738a6
Show file tree
Hide file tree
Showing 16 changed files with 519 additions and 459 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ build.sh
*.dylib
age--*.*.*.sql
!age--*--*sql
__pycache__
**/__pycache__

drivers/python/build
12 changes: 8 additions & 4 deletions drivers/python/age/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# specific language governing permissions and limitations
# under the License.

import psycopg.conninfo as conninfo
from . import age
from .age import *
from .models import *
Expand All @@ -23,10 +24,13 @@ def version():
return VERSION.VERSION


def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=None, **kwargs):
ag = Age()
ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory, **kwargs)
return ag
def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=ClientCursor, **kwargs):

dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs)

ag = Age()
ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory, **kwargs)
return ag

# Dummy ResultHandler
rawPrinter = DummyResultHandler()
Expand Down
79 changes: 45 additions & 34 deletions drivers/python/age/age.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,59 @@
# specific language governing permissions and limitations
# under the License.

import re
import psycopg2
from psycopg2 import errors
from psycopg2 import extensions as ext
from psycopg2 import sql
import re
import psycopg
from psycopg.types import TypeInfo
from psycopg.adapt import Loader
from psycopg import sql
from psycopg.client_cursor import ClientCursor
from .exceptions import *
from .builder import ResultHandler , parseAgeValue, newResultHandler
from .builder import parseAgeValue


_EXCEPTION_NoConnection = NoConnection()
_EXCEPTION_GraphNotSet = GraphNotSet()

WHITESPACE = re.compile('\s')

def setUpAge(conn:ext.connection, graphName:str):

class AgeDumper(psycopg.adapt.Dumper):
def dump(self, obj: Any) -> bytes | bytearray | memoryview:
pass


class AgeLoader(psycopg.adapt.Loader):
def load(self, data: bytes | bytearray | memoryview) -> Any | None:
return parseAgeValue(data.decode('utf-8'))


def setUpAge(conn:psycopg.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute("LOAD 'age';")
cursor.execute("SET search_path = ag_catalog, '$user', public;")

cursor.execute("SELECT typelem FROM pg_type WHERE typname='_agtype'")
oid = cursor.fetchone()[0]
if oid == None :
raise AgeNotSet()
ag_info = TypeInfo.fetch(conn, 'agtype')

AGETYPE = ext.new_type((oid,), 'AGETYPE', parseAgeValue)
ext.register_type(AGETYPE)
# ext.register_adapter(Path, marshalAgtValue)
if not ag_info:
raise AgeNotSet()

conn.adapters.register_loader(ag_info.oid, AgeLoader)
conn.adapters.register_loader(ag_info.array_oid, AgeLoader)

# Check graph exists
if graphName != None:
checkGraphCreated(conn, graphName)

# Create the graph, if it does not exist
def checkGraphCreated(conn:ext.connection, graphName:str):
def checkGraphCreated(conn:psycopg.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName)))
if cursor.fetchone()[0] == 0:
cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName)))
conn.commit()


def deleteGraph(conn:ext.connection, graphName:str):
def deleteGraph(conn:psycopg.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName)))
conn.commit()
Expand Down Expand Up @@ -82,7 +93,7 @@ def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
stmtArr.append(");")
return "".join(stmtArr)

def execSql(conn:ext.connection, stmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
def execSql(conn:psycopg.connection, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
if conn == None or conn.closed:
raise _EXCEPTION_NoConnection

Expand All @@ -101,14 +112,14 @@ def execSql(conn:ext.connection, stmt:str, commit:bool=False, params:tuple=None)
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)


def querySql(conn:ext.connection, stmt:str, params:tuple=None) -> ext.cursor :
def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) -> psycopg.cursor :
return execSql(conn, stmt, False, params)

# Execute cypher statement and return cursor.
# If cypher statement changes data (create, set, remove),
# You must commit session(ag.commit())
# (Otherwise the execution cannot make any effect.)
def execCypher(conn:ext.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
if conn == None or conn.closed:
raise _EXCEPTION_NoConnection

Expand All @@ -117,7 +128,7 @@ def execCypher(conn:ext.connection, graphName:str, cypherStmt:str, cols:list=Non
cypherStmt = cypherStmt.replace("\n", "")
cypherStmt = cypherStmt.replace("\t", "")
cypher = str(cursor.mogrify(cypherStmt, params))
cypher = cypher[2:len(cypher)-1]
cypher = cypher.strip()

preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"

Expand Down Expand Up @@ -145,12 +156,12 @@ def execCypher(conn:ext.connection, graphName:str, cypherStmt:str, cols:list=Non
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)


def cypher(cursor:ext.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
#clean up the string for mogrification
cypherStmt = cypherStmt.replace("\n", "")
cypherStmt = cypherStmt.replace("\t", "")
cypher = str(cursor.mogrify(cypherStmt, params))
cypher = cypher[2:len(cypher)-1]
cypher = cypher.strip()

preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
Expand All @@ -159,22 +170,22 @@ def cypher(cursor:ext.cursor, graphName:str, cypherStmt:str, cols:list=None, par
cursor.execute(stmt)


# def execCypherWithReturn(conn:ext.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# def execCypherWithReturn(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# stmt = buildCypher(graphName, cypherStmt, columns)
# return execSql(conn, stmt, False, params)

# def queryCypher(conn:ext.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params)


class Age:
def __init__(self):
self.connection = None # psycopg2 connection]
self.connection = None # psycopg connection]
self.graphName = None

# Connect to PostgreSQL Server and establish session and type extension environment.
def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=None, **kwargs):
conn = psycopg2.connect(dsn, connection_factory, cursor_factory, **kwargs)
def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor, **kwargs):
conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
setUpAge(conn, graph)
self.connection = conn
self.graphName = graph
Expand All @@ -194,21 +205,21 @@ def commit(self):
def rollback(self):
self.connection.rollback()

def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params)

def cypher(self, cursor:ext.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
return cypher(cursor, self.graphName, cypherStmt, cols=cols, params=params)

# def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
# def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
# return execSql(self.connection, stmt, commit, params)


# def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
# def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
# return execCypher(self.connection, self.graphName, cypherStmt, commit, params)

# def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params)

# def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)
# def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)
6 changes: 3 additions & 3 deletions drivers/python/age/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from .gen.AgtypeVisitor import AgtypeVisitor
from .models import *
from .exceptions import *
from antlr4 import *
from antlr4.tree.Tree import *
from antlr4 import InputStream, CommonTokenStream, ParserRuleContext
from antlr4.tree.Tree import TerminalNode
from decimal import Decimal

resultHandler = None
Expand All @@ -42,7 +42,7 @@ def parseAgeValue(value, cursor=None):
try:
return resultHandler.parse(value)
except Exception as ex:
raise AGTypeError(value)
raise AGTypeError(value, ex)


class Antlr4ResultHandler(ResultHandler):
Expand Down
2 changes: 1 addition & 1 deletion drivers/python/age/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# specific language governing permissions and limitations
# under the License.

from psycopg2.errors import *
from psycopg.errors import *

class AgeNotSet(Exception):
def __init__(self, name):
Expand Down
6 changes: 3 additions & 3 deletions drivers/python/age/networkx/age_to_networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
# under the License.

from age import *
import psycopg2
import psycopg
import networkx as nx
from age.models import Vertex, Edge, Path
from .lib import *


def age_to_networkx(connection: psycopg2.connect,
def age_to_networkx(connection: psycopg.connect,
graphName: str,
G: None | nx.DiGraph = None,
query: str | None = None
) -> nx.DiGraph:
"""
@params
---------------------
connection - (psycopg2.connect) Connection object
connection - (psycopg.connect) Connection object
graphName - (str) Name of the graph
G - (networkx.DiGraph) Networkx directed Graph [optional]
query - (str) Cypher query [optional]
Expand Down
Loading

0 comments on commit f3738a6

Please sign in to comment.