Skip to content

Commit

Permalink
Allow providing alternate statement cache implementations to SqlBackends
Browse files Browse the repository at this point in the history
  • Loading branch information
iand675 committed Mar 31, 2021
1 parent b6d092d commit d340716
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 44 deletions.
63 changes: 32 additions & 31 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-} -- Pattern match 'PersistDbSpecific'
Expand Down Expand Up @@ -126,7 +125,7 @@ withPostgresqlPool :: (MonadLoggerIO m, MonadUnliftIO m)
-- ^ Action to be executed that uses the
-- connection pool.
-> m a
withPostgresqlPool ci = withPostgresqlPoolWithVersion getServerVersion ci
withPostgresqlPool = withPostgresqlPoolWithVersion getServerVersion

-- | Same as 'withPostgresPool', but takes a callback for obtaining
-- the server version (to work around an Amazon Redshift bug).
Expand All @@ -146,7 +145,7 @@ withPostgresqlPoolWithVersion :: (MonadUnliftIO m, MonadLoggerIO m)
-> m a
withPostgresqlPoolWithVersion getVerDouble ci = do
let getVer = oldGetVersionToNew getVerDouble
withSqlPool $ open' (const $ return ()) getVer ci
withSqlPool $ open' (defaultPostgresConfHooks { pgConfHooksGetServerVersion = getVer }) ci

-- | Same as 'withPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'.
--
Expand All @@ -159,9 +158,7 @@ withPostgresqlPoolWithConf :: (MonadUnliftIO m, MonadLoggerIO m)
-- connection pool.
-> m a
withPostgresqlPoolWithConf conf hooks = do
let getVer = pgConfHooksGetServerVersion hooks
modConn = pgConfHooksAfterCreate hooks
let logFuncToBackend = open' modConn getVer (pgConnStr conf)
let logFuncToBackend = open' hooks (pgConnStr conf)
withSqlPoolWithConfig logFuncToBackend (postgresConfToConnectionPoolConfig conf)

-- | Create a PostgreSQL connection pool. Note that it's your
Expand Down Expand Up @@ -207,7 +204,11 @@ createPostgresqlPoolModifiedWithVersion
-> m (Pool SqlBackend)
createPostgresqlPoolModifiedWithVersion getVerDouble modConn ci = do
let getVer = oldGetVersionToNew getVerDouble
createSqlPool $ open' modConn getVer ci
hooks = defaultPostgresConfHooks
{ pgConfHooksAfterCreate = modConn
, pgConfHooksGetServerVersion = getVer
}
createSqlPool $ open' hooks ci

-- | Same as 'createPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'.
--
Expand All @@ -218,9 +219,7 @@ createPostgresqlPoolWithConf
-> PostgresConfHooks -- ^ Record of callback functions
-> m (Pool SqlBackend)
createPostgresqlPoolWithConf conf hooks = do
let getVer = pgConfHooksGetServerVersion hooks
modConn = pgConfHooksAfterCreate hooks
createSqlPoolWithConfig (open' modConn getVer (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf)
createSqlPoolWithConfig (open' hooks (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf)

postgresConfToConnectionPoolConfig :: PostgresConf -> ConnectionPoolConfig
postgresConfToConnectionPoolConfig conf =
Expand Down Expand Up @@ -249,17 +248,18 @@ withPostgresqlConnWithVersion :: (MonadUnliftIO m, MonadLoggerIO m)
-> m a
withPostgresqlConnWithVersion getVerDouble = do
let getVer = oldGetVersionToNew getVerDouble
withSqlConn . open' (const $ return ()) getVer
withSqlConn . open' (defaultPostgresConfHooks { pgConfHooksGetServerVersion = getVer })

open'
:: (PG.Connection -> IO ())
-> (PG.Connection -> IO (NonEmpty Word))
-> ConnectionString -> LogFunc -> IO SqlBackend
open' modConn getVer cstr logFunc = do
:: PostgresConfHooks
-> ConnectionString
-> LogFunc
-> IO SqlBackend
open' PostgresConfHooks{..} cstr logFunc = do
conn <- PG.connectPostgreSQL cstr
modConn conn
ver <- getVer conn
smap <- newIORef $ Map.empty
pgConfHooksAfterCreate conn
ver <- pgConfHooksGetServerVersion conn
smap <- pgConfHooksCreateStatementCache
return $ createBackend logFunc ver smap conn

-- | Gets the PostgreSQL server version
Expand Down Expand Up @@ -295,10 +295,9 @@ getServerVersionNonEmpty conn = do
-- so depending upon that we have to choose how the sql query is generated.
-- upsertFunction :: Double -> Maybe (EntityDef -> Text -> Text)
upsertFunction :: a -> NonEmpty Word -> Maybe a
upsertFunction f version = if (version >= postgres9dot5)
upsertFunction f version = if version >= postgres9dot5
then Just f
else Nothing
where

postgres9dot5 :: NonEmpty Word
postgres9dot5 = 9 NEL.:| [5]
Expand All @@ -310,7 +309,7 @@ minimumPostgresVersion :: NonEmpty Word
minimumPostgresVersion = 9 NEL.:| [4]

oldGetVersionToNew :: (PG.Connection -> IO (Maybe Double)) -> (PG.Connection -> IO (NonEmpty Word))
oldGetVersionToNew oldFn = \conn -> do
oldGetVersionToNew oldFn conn = do
mDouble <- oldFn conn
case mDouble of
Nothing -> pure minimumPostgresVersion
Expand All @@ -328,14 +327,14 @@ openSimpleConn = openSimpleConnWithVersion getServerVersion
-- @since 2.9.1
openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend
openSimpleConnWithVersion getVerDouble logFunc conn = do
smap <- newIORef $ Map.empty
smap <- makeSimpleStatementCache
serverVersion <- oldGetVersionToNew getVerDouble conn
return $ createBackend logFunc serverVersion smap conn

-- | Create the backend given a logging function, server version, mutable statement cell,
-- and connection.
createBackend :: LogFunc -> NonEmpty Word
-> IORef (Map.Map Text Statement) -> PG.Connection -> SqlBackend
-> StatementCache -> PG.Connection -> SqlBackend
createBackend logFunc serverVersion smap conn = do
SqlBackend
{ connPrepare = prepare' conn
Expand Down Expand Up @@ -422,7 +421,7 @@ upsertSql' ent uniqs updateVal =
wher = T.intercalate " AND " $ map (singleClause . snd) $ NEL.toList uniqs

singleClause :: FieldNameDB -> Text
singleClause field = escapeE (entityDB ent) <> "." <> (escapeF field) <> " =?"
singleClause field = escapeE (entityDB ent) <> "." <> escapeF field <> " =?"

-- | SQL for inserting multiple rows at once and returning their primary keys.
insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult
Expand Down Expand Up @@ -608,7 +607,7 @@ instance PGFF.FromField PgInterval where
nominalDiffTime :: P.Parser NominalDiffTime
nominalDiffTime = do
(s, h, m, ss) <- interval
let pico = ss + 60 * (fromIntegral m) + 60 * 60 * (fromIntegral (abs h))
let pico = ss + 60 * fromIntegral m + 60 * 60 * fromIntegral (abs h)
return . fromRational . toRational $ if s then (-pico) else pico

fromPersistValueError :: Text -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64"
Expand Down Expand Up @@ -799,7 +798,7 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do
-- for https://github.com/yesodweb/persistent/issues/152

createText newcols fdefs_ udspair =
(addTable newcols entity) : uniques ++ references ++ foreignsAlt
addTable newcols entity : uniques ++ references ++ foreignsAlt
where
uniques = flip concatMap udspair $ \(uname, ucols) ->
[AlterTable name $ AddUniqueConstraint uname ucols]
Expand Down Expand Up @@ -1076,7 +1075,7 @@ getColumn getter tableName' [ PersistText columnName

let cname = FieldNameDB columnName

ref <- lift $ fmap join $ traverse (getRef cname) refName_
ref <- lift $ join <$> traverse (getRef cname) refName_

return Column
{ cName = cname
Expand Down Expand Up @@ -1538,9 +1537,9 @@ instance FromJSON PostgresConf where
port <- o .:? "port" .!= 5432
user <- o .: "user"
password <- o .: "password"
poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig)
poolStripes <- o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig)
poolIdleTimeout <- o .:? "idleTimeout" .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig)
poolSize <- o .:? "poolsize" .!= connectionPoolConfigSize defaultPoolConfig
poolStripes <- o .:? "stripes" .!= connectionPoolConfigStripes defaultPoolConfig
poolIdleTimeout <- o .:? "idleTimeout" .!= floor (connectionPoolConfigIdleTimeout defaultPoolConfig)
let ci = PG.ConnectInfo
{ PG.connectHost = host
, PG.connectPort = port
Expand Down Expand Up @@ -1605,6 +1604,7 @@ data PostgresConfHooks = PostgresConfHooks
-- The default implementation does nothing.
--
-- @since 2.11.0
, pgConfHooksCreateStatementCache :: IO StatementCache
}

-- | Default settings for 'PostgresConfHooks'. See the individual fields of 'PostgresConfHooks' for the default values.
Expand All @@ -1614,6 +1614,7 @@ defaultPostgresConfHooks :: PostgresConfHooks
defaultPostgresConfHooks = PostgresConfHooks
{ pgConfHooksGetServerVersion = getServerVersionNonEmpty
, pgConfHooksAfterCreate = const $ pure ()
, pgConfHooksCreateStatementCache = makeSimpleStatementCache
}


Expand Down Expand Up @@ -1695,7 +1696,7 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do
-- with the difference that an actual database is not needed.
mockMigration :: Migration -> IO ()
mockMigration mig = do
smap <- newIORef $ Map.empty
smap <- makeSimpleStatementCache
let sqlbackend = SqlBackend { connPrepare = \_ -> do
return Statement
{ stmtFinalize = return ()
Expand Down
8 changes: 4 additions & 4 deletions persistent/Database/Persist/Sql/Raw.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import Control.Monad.Trans.Resource (MonadResource,release)
import Data.Acquire (allocateAcquire, Acquire, mkAcquire, with)
import Data.Conduit
import Data.IORef (writeIORef, readIORef, newIORef)
import qualified Data.Map as Map
import Data.Int (Int64)
import Data.Text (Text, pack)
import qualified Data.Text as T

import Database.Persist
import Database.Persist.Sql.Types
import Database.Persist.Sql.Class
import Database.Persist.Sql.Types.Internal (statementCacheLookup, StatementCache (statementCacheInsert))

rawQuery :: (MonadResource m, MonadReader env m, BackendCompatible SqlBackend env)
=> Text
Expand Down Expand Up @@ -74,8 +74,8 @@ getStmt sql = do

getStmtConn :: SqlBackend -> Text -> IO Statement
getStmtConn conn sql = do
smap <- liftIO $ readIORef $ connStmtMap conn
case Map.lookup sql smap of
smap <- liftIO $ statementCacheLookup (connStmtMap conn) sql
case smap of
Just stmt -> connStatementMiddleware conn sql stmt
Nothing -> do
stmt' <- liftIO $ connPrepare conn sql
Expand All @@ -99,7 +99,7 @@ getStmtConn conn sql = do
then stmtQuery stmt' x
else liftIO $ throwIO $ StatementAlreadyFinalized sql
}
liftIO $ writeIORef (connStmtMap conn) $ Map.insert sql stmt smap
liftIO $ statementCacheInsert (connStmtMap conn) sql stmt
connStatementMiddleware conn sql stmt

-- | Execute a raw SQL statement and return its results as a
Expand Down
11 changes: 5 additions & 6 deletions persistent/Database/Persist/Sql/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ import qualified Control.Monad.Reader as MonadReader
import Control.Monad.Trans.Reader hiding (local)
import Control.Monad.Trans.Resource
import Data.Acquire (Acquire, ReleaseType(..), mkAcquireType, with)
import Data.IORef (readIORef)
import Data.Pool (Pool)
import Data.Pool as P
import qualified Data.Map as Map
import qualified Data.Text as T

import Database.Persist.Class.PersistStore
import Database.Persist.Sql.Types
import Database.Persist.Sql.Types.Internal (IsolationLevel)
import Database.Persist.Sql.Types.Internal (IsolationLevel, StatementCache (..))
import Database.Persist.Sql.Raw

-- | Get a connection from the pool, run the given action, and then return the
Expand Down Expand Up @@ -184,7 +182,7 @@ withSqlPool
-> Int -- ^ connection count
-> (Pool backend -> m a)
-> m a
withSqlPool mkConn connCount f = withSqlPoolWithConfig mkConn (defaultConnectionPoolConfig { connectionPoolConfigSize = connCount } ) f
withSqlPool mkConn connCount = withSqlPoolWithConfig mkConn (defaultConnectionPoolConfig { connectionPoolConfigSize = connCount } )

-- | Creates a pool of connections to a SQL database which can be used by the @Pool backend -> m a@ function.
-- After the function completes, the connections are destroyed.
Expand Down Expand Up @@ -297,5 +295,6 @@ withSqlConn open f = do

close' :: (BackendCompatible SqlBackend backend) => backend -> IO ()
close' conn = do
readIORef (connStmtMap $ projectBackend conn) >>= mapM_ stmtFinalize . Map.elems
connClose $ projectBackend conn
let backend = projectBackend conn
statementCacheClear $ connStmtMap backend
connClose backend
2 changes: 2 additions & 0 deletions persistent/Database/Persist/Sql/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module Database.Persist.Sql.Types
, SqlBackendCanRead, SqlBackendCanWrite, SqlReadT, SqlWriteT, IsSqlBackend
, OverflowNatural(..)
, ConnectionPoolConfig(..)
, StatementCache(..)
, makeSimpleStatementCache
) where

import Database.Persist.Types.Base (FieldCascade)
Expand Down
29 changes: 26 additions & 3 deletions persistent/Database/Persist/Sql/Types/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ module Database.Persist.Sql.Types.Internal
, SqlReadT
, SqlWriteT
, IsSqlBackend
, StatementCache(..)
, makeSimpleStatementCache
) where

import Data.List.NonEmpty (NonEmpty(..))
Expand All @@ -29,8 +31,8 @@ import Control.Monad.Trans.Reader (ReaderT, runReaderT, ask)
import Data.Acquire (Acquire)
import Data.Conduit (ConduitM)
import Data.Int (Int64)
import Data.IORef (IORef)
import Data.Map (Map)
import Data.IORef
import qualified Data.Map as Map
import Data.Monoid ((<>))
import Data.String (IsString)
import Data.Text (Text)
Expand All @@ -45,6 +47,7 @@ import Database.Persist.Class
)
import Database.Persist.Class.PersistStore (IsPersistBackend (..))
import Database.Persist.Types
import Data.Foldable (traverse_)

type LogFunc = Loc -> LogSource -> LogLevel -> LogStr -> IO ()

Expand Down Expand Up @@ -76,6 +79,26 @@ makeIsolationLevelStatement l = "SET TRANSACTION ISOLATION LEVEL " <> case l of
RepeatableRead -> "REPEATABLE READ"
Serializable -> "SERIALIZABLE"

data StatementCache = StatementCache
{ statementCacheLookup :: Text -> IO (Maybe Statement)
, statementCacheInsert :: Text -> Statement -> IO ()
, statementCacheClear :: IO ()
, statementCacheSize :: IO Int
}

makeSimpleStatementCache :: IO StatementCache
makeSimpleStatementCache = do
stmtMap <- newIORef Map.empty
pure $ StatementCache
{ statementCacheLookup = \sql -> Map.lookup sql <$> readIORef stmtMap
, statementCacheInsert = \sql stmt ->
modifyIORef' stmtMap (Map.insert sql stmt)
, statementCacheClear = do
oldStatements <- atomicModifyIORef' stmtMap (\oldStatements -> (Map.empty, oldStatements))
traverse_ stmtFinalize oldStatements
, statementCacheSize = Map.size <$> readIORef stmtMap
}

-- | A 'SqlBackend' represents a handle or connection to a database. It
-- contains functions and values that allow databases to have more
-- optimized implementations, as well as references that benefit
Expand Down Expand Up @@ -127,7 +150,7 @@ data SqlBackend = SqlBackend
-- When left as 'Nothing', we default to using 'defaultPutMany'.
--
-- @since 2.8.1
, connStmtMap :: IORef (Map Text Statement)
, connStmtMap :: StatementCache
-- ^ A reference to the cache of statements. 'Statement's are keyed by
-- the 'Text' queries that generated them.
, connClose :: IO ()
Expand Down

0 comments on commit d340716

Please sign in to comment.