Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #69 #71

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions postgresql-simple.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ test-suite test

build-depends:
aeson
, async
, base
, base16-bytestring
, bytestring
Expand Down
123 changes: 88 additions & 35 deletions src/Database/PostgreSQL/Simple/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module Database.PostgreSQL.Simple.Internal where
import Control.Applicative
import Control.Exception
import Control.Concurrent.MVar
import Control.Monad(MonadPlus(..))
import Control.Monad(MonadPlus(..), when)
import Data.ByteString(ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
Expand Down Expand Up @@ -77,6 +77,10 @@ data Connection = Connection {
connectionHandle :: {-# UNPACK #-} !(MVar PQ.Connection)
, connectionObjects :: {-# UNPACK #-} !(MVar TypeInfoCache)
, connectionTempNameCounter :: {-# UNPACK #-} !(IORef Int64)
, connectionMayHaveOrphanedStatement :: {-# UNPACK #-} !(IORef Bool)
-- ^ True if there could be a statement running in postgres in this connection, but
-- postgresql-simple is not waiting for results from it. This can happen when
-- postgresql-simple is interrupted by asynchronous exceptions.
} deriving (Typeable)

instance Eq Connection where
Expand Down Expand Up @@ -238,6 +242,7 @@ connectPostgreSQL connstr = do
connectionHandle <- newMVar conn
connectionObjects <- newMVar (IntMap.empty)
connectionTempNameCounter <- newIORef 0
connectionMayHaveOrphanedStatement <- newIORef False
let wconn = Connection{..}
version <- PQ.serverVersion conn
let settings
Expand Down Expand Up @@ -330,43 +335,90 @@ exec conn sql =
Just res -> return res
#else
exec conn sql =
withConnection conn $ \h -> do
success <- PQ.sendQuery h sql
if success
then awaitResult h Nothing
else throwLibPQError h "PQsendQuery failed"
withConnection conn $ \h -> withSocket h $ \socket-> uninterruptibleMask $ \restore -> do
-- 1. If postgresql-simple was interrupted when waiting for query results
-- before, cancel that query (it may even have completed by now, but that's fine)
-- before issuing a new one.
restore $ do
needsToCancel <- readIORef (connectionMayHaveOrphanedStatement conn)
when needsToCancel $ do
cancelRunningQuery h socket
writeIORef (connectionMayHaveOrphanedStatement conn) False

-- 2. Ideally, the code that issues the query and waits for results
-- should not throw exceptions. That way we know an exception means
-- postgresql-simple was interrupted and the query might still be running.
-- Still, even if the code throws exceptions for other reasons, it means
-- we'll try to cancel a running query later once, which is fairly inocuous
-- as long as such exceptions are rare (which they should be).
restore (sendQueryAndWaitForResults h socket)
`onException` writeIORef (connectionMayHaveOrphanedStatement conn) True

where
awaitResult h mres = do
mfd <- PQ.socket h
case mfd of
Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec"
Just fd -> do
threadWaitRead fd
_ <- PQ.consumeInput h -- FIXME?
getResult h mres
withSocket h f = do
mfd <- PQ.socket h
case mfd of
Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec"
Just socket -> f socket

sendQueryAndWaitForResults h socket = do
success <- PQ.sendQuery h sql
if success then do
consumeUntilNotBusy h socket
getResult h Nothing
else throwLibPQError h "PQsendQuery failed"

cancelRunningQuery h socket = do
mcncl <- PQ.getCancel h
case mcncl of
Nothing -> pure ()
Just cncl -> do
cancelStatus <- PQ.cancel cncl
case cancelStatus of
Left _ -> PQ.errorMessage h >>= \mmsg -> throwLibPQError h ("Database.PostgreSQL.Simple.Internal.cancelRunningQuery: " <> fromMaybe "Unknown error" mmsg
<> "\nIt looks like postgresql-simple was previously interrupted by an exception while waiting for query results."
<> " Because of that, before issuing a new query, we tried to cancel that previous query that was interrupted, but failed to do so.")
Right () -> do
consumeUntilNotBusy h socket
waitForNullResult h

waitForNullResult h = do
mres <- PQ.getResult h
case mres of
Nothing -> pure ()
Just _ -> waitForNullResult h

-- | Waits until results are ready to be fetched.
consumeUntilNotBusy h socket = do
-- According to https://www.postgresql.org/docs/current/libpq-async.html :
-- 1. The isBusy status only changes by calling PQConsumeInput
-- 2. In case of errors, "PQgetResult should be called until it returns a null pointer, to allow libpq to process the error information completely"
-- 3. Also, "A typical application using these functions will have a main loop that uses select() or poll() ... When the main loop detects input ready, it should call PQconsumeInput to read the input. It can then call PQisBusy, followed by PQgetResult if PQisBusy returns false (0)"
busy <- PQ.isBusy h
when busy $ do
threadWaitRead socket
someError <- not <$> PQ.consumeInput h
when someError $ PQ.errorMessage h >>= \mmsg -> throwLibPQError h ("Database.PostgreSQL.Simple.Internal.consumeUntilNotBusy: " <> fromMaybe "Unknown error" mmsg)
consumeUntilNotBusy h socket

getResult h mres = do
isBusy <- PQ.isBusy h
if isBusy
then awaitResult h mres
else do
mres' <- PQ.getResult h
case mres' of
Nothing -> case mres of
Nothing -> throwLibPQError h "PQgetResult returned no results"
Just res -> return res
Just res -> do
status <- PQ.resultStatus res
case status of
-- FIXME: handle PQ.CopyBoth and PQ.SingleTuple
PQ.EmptyQuery -> getResult h mres'
PQ.CommandOk -> getResult h mres'
PQ.TuplesOk -> getResult h mres'
PQ.CopyOut -> return res
PQ.CopyIn -> return res
PQ.BadResponse -> getResult h mres'
PQ.NonfatalError -> getResult h mres'
PQ.FatalError -> getResult h mres'
mres' <- PQ.getResult h
case mres' of
Nothing -> case mres of
Nothing -> throwLibPQError h "PQgetResult returned no results"
Just res -> return res
Just res -> do
status <- PQ.resultStatus res
case status of
-- FIXME: handle PQ.CopyBoth and PQ.SingleTuple
PQ.EmptyQuery -> getResult h mres'
PQ.CommandOk -> getResult h mres'
PQ.TuplesOk -> getResult h mres'
PQ.CopyOut -> return res
PQ.CopyIn -> return res
PQ.BadResponse -> getResult h mres'
PQ.NonfatalError -> getResult h mres'
PQ.FatalError -> getResult h mres'
#endif

-- | A version of 'execute' that does not perform query substitution.
Expand Down Expand Up @@ -450,6 +502,7 @@ newNullConnection = do
connectionHandle <- newMVar =<< PQ.newNullConnection
connectionObjects <- newMVar IntMap.empty
connectionTempNameCounter <- newIORef 0
connectionMayHaveOrphanedStatement <- newIORef False
return Connection{..}

data Row = Row {
Expand Down
135 changes: 133 additions & 2 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ import Database.PostgreSQL.Simple.ToField (ToField)
import Database.PostgreSQL.Simple.FromField (FromField)
import Database.PostgreSQL.Simple.HStore
import Database.PostgreSQL.Simple.Newtypes
import Database.PostgreSQL.Simple.Internal (breakOnSingleQuestionMark)
import Database.PostgreSQL.Simple.Internal (breakOnSingleQuestionMark, connectionMayHaveOrphanedStatement)
import Database.PostgreSQL.Simple.Types(Query(..),Values(..), PGArray(..))
import qualified Database.PostgreSQL.Simple.Transaction as ST

import Control.Applicative
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (withAsync, wait)
import Control.Exception as E
import Control.Monad
import Data.Char
import Data.Foldable (toList)
import Data.List (concat, sort)
import Data.List (concat, sort, isInfixOf)
import Data.IORef
import Data.Monoid ((<>))
import Data.String (fromString)
Expand All @@ -48,6 +50,7 @@ import System.FilePath
import System.Timeout(timeout)
import Data.Time.Compat (getCurrentTime, diffUTCTime)
import System.Environment (getEnvironment)
import qualified System.IO as IO

import Test.Tasty
import Test.Tasty.Golden
Expand Down Expand Up @@ -84,6 +87,10 @@ tests env = testGroup "tests"
, testCase "2-ary generic" . testGeneric2
, testCase "3-ary generic" . testGeneric3
, testCase "Timeout" . testTimeout
, testCase "Expected user exceptions" . testExpectedExceptions
, testCase "Orphaned running query state mgmt" . testOrphanedRunningQueryStateMgmt
, testCase "Async exceptions" . testAsyncExceptionFailure
, testCase "Query canceled" . testCanceledQueryExceptions
]

testBytea :: TestEnv -> TestTree
Expand Down Expand Up @@ -534,6 +541,128 @@ testDouble TestEnv{..} = do
[Only (x :: Double)] <- query_ conn "SELECT '-Infinity'::float8"
x @?= (-1 / 0)

-- | Specifies exceptions thrown by postgresql-simple for certain user errors.
testExpectedExceptions :: TestEnv -> Assertion
testExpectedExceptions TestEnv{..} = do
withConn $ \c -> do
execute_ c "SELECT 1,2" `shouldThrow` (\(e :: QueryError) -> "2-column result" `isInfixOf` show e)
execute_ c "SELECT 1/0" `shouldThrow` (\(e :: SqlError) -> sqlState e == "22012")
(query_ c "SELECT 1, 2, 3" :: IO [(String, Int)]) `shouldThrow` (\(e :: ResultError) -> errSQLType e == "int4" && errHaskellType e == "Text")

shouldThrow :: forall e a. Exception e => IO a -> (e -> Bool) -> IO ()
shouldThrow f pred = do
ea <- try f
assertBool "Exception is as expected" $ case ea of
Right _ -> False
Left (ex :: e) -> pred ex

-- | Ensures that the state associated with there being an orphaned
-- running statement in a connection is updated accordingly.
testOrphanedRunningQueryStateMgmt :: TestEnv -> Assertion
testOrphanedRunningQueryStateMgmt TestEnv{..} = withConn $ \c -> do
-- 1. Connections are created with no orphaned running queries, naturally.
runState c `shouldReturn` False

-- 2. Interrupting a query that is still running should set the state
-- to True.
-- We need to give it enough time to start executing the query
-- before timing out. One second should be more than enough
void $ timeout (1000 * 1000) (execute_ c "SELECT pg_sleep(100)")
runState c `shouldReturn` True

-- 3. Running a new query should clear the state again
[ Only (num13 :: Int) ] <- query c "SELECT 13" ()
num13 @?= 13
runState c `shouldReturn` False

-- 4. Interrupting a query but letting it run until completion shouldn't
-- matter (postgresql-simple has no way of knowing that), but no errors
-- should come out of it
void $ timeout (1000 * 1000) (execute_ c "SELECT pg_sleep(2)")
runState c `shouldReturn` True

-- One second has passed, wait 2 more to ensure the query finished.
-- The state is still True.
threadDelay (1000 * 1000 * 2)
runState c `shouldReturn` True

-- 5. Check that nothing wrong happens if we try to cancel a query
-- that is no longer running (this happens automatically by running another query)
[ Only (num17 :: Int) ] <- query c "SELECT 17" ()
num17 @?= 17
runState c `shouldReturn` False

-- 6. Other errors that are not interruptions don't change the connection's state
execute_ c "SELECT 1/0" `shouldThrow` (\(_ :: SqlError) -> True)
runState c `shouldReturn` False

where
runState = readIORef . connectionMayHaveOrphanedStatement
shouldReturn :: (Eq a, Show a, HasCallStack) => IO a -> a -> IO ()
shouldReturn f expected = do
actual <- f
actual @?= expected


-- | Ensures that asynchronous exceptions thrown while queries are executing
-- are handled properly.
testAsyncExceptionFailure :: TestEnv -> Assertion
testAsyncExceptionFailure TestEnv{..} = withConn $ \c -> do
-- We need to give it enough time to start executing the query
-- before timing out. One second should be more than enough
execute_ c "SET my.setting TO '42'"
testAsyncException c (1000 * 1000) (execute_ c "SELECT pg_sleep(5)")
testAsyncException c (1000 * 1000) $
bracket_ (execute_ c "CREATE TABLE IF NOT EXISTS copy_cancel (v INT)") (execute_ c "DROP TABLE IF EXISTS copy_cancel") $
bracket_ (copy_ c "COPY copy_cancel FROM STDIN (FORMAT CSV)") (putCopyEnd c) $ do
putCopyData c "1\n"
threadDelay (1000 * 1000 * 60)

where
testAsyncException c timeLimit f = do
tmt <- timeout timeLimit f
tmt @?= Nothing
-- Any other query should work now without errors.
number42 <- query_ c "SELECT current_setting('my.setting')"
number42 @?= [ Only ("42" :: String) ]

-- | Ensures that canceled queries don't invalidate the Connection and specifies how
-- they can be detected.
testCanceledQueryExceptions :: TestEnv -> Assertion
testCanceledQueryExceptions TestEnv{..} = do
withConn $ \c1 -> withConn $ \c2 -> do
[ Only (c1Pid :: Int) ] <- query_ c1 "SELECT pg_backend_pid()"
execute_ c1 "SET my.setting TO '42'"

testCancelation c1 c2 c1Pid execPgSleep $ \(ex :: SqlError) -> sqlState ex == "57014"

-- What should we expect when COPY is canceled and putCopyEnd runs? The same SqlError as above, perhaps? Right now,
-- detecting if a query was canceled involves detecting two distinct types of exception.
testCancelation c1 c2 c1Pid execCopy $ \(ex :: IOException) -> "Database.PostgreSQL.Simple.Copy.putCopyEnd: failed to parse command status" `isInfixOf` show ex
&& "ERROR: canceling statement due to user request" `isInfixOf` show ex

-- Any other query should work now without errors.
number42 <- query_ c1 "SELECT current_setting('my.setting')"
number42 @?= [ Only ("42" :: String) ]

where
execPgSleep c = execute_ c "SELECT pg_sleep(5)"
execCopy c =
bracket_ (execute_ c "CREATE TABLE IF NOT EXISTS copy_cancel (v INT)") (execute_ c "DROP TABLE IF EXISTS copy_cancel") $
bracket_ (copy_ c "COPY copy_cancel FROM STDIN (FORMAT CSV)") (putCopyEnd c) $ do
putCopyData c "1\n"
threadDelay (1000 * 1000 * 2)
-- putCopyEnd will run after pg_cancel_backend due to threadDelays
testCancelation c1 c2 cPid f exPred = withAsync (f c1) $ \longRunningAction -> do
-- We need to give it enough time to start executing the query
-- before canceling it. One second should be more than enough
threadDelay (1000 * 1000)
cancelResult <- query c2 "SELECT pg_cancel_backend(?)" (Only cPid)
cancelResult @?= [ Only True ]
wait longRunningAction `shouldThrow` exPred
-- Connection is still usable after query canceled
[ Only (cPidAgain :: Int) ] <- query_ c1 "SELECT pg_backend_pid()"
cPid @?= cPidAgain

testGeneric1 :: TestEnv -> Assertion
testGeneric1 TestEnv{..} = do
Expand Down Expand Up @@ -621,6 +750,8 @@ withTestEnv connstr cb =

main :: IO ()
main = withConnstring $ \connstring -> do
IO.hSetBuffering IO.stdout IO.NoBuffering
IO.hSetBuffering IO.stderr IO.NoBuffering
withTestEnv connstring (defaultMain . tests)

withConnstring :: (BS8.ByteString -> IO ()) -> IO ()
Expand Down