diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index a18f6c57f..fd8c76d77 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -3,7 +3,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-deprecations #-} -- Pattern match 'PersistDbSpecific' @@ -126,7 +125,7 @@ withPostgresqlPool :: (MonadLoggerIO m, MonadUnliftIO m) -- ^ Action to be executed that uses the -- connection pool. -> m a -withPostgresqlPool ci = withPostgresqlPoolWithVersion getServerVersion ci +withPostgresqlPool = withPostgresqlPoolWithVersion getServerVersion -- | Same as 'withPostgresPool', but takes a callback for obtaining -- the server version (to work around an Amazon Redshift bug). @@ -146,7 +145,7 @@ withPostgresqlPoolWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) -> m a withPostgresqlPoolWithVersion getVerDouble ci = do let getVer = oldGetVersionToNew getVerDouble - withSqlPool $ open' (const $ return ()) getVer ci + withSqlPool $ open' (defaultPostgresConfHooks { pgConfHooksGetServerVersion = getVer }) ci -- | Same as 'withPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'. -- @@ -159,9 +158,7 @@ withPostgresqlPoolWithConf :: (MonadUnliftIO m, MonadLoggerIO m) -- connection pool. -> m a withPostgresqlPoolWithConf conf hooks = do - let getVer = pgConfHooksGetServerVersion hooks - modConn = pgConfHooksAfterCreate hooks - let logFuncToBackend = open' modConn getVer (pgConnStr conf) + let logFuncToBackend = open' hooks (pgConnStr conf) withSqlPoolWithConfig logFuncToBackend (postgresConfToConnectionPoolConfig conf) -- | Create a PostgreSQL connection pool. Note that it's your @@ -207,7 +204,11 @@ createPostgresqlPoolModifiedWithVersion -> m (Pool SqlBackend) createPostgresqlPoolModifiedWithVersion getVerDouble modConn ci = do let getVer = oldGetVersionToNew getVerDouble - createSqlPool $ open' modConn getVer ci + hooks = defaultPostgresConfHooks + { pgConfHooksAfterCreate = modConn + , pgConfHooksGetServerVersion = getVer + } + createSqlPool $ open' hooks ci -- | Same as 'createPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'. -- @@ -218,9 +219,7 @@ createPostgresqlPoolWithConf -> PostgresConfHooks -- ^ Record of callback functions -> m (Pool SqlBackend) createPostgresqlPoolWithConf conf hooks = do - let getVer = pgConfHooksGetServerVersion hooks - modConn = pgConfHooksAfterCreate hooks - createSqlPoolWithConfig (open' modConn getVer (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf) + createSqlPoolWithConfig (open' hooks (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf) postgresConfToConnectionPoolConfig :: PostgresConf -> ConnectionPoolConfig postgresConfToConnectionPoolConfig conf = @@ -249,17 +248,18 @@ withPostgresqlConnWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) -> m a withPostgresqlConnWithVersion getVerDouble = do let getVer = oldGetVersionToNew getVerDouble - withSqlConn . open' (const $ return ()) getVer + withSqlConn . open' (defaultPostgresConfHooks { pgConfHooksGetServerVersion = getVer }) open' - :: (PG.Connection -> IO ()) - -> (PG.Connection -> IO (NonEmpty Word)) - -> ConnectionString -> LogFunc -> IO SqlBackend -open' modConn getVer cstr logFunc = do + :: PostgresConfHooks + -> ConnectionString + -> LogFunc + -> IO SqlBackend +open' PostgresConfHooks{..} cstr logFunc = do conn <- PG.connectPostgreSQL cstr - modConn conn - ver <- getVer conn - smap <- newIORef $ Map.empty + pgConfHooksAfterCreate conn + ver <- pgConfHooksGetServerVersion conn + smap <- pgConfHooksCreateStatementCache return $ createBackend logFunc ver smap conn -- | Gets the PostgreSQL server version @@ -295,10 +295,9 @@ getServerVersionNonEmpty conn = do -- so depending upon that we have to choose how the sql query is generated. -- upsertFunction :: Double -> Maybe (EntityDef -> Text -> Text) upsertFunction :: a -> NonEmpty Word -> Maybe a -upsertFunction f version = if (version >= postgres9dot5) +upsertFunction f version = if version >= postgres9dot5 then Just f else Nothing - where postgres9dot5 :: NonEmpty Word postgres9dot5 = 9 NEL.:| [5] @@ -310,7 +309,7 @@ minimumPostgresVersion :: NonEmpty Word minimumPostgresVersion = 9 NEL.:| [4] oldGetVersionToNew :: (PG.Connection -> IO (Maybe Double)) -> (PG.Connection -> IO (NonEmpty Word)) -oldGetVersionToNew oldFn = \conn -> do +oldGetVersionToNew oldFn conn = do mDouble <- oldFn conn case mDouble of Nothing -> pure minimumPostgresVersion @@ -328,14 +327,14 @@ openSimpleConn = openSimpleConnWithVersion getServerVersion -- @since 2.9.1 openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend openSimpleConnWithVersion getVerDouble logFunc conn = do - smap <- newIORef $ Map.empty + smap <- makeSimpleStatementCache serverVersion <- oldGetVersionToNew getVerDouble conn return $ createBackend logFunc serverVersion smap conn -- | Create the backend given a logging function, server version, mutable statement cell, -- and connection. createBackend :: LogFunc -> NonEmpty Word - -> IORef (Map.Map Text Statement) -> PG.Connection -> SqlBackend + -> StatementCache -> PG.Connection -> SqlBackend createBackend logFunc serverVersion smap conn = do SqlBackend { connPrepare = prepare' conn @@ -422,7 +421,7 @@ upsertSql' ent uniqs updateVal = wher = T.intercalate " AND " $ map (singleClause . snd) $ NEL.toList uniqs singleClause :: FieldNameDB -> Text - singleClause field = escapeE (entityDB ent) <> "." <> (escapeF field) <> " =?" + singleClause field = escapeE (entityDB ent) <> "." <> escapeF field <> " =?" -- | SQL for inserting multiple rows at once and returning their primary keys. insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult @@ -608,7 +607,7 @@ instance PGFF.FromField PgInterval where nominalDiffTime :: P.Parser NominalDiffTime nominalDiffTime = do (s, h, m, ss) <- interval - let pico = ss + 60 * (fromIntegral m) + 60 * 60 * (fromIntegral (abs h)) + let pico = ss + 60 * fromIntegral m + 60 * 60 * fromIntegral (abs h) return . fromRational . toRational $ if s then (-pico) else pico fromPersistValueError :: Text -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64" @@ -799,7 +798,7 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do -- for https://github.com/yesodweb/persistent/issues/152 createText newcols fdefs_ udspair = - (addTable newcols entity) : uniques ++ references ++ foreignsAlt + addTable newcols entity : uniques ++ references ++ foreignsAlt where uniques = flip concatMap udspair $ \(uname, ucols) -> [AlterTable name $ AddUniqueConstraint uname ucols] @@ -1076,7 +1075,7 @@ getColumn getter tableName' [ PersistText columnName let cname = FieldNameDB columnName - ref <- lift $ fmap join $ traverse (getRef cname) refName_ + ref <- lift $ join <$> traverse (getRef cname) refName_ return Column { cName = cname @@ -1538,9 +1537,9 @@ instance FromJSON PostgresConf where port <- o .:? "port" .!= 5432 user <- o .: "user" password <- o .: "password" - poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig) - poolStripes <- o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig) - poolIdleTimeout <- o .:? "idleTimeout" .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig) + poolSize <- o .:? "poolsize" .!= connectionPoolConfigSize defaultPoolConfig + poolStripes <- o .:? "stripes" .!= connectionPoolConfigStripes defaultPoolConfig + poolIdleTimeout <- o .:? "idleTimeout" .!= floor (connectionPoolConfigIdleTimeout defaultPoolConfig) let ci = PG.ConnectInfo { PG.connectHost = host , PG.connectPort = port @@ -1605,6 +1604,7 @@ data PostgresConfHooks = PostgresConfHooks -- The default implementation does nothing. -- -- @since 2.11.0 + , pgConfHooksCreateStatementCache :: IO StatementCache } -- | Default settings for 'PostgresConfHooks'. See the individual fields of 'PostgresConfHooks' for the default values. @@ -1614,6 +1614,7 @@ defaultPostgresConfHooks :: PostgresConfHooks defaultPostgresConfHooks = PostgresConfHooks { pgConfHooksGetServerVersion = getServerVersionNonEmpty , pgConfHooksAfterCreate = const $ pure () + , pgConfHooksCreateStatementCache = makeSimpleStatementCache } @@ -1695,7 +1696,7 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do -- with the difference that an actual database is not needed. mockMigration :: Migration -> IO () mockMigration mig = do - smap <- newIORef $ Map.empty + smap <- makeSimpleStatementCache let sqlbackend = SqlBackend { connPrepare = \_ -> do return Statement { stmtFinalize = return () diff --git a/persistent/Database/Persist/Sql/Raw.hs b/persistent/Database/Persist/Sql/Raw.hs index 40d2eb123..b8113e1b3 100644 --- a/persistent/Database/Persist/Sql/Raw.hs +++ b/persistent/Database/Persist/Sql/Raw.hs @@ -9,7 +9,6 @@ import Control.Monad.Trans.Resource (MonadResource,release) import Data.Acquire (allocateAcquire, Acquire, mkAcquire, with) import Data.Conduit import Data.IORef (writeIORef, readIORef, newIORef) -import qualified Data.Map as Map import Data.Int (Int64) import Data.Text (Text, pack) import qualified Data.Text as T @@ -17,6 +16,7 @@ import qualified Data.Text as T import Database.Persist import Database.Persist.Sql.Types import Database.Persist.Sql.Class +import Database.Persist.Sql.Types.Internal (statementCacheLookup, StatementCache (statementCacheInsert)) rawQuery :: (MonadResource m, MonadReader env m, BackendCompatible SqlBackend env) => Text @@ -74,8 +74,8 @@ getStmt sql = do getStmtConn :: SqlBackend -> Text -> IO Statement getStmtConn conn sql = do - smap <- liftIO $ readIORef $ connStmtMap conn - case Map.lookup sql smap of + smap <- liftIO $ statementCacheLookup (connStmtMap conn) sql + case smap of Just stmt -> connStatementMiddleware conn sql stmt Nothing -> do stmt' <- liftIO $ connPrepare conn sql @@ -99,7 +99,7 @@ getStmtConn conn sql = do then stmtQuery stmt' x else liftIO $ throwIO $ StatementAlreadyFinalized sql } - liftIO $ writeIORef (connStmtMap conn) $ Map.insert sql stmt smap + liftIO $ statementCacheInsert (connStmtMap conn) sql stmt connStatementMiddleware conn sql stmt -- | Execute a raw SQL statement and return its results as a diff --git a/persistent/Database/Persist/Sql/Run.hs b/persistent/Database/Persist/Sql/Run.hs index 2bc79b3ea..622c6c738 100644 --- a/persistent/Database/Persist/Sql/Run.hs +++ b/persistent/Database/Persist/Sql/Run.hs @@ -9,15 +9,13 @@ import qualified Control.Monad.Reader as MonadReader import Control.Monad.Trans.Reader hiding (local) import Control.Monad.Trans.Resource import Data.Acquire (Acquire, ReleaseType(..), mkAcquireType, with) -import Data.IORef (readIORef) import Data.Pool (Pool) import Data.Pool as P -import qualified Data.Map as Map import qualified Data.Text as T import Database.Persist.Class.PersistStore import Database.Persist.Sql.Types -import Database.Persist.Sql.Types.Internal (IsolationLevel) +import Database.Persist.Sql.Types.Internal (IsolationLevel, StatementCache (..)) import Database.Persist.Sql.Raw -- | Get a connection from the pool, run the given action, and then return the @@ -184,7 +182,7 @@ withSqlPool -> Int -- ^ connection count -> (Pool backend -> m a) -> m a -withSqlPool mkConn connCount f = withSqlPoolWithConfig mkConn (defaultConnectionPoolConfig { connectionPoolConfigSize = connCount } ) f +withSqlPool mkConn connCount = withSqlPoolWithConfig mkConn (defaultConnectionPoolConfig { connectionPoolConfigSize = connCount } ) -- | Creates a pool of connections to a SQL database which can be used by the @Pool backend -> m a@ function. -- After the function completes, the connections are destroyed. @@ -297,5 +295,6 @@ withSqlConn open f = do close' :: (BackendCompatible SqlBackend backend) => backend -> IO () close' conn = do - readIORef (connStmtMap $ projectBackend conn) >>= mapM_ stmtFinalize . Map.elems - connClose $ projectBackend conn + let backend = projectBackend conn + statementCacheClear $ connStmtMap backend + connClose backend diff --git a/persistent/Database/Persist/Sql/Types.hs b/persistent/Database/Persist/Sql/Types.hs index 9d5e870d7..cce97b66b 100644 --- a/persistent/Database/Persist/Sql/Types.hs +++ b/persistent/Database/Persist/Sql/Types.hs @@ -6,6 +6,8 @@ module Database.Persist.Sql.Types , SqlBackendCanRead, SqlBackendCanWrite, SqlReadT, SqlWriteT, IsSqlBackend , OverflowNatural(..) , ConnectionPoolConfig(..) + , StatementCache(..) + , makeSimpleStatementCache ) where import Database.Persist.Types.Base (FieldCascade) diff --git a/persistent/Database/Persist/Sql/Types/Internal.hs b/persistent/Database/Persist/Sql/Types/Internal.hs index 93712e4de..304db458e 100644 --- a/persistent/Database/Persist/Sql/Types/Internal.hs +++ b/persistent/Database/Persist/Sql/Types/Internal.hs @@ -19,6 +19,8 @@ module Database.Persist.Sql.Types.Internal , SqlReadT , SqlWriteT , IsSqlBackend + , StatementCache(..) + , makeSimpleStatementCache ) where import Data.List.NonEmpty (NonEmpty(..)) @@ -29,8 +31,8 @@ import Control.Monad.Trans.Reader (ReaderT, runReaderT, ask) import Data.Acquire (Acquire) import Data.Conduit (ConduitM) import Data.Int (Int64) -import Data.IORef (IORef) -import Data.Map (Map) +import Data.IORef +import qualified Data.Map as Map import Data.Monoid ((<>)) import Data.String (IsString) import Data.Text (Text) @@ -45,6 +47,7 @@ import Database.Persist.Class ) import Database.Persist.Class.PersistStore (IsPersistBackend (..)) import Database.Persist.Types +import Data.Foldable (traverse_) type LogFunc = Loc -> LogSource -> LogLevel -> LogStr -> IO () @@ -76,6 +79,26 @@ makeIsolationLevelStatement l = "SET TRANSACTION ISOLATION LEVEL " <> case l of RepeatableRead -> "REPEATABLE READ" Serializable -> "SERIALIZABLE" +data StatementCache = StatementCache + { statementCacheLookup :: Text -> IO (Maybe Statement) + , statementCacheInsert :: Text -> Statement -> IO () + , statementCacheClear :: IO () + , statementCacheSize :: IO Int + } + +makeSimpleStatementCache :: IO StatementCache +makeSimpleStatementCache = do + stmtMap <- newIORef Map.empty + pure $ StatementCache + { statementCacheLookup = \sql -> Map.lookup sql <$> readIORef stmtMap + , statementCacheInsert = \sql stmt -> + modifyIORef' stmtMap (Map.insert sql stmt) + , statementCacheClear = do + oldStatements <- atomicModifyIORef' stmtMap (\oldStatements -> (Map.empty, oldStatements)) + traverse_ stmtFinalize oldStatements + , statementCacheSize = Map.size <$> readIORef stmtMap + } + -- | A 'SqlBackend' represents a handle or connection to a database. It -- contains functions and values that allow databases to have more -- optimized implementations, as well as references that benefit @@ -127,7 +150,7 @@ data SqlBackend = SqlBackend -- When left as 'Nothing', we default to using 'defaultPutMany'. -- -- @since 2.8.1 - , connStmtMap :: IORef (Map Text Statement) + , connStmtMap :: StatementCache -- ^ A reference to the cache of statements. 'Statement's are keyed by -- the 'Text' queries that generated them. , connClose :: IO ()