Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hls-explicit-record-fields-plugin] Expand used fields only #3386

Merged
merged 20 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions ghcide/src/Development/IDE/GHC/Compat/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ module Development.IDE.GHC.Compat.Core (
noLocA,
unLocA,
LocatedAn,
LocatedA,
#if MIN_VERSION_ghc(9,2,0)
GHC.AnnListItem(..),
GHC.NameAnn(..),
Expand Down Expand Up @@ -482,8 +483,9 @@ module Development.IDE.GHC.Compat.Core (
#if !MIN_VERSION_ghc_boot_th(9,4,1)
Extension(.., NamedFieldPuns),
#else
Extension(..)
Extension(..),
#endif
UniqFM,
) where

import qualified GHC
Expand Down Expand Up @@ -518,7 +520,8 @@ import GHC.Core.DataCon hiding (dataConExTyCoVars)
import qualified GHC.Core.DataCon as DataCon
import GHC.Core.FamInstEnv hiding (pprFamInst)
import GHC.Core.InstEnv
import GHC.Types.Unique.FM
import GHC.Types.Unique.FM hiding (UniqFM)
import qualified GHC.Types.Unique.FM as UniqFM
#if MIN_VERSION_ghc(9,3,0)
import qualified GHC.Driver.Config.Tidy as GHC
import qualified GHC.Data.Strict as Strict
Expand Down Expand Up @@ -741,7 +744,8 @@ import Type
import TysPrim
import TysWiredIn
import Unify
import UniqFM
import UniqFM hiding (UniqFM)
import qualified UniqFM
import UniqSupply
import Var (Var (varName), setTyVarUnique,
setVarUnique, varType)
Expand Down Expand Up @@ -1038,6 +1042,12 @@ type LocatedAn a = GHC.LocatedAn a
type LocatedAn a = GHC.Located
#endif

#if MIN_VERSION_ghc(9,2,0)
type LocatedA = GHC.LocatedA
#else
type LocatedA = GHC.Located
#endif

#if MIN_VERSION_ghc(9,2,0)
locA :: SrcSpanAnn' a -> SrcSpan
locA = GHC.locA
Expand Down Expand Up @@ -1165,3 +1175,9 @@ pattern HsFieldBind {hfbAnn, hfbLHS, hfbRHS, hfbPun} <- HsRecField hfbAnn (SrcLo
pattern NamedFieldPuns :: Extension
pattern NamedFieldPuns = RecordPuns
#endif

#if MIN_VERSION_ghc(9,0,0)
type UniqFM = UniqFM.UniqFM
#else
type UniqFM k = UniqFM.UniqFM
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ library
, transformers
, ghc-boot-th
, unordered-containers
, containers
hs-source-dirs: src
default-language: Haskell2010

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
Expand All @@ -17,6 +18,7 @@ module Ide.Plugin.ExplicitFields
import Control.Lens ((^.))
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Except (ExceptT)
import Data.Functor ((<&>))
import Data.Generics (GenericQ, everything, extQ,
mkQ)
import qualified Data.HashMap.Strict as HashMap
Expand All @@ -38,11 +40,15 @@ import Development.IDE.GHC.Compat (HsConDetails (RecCon),
import Development.IDE.GHC.Compat.Core (Extension (NamedFieldPuns),
GhcPass,
HsExpr (RecordCon, rcon_flds),
LHsExpr, Pass (..), Pat (..),
RealSrcSpan, conPatDetails,
hfbPun, hs_valds,
mapConPatDetail, mapLoc,
pattern RealSrcSpan)
HsRecField, LHsExpr, LocatedA,
Name, Pass (..), Pat (..),
RealSrcSpan, UniqFM,
conPatDetails, emptyUFM,
hfbPun, hfbRHS, hs_valds,
lookupUFM, mapConPatDetail,
mapLoc, pattern RealSrcSpan,
plusUFM_C, ufmToIntMap,
unitUFM)
import Development.IDE.GHC.Util (getExtensions,
printOutputable)
import Development.IDE.Graph (RuleResult)
Expand Down Expand Up @@ -89,7 +95,7 @@ instance Pretty Log where
descriptor :: Recorder (WithPriority Log) -> PluginId -> PluginDescriptor IdeState
descriptor recorder plId = (defaultPluginDescriptor plId)
{ pluginHandlers = mkPluginHandler STextDocumentCodeAction codeActionProvider
, pluginRules = collectRecordsRule recorder
, pluginRules = collectRecordsRule recorder *> collectNamesRule
}

codeActionProvider :: PluginMethodHandler IdeState 'TextDocumentCodeAction
Expand Down Expand Up @@ -137,15 +143,21 @@ codeActionProvider ideState pId (CodeActionParams _ _ docId range _) = pluginRes
title = "Expand record wildcard"

collectRecordsRule :: Recorder (WithPriority Log) -> Rules ()
collectRecordsRule recorder = define (cmapWithPrio LogShake recorder) $ \CollectRecords nfp -> do
tmr <- use TypeCheck nfp
let exts = getEnabledExtensions <$> tmr
recs = concat $ maybeToList (getRecords <$> tmr)
logWith recorder Debug (LogCollectedRecords recs)
let renderedRecs = traverse renderRecordInfo recs
recMap = RangeMap.fromList (realSrcSpanToRange . renderedSrcSpan) <$> renderedRecs
logWith recorder Debug (LogRenderedRecords (concat renderedRecs))
pure ([], CRR <$> recMap <*> exts)
collectRecordsRule recorder = define (cmapWithPrio LogShake recorder) $ \CollectRecords nfp ->
use TypeCheck nfp >>= \case
Nothing -> pure ([], Nothing)
Just tmr -> do
let exts = getEnabledExtensions tmr
recs = getRecords tmr
logWith recorder Debug (LogCollectedRecords recs)
use CollectNames nfp >>= \case
Nothing -> pure ([], Nothing)
Just (CNR names) -> do
let renderedRecs = traverse (renderRecordInfo names) recs
recMap = RangeMap.fromList (realSrcSpanToRange . renderedSrcSpan) <$> renderedRecs
logWith recorder Debug (LogRenderedRecords (concat renderedRecs))
pure ([], CRR <$> recMap <*> Just exts)

where
getEnabledExtensions :: TcModuleResult -> [GhcExtension]
getEnabledExtensions = map GhcExtension . getExtensions . tmrParsed
Expand All @@ -154,6 +166,17 @@ getRecords :: TcModuleResult -> [RecordInfo]
getRecords (tmrRenamed -> (hs_valds -> valBinds,_,_,_)) =
collectRecords valBinds

collectNamesRule :: Rules ()
collectNamesRule = define mempty $ \CollectNames nfp ->
use TypeCheck nfp <&> \case
Nothing -> ([], Nothing)
Just tmr -> ([], Just (CNR (getNames tmr)))

-- | Collects all 'Name's of a given source file, to be used
ozkutuk marked this conversation as resolved.
Show resolved Hide resolved
-- in the variable usage analysis.
ozkutuk marked this conversation as resolved.
Show resolved Hide resolved
getNames :: TcModuleResult -> NameMap
getNames (tmrRenamed -> (group,_,_,_)) = NameMap (collectNames group)

data CollectRecords = CollectRecords
deriving (Eq, Show, Generic)

Expand All @@ -173,13 +196,36 @@ instance Show CollectRecordsResult where

type instance RuleResult CollectRecords = CollectRecordsResult

data CollectNames = CollectNames
deriving (Eq, Show, Generic)

instance Hashable CollectNames
instance NFData CollectNames

data CollectNamesResult = CNR NameMap
deriving (Generic)

instance NFData CollectNamesResult

instance Show CollectNamesResult where
show _ = "<CollectNamesResult>"

type instance RuleResult CollectNames = CollectNamesResult

-- `Extension` is wrapped so that we can provide an `NFData` instance
-- (without resorting to creating an orphan instance).
newtype GhcExtension = GhcExtension { unExt :: Extension }

instance NFData GhcExtension where
rnf x = x `seq` ()

-- As with `GhcExtension`, this newtype exists mostly to attach
-- an `NFData` instance to `UniqFM`.
newtype NameMap = NameMap (UniqFM Name [Name])

instance NFData NameMap where
rnf (NameMap (ufmToIntMap -> m)) = rnf m

data RecordInfo
= RecordInfoPat RealSrcSpan (Pat (GhcPass 'Renamed))
| RecordInfoCon RealSrcSpan (HsExpr (GhcPass 'Renamed))
Expand All @@ -199,10 +245,48 @@ instance Pretty RenderedRecordInfo where

instance NFData RenderedRecordInfo

renderRecordInfo :: RecordInfo -> Maybe RenderedRecordInfo
renderRecordInfo (RecordInfoPat ss pat) = RenderedRecordInfo ss <$> showRecordPat pat
renderRecordInfo (RecordInfoCon ss expr) = RenderedRecordInfo ss <$> showRecordCon expr

renderRecordInfo :: NameMap -> RecordInfo -> Maybe RenderedRecordInfo
renderRecordInfo names (RecordInfoPat ss pat) = RenderedRecordInfo ss <$> showRecordPat names pat
renderRecordInfo _ (RecordInfoCon ss expr) = RenderedRecordInfo ss <$> showRecordCon expr

-- | Checks if a 'Name' is referenced in the given map of names. The
-- 'hasNonBindingOcc' check is necessary in order to make sure that only the
-- references at the use-sites are considered (i.e. the binding occurence
-- is excluded). For more information regarding the structure of the map,
-- refer to the documentation of 'collectNames'.
referencedIn :: Name -> NameMap -> Bool
referencedIn name (NameMap names) = maybe True hasNonBindingOcc $ lookupUFM names name
where
hasNonBindingOcc :: [Name] -> Bool
hasNonBindingOcc = (> 1) . length

-- Default to leaving the element in if somehow a name can't be extracted (i.e.
-- `getName` returns `Nothing`).
filterReferenced :: (a -> Maybe Name) -> NameMap -> [a] -> [a]
filterReferenced getName names = filter (\x -> maybe True (`referencedIn` names) (getName x))

preprocessRecordPat
:: p ~ GhcPass 'Renamed
=> NameMap
-> HsRecFields p (LPat p)
-> HsRecFields p (LPat p)
preprocessRecordPat = preprocessRecord (getFieldName . unLoc)
where
getFieldName x = case unLoc (hfbRHS x) of
VarPat _ x' -> Just $ unLoc x'
_ -> Nothing

-- No need to check the name usage in the record construction case
preprocessRecordCon :: HsRecFields (GhcPass c) arg -> HsRecFields (GhcPass c) arg
preprocessRecordCon = preprocessRecord (const Nothing) (NameMap emptyUFM)

-- This function does two things:
-- 1) Tweak the AST type so that the pretty-printed record is in the
-- expanded form
-- 2) Determine the unused record fields so that they are filtered out
-- of the final output
--
-- Regarding first point:
-- We make use of the `Outputable` instances on AST types to pretty-print
-- the renamed and expanded records back into source form, to be substituted
-- with the original record later. However, `Outputable` instance of
Expand All @@ -212,8 +296,13 @@ renderRecordInfo (RecordInfoCon ss expr) = RenderedRecordInfo ss <$> showRecordC
-- as we want to print the records in their fully expanded form.
-- Here `rec_dotdot` is set to `Nothing` so that fields are printed without
-- such post-processing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment could do with some tweaks. It suggests that this is primarily about printing the records which makes it sounds like you always print them as they are, but now you're adding some additional logic to print something else depending on the fields in use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have revised this and a bunch of other comments, can you take another look to see if it looks better?

preprocessRecord :: HsRecFields (GhcPass c) arg -> HsRecFields (GhcPass c) arg
preprocessRecord flds = flds { rec_dotdot = Nothing , rec_flds = rec_flds' }
preprocessRecord
:: p ~ GhcPass c
=> (LocatedA (HsRecField p arg) -> Maybe Name)
-> NameMap
-> HsRecFields p arg
-> HsRecFields p arg
preprocessRecord getName names flds = flds { rec_dotdot = Nothing , rec_flds = rec_flds' }
where
no_pun_count = maybe (length (rec_flds flds)) unLoc (rec_dotdot flds)
-- Field binds of the explicit form (e.g. `{ a = a' }`) should be
Expand All @@ -223,29 +312,47 @@ preprocessRecord flds = flds { rec_dotdot = Nothing , rec_flds = rec_flds' }
-- puns (since there is similar mechanism in the `Outputable` instance as
-- explained above).
puns' = map (mapLoc (\fld -> fld { hfbPun = True })) puns
rec_flds' = no_puns <> puns'

showRecordPat :: Outputable (Pat (GhcPass c)) => Pat (GhcPass c) -> Maybe Text
showRecordPat = fmap printOutputable . mapConPatDetail (\case
RecCon flds -> Just $ RecCon (preprocessRecord flds)
-- Unused fields are filtered out so that they don't end up in the expanded
-- form.
punsUsed = filterReferenced getName names puns'
rec_flds' = no_puns <> punsUsed

showRecordPat :: Outputable (Pat (GhcPass 'Renamed)) => NameMap -> Pat (GhcPass 'Renamed) -> Maybe Text
showRecordPat names = fmap printOutputable . mapConPatDetail (\case
RecCon flds -> Just $ RecCon (preprocessRecordPat names flds)
_ -> Nothing)

showRecordCon :: Outputable (HsExpr (GhcPass c)) => HsExpr (GhcPass c) -> Maybe Text
showRecordCon expr@(RecordCon _ _ flds) =
Just $ printOutputable $
expr { rcon_flds = preprocessRecord flds }
expr { rcon_flds = preprocessRecordCon flds }
showRecordCon _ = Nothing

collectRecords :: GenericQ [RecordInfo]
collectRecords = everything (<>) (maybeToList . (Nothing `mkQ` getRecPatterns `extQ` getRecCons))

-- | Collect 'Name's into a map, indexed by the names' unique identifiers.
-- The 'Eq' instance of 'Name's makes use of their unique identifiers, hence
-- any 'Name' referring to the same entity is considered equal. In effect,
-- each individual list of names contains the binding occurence, along with
-- all the occurences at the use-sites (if there are any).
--
-- @UniqFM Name [Name]@ is morally the same as @Map Unique [Name]@.
-- Using 'UniqFM' gains us a bit of performance (in theory) since it
-- internally uses 'IntMap', and saves us rolling our own newtype wrapper over
-- 'Unique' (since 'Unique' doesn't have an 'Ord' instance, it can't be used
-- as 'Map' key as is). More information regarding 'UniqFM' can be found in
-- the GHC source.
collectNames :: GenericQ (UniqFM Name [Name])
collectNames = everything (plusUFM_C (<>)) (emptyUFM `mkQ` (\x -> unitUFM x [x]))

getRecCons :: LHsExpr (GhcPass 'Renamed) -> Maybe RecordInfo
getRecCons e@(unLoc -> RecordCon _ _ flds)
| isJust (rec_dotdot flds) = mkRecInfo e
where
mkRecInfo :: LHsExpr (GhcPass 'Renamed) -> Maybe RecordInfo
mkRecInfo expr = listToMaybe
[ RecordInfoCon realSpan (unLoc expr) | RealSrcSpan realSpan _ <- [ getLoc expr ]]
[ RecordInfoCon realSpan' (unLoc expr) | RealSrcSpan realSpan' _ <- [ getLoc expr ]]
getRecCons _ = Nothing

getRecPatterns :: LPat (GhcPass 'Renamed) -> Maybe RecordInfo
Expand All @@ -254,7 +361,7 @@ getRecPatterns conPat@(conPatDetails . unLoc -> Just (RecCon flds))
where
mkRecInfo :: LPat (GhcPass 'Renamed) -> Maybe RecordInfo
mkRecInfo pat = listToMaybe
[ RecordInfoPat realSpan (unLoc pat) | RealSrcSpan realSpan _ <- [ getLoc pat ]]
[ RecordInfoPat realSpan' (unLoc pat) | RealSrcSpan realSpan' _ <- [ getLoc pat ]]
getRecPatterns _ = Nothing

collectRecords' :: MonadIO m => IdeState -> NormalizedFilePath -> ExceptT String m CollectRecordsResult
Expand Down
4 changes: 3 additions & 1 deletion plugins/hls-explicit-record-fields-plugin/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ plugin = mkPluginTestDescriptor ExplicitFields.descriptor "explicit-fields"
test :: TestTree
test = testGroup "explicit-fields"
[ mkTest "WildcardOnly" "WildcardOnly" 12 10 12 20
, mkTest "Unused" "Unused" 12 10 12 20
, mkTest "Unused2" "Unused2" 12 10 12 20
, mkTest "WithPun" "WithPun" 13 10 13 25
, mkTest "WithExplicitBind" "WithExplicitBind" 12 10 12 32
, mkTest "Mixed" "Mixed" 13 10 13 37
, mkTest "Mixed" "Mixed" 14 10 14 37
, mkTest "Construction" "Construction" 16 5 16 15
, mkTestNoAction "ExplicitBinds" "ExplicitBinds" 11 10 11 52
, mkTestNoAction "Puns" "Puns" 12 10 12 31
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
, quux :: Double
}

convertMe :: MyRec -> String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
, quux :: Double
}

convertMe :: MyRec -> String
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE NamedFieldPuns #-}

module Unused where

data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
}

convertMe :: MyRec -> String
convertMe MyRec {foo, bar} = show foo ++ show bar
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE RecordWildCards #-}

module Unused where

data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
}

convertMe :: MyRec -> String
convertMe MyRec {..} = show foo ++ show bar
Loading