diff --git a/src/Futhark/CLI/REPL.hs b/src/Futhark/CLI/REPL.hs index 3fc3ce217b..46387b0a57 100644 --- a/src/Futhark/CLI/REPL.hs +++ b/src/Futhark/CLI/REPL.hs @@ -406,7 +406,7 @@ typeCommand = genTypeCommand parseExp T.checkExp $ \(ps, e) -> annotate italicized $ "\n\nPolymorphic in" <+> mconcat (intersperse " " $ map pretty ps) - <> "." + <> "." else mempty mtypeCommand :: Command diff --git a/src/Futhark/CodeGen/Backends/GenericPython/AST.hs b/src/Futhark/CodeGen/Backends/GenericPython/AST.hs index 460898a8ee..c125771f6f 100644 --- a/src/Futhark/CodeGen/Backends/GenericPython/AST.hs +++ b/src/Futhark/CodeGen/Backends/GenericPython/AST.hs @@ -126,27 +126,27 @@ instance Pretty PyStmt where pretty (If cond [] []) = "if" <+> pretty cond - <> ":" - indent 2 "pass" + <> ":" + indent 2 "pass" pretty (If cond [] fbranch) = "if" <+> pretty cond - <> ":" - indent 2 "pass" - "else:" - indent 2 (stack $ map pretty fbranch) + <> ":" + indent 2 "pass" + "else:" + indent 2 (stack $ map pretty fbranch) pretty (If cond tbranch []) = "if" <+> pretty cond - <> ":" - indent 2 (stack $ map pretty tbranch) + <> ":" + indent 2 (stack $ map pretty tbranch) pretty (If cond tbranch fbranch) = "if" <+> pretty cond - <> ":" - indent 2 (stack $ map pretty tbranch) - "else:" - indent 2 (stack $ map pretty fbranch) + <> ":" + indent 2 (stack $ map pretty tbranch) + "else:" + indent 2 (stack $ map pretty fbranch) pretty (Try pystms pyexcepts) = "try:" indent 2 (stack $ map pretty pystms) @@ -154,20 +154,20 @@ instance Pretty PyStmt where pretty (While cond body) = "while" <+> pretty cond - <> ":" - indent 2 (stack $ map pretty body) + <> ":" + indent 2 (stack $ map pretty body) pretty (For i what body) = "for" <+> pretty i <+> "in" <+> pretty what - <> ":" - indent 2 (stack $ map pretty body) + <> ":" + indent 2 (stack $ map pretty body) pretty (With what body) = "with" <+> pretty what - <> ":" - indent 2 (stack $ map pretty body) + <> ":" + indent 2 (stack $ map pretty body) pretty (Assign e1 e2) = pretty e1 <+> "=" <+> pretty e2 pretty (AssignOp op e1 e2) = pretty e1 <+> pretty (op ++ "=") <+> pretty e2 pretty (Comment s body) = "#" <> pretty s stack (map pretty body) @@ -188,16 +188,16 @@ instance Pretty PyFunDef where pretty (Def fname params body) = "def" <+> pretty fname - <> parens (commasep $ map pretty params) - <> ":" - indent 2 (stack (map pretty body)) + <> parens (commasep $ map pretty params) + <> ":" + indent 2 (stack (map pretty body)) instance Pretty PyClassDef where pretty (Class cname body) = "class" <+> pretty cname - <> ":" - indent 2 (stack (map pretty body)) + <> ":" + indent 2 (stack (map pretty body)) instance Pretty PyExcept where pretty (Catch pyexp stms) = diff --git a/src/Futhark/CodeGen/ImpCode.hs b/src/Futhark/CodeGen/ImpCode.hs index 125cfa349c..b5655de4ab 100644 --- a/src/Futhark/CodeGen/ImpCode.hs +++ b/src/Futhark/CodeGen/ImpCode.hs @@ -581,8 +581,8 @@ instance (Pretty op) => Pretty (Code op) where <> pretty space <> rangle <> brackets (pretty i) - <+> "<-" - <+> pretty val + <+> "<-" + <+> pretty val where vol' = case vol of Volatile -> "volatile " @@ -591,12 +591,12 @@ instance (Pretty op) => Pretty (Code op) where pretty name <+> "<-" <+> pretty v - <> langle - <> vol' - <> pretty bt - <> pretty space - <> rangle - <> brackets (pretty is) + <> langle + <> vol' + <> pretty bt + <> pretty space + <> rangle + <> brackets (pretty is) where vol' = case vol of Volatile -> "volatile " @@ -614,9 +614,9 @@ instance (Pretty op) => Pretty (Code op) where <> (parens . align) ( foldMap (brackets . pretty) shape <> "," - p dst dstspace dstoffset dststrides - <> "," - p src srcspace srcoffset srcstrides + p dst dstspace dstoffset dststrides + <> "," + p src srcspace srcoffset srcstrides ) where p mem space offset strides = @@ -624,7 +624,7 @@ instance (Pretty op) => Pretty (Code op) where <> pretty space <> "+" <> pretty offset - <+> foldMap (brackets . pretty) strides + <+> foldMap (brackets . pretty) strides pretty (If cond tbranch fbranch) = "if" <+> pretty cond @@ -642,7 +642,7 @@ instance (Pretty op) => Pretty (Code op) where <+> commasep (map pretty dests) <+> "<-" <+> pretty fname - <> parens (commasep $ map pretty args) + <> parens (commasep $ map pretty args) pretty (Comment s code) = "--" <+> pretty s pretty code pretty (DebugPrint desc (Just e)) = diff --git a/src/Futhark/CodeGen/ImpCode/GPU.hs b/src/Futhark/CodeGen/ImpCode/GPU.hs index f303ac0251..1e79504f1a 100644 --- a/src/Futhark/CodeGen/ImpCode/GPU.hs +++ b/src/Futhark/CodeGen/ImpCode/GPU.hs @@ -107,16 +107,16 @@ instance Pretty HostOp where pretty dest <+> "<-" <+> "get_size" - <> parens (commasep [pretty key, pretty size_class]) + <> parens (commasep [pretty key, pretty size_class]) pretty (GetSizeMax dest size_class) = pretty dest <+> "<-" <+> "get_size_max" <> parens (pretty size_class) pretty (CmpSizeLe dest name size_class x) = pretty dest <+> "<-" <+> "get_size" - <> parens (commasep [pretty name, pretty size_class]) - <+> "<" - <+> pretty x + <> parens (commasep [pretty name, pretty size_class]) + <+> "<" + <+> pretty x pretty (CallKernel c) = pretty c @@ -223,17 +223,17 @@ instance Pretty KernelOp where pretty dest <+> "<-" <+> "get_tblock_id" - <> parens (pretty i) + <> parens (pretty i) pretty (GetLocalId dest i) = pretty dest <+> "<-" <+> "get_local_id" - <> parens (pretty i) + <> parens (pretty i) pretty (GetLocalSize dest i) = pretty dest <+> "<-" <+> "get_local_size" - <> parens (pretty i) + <> parens (pretty i) pretty (GetLockstepWidth dest) = pretty dest <+> "<-" @@ -256,68 +256,68 @@ instance Pretty KernelOp where pretty old <+> "<-" <+> "atomic_add_" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicFAdd t old arr ind x)) = pretty old <+> "<-" <+> "atomic_fadd_" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicSMax t old arr ind x)) = pretty old <+> "<-" <+> "atomic_smax" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicSMin t old arr ind x)) = pretty old <+> "<-" <+> "atomic_smin" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicUMax t old arr ind x)) = pretty old <+> "<-" <+> "atomic_umax" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicUMin t old arr ind x)) = pretty old <+> "<-" <+> "atomic_umin" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicAnd t old arr ind x)) = pretty old <+> "<-" <+> "atomic_and" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicOr t old arr ind x)) = pretty old <+> "<-" <+> "atomic_or" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicXor t old arr ind x)) = pretty old <+> "<-" <+> "atomic_xor" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicCmpXchg t old arr ind x y)) = pretty old <+> "<-" <+> "atomic_cmp_xchg" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x, pretty y]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x, pretty y]) pretty (Atomic _ (AtomicXchg t old arr ind x)) = pretty old <+> "<-" <+> "atomic_xchg" - <> pretty t - <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) + <> pretty t + <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicWrite t arr ind x)) = "atomic_write" <> pretty t diff --git a/src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs b/src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs index 756d6d9f49..8b5f30dddf 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs @@ -449,7 +449,7 @@ histKernelGlobalPass map_pes num_tblocks tblock_size space slugs kbody histogram chk_beg .<=. flat_bucket .&&. flat_bucket - .<. (chk_beg + hist_H_chk) + .<. (chk_beg + hist_H_chk) .&&. inBounds (Slice (map DimFix bucket')) dest_shape' vs_params = takeLast (length vs') $ lambdaParams lam @@ -748,9 +748,9 @@ histKernelLocalPass bucket_in_bounds = inBounds (Slice (map DimFix bucket')) dest_shape' .&&. chk_beg - .<=. flat_bucket + .<=. flat_bucket .&&. flat_bucket - .<. (chk_beg + tvExp hist_H_chk) + .<. (chk_beg + tvExp hist_H_chk) bucket_is = [sExt64 thread_local_subhisto_i, flat_bucket - chk_beg] vs_params = takeLast (length vs') $ lambdaParams lam @@ -1021,9 +1021,9 @@ localMemoryCase map_pes hist_T space hist_H hist_el_size hist_N _ slugs kbody = .&&. (local_mem_needed .<=. tvExp hist_L) .&&. (hist_S .<=. max_S) .&&. hist_C - .<=. hist_B + .<=. hist_B .&&. tvExp hist_M - .>. 0 + .>. 0 run = do emit $ Imp.DebugPrint "## Using shared memory" Nothing diff --git a/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs b/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs index c8406f7110..2d731e45ec 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs @@ -434,8 +434,8 @@ smallSegmentsReduction (Pat segred_pes) num_tblocks tblock_size _ space segbinop .>. 0 .&&. isActive (init $ zip gtids dims) .&&. ltid - .<. segment_size - * segments_per_block + .<. segment_size + * segments_per_block ) in_bounds out_of_bounds @@ -460,8 +460,8 @@ smallSegmentsReduction (Pat segred_pes) num_tblocks tblock_size _ space segbinop ( sExt64 virttblock_id * segments_per_block + sExt64 ltid - .<. num_segments - .&&. ltid + .<. num_segments + .&&. ltid .<. segments_per_block ) $ forM2_ segred_pes (concat reds_arrs) diff --git a/src/Futhark/IR/GPU/Op.hs b/src/Futhark/IR/GPU/Op.hs index a58a6134bc..6173318af6 100644 --- a/src/Futhark/IR/GPU/Op.hs +++ b/src/Futhark/IR/GPU/Op.hs @@ -114,8 +114,8 @@ instance PP.Pretty KernelGrid where "grid=" <> pretty num_tblocks <> PP.semi - <+> "blocksize=" - <> pretty tblock_size + <+> "blocksize=" + <> pretty tblock_size instance PP.Pretty SegLevel where pretty (SegThread virt grid) = @@ -210,8 +210,8 @@ instance PP.Pretty SizeOp where pretty (CmpSizeLe name size_class x) = "cmp_size" <> parens (commasep [pretty name, pretty size_class]) - <+> "<=" - <+> pretty x + <+> "<=" + <+> pretty x pretty (CalcNumBlocks w max_num_tblocks tblock_size) = "calc_num_tblocks" <> parens (commasep [pretty w, pretty max_num_tblocks, pretty tblock_size]) diff --git a/src/Futhark/IR/Pretty.hs b/src/Futhark/IR/Pretty.hs index 609e18eedb..9871993110 100644 --- a/src/Futhark/IR/Pretty.hs +++ b/src/Futhark/IR/Pretty.hs @@ -197,12 +197,12 @@ instance Pretty BasicOp where _ -> brackets $ commasep $ map pretty es <+> colon <+> "[]" - <> pretty rt + <> pretty rt pretty (ArrayVal vs t) = brackets (commasep $ map pretty vs) <+> colon <+> "[]" - <> pretty t + <> pretty t pretty (BinOp bop x y) = pretty bop <> parens (pretty x <> comma <+> pretty y) pretty (CmpOp op x y) = pretty op <> parens (pretty x <> comma <+> pretty y) pretty (ConvOp conv x) = @@ -280,13 +280,13 @@ instance (PrettyRep rep) => Pretty (Exp rep) where pretty (Match [c] [Case [Just (BoolValue True)] t] f (MatchDec ret ifsort)) = "if" <> info' - <+> pretty c - "then" - <+> maybeNest t - <+> "else" - <+> maybeNest f - colon - <+> ppTupleLines' (map pretty ret) + <+> pretty c + "then" + <+> maybeNest t + <+> "else" + <+> maybeNest f + colon + <+> ppTupleLines' (map pretty ret) where info' = case ifsort of MatchNormal -> mempty @@ -309,9 +309,9 @@ instance (PrettyRep rep) => Pretty (Exp rep) where pretty (Apply fname args ret (safety, _, _)) = applykw <+> pretty (nameToString fname) - <> apply (map (align . prettyArg) args) - colon - <+> braces (commasep $ map prettyRet ret) + <> apply (map (align . prettyArg) args) + colon + <+> braces (commasep $ map prettyRet ret) where prettyArg (arg, Consume) = "*" <> pretty arg prettyArg (arg, _) = pretty arg @@ -331,8 +331,8 @@ instance (PrettyRep rep) => Pretty (Exp rep) where ( pretty i <> ":" <> pretty it - <+> "<" - <+> align (pretty bound) + <+> "<" + <+> align (pretty bound) ) WhileLoop cond -> "while" <+> pretty cond @@ -349,11 +349,11 @@ instance (PrettyRep rep) => Pretty (Exp rep) where parens ( pretty shape <> comma - <+> ppTuple' (map pretty arrs) - <> case op of - Nothing -> mempty - Just (op', nes) -> - comma parens (pretty op' <> comma ppTuple' (map pretty nes)) + <+> ppTuple' (map pretty arrs) + <> case op of + Nothing -> mempty + Just (op', nes) -> + comma parens (pretty op' <> comma ppTuple' (map pretty nes)) ) instance (PrettyRep rep) => Pretty (Lambda rep) where @@ -401,9 +401,9 @@ instance (PrettyRep rep) => Pretty (FunDef rep) where <> pretty p_name <> "\"" <> comma - ppTupleLines' (map pretty p_entry) - <> comma - ppTupleLines' (map pretty ret_entry) + ppTupleLines' (map pretty p_entry) + <> comma + ppTupleLines' (map pretty ret_entry) ) instance Pretty OpaqueType where @@ -418,10 +418,11 @@ instance Pretty OpaqueType where where p (c, ets) = hsep $ "#" <> pretty c : map pretty ets pretty (OpaqueArray r v ts) = - "array" <+> pretty r - <> "d" - <+> dquotes (pretty v) - <+> nestedBlock "{" "}" (stack $ map pretty ts) + "array" + <+> pretty r + <> "d" + <+> dquotes (pretty v) + <+> nestedBlock "{" "}" (stack $ map pretty ts) pretty (OpaqueRecordArray r v fs) = "record_array" <+> pretty r <> "d" <+> dquotes (pretty v) <+> nestedBlock "{" "}" (stack $ map p fs) where diff --git a/src/Futhark/IR/Prop/Types.hs b/src/Futhark/IR/Prop/Types.hs index 510cc6e0fc..1c602ba8c5 100644 --- a/src/Futhark/IR/Prop/Types.hs +++ b/src/Futhark/IR/Prop/Types.hs @@ -384,9 +384,9 @@ subtypeOf (Array t1 shape1 u1) (Array t2 shape2 u2) = u2 <= u1 && t1 - == t2 + == t2 && shape1 - `subShapeOf` shape2 + `subShapeOf` shape2 subtypeOf t1 t2 = t1 == t2 -- | @xs \`subtypesOf\` ys@ is true if @xs@ is the same size as @ys@, diff --git a/src/Futhark/IR/SOACS/SOAC.hs b/src/Futhark/IR/SOACS/SOAC.hs index eea5fdab79..8c5f89a576 100644 --- a/src/Futhark/IR/SOACS/SOAC.hs +++ b/src/Futhark/IR/SOACS/SOAC.hs @@ -941,9 +941,9 @@ instance (PrettyRep rep) => PP.Pretty (SOAC rep) where ( PP.align $ pretty lam <> comma - PP.braces (commasep $ map pretty args) - <> comma - PP.braces (commasep $ map pretty vec) + PP.braces (commasep $ map pretty args) + <> comma + PP.braces (commasep $ map pretty vec) ) pretty (JVP lam args vec) = "jvp" @@ -951,9 +951,9 @@ instance (PrettyRep rep) => PP.Pretty (SOAC rep) where ( PP.align $ pretty lam <> comma - PP.braces (commasep $ map pretty args) - <> comma - PP.braces (commasep $ map pretty vec) + PP.braces (commasep $ map pretty args) + <> comma + PP.braces (commasep $ map pretty vec) ) pretty (Stream size arrs acc lam) = ppStream size arrs acc lam @@ -968,31 +968,31 @@ instance (PrettyRep rep) => PP.Pretty (SOAC rep) where <> (parens . align) ( pretty w <> comma - ppTuple' (map pretty arrs) - <> comma - pretty map_lam + ppTuple' (map pretty arrs) + <> comma + pretty map_lam ) | null scans = "redomap" <> (parens . align) ( pretty w <> comma - ppTuple' (map pretty arrs) - <> comma - PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) - <> comma - pretty map_lam + ppTuple' (map pretty arrs) + <> comma + PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) + <> comma + pretty map_lam ) | null reds = "scanomap" <> (parens . align) ( pretty w <> comma - ppTuple' (map pretty arrs) - <> comma - PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) - <> comma - pretty map_lam + ppTuple' (map pretty arrs) + <> comma + PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) + <> comma + pretty map_lam ) pretty (Screma w arrs form) = ppScrema w arrs form @@ -1004,13 +1004,13 @@ ppScrema w arrs (ScremaForm scans reds map_lam) = <> (parens . align) ( pretty w <> comma - ppTuple' (map pretty arrs) - <> comma - PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) - <> comma - PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) - <> comma - pretty map_lam + ppTuple' (map pretty arrs) + <> comma + PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) + <> comma + PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) + <> comma + pretty map_lam ) -- | Prettyprint the given Stream. @@ -1021,11 +1021,11 @@ ppStream size arrs acc lam = <> (parens . align) ( pretty size <> comma - ppTuple' (map pretty arrs) - <> comma - ppTuple' (map pretty acc) - <> comma - pretty lam + ppTuple' (map pretty arrs) + <> comma + ppTuple' (map pretty acc) + <> comma + pretty lam ) -- | Prettyprint the given Scatter. @@ -1036,11 +1036,11 @@ ppScatter w arrs dests lam = <> (parens . align) ( pretty w <> comma - ppTuple' (map pretty arrs) - <> comma - commasep (map pretty dests) - <> comma - pretty lam + ppTuple' (map pretty arrs) + <> comma + commasep (map pretty dests) + <> comma + pretty lam ) instance (PrettyRep rep) => Pretty (Scan rep) where @@ -1056,7 +1056,7 @@ instance (PrettyRep rep) => Pretty (Reduce rep) where ppComm comm <> pretty red_lam <> comma - PP.braces (commasep $ map pretty red_nes) + PP.braces (commasep $ map pretty red_nes) -- | Prettyprint the given histogram operation. ppHist :: @@ -1071,20 +1071,20 @@ ppHist w arrs ops bucket_fun = <> parens ( pretty w <> comma - ppTuple' (map pretty arrs) - <> comma - PP.braces (mconcat $ intersperse (comma <> PP.line) $ map ppOp ops) - <> comma - pretty bucket_fun + ppTuple' (map pretty arrs) + <> comma + PP.braces (mconcat $ intersperse (comma <> PP.line) $ map ppOp ops) + <> comma + pretty bucket_fun ) where ppOp (HistOp dest_w rf dests nes op) = pretty dest_w <> comma - <+> pretty rf - <> comma - <+> PP.braces (commasep $ map pretty dests) - <> comma - ppTuple' (map pretty nes) - <> comma - pretty op + <+> pretty rf + <> comma + <+> PP.braces (commasep $ map pretty dests) + <> comma + ppTuple' (map pretty nes) + <> comma + pretty op diff --git a/src/Futhark/IR/SegOp.hs b/src/Futhark/IR/SegOp.hs index ad5a98f14d..9a3ec128c9 100644 --- a/src/Futhark/IR/SegOp.hs +++ b/src/Futhark/IR/SegOp.hs @@ -870,10 +870,10 @@ instance (PrettyRep rep) => Pretty (SegBinOp rep) where pretty (SegBinOp comm lam nes shape) = PP.braces (PP.commasep $ map pretty nes) <> PP.comma - pretty shape - <> PP.comma - comm' - <> pretty lam + pretty shape + <> PP.comma + comm' + <> pretty lam where comm' = case comm of Commutative -> "commutative " @@ -883,47 +883,47 @@ instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where pretty (SegMap lvl space ts body) = "segmap" <> pretty lvl - PP.align (pretty space) - <+> PP.colon - <+> ppTuple' (map pretty ts) - <+> PP.nestedBlock "{" "}" (pretty body) + PP.align (pretty space) + <+> PP.colon + <+> ppTuple' (map pretty ts) + <+> PP.nestedBlock "{" "}" (pretty body) pretty (SegRed lvl space reds ts body) = "segred" <> pretty lvl - PP.align (pretty space) - PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty reds) - PP.colon - <+> ppTuple' (map pretty ts) - <+> PP.nestedBlock "{" "}" (pretty body) + PP.align (pretty space) + PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty reds) + PP.colon + <+> ppTuple' (map pretty ts) + <+> PP.nestedBlock "{" "}" (pretty body) pretty (SegScan lvl space scans ts body) = "segscan" <> pretty lvl - PP.align (pretty space) - PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty scans) - PP.colon - <+> ppTuple' (map pretty ts) - <+> PP.nestedBlock "{" "}" (pretty body) + PP.align (pretty space) + PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty scans) + PP.colon + <+> ppTuple' (map pretty ts) + <+> PP.nestedBlock "{" "}" (pretty body) pretty (SegHist lvl space ops ts body) = "seghist" <> pretty lvl - PP.align (pretty space) - PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops) - PP.colon - <+> ppTuple' (map pretty ts) - <+> PP.nestedBlock "{" "}" (pretty body) + PP.align (pretty space) + PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops) + PP.colon + <+> ppTuple' (map pretty ts) + <+> PP.nestedBlock "{" "}" (pretty body) where ppOp (HistOp w rf dests nes shape op) = pretty w <> PP.comma - <+> pretty rf - <> PP.comma - PP.braces (PP.commasep $ map pretty dests) - <> PP.comma - PP.braces (PP.commasep $ map pretty nes) - <> PP.comma - pretty shape - <> PP.comma - pretty op + <+> pretty rf + <> PP.comma + PP.braces (PP.commasep $ map pretty dests) + <> PP.comma + PP.braces (PP.commasep $ map pretty nes) + <> PP.comma + pretty shape + <> PP.comma + pretty op instance CanBeAliased (SegOp lvl) where addOpAliases aliases = runIdentity . mapSegOpM alias diff --git a/src/Futhark/IR/Syntax.hs b/src/Futhark/IR/Syntax.hs index 761e4b4813..0e009616ed 100644 --- a/src/Futhark/IR/Syntax.hs +++ b/src/Futhark/IR/Syntax.hs @@ -466,12 +466,12 @@ deriving instance (RepTypes rep) => Ord (Exp rep) -- | For-loop or while-loop? data LoopForm = ForLoop - -- | The loop iterator var VName - -- | The type of the loop iterator var + -- ^ The loop iterator var IntType - -- | The number of iterations. + -- ^ The type of the loop iterator var SubExp + -- ^ The number of iterations. | WhileLoop VName deriving (Eq, Ord, Show) diff --git a/src/Futhark/IR/Syntax/Core.hs b/src/Futhark/IR/Syntax/Core.hs index fb9603de50..60b963572c 100644 --- a/src/Futhark/IR/Syntax/Core.hs +++ b/src/Futhark/IR/Syntax/Core.hs @@ -425,10 +425,10 @@ sliceSlice (Slice jslice) (Slice islice) = Slice $ sliceSlice' jslice islice -- | A dimension in a 'FlatSlice'. data FlatDimIndex d = FlatDimIndex - -- | Number of elements in dimension d - -- | Stride of dimension + -- ^ Number of elements in dimension d + -- ^ Stride of dimension deriving (Eq, Ord, Show) instance Traversable FlatDimIndex where diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index d463df93b3..5f9092d9bc 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1729,11 +1729,12 @@ isIntrinsicFunction qname args loc = do old_dim <- I.arraysSize 0 <$> mapM lookupType arrs dim_ok <- letSubExp "dim_ok" <=< toExp $ - pe64 old_dim .==. pe64 n' - * pe64 m' - .&&. pe64 n' + pe64 old_dim + .==. pe64 n' + * pe64 m' + .&&. pe64 n' .>=. 0 - .&&. pe64 m' + .&&. pe64 m' .>=. 0 dim_ok_cert <- assert diff --git a/src/Futhark/Optimise/ArrayShortCircuiting/DataStructs.hs b/src/Futhark/Optimise/ArrayShortCircuiting/DataStructs.hs index adf7c10f10..09d65ab90a 100644 --- a/src/Futhark/Optimise/ArrayShortCircuiting/DataStructs.hs +++ b/src/Futhark/Optimise/ArrayShortCircuiting/DataStructs.hs @@ -119,14 +119,14 @@ type FreeVarSubsts = M.Map VName (TPrimExp Int64 VName) -- | Coalesced Access Entry data Coalesced = Coalesced - -- | the kind of coalescing CoalescedKind - -- | destination mem_block info @f_m_x[i]@ (must be ArrayMem) + -- ^ the kind of coalescing + ArrayMemBound + -- ^ destination mem_block info @f_m_x[i]@ (must be ArrayMem) -- (Maybe IxFun) -- the inverse ixfun of a coalesced array, such that -- -- ixfuns can be correctly constructed for aliases; - ArrayMemBound - -- | substitutions for free vars in index function FreeVarSubsts + -- ^ substitutions for free vars in index function data CoalsEntry = CoalsEntry { -- | destination memory block @@ -223,26 +223,26 @@ instance Pretty Coalesced where pretty (Coalesced knd mbd _) = "(Kind:" <+> pretty knd - <> ", membds:" - <+> pretty mbd -- <> ", subs:" <+> pretty subs - <> ")" - <+> "\n" + <> ", membds:" + <+> pretty mbd -- <> ", subs:" <+> pretty subs + <> ")" + <+> "\n" instance Pretty CoalsEntry where pretty etry = "{" <+> "Dstmem:" <+> pretty (dstmem etry) - <> ", AliasMems:" - <+> pretty (alsmem etry) - <+> ", optdeps:" - <+> pretty (M.toList $ optdeps etry) - <+> ", memrefs:" - <+> pretty (memrefs etry) - <+> ", vartab:" - <+> pretty (M.toList $ vartab etry) - <+> "}" - <+> "\n" + <> ", AliasMems:" + <+> pretty (alsmem etry) + <+> ", optdeps:" + <+> pretty (M.toList $ optdeps etry) + <+> ", memrefs:" + <+> pretty (memrefs etry) + <+> ", vartab:" + <+> pretty (M.toList $ vartab etry) + <+> "}" + <+> "\n" -- | Compute the union of two 'CoalsEntry'. If two 'CoalsEntry' do not refer to -- the same destination memory and use the same index function, the first diff --git a/src/Futhark/Optimise/BlkRegTiling.hs b/src/Futhark/Optimise/BlkRegTiling.hs index 39b7bb0d1a..07e18feb84 100644 --- a/src/Futhark/Optimise/BlkRegTiling.hs +++ b/src/Futhark/Optimise/BlkRegTiling.hs @@ -147,12 +147,12 @@ kkLoopBody + le64 i + pe64 ry * le64 ltid_y - .<. pe64 height_A - .&&. le64 jjj - + le64 j - + pe64 rx - * le64 ltid_x - .<. pe64 width_B + .<. pe64 height_A + .&&. le64 jjj + + le64 j + + pe64 rx + * le64 ltid_x + .<. pe64 width_B ) ( do a <- a_idx_fn ltid_y i @@ -373,7 +373,7 @@ mmBlkRegTilingAcc env (Let pat aux (Op (SegOp (SegMap SegThread {} seg_space ts le64 full_tiles .==. pe64 rk .||. pe64 common_dim - .==. (pe64 tk * le64 full_tiles + le64 ttt) + .==. (pe64 tk * le64 full_tiles + le64 ttt) ) (resultBodyM $ map Var prologue_res_list) ( do @@ -459,7 +459,7 @@ mmBlkRegTilingAcc env (Let pat aux (Op (SegOp (SegMap SegThread {} seg_space ts le64 gtid_y .<. pe64 height_A .&&. le64 gtid_x - .<. pe64 width_B + .<. pe64 width_B ) ( do addStms code2_subs @@ -609,7 +609,7 @@ mmBlkRegTilingNrm env (Let pat aux (Op (SegOp (SegMap SegThread {} seg_space ts le64 gtid_y .<. pe64 height_A .&&. le64 gtid_x - .<. pe64 width_B + .<. pe64 width_B ) ( do addStms code2' @@ -1257,9 +1257,9 @@ doRegTiling3D (Let pat aux (Op (SegOp old_kernel))) le64 gtid_y .<. pe64 d_Ky .&&. le64 gtid_x - .<. pe64 d_Kx + .<. pe64 d_Kx .&&. le64 gtid_z - .<. pe64 d_M + .<. pe64 d_M ) ( do addStms code2' diff --git a/src/Futhark/Pass/ExtractKernels.hs b/src/Futhark/Pass/ExtractKernels.hs index 40b3e1b43f..17b44873c0 100644 --- a/src/Futhark/Pass/ExtractKernels.hs +++ b/src/Futhark/Pass/ExtractKernels.hs @@ -704,7 +704,7 @@ mayExploitOuter attrs = AttrComp "incremental_flattening" ["no_outer"] `inAttrs` attrs || AttrComp "incremental_flattening" ["only_inner"] - `inAttrs` attrs + `inAttrs` attrs mayExploitIntra :: Attrs -> Bool mayExploitIntra attrs = @@ -712,7 +712,7 @@ mayExploitIntra attrs = AttrComp "incremental_flattening" ["no_intra"] `inAttrs` attrs || AttrComp "incremental_flattening" ["only_inner"] - `inAttrs` attrs + `inAttrs` attrs -- The minimum amount of inner parallelism we require (by default) in -- intra-group versions. Less than this is usually pointless on a GPU diff --git a/src/Futhark/Pass/ExtractKernels/DistributeNests.hs b/src/Futhark/Pass/ExtractKernels/DistributeNests.hs index a805e42ba1..6cbf3dfbdf 100644 --- a/src/Futhark/Pass/ExtractKernels/DistributeNests.hs +++ b/src/Futhark/Pass/ExtractKernels/DistributeNests.hs @@ -662,7 +662,8 @@ distributeSingleUnaryStm acc stm stm_arr f = | map resSubExp res == map Var (patNames $ stmPat stm), (outer, _) <- nest, [(_, arr)] <- loopNestingParamsAndArrs outer, - boundInKernelNest nest `namesIntersection` freeIn stm + boundInKernelNest nest + `namesIntersection` freeIn stm == oneName stm_arr, perfectlyMapped arr nest -> do addPostStms kernels diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index ae01db51ed..7d8f041c6e 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -263,8 +263,8 @@ type Value = Language.Futhark.Interpreter.Values.Value EvalM -- TODO: Could be better - perhaps convert to interpreter type first asInteger :: Value -> Integer asInteger (ValueAD d v) = case AD.primitive $ AD.primal $ AD.Variable d v of - P.IntValue v' -> P.valueIntegral v' - _ -> error $ "Unexpectedly not an integer: " <> show v + P.IntValue v' -> P.valueIntegral v' + _ -> error $ "Unexpectedly not an integer: " <> show v asInteger (ValuePrim (SignedValue v)) = P.valueIntegral v asInteger (ValuePrim (UnsignedValue v)) = toInteger (P.valueIntegral (P.doZExt v Int64) :: Word64) @@ -276,8 +276,8 @@ asInt = fromIntegral . asInteger -- TODO: Could be better - perhaps convert to interpreter type first asSigned :: Value -> IntValue asSigned (ValueAD d v) = case AD.primitive $ AD.primal $ AD.Variable d v of - P.IntValue v' -> v' - _ -> error $ "Unexpectedly not a signed integer: " <> show v + P.IntValue v' -> v' + _ -> error $ "Unexpectedly not a signed integer: " <> show v asSigned (ValuePrim (SignedValue v)) = v asSigned v = error $ "Unexpectedly not a signed integer: " <> show v @@ -287,8 +287,8 @@ asInt64 = fromIntegral . asInteger -- TODO: Could be better - perhaps convert to interpreter type first asBool :: Value -> Bool asBool (ValueAD d v) = case AD.primitive $ AD.primal $ AD.Variable d v of - P.BoolValue v' -> v' - _ -> error $ "Unexpectedly not a boolean: " <> show v + P.BoolValue v' -> v' + _ -> error $ "Unexpectedly not a boolean: " <> show v asBool (ValuePrim (BoolValue x)) = x asBool v = error $ "Unexpectedly not a boolean: " <> show v @@ -1291,13 +1291,13 @@ initialCtx = types = M.mapMaybeWithKey (const . tdef . baseString) intrinsics sintOp f = - [ (getS, putS, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), + [ (getS, putS, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), (getS, putS, P.doBinOp (f Int16), adBinOp $ AD.OpBin (f Int16)), (getS, putS, P.doBinOp (f Int32), adBinOp $ AD.OpBin (f Int32)), (getS, putS, P.doBinOp (f Int64), adBinOp $ AD.OpBin (f Int64)) ] uintOp f = - [ (getU, putU, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), + [ (getU, putU, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), (getU, putU, P.doBinOp (f Int16), adBinOp $ AD.OpBin (f Int16)), (getU, putU, P.doBinOp (f Int32), adBinOp $ AD.OpBin (f Int32)), (getU, putU, P.doBinOp (f Int64), adBinOp $ AD.OpBin (f Int64)) @@ -1312,13 +1312,13 @@ initialCtx = flipCmps = map (\(f, g, h, o) -> (f, g, flip h, flip o)) sintCmp f = - [ (getS, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), + [ (getS, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), (getS, Just . BoolValue, P.doCmpOp (f Int16), adBinOp $ AD.OpCmp (f Int16)), (getS, Just . BoolValue, P.doCmpOp (f Int32), adBinOp $ AD.OpCmp (f Int32)), (getS, Just . BoolValue, P.doCmpOp (f Int64), adBinOp $ AD.OpCmp (f Int64)) ] uintCmp f = - [ (getU, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), + [ (getU, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), (getU, Just . BoolValue, P.doCmpOp (f Int16), adBinOp $ AD.OpCmp (f Int16)), (getU, Just . BoolValue, P.doCmpOp (f Int32), adBinOp $ AD.OpCmp (f Int32)), (getU, Just . BoolValue, P.doCmpOp (f Int64), adBinOp $ AD.OpCmp (f Int64)) @@ -1446,7 +1446,7 @@ initialCtx = <+> dquotes (prettyValue x) <+> "and" <+> dquotes (prettyValue y) - <> "." + <> "." where bopDef' (valf, retf, op, _) (x, y) = do x' <- valf x @@ -1469,7 +1469,7 @@ initialCtx = bad noLoc mempty . docText $ "Cannot apply function to argument" <+> dquotes (prettyValue x) - <> "." + <> "." where unopDef' (valf, retf, op, _) x = do x' <- valf x @@ -1494,20 +1494,20 @@ initialCtx = bad noLoc mempty . docText $ "Cannot apply operator to argument" <+> dquotes (prettyValue v) - <> "." + <> "." def "!" = Just $ unopDef - [ (getS, putS, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), + [ (getS, putS, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), (getS, putS, P.doUnOp $ P.Complement Int16, adUnOp $ AD.OpUn $ P.Complement Int16), (getS, putS, P.doUnOp $ P.Complement Int32, adUnOp $ AD.OpUn $ P.Complement Int32), (getS, putS, P.doUnOp $ P.Complement Int64, adUnOp $ AD.OpUn $ P.Complement Int64), - (getU, putU, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), + (getU, putU, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), (getU, putU, P.doUnOp $ P.Complement Int16, adUnOp $ AD.OpUn $ P.Complement Int16), (getU, putU, P.doUnOp $ P.Complement Int32, adUnOp $ AD.OpUn $ P.Complement Int32), (getU, putU, P.doUnOp $ P.Complement Int64, adUnOp $ AD.OpUn $ P.Complement Int64), - (getB, putB, P.doUnOp P.Not, adUnOp $ AD.OpUn P.Not) + (getB, putB, P.doUnOp P.Not, adUnOp $ AD.OpUn P.Not) ] def "+" = arithOp (`P.Add` P.OverflowWrap) P.FAdd def "-" = arithOp (`P.Sub` P.OverflowWrap) P.FSub @@ -1796,13 +1796,14 @@ initialCtx = -- TODO: The above code is identical. Share the code ( ValueAcc shape op acc_arr, adv@(ValueAD {}) - ) | Just (SignedValue (Int64Value i')) <- putV . AD.primitive <$> getAD adv -> - if i' >= 0 && i' < arrayLength acc_arr - then do - let x = acc_arr ! fromIntegral i' - res <- op x v - pure $ ValueAcc shape op $ acc_arr // [(fromIntegral i', res)] - else pure acc + ) + | Just (SignedValue (Int64Value i')) <- putV . AD.primitive <$> getAD adv -> + if i' >= 0 && i' < arrayLength acc_arr + then do + let x = acc_arr ! fromIntegral i' + res <- op x v + pure $ ValueAcc shape op $ acc_arr // [(fromIntegral i', res)] + else pure acc _ -> error $ "acc_write invalid arguments: " <> prettyString (show acc, show i, show v) -- @@ -1987,49 +1988,64 @@ initialCtx = fun3 $ \f v s -> do -- Get the depth depth <- length <$> stacktrace - + -- Augment the values - let v' = expectJust ("vjp: invalid values " ++ show v) $ - modifyValueM (\i lv -> ValueAD depth . AD.VJP . AD.VJPValue . AD.TapeID i <$> getAD lv) v + let v' = + expectJust ("vjp: invalid values " ++ show v) $ + modifyValueM (\i lv -> ValueAD depth . AD.VJP . AD.VJPValue . AD.TapeID i <$> getAD lv) v -- Turn the seeds into a list of ADValues - let s' = expectJust ("vjp: invalid seeds " ++ show s) $ - mapM getAD $ fst $ valueAccum (\a b -> (b : a, b)) [] s + let s' = + expectJust ("vjp: invalid seeds " ++ show s) $ + mapM getAD $ + fst $ + valueAccum (\a b -> (b : a, b)) [] s -- Run the function, and turn its outputs into a list of Values o <- apply noLoc mempty f v' let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o -- For each output.. - let m = expectJust "vjp: differentiation failed" $ zipWithM (\on sn -> case on of - -- If it is a VJP variable of the correct depth, run deriveTape on it- and its corresponding seed - (ValueAD d (AD.VJP (AD.VJPValue t))) | d == depth -> (putAD $ AD.tapePrimal t,) <$> AD.deriveTape t sn - -- Otherwise, its partial derivatives are all 0 - _ -> Just (on, M.empty) - ) o' s' + let m = + expectJust "vjp: differentiation failed" $ + zipWithM + ( \on sn -> case on of + -- If it is a VJP variable of the correct depth, run deriveTape on it- and its corresponding seed + (ValueAD d (AD.VJP (AD.VJPValue t))) | d == depth -> (putAD $ AD.tapePrimal t,) <$> AD.deriveTape t sn + -- Otherwise, its partial derivatives are all 0 + _ -> Just (on, M.empty) + ) + o' + s' -- Add together every derivative let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m -- Extract the output values, and the partial derivatives let ov = modifyValue (\i _ -> fst $ m !! i) o - let od = expectJust "vjp: differentiation failed" $ - modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v - + let od = + expectJust "vjp: differentiation failed" $ + modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v + -- Return a tuple of the output values, and partial derivatives pure $ toTuple [ov, od] where modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v - modifyValueM f v = snd <$> valueAccumLM (\a b -> do - b' <- f a b - pure (a + 1, b')) 0 v + modifyValueM f v = + snd + <$> valueAccumLM + ( \a b -> do + b' <- f a b + pure (a + 1, b') + ) + 0 + v expectJust _ (Just v) = v expectJust s Nothing = error s - + -- TODO: Perhaps this could be fully abstracted by AD? -- Making addFor private would be nice.. add x y = expectJust "TODO" $ AD.doOp (AD.OpBin $ AD.addFor $ P.primValueType $ AD.primitive x) [x, y] - def "jvp2" = Just $ -- TODO: This could be much better. Currently, it is very inefficient -- Perhaps creating JVPValues could be abstracted into a function @@ -2037,39 +2053,56 @@ initialCtx = fun3 $ \f v s -> do -- Get the depth depth <- length <$> stacktrace - + -- Turn the seeds into a list of ADValues - let s' = expectJust ("jvp: invalid seeds " ++ show s) $ - mapM getAD $ fst $ valueAccum (\a b -> (b : a, b)) [] s + let s' = + expectJust ("jvp: invalid seeds " ++ show s) $ + mapM getAD $ + fst $ + valueAccum (\a b -> (b : a, b)) [] s -- Augment the values - let v' = expectJust ("jvp: invalid values " ++ show v) $ - modifyValueM (\i lv -> do - lv' <- getAD lv - pure $ ValueAD depth . AD.JVP . AD.JVPValue lv' $ s' !! (length s' - 1 - i)) v + let v' = + expectJust ("jvp: invalid values " ++ show v) $ + modifyValueM + ( \i lv -> do + lv' <- getAD lv + pure $ ValueAD depth . AD.JVP . AD.JVPValue lv' $ s' !! (length s' - 1 - i) + ) + v -- Run the function, and turn its outputs into a list of Values o <- apply noLoc mempty f v' let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o -- For each output.. - let m = expectJust "jvp: differentiation failed" $ mapM (\on -> case on of - -- If it is a JVP variable of the correct depth, return its primal and derivative - (ValueAD d (AD.JVP (AD.JVPValue pv dv))) | d == depth -> Just (putAD pv, putAD dv) - -- Otherwise, its partial derivatives are all 0 - _ -> (on,) . ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD on - ) o' + let m = + expectJust "jvp: differentiation failed" $ + mapM + ( \on -> case on of + -- If it is a JVP variable of the correct depth, return its primal and derivative + (ValueAD d (AD.JVP (AD.JVPValue pv dv))) | d == depth -> Just (putAD pv, putAD dv) + -- Otherwise, its partial derivatives are all 0 + _ -> (on,) . ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD on + ) + o' -- Extract the output values, and the partial derivatives let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o od = modifyValue (\i _ -> snd $ m !! (length m - 1 - i)) o - + -- Return a tuple of the output values, and partial derivatives pure $ toTuple [ov, od] where modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v - modifyValueM f v = snd <$> valueAccumLM (\a b -> do - b' <- f a b - pure (a + 1, b')) 0 v + modifyValueM f v = + snd + <$> valueAccumLM + ( \a b -> do + b' <- f a b + pure (a + 1, b') + ) + 0 + v expectJust _ (Just v) = v expectJust s Nothing = error s @@ -2151,7 +2184,7 @@ checkEntryArgs entry args entry_t "Entry point " <> dquotes (prettyName entry) <> " expects input of type(s)" - indent 2 (stack (map pretty param_ts)) + indent 2 (stack (map pretty param_ts)) -- | Execute the named function on the given arguments; may fail -- horribly if these are ill-typed. diff --git a/src/Language/Futhark/Interpreter/AD.hs b/src/Language/Futhark/Interpreter/AD.hs index c9f8551037..f8d7e79eb5 100644 --- a/src/Language/Futhark/Interpreter/AD.hs +++ b/src/Language/Futhark/Interpreter/AD.hs @@ -1,69 +1,62 @@ module Language.Futhark.Interpreter.AD - ( - Op (..), + ( Op (..), ADVariable (..), ADValue (..), Tape (..), VJPValue (..), JVPValue (..), - doOp, addFor, - primal, tapePrimal, primitive, - deriveTape, ) where -import Control.Monad (zipWithM, foldM) +import Control.Monad (foldM, zipWithM) import Data.Either (isRight) import Data.List (find) import Data.Map qualified as M - -- These are needed to reuse the definitions of the derivatives -- used by the compiler import Futhark.AD.Derivatives -import Futhark.Analysis.PrimExp (PrimExp(..)) - +import Futhark.Analysis.PrimExp (PrimExp (..)) import Language.Futhark.Core (VName (..), nameFromString) - -- As the mathematical functions of Futhark are implemented -- for the primitive type used in the compiler, said type is -- used for AD for simplicity import Language.Futhark.Primitive ---Mathematical operations, and type checking-- +-- Mathematical operations, and type checking-- -- A mathematical operation data Op - = OpBin BinOp - | OpCmp CmpOp - | OpUn UnOp - | OpFn String + = OpBin BinOp + | OpCmp CmpOp + | OpUn UnOp + | OpFn String | OpConv ConvOp - deriving Show + deriving (Show) -- Checks if an operation matches the types of its operands opTypeMatch :: Op -> [PrimType] -> Bool -opTypeMatch (OpBin op) p = all (\x -> binOpType op == x) p -opTypeMatch (OpCmp op) p = all (\x -> cmpOpType op == x) p -opTypeMatch (OpUn op) p = all (\x -> unOpType op == x) p +opTypeMatch (OpBin op) p = all (\x -> binOpType op == x) p +opTypeMatch (OpCmp op) p = all (\x -> cmpOpType op == x) p +opTypeMatch (OpUn op) p = all (\x -> unOpType op == x) p opTypeMatch (OpConv op) p = all (\x -> fst (convOpType op) == x) p -opTypeMatch (OpFn fn) p = case M.lookup fn primFuns of - Just (t, _, _) -> and $ zipWith (==) t p - Nothing -> error "TODO: IMPOSSIBLE" -- It is assumed that the function exists +opTypeMatch (OpFn fn) p = case M.lookup fn primFuns of + Just (t, _, _) -> and $ zipWith (==) t p + Nothing -> error "TODO: IMPOSSIBLE" -- It is assumed that the function exists -- Gets the return type of an operation opReturnType :: Op -> PrimType -opReturnType (OpBin op) = binOpType op -opReturnType (OpCmp op) = cmpOpType op -opReturnType (OpUn op) = unOpType op +opReturnType (OpBin op) = binOpType op +opReturnType (OpCmp op) = cmpOpType op +opReturnType (OpUn op) = unOpType op opReturnType (OpConv op) = snd $ convOpType op -opReturnType (OpFn fn) = case M.lookup fn primFuns of - Just (_, t, _) -> t - Nothing -> error "TODO: IMPOSSIBLE" -- It is assumed that the function exists +opReturnType (OpFn fn) = case M.lookup fn primFuns of + Just (_, t, _) -> t + Nothing -> error "TODO: IMPOSSIBLE" -- It is assumed that the function exists -- Returns the operation which performs addition (or an -- equivalent operation) on the given type @@ -81,23 +74,23 @@ multiplyFor (FloatType t) = FMul t multiplyFor Bool = LogAnd multiplyFor t = error $ "No notion of multiplication exists for type `" ++ show t ++ "`" ---Types and utility functions-- +-- Types and utility functions-- -- When taking the partial derivative of a function, we -- must differentiate between the values which are kept -- constant, and those which are not data ADValue = Variable Int ADVariable | Constant PrimValue - deriving Show + deriving (Show) -- When performing automatic differentiation, each derived -- variable must be augmented with additional data. This -- value holds the primitive value of the variable, as well --- as its data +-- as its data data ADVariable = VJP VJPValue | JVP JVPValue - deriving Show + deriving (Show) depth :: ADValue -> Int depth (Variable d _) = d @@ -112,7 +105,7 @@ primitive :: ADValue -> PrimValue primitive v@(Variable _ _) = primitive $ primal v primitive (Constant v) = v ---Code reuse from compiler AD-- +-- Code reuse from compiler AD-- -- Evaluates a PrimExp using doOp evalPrimExp :: PrimExp VName -> M.Map VName ADValue -> Maybe ADValue evalPrimExp (LeafExp n _) m = M.lookup n m @@ -141,71 +134,69 @@ lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName] lookupPDs (OpBin op) [x, y] = Just $ do let (a, b) = pdBinOp op x y [a, b] -lookupPDs (OpUn op) [x] = Just [pdUnOp op x] -lookupPDs (OpFn fn) p = pdBuiltin (nameFromString fn) p +lookupPDs (OpUn op) [x] = Just [pdUnOp op x] +lookupPDs (OpFn fn) p = pdBuiltin (nameFromString fn) p lookupPDs _ _ = Nothing ---Shared AD logic-- +-- Shared AD logic-- -- This function performs a mathematical operation on a -- list of operands, performing automatic differentiation -- if one or more operands is a Variable (of depth > 0) doOp :: Op -> [ADValue] -> Maybe ADValue -doOp op o = +doOp op o = let dep = case op of OpCmp _ -> 0 -- AD is not well-defined for comparason operations - -- There are no derivatives for those written in - -- PrimExp (check lookupPDs) + -- There are no derivatives for those written in + -- PrimExp (check lookupPDs) _ -> maximum (map depth o) - - in if dep == 0 then - -- In this case, every value is a constant, and - -- the mathematical operation can be applied as - -- it would be otherwise - - -- First, we make sure that the types of the - -- operands match those of the operation - let o' = map primitive o - in if opTypeMatch op (map primValueType o') then do - -- If they do, we perform the operation, and - -- return a Constant - Constant <$> case (op, o') of - (OpBin op', [x, y]) -> doBinOp op' x y - (OpCmp op', [x, y]) -> BoolValue <$> doCmpOp op' x y - (OpUn op', [x]) -> doUnOp op' x - (OpConv op', [x]) -> doConvOp op' x - (OpFn fn, _) -> do - (_, _, f) <- M.lookup fn primFuns - f o' - _ -> error "TODO: IMPOSSIBLE" -- This is needed due to the fact that the function - -- takes an array, yet some of the operations have a - -- fixed number of operands - - -- If the types do not match, we return Nothing - else Nothing - - else do - -- In this case, some values are variables. - -- We therefore have to perform the necessary - -- steps for AD - - -- First, we calculate the value for the - -- previous depth - let oprev = map primal o - vprev <- doOp op oprev - - -- Then we separate the values of the maximum - -- depth from those of a lower depth - let o' = map (divideDepths dep) o - -- Then we find out what type of AD is being - -- performed - case find isRight o' of - -- Finally, we perform the necessary steps - -- for the given type of AD - Just (Right (VJP {})) -> Just . Variable dep . VJP . VJPValue $ vjpHandleOp op (map extractVJP o') vprev - Just (Right (JVP {})) -> Variable dep . JVP . JVPValue vprev <$> jvpHandleFn op (map extractJVP o') - - _ -> error "TODO: IMPOSSIBLE" -- Since the maximum depth is non-zero, there must - -- be at least one variable of depth > 0 + in if dep == 0 + then -- In this case, every value is a constant, and + -- the mathematical operation can be applied as + -- it would be otherwise + + -- First, we make sure that the types of the + -- operands match those of the operation + + let o' = map primitive o + in if opTypeMatch op (map primValueType o') + then do + -- If they do, we perform the operation, and + -- return a Constant + Constant <$> case (op, o') of + (OpBin op', [x, y]) -> doBinOp op' x y + (OpCmp op', [x, y]) -> BoolValue <$> doCmpOp op' x y + (OpUn op', [x]) -> doUnOp op' x + (OpConv op', [x]) -> doConvOp op' x + (OpFn fn, _) -> do + (_, _, f) <- M.lookup fn primFuns + f o' + _ -> error "TODO: IMPOSSIBLE" -- This is needed due to the fact that the function + -- takes an array, yet some of the operations have a + -- fixed number of operands + else -- If the types do not match, we return Nothing + Nothing + else do + -- In this case, some values are variables. + -- We therefore have to perform the necessary + -- steps for AD + + -- First, we calculate the value for the + -- previous depth + let oprev = map primal o + vprev <- doOp op oprev + + -- Then we separate the values of the maximum + -- depth from those of a lower depth + let o' = map (divideDepths dep) o + -- Then we find out what type of AD is being + -- performed + case find isRight o' of + -- Finally, we perform the necessary steps + -- for the given type of AD + Just (Right (VJP {})) -> Just . Variable dep . VJP . VJPValue $ vjpHandleOp op (map extractVJP o') vprev + Just (Right (JVP {})) -> Variable dep . JVP . JVPValue vprev <$> jvpHandleFn op (map extractJVP o') + _ -> error "TODO: IMPOSSIBLE" -- Since the maximum depth is non-zero, there must + -- be at least one variable of depth > 0 where divideDepths :: Int -> ADValue -> Either ADValue ADVariable divideDepths _ v@(Constant {}) = Left v @@ -228,7 +219,7 @@ doOp op o = calculatePDs :: Op -> [ADValue] -> Maybe [ADValue] calculatePDs op p = do -- Create a unique VName for each operand - let n = map (\i -> VName (nameFromString $ "x" ++ show i) i) [1..length p] + let n = map (\i -> VName (nameFromString $ "x" ++ show i) i) [1 .. length p] -- Put the operands in the environment let m = M.fromList $ zip n p @@ -237,27 +228,27 @@ calculatePDs op p = do pde <- lookupPDs op $ map (`LeafExp` opReturnType op) n mapM (`evalPrimExp` m) pde - ---VJP / Reverse mode automatic differentiation-- +-- VJP / Reverse mode automatic differentiation-- -- In reverse mode AD, the entire computation -- leading up to a variable must be saved -- This is represented as a Tape newtype VJPValue = VJPValue Tape - deriving Show + deriving (Show) + -- Represents a computation tree, as well as every -- intermediate value in its evaluation -- TODO: Consider making this a graph data Tape - -- This represents a variable. Each variable is given - -- a unique ID, and has an initial value - = TapeID Int ADValue - -- this represents a constant - | TapeConst ADValue - -- This represents the application of a mathematical - -- operation. Each parameter is given by its Tape, and - -- the return value of the operation is saved - | TapeOp Op [Tape] ADValue - deriving Show + = -- This represents a variable. Each variable is given + -- a unique ID, and has an initial value + TapeID Int ADValue + | -- this represents a constant + TapeConst ADValue + | -- This represents the application of a mathematical + -- operation. Each parameter is given by its Tape, and + -- the return value of the operation is saved + TapeOp Op [Tape] ADValue + deriving (Show) -- Returns the primal value of a Tape tapePrimal :: Tape -> ADValue @@ -297,7 +288,6 @@ deriveTape (TapeOp op p _) s = do pd <- zipWithM deriveTape p s'' -- Add up the results Just $ foldl (M.unionWith add) M.empty pd - where add x y = expectJust "TODO: Remove me" $ doOp (OpBin $ addFor $ opReturnType op) [x, y] mul x y = doOp (OpBin $ multiplyFor $ opReturnType op) [x, y] @@ -309,13 +299,12 @@ deriveTape (TapeOp op p _) s = do expectJust _ (Just v) = v expectJust e Nothing = error e - ---JVP / Forward mode automatic differentiation-- +-- JVP / Forward mode automatic differentiation-- -- In JVP, the derivative of the variable must be saved. -- This is represented as a second value data JVPValue = JVPValue ADValue ADValue - deriving Show + deriving (Show) -- This calculates the derivative part of the JVPValue -- resulting from the application of a mathematical @@ -333,7 +322,6 @@ jvpHandleFn op p = do pds <- calculatePDs op $ map primal' p vs <- zipWithM mul pds $ map derivative p foldM add (Constant $ blankPrimValue $ opReturnType op) vs - where primal' (Left v) = v primal' (Right (JVPValue v _)) = v diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 4a01ecf781..8b63f5e7d7 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -30,7 +30,7 @@ module Language.Futhark.Interpreter.Values where import Data.Array -import Data.Bifunctor (Bifunctor(second)) +import Data.Bifunctor (Bifunctor (second)) import Data.List (genericLength) import Data.Map qualified as M import Data.Maybe @@ -152,7 +152,7 @@ prettyValueWith pprPrim = pprPrec 0 pprPrec _ ValueAcc {} = "#" pprPrec p (ValueSum _ n vs) = parensIf (p > (0 :: Int)) $ "#" <> sep (pretty n : map (pprPrec 1) vs) - -- TODO: This could be prettier. Perhaps add pretty printing for ADVariable / ADValues + -- TODO: This could be prettier. Perhaps add pretty printing for ADVariable / ADValues pprPrec _ (ValueAD d v) = pretty $ "d[" ++ show d ++ "]" ++ show v pprElem v@ValueArray {} = pprPrec 0 v pprElem v = group $ pprPrec 0 v diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 2ddc8039dc..e1c4287ef4 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -262,10 +262,10 @@ prettyAppExp p (LetPat sizes pat e body _) = prettyAppExp _ (LetFun fname (tparams, params, retdecl, rettype, e) body _) = "let" <+> hsep (prettyName fname : map pretty tparams ++ map pretty params) - <> retdecl' - <+> equals - indent 2 (pretty e) - letBody body + <> retdecl' + <+> equals + indent 2 (pretty e) + letBody body where retdecl' = case (pretty <$> unAnnot rettype) `mplus` (pretty <$> retdecl) of Just rettype' -> colon <+> align rettype' @@ -274,10 +274,10 @@ prettyAppExp _ (LetWith dest src idxs ve body _) | dest == src = "let" <+> pretty dest - <> list (map pretty idxs) - <+> equals - <+> align (pretty ve) - letBody body + <> list (map pretty idxs) + <+> equals + <+> align (pretty ve) + letBody body | otherwise = "let" <+> pretty dest @@ -381,8 +381,8 @@ prettyExp p (Lambda params body rettype _ _) = "\\" <> hsep (map pretty params) <> ppAscription rettype - <+> "->" - indent 2 (align (pretty body)) + <+> "->" + indent 2 (align (pretty body)) prettyExp _ (OpSection binop _ _) = parens $ pretty binop prettyExp _ (OpSectionLeft binop _ x _ _ _) = @@ -406,7 +406,7 @@ prettyExp i (AppExp e res) not $ null ext = parens (prettyAppExp i e) "@" - <> parens (pretty t <> "," <+> brackets (commasep $ map prettyName ext)) + <> parens (pretty t <> "," <+> brackets (commasep $ map prettyName ext)) | otherwise = prettyAppExp i e instance (Eq vn, IsName vn, Annot f) => Pretty (ExpBase f vn) where @@ -491,8 +491,8 @@ prettyModExp p (ModLambda param maybe_sig body _) = "\\" <> pretty param <> maybe_sig' - <+> "->" - indent 2 (pretty body) + <+> "->" + indent 2 (pretty body) where maybe_sig' = case maybe_sig of Nothing -> mempty @@ -510,9 +510,9 @@ instance (Eq vn, IsName vn, Annot f) => Pretty (TypeBindBase f vn) where pretty (TypeBind name l params te rt _ _) = "type" <> pretty l - <+> hsep (prettyName name : map pretty params) - <+> equals - <+> maybe (pretty te) pretty (unAnnot rt) + <+> hsep (prettyName name : map pretty params) + <+> equals + <+> maybe (pretty te) pretty (unAnnot rt) instance (Eq vn, IsName vn) => Pretty (TypeParamBase vn) where pretty (TypeParamDim name _) = brackets $ prettyName name @@ -522,16 +522,16 @@ instance (Eq vn, IsName vn, Annot f) => Pretty (ValBindBase f vn) where pretty (ValBind entry name retdecl rettype tparams args body _ attrs _) = mconcat (map ((<> line) . prettyAttr) attrs) <> fun - <+> align - ( sep - ( prettyName name - : map pretty tparams - ++ map pretty args - ++ retdecl' - ++ ["="] - ) - ) - indent 2 (pretty body) + <+> align + ( sep + ( prettyName name + : map pretty tparams + ++ map pretty args + ++ retdecl' + ++ ["="] + ) + ) + indent 2 (pretty body) where fun | isJust entry = "entry" diff --git a/src/Language/Futhark/Semantic.hs b/src/Language/Futhark/Semantic.hs index 2d2d83d810..5a93d1364b 100644 --- a/src/Language/Futhark/Semantic.hs +++ b/src/Language/Futhark/Semantic.hs @@ -179,9 +179,9 @@ instance Pretty Env where renderTypeBind (name, TypeAbbr l tps tp) = p l <+> prettyName name - <> mconcat (map ((" " <>) . pretty) tps) - <> " =" - <+> pretty tp + <> mconcat (map ((" " <>) . pretty) tps) + <> " =" + <+> pretty tp where p Lifted = "type^" p SizeLifted = "type~" @@ -189,9 +189,9 @@ instance Pretty Env where renderValBind (name, BoundV tps t) = "val" <+> prettyName name - <> mconcat (map ((" " <>) . pretty) tps) - <> " =" - <+> pretty t + <> mconcat (map ((" " <>) . pretty) tps) + <> " =" + <+> pretty t renderModType (name, _sig) = "module type" <+> prettyName name renderMod (name, mod) = diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 38fbb13147..4050f0c574 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -165,10 +165,10 @@ dupDefinitionError space name loc1 loc2 = "Duplicate definition of" <+> pretty space <+> prettyName name - <> "." - "Previously defined at" - <+> pretty (locStr loc2) - <> "." + <> "." + "Previously defined at" + <+> pretty (locStr loc2) + <> "." checkForDuplicateDecs :: [DecBase NoInfo Name] -> TypeM () checkForDuplicateDecs = @@ -235,7 +235,7 @@ checkSpecs (ValSpec name tparams vtype NoInfo doc loc : specs) = do "All function parameters must have non-anonymous sizes." "Hint: add size parameters to" <+> dquotes (pretty name) - <> "." + <> "." pure (tparams', vtype', vtype_t) @@ -597,7 +597,7 @@ checkTypeBind (TypeBind name l tps te NoInfo doc loc) = "Non-lifted type abbreviations may not use existential sizes in their definition." "Hint: use 'type~' or add size parameters to" <+> dquotes (prettyName name) - <> "." + <> "." _ -> pure () bindSpaced1 Type name loc $ \name' -> do diff --git a/src/Language/Futhark/TypeChecker/Consumption.hs b/src/Language/Futhark/TypeChecker/Consumption.hs index b44d9baea5..47bbea8da5 100644 --- a/src/Language/Futhark/TypeChecker/Consumption.hs +++ b/src/Language/Futhark/TypeChecker/Consumption.hs @@ -169,7 +169,7 @@ returnAliased name loc = addError loc mempty . withIndexLink "return-aliased" $ "Unique-typed return value is aliased to" <+> dquotes (prettyName name) - <> ", which is not consumable." + <> ", which is not consumable." uniqueReturnAliased :: SrcLoc -> CheckM () uniqueReturnAliased loc = @@ -307,9 +307,9 @@ checkIfConsumed rloc als = do addError rloc mempty . withIndexLink "use-after-consume" $ "Using" <+> v' - <> ", but this was consumed at" - <+> pretty (locStrRel rloc wloc) - <> ". (Possibly through aliases.)" + <> ", but this was consumed at" + <+> pretty (locStrRel rloc wloc) + <> ". (Possibly through aliases.)" consumed :: Consumed -> CheckM () consumed vs = modify $ \s -> s {stateConsumed = stateConsumed s <> vs} @@ -509,7 +509,7 @@ checkArg prev p_t e = do indent 2 (pretty prev_arg) "at" <+> pretty (locTextRel (locOf e) (locOf prev_arg)) - <> "." + <> "." pure (e', e_als) where prevAlias v = @@ -604,7 +604,7 @@ convergeLoopParam loop_loc param body_cons body_als = do <+> dquotes (prettyName pat_v) <+> "aliases" <+> dquotes (prettyName v) - <> "." + <> "." (cons, obs) <- get unless (S.null $ aliases t `S.intersection` cons) $ lift . addError loop_loc mempty $ @@ -674,8 +674,8 @@ checkLoop loop_loc (param, arg, form, body) = do addError loop_loc mempty $ "Loop body uses" <+> v' - <> " (or an alias)," - "but this is consumed by the initial loop argument." + <> " (or an alias)," + "but this is consumed by the initial loop argument." v <- VName "internal_loop_result" <$> incCounter modify $ \s -> s {stateNames = M.insert v (NameLoopRes (srclocOf loop_loc)) $ stateNames s} @@ -969,9 +969,9 @@ checkGlobalAliases loc params body_t = do "Function result aliases the free variable " <> dquotes (prettyName v) <> "." - "Use" - <+> dquotes "copy" - <+> "to break the aliasing." + "Use" + <+> dquotes "copy" + <+> "to break the aliasing." -- | Type-check a value definition. This also infers a new return -- type that may be more unique than previously. diff --git a/src/Language/Futhark/TypeChecker/Modules.hs b/src/Language/Futhark/TypeChecker/Modules.hs index 9d252c55ad..79248f8373 100644 --- a/src/Language/Futhark/TypeChecker/Modules.hs +++ b/src/Language/Futhark/TypeChecker/Modules.hs @@ -212,9 +212,9 @@ refineEnv loc tset env tname ps t typeError loc mempty $ "Cannot refine a type having" <+> tpMsg ps - <> " with a type having " - <> tpMsg cur_ps - <> "." + <> " with a type having " + <> tpMsg cur_ps + <> "." | otherwise = typeError loc mempty $ dquotes (pretty tname) <+> "is not an abstract type in the module type." where @@ -288,7 +288,7 @@ resolveAbsTypes mod_abs mod sig_abs loc = do indent 2 (ppTypeAbbr abs name mod_t) "but module type requires" <+> what - <> "." + <> "." where what = case name_l of Unlifted -> "a non-lifted type" @@ -381,14 +381,14 @@ ppTypeAbbr abs name (l, ps, RetType [] (Scalar (TypeVar _ tn args))) map typeParamToArg ps == args = "type" <> pretty l - <+> pretty name - <+> hsep (map pretty ps) + <+> pretty name + <+> hsep (map pretty ps) ppTypeAbbr _ name (l, ps, t) = "type" <> pretty l - <+> hsep (pretty name : map pretty ps) - <+> equals - <+> nest 2 (align (pretty t)) + <+> hsep (pretty name : map pretty ps) + <+> equals + <+> nest 2 (align (pretty t)) -- | Return new renamed/abstracted env, as well as a mapping from -- names in the signature to names in the new env. This is used for diff --git a/src/Language/Futhark/TypeChecker/Monad.hs b/src/Language/Futhark/TypeChecker/Monad.hs index 6c6d2e10dd..099eda554f 100644 --- a/src/Language/Futhark/TypeChecker/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Monad.hs @@ -155,7 +155,7 @@ underscoreUse loc name = typeError loc mempty $ "Use of" <+> dquotes (pretty name) - <> ": variables prefixed with underscore may not be accessed." + <> ": variables prefixed with underscore may not be accessed." -- | A mapping from import import names to 'Env's. This is used to -- resolve @import@ declarations. diff --git a/src/Language/Futhark/TypeChecker/Names.hs b/src/Language/Futhark/TypeChecker/Names.hs index 6044f73f60..4da954838f 100644 --- a/src/Language/Futhark/TypeChecker/Names.hs +++ b/src/Language/Futhark/TypeChecker/Names.hs @@ -50,7 +50,7 @@ checkForDuplicateNamesInType = check mempty <+> dquotes (pretty v) <+> "also bound at" <+> pretty (locStr prev_loc) - <> "." + <> "." check seen (TEArrow (Just v) t1 t2 loc) | Just prev_loc <- M.lookup v seen = @@ -110,7 +110,7 @@ checkForDuplicateNames tps pats = (`evalStateT` mempty) $ do <+> dquotes (pretty v) <+> "also bound at" <+> pretty (locStr prev_loc) - <> "." + <> "." Nothing -> modify $ M.insert (ns, v) loc @@ -432,7 +432,7 @@ resolveTypeParams ps m = <+> dquotes (pretty v) <+> "previously defined at" <+> pretty (locStr prev) - <> "." + <> "." Nothing -> do modify $ M.insert (ns, v) loc lift $ checkName ns v loc diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 1dd640ff3b..6cd42fdc98 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -387,7 +387,7 @@ checkExp (RecordLit fs loc) = <+> dquotes (pretty f) <+> "previously defined at" <+> pretty (locStrRel rloc sloc) - <> "." + <> "." Nothing -> pure () -- No need to type check this, as these are only produced by the -- parser if the elements are monomorphic and all match. @@ -1025,15 +1025,15 @@ checkApply loc (fname, prev_applied) ftype argexp = do "Cannot apply" <+> fname' <+> "to argument #" - <> pretty (prev_applied + 1) - <+> dquotes (shorten $ group $ pretty argexp) - <> "," - "as" - <+> fname' - <+> "only takes" - <+> pretty prev_applied - <+> arguments - <> "." + <> pretty (prev_applied + 1) + <+> dquotes (shorten $ group $ pretty argexp) + <> "," + "as" + <+> fname' + <+> "only takes" + <+> pretty prev_applied + <+> arguments + <> "." where arguments | prev_applied == 1 = "argument" @@ -1192,20 +1192,20 @@ causalityCheck binding_body = do <+> dquotes (prettyName d) <+> "needed for type of" <+> what - <> colon - indent 2 (pretty t) - "But" - <+> dquotes (prettyName d) - <+> "is computed at" - <+> pretty (locStrRel loc dloc) - <> "." - "" - "Hint:" - <+> align - ( textwrap "Bind the expression producing" - <+> dquotes (prettyName d) - <+> "with 'let' beforehand." - ) + <> colon + indent 2 (pretty t) + "But" + <+> dquotes (prettyName d) + <+> "is computed at" + <+> pretty (locStrRel loc dloc) + <> "." + "" + "Hint:" + <+> align + ( textwrap "Bind the expression producing" + <+> dquotes (prettyName d) + <+> "with 'let' beforehand." + ) mustBeIrrefutable :: (MonadTypeChecker f) => Pat StructType -> f () mustBeIrrefutable p = do @@ -1366,8 +1366,8 @@ fixOverloadedTypes tyvars_at_toplevel = typeError usage mempty . withIndexLink "ambiguous-type" $ "Type is ambiguous (could be one of" <+> commasep (map pretty ots) - <> ")." - "Add a type annotation to disambiguate the type." + <> ")." + "Add a type annotation to disambiguate the type." fixOverloaded (v, NoConstraint _ usage) = do -- See #1552. unify usage (Scalar (TypeVar mempty (qualName v) [])) $ @@ -1389,8 +1389,8 @@ fixOverloadedTypes tyvars_at_toplevel = typeError usage mempty . withIndexLink "ambiguous-type" $ "Type is ambiguous (must be a sum type with constructors:" <+> pretty (Sum cs) - <> ")." - "Add a type annotation to disambiguate the type." + <> ")." + "Add a type annotation to disambiguate the type." fixOverloaded (v, Size Nothing (Usage Nothing loc)) = typeError loc mempty . withIndexLink "ambiguous-size" $ "Ambiguous size" <+> dquotes (prettyName v) <> "." @@ -1512,13 +1512,13 @@ verifyFunctionParams fname params = <+> dquotes (pretty p) "refers to size" <+> dquotes (prettyName d) - <> comma - textwrap "which will not be accessible to the caller" - <> comma - textwrap "possibly because it is nested in a tuple or record." - textwrap "Consider ascribing an explicit type that does not reference " - <> dquotes (prettyName d) - <> "." + <> comma + textwrap "which will not be accessible to the caller" + <> comma + textwrap "possibly because it is nested in a tuple or record." + textwrap "Consider ascribing an explicit type that does not reference " + <> dquotes (prettyName d) + <> "." | otherwise = verifyParams forbidden' ps where forbidden' = @@ -1610,8 +1610,8 @@ closeOverTypes defname defloc tparams paramts ret substs = do <+> dquotes (prettyName k) <+> "in parameter of" <+> dquotes (prettyName defname) - <> ", which is inferred as:" - indent 2 (pretty t) + <> ", which is inferred as:" + indent 2 (pretty t) | k `S.member` produced_sizes = pure $ Just $ Right k closeOver (_, _) = diff --git a/src/Language/Futhark/TypeChecker/Terms/Loop.hs b/src/Language/Futhark/TypeChecker/Terms/Loop.hs index a44f5ec522..b1bb3fb9ea 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Loop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Loop.hs @@ -121,7 +121,7 @@ checkForImpossible loc known_before pat_t = do <+> dquotes (prettyName v) <+> "is an existential size created inside the loop body at" <+> pretty (locStrRel loc v_loc) - <> "." + <> "." case mapMaybe bad $ S.toList $ fvVars $ freeInType pat_t of problem : _ -> problem [] -> pure () diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index aa95a3cbad..c903678f9b 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -114,13 +114,13 @@ instance Pretty Checking where Nothing -> "Cannot apply function to" <+> dquotes (shorten $ group $ pretty e) - <> " (invalid type)." + <> " (invalid type)." Just fname -> "Cannot apply" <+> dquotes (pretty fname) <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) - <> " (invalid type)." + <> " (invalid type)." pretty (CheckingReturn expected actual) = "Function body does not have expected type." "Expected:" @@ -161,25 +161,25 @@ instance Pretty Checking where pretty (CheckingRecordUpdate fs expected actual) = "Type mismatch when updating record field" <+> dquotes fs' - <> "." - "Existing:" - <+> align (pretty expected) - "New: " - <+> align (pretty actual) + <> "." + "Existing:" + <+> align (pretty expected) + "New: " + <+> align (pretty actual) where fs' = mconcat $ punctuate "." $ map pretty fs pretty (CheckingRequired [expected] actual) = "Expression must have type" <+> pretty expected - <> "." - "Actual type:" - <+> align (pretty actual) + <> "." + "Actual type:" + <+> align (pretty actual) pretty (CheckingRequired expected actual) = "Type of expression must be one of " <+> expected' - <> "." - "Actual type:" - <+> align (pretty actual) + <> "." + "Actual type:" + <+> align (pretty actual) where expected' = commasep (map pretty expected) pretty (CheckingBranches t1 t2) = diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 5dfa5e622f..f86b587f62 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -158,8 +158,8 @@ checkPat' _ (RecordPat p_fs _) _ typeError fp mempty $ "Underscore-prefixed fields are not allowed." "Did you mean" - <> dquotes (pretty (drop 1 (nameToString f)) <> "=_") - <> "?" + <> dquotes (pretty (drop 1 (nameToString f)) <> "=_") + <> "?" checkPat' sizes p@(RecordPat p_fs loc) (Ascribed t) | Scalar (Record t_fs) <- t, sort (map fst p_fs) == sort (M.keys t_fs) = @@ -212,10 +212,10 @@ checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) "Pattern #" <> pretty n <> " expects" - <+> pretty (length ps) - <+> "constructor arguments, but type provides" - <+> pretty (length ts) - <+> "arguments." + <+> pretty (length ps) + <+> "constructor arguments, but type provides" + <+> pretty (length ts) + <+> "arguments." ps' <- zipWithM (checkPat' sizes) ps $ map Ascribed ts pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed t) = do diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 9316c1b62f..c854476a92 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -254,7 +254,7 @@ evalTypeExp df ote@TEApply {} = do <+> pretty (length ps) <+> "arguments, but provided" <+> pretty (length targs) - <> "." + <> "." else do (targs', dims, substs) <- unzip3 <$> zipWithM checkArgApply ps targs pure @@ -306,7 +306,7 @@ evalTypeExp df ote@TEApply {} = do <+> pretty a <+> "not valid for a type parameter" <+> pretty p - <> "." + <> "." -- | Check a type expression, producing: -- diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 6b245221b0..ffa375deec 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -59,7 +59,7 @@ instance Pretty BreadCrumb where pretty (MatchingFields fields) = "When matching types of record field" <+> dquotes (mconcat $ punctuate "." $ map pretty fields) - <> dot + <> dot pretty (MatchingConstructor c) = "When matching types of constructor" <+> dquotes (pretty c) <> dot pretty (Matching s) = @@ -188,13 +188,13 @@ prettySource :: Loc -> Loc -> RigidSource -> Doc () prettySource ctx loc (RigidRet Nothing) = "is unknown size returned by function at" <+> pretty (locStrRel ctx loc) - <> "." + <> "." prettySource ctx loc (RigidRet (Just fname)) = "is unknown size returned by" <+> dquotes (pretty fname) <+> "at" <+> pretty (locStrRel ctx loc) - <> "." + <> "." prettySource ctx loc (RigidArg fname arg) = "is value of argument" indent 2 (shorten (pretty arg)) @@ -202,16 +202,16 @@ prettySource ctx loc (RigidArg fname arg) = <+> fname' <+> "at" <+> pretty (locStrRel ctx loc) - <> "." + <> "." where fname' = maybe "function" (dquotes . pretty) fname prettySource ctx loc (RigidSlice d slice) = "is size produced by slice" indent 2 (shorten (pretty slice)) d_desc - <> "at" - <+> pretty (locStrRel ctx loc) - <> "." + <> "at" + <+> pretty (locStrRel ctx loc) + <> "." where d_desc = case d of Just d' -> "of dimension of size " <> dquotes (pretty d') <> " " @@ -226,19 +226,19 @@ prettySource ctx loc (RigidOutOfScope boundloc v) = <> " going out of scope at " <> pretty (locStrRel ctx loc) <> "." - "Originally bound at " - <> pretty (locStrRel ctx boundloc) - <> "." + "Originally bound at " + <> pretty (locStrRel ctx boundloc) + <> "." prettySource _ _ RigidUnify = "is an artificial size invented during unification of functions with anonymous sizes." prettySource ctx loc (RigidCond t1 t2) = "is unknown due to conditional expression at " <> pretty (locStrRel ctx loc) <> "." - "One branch returns array of type: " - <> align (pretty t1) - "The other an array of type: " - <> align (pretty t2) + "One branch returns array of type: " + <> align (pretty t1) + "The other an array of type: " + <> align (pretty t2) -- | Retrieve notes describing the purpose or origin of the given -- t'Size'. The location is used as the *current* location, for the @@ -578,7 +578,7 @@ occursCheck usage bcs vn tp = <+> prettyName vn <+> "with" <+> pretty tp - <> "." + <> "." scopeCheck :: (MonadUnify m) => @@ -692,7 +692,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do <+> commasep (map pretty ts) "due to" <+> pretty old_usage - <> "." + <> "." Just (HasFields l required_fields old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp case tp of @@ -735,7 +735,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do indent 2 (pretty (Record required_fields)) "due to" <+> pretty old_usage - <> "." + <> "." -- See Note [Linking variables to sum types] Just (HasConstrs l required_cs old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp @@ -853,8 +853,8 @@ mustBeOneOf ts usage t = do "Cannot unify type" <+> dquotes (pretty t) <+> "with any of " - <> commasep (map pretty ts) - <> "." + <> commasep (map pretty ts) + <> "." linkVarToTypes :: (MonadUnify m) => Usage -> VName -> [PrimType] -> m () linkVarToTypes usage vn ts = do @@ -870,22 +870,22 @@ linkVarToTypes usage vn ts = do <+> commasep (map pretty vn_ts) <+> "due to" <+> pretty vn_usage - <> "." + <> "." ts' -> modifyConstraints $ M.insert vn (lvl, Overloaded ts' usage) Just (_, HasConstrs _ _ vn_usage) -> unifyError usage mempty noBreadCrumbs $ "Type constrained to one of" <+> commasep (map pretty ts) - <> ", but also inferred to be sum type due to" - <+> pretty vn_usage - <> "." + <> ", but also inferred to be sum type due to" + <+> pretty vn_usage + <> "." Just (_, HasFields _ _ vn_usage) -> unifyError usage mempty noBreadCrumbs $ "Type constrained to one of" <+> commasep (map pretty ts) - <> ", but also inferred to be record due to" - <+> pretty vn_usage - <> "." + <> ", but also inferred to be record due to" + <+> pretty vn_usage + <> "." Just (lvl, _) -> modifyConstraints $ M.insert vn (lvl, Overloaded ts usage) Nothing -> unifyError usage mempty noBreadCrumbs $ @@ -1115,7 +1115,7 @@ mustHaveFieldWith onDims usage bound bcs l t = do <+> dquotes (pretty l) <+> " of value of type" <+> pretty (toStructural t) - <> "." + <> "." _ -> do unify usage t $ Scalar $ Record $ M.singleton l l_type pure l_type diff --git a/unittests/Futhark/IR/Mem/IxFunTests.hs b/unittests/Futhark/IR/Mem/IxFunTests.hs index c5925f21df..e1ed79b85d 100644 --- a/unittests/Futhark/IR/Mem/IxFunTests.hs +++ b/unittests/Futhark/IR/Mem/IxFunTests.hs @@ -60,20 +60,20 @@ compareIxFuns (Just ixfunLMAD) ixfunAlg = T.unpack . docText $ "lmad ixfun: " <> pretty ixfunLMAD - "alg ixfun: " - <> pretty ixfunAlg - "lmad shape: " - <> pretty lmadShape - "alg shape: " - <> pretty algShape - "lmad points length: " - <> pretty (length resLMAD) - "alg points length: " - <> pretty (length resAlg) - "lmad points: " - <> pretty resLMAD - "alg points: " - <> pretty resAlg + "alg ixfun: " + <> pretty ixfunAlg + "lmad shape: " + <> pretty lmadShape + "alg shape: " + <> pretty algShape + "lmad points length: " + <> pretty (length resLMAD) + "alg points length: " + <> pretty (length resAlg) + "lmad points: " + <> pretty resLMAD + "alg points: " + <> pretty resAlg in (lmadShape == algShape && resLMAD == resAlg) @? errorMessage compareIxFuns Nothing ixfunAlg = assertFailure $ @@ -91,9 +91,9 @@ compareOpsFailure (Just ixfunLMAD, ixfunAlg) = assertFailure . T.unpack . docText $ "Not supposed to be representable as LMAD." "lmad ixfun: " - <> pretty ixfunLMAD - "alg ixfun: " - <> pretty ixfunAlg + <> pretty ixfunLMAD + "alg ixfun: " + <> pretty ixfunAlg -- XXX: Clean this up. n :: Int