From eebd8658c6ef149e9ef08879d139297f4fd63b6a Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Thu, 2 Jan 2025 14:53:32 -0800 Subject: [PATCH] Main --- Main.lean | 76 ++++++++++++++++++++++++++++++++++++++++++- TensorLib/Basic.lean | 1 + TensorLib/Tensor.lean | 3 +- TensorLib/Test.lean | 57 ++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 3 deletions(-) create mode 100644 TensorLib/Test.lean diff --git a/Main.lean b/Main.lean index 4679268..e9e92e1 100644 --- a/Main.lean +++ b/Main.lean @@ -4,7 +4,81 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin -/ +import Init.System.IO +import Cli import TensorLib -def main (_args : List String) : IO UInt32 := +open Cli +open TensorLib + +def format (p : Parsed) : IO UInt32 := do + let shape : List Nat := (p.variableArgsAs! Nat).toList + let n := natProd shape + IO.println s!"Got shape {shape}" + let range := Tensor.Element.arange BV16 n + let v := range.reshape! shape + let s := v.format BV16 + IO.println s return 0 + +def formatCmd := `[Cli| + "format" VIA format; + "Test formatting" + + ARGS: + ...shape : Nat; "shape to test" +] + +def parseNpy (p : Parsed) : IO UInt32 := do + let file := p.positionalArg! "input" |>.as! String + IO.println s!"Parsing {file}..." + let v <- Npy.parseFile file + IO.println (repr v) + if p.hasFlag "write" then do + let new := (System.FilePath.mk file).addExtension "new" + IO.println s!"Writing copy to {new}" + let _ <- v.save! new + -- TensorLib.Npy.save! (arr : Ndarray) (file : System.FilePath) : IO Unit + return 0 + +def parseNpyCmd := `[Cli| + "parse-npy" VIA parseNpy; + "Parse a .npy file and pretty print the contents" + + FLAGS: + write; "Also write the result back to `input`.new to test saving arrays to disk" + + ARGS: + input : String; ".npy file to parse" +] + +def runTests (_ : Parsed) : IO UInt32 := do + -- Just pytest for now, but add Lean tests here as well + -- pytest will exit nonzero on it's own, so we don't need to check exit code + IO.println "Running PyTest..." + let output <- IO.Process.output { cmd := "pytest" } + IO.println s!"stdout: {output.stdout}" + IO.println s!"stderr: {output.stderr}" + return output.exitCode + +def runTestsCmd := `[Cli| + "test" VIA runTests; + "Run tests" +] + +def tensorlibCmd : Cmd := `[Cli| + tensorlib NOOP; ["0.0.1"] + "TensorLib is a NumPy-like library for Lean." + + SUBCOMMANDS: + formatCmd; + parseNpyCmd; + runTestsCmd +] + +def main (args : List String) : IO UInt32 := + if args.isEmpty then do + IO.println tensorlibCmd.help + return 0 + else do + tensorlibCmd.validate args diff --git a/TensorLib/Basic.lean b/TensorLib/Basic.lean index d813c0e..75a0656 100644 --- a/TensorLib/Basic.lean +++ b/TensorLib/Basic.lean @@ -13,3 +13,4 @@ import TensorLib.Tensor import TensorLib.Index import TensorLib.Mgrid import TensorLib.Ufunc +import TensorLib.Test diff --git a/TensorLib/Tensor.lean b/TensorLib/Tensor.lean index 2450646..b11e7c2 100644 --- a/TensorLib/Tensor.lean +++ b/TensorLib/Tensor.lean @@ -316,8 +316,7 @@ def setDimIndex [Element a] (x : Tensor) (index : DimIndex) (v : a): Err Tensor -- TODO: remove `Err` by proving all indices are within range def toList (a : Type) [Tensor.Element a] (x : Tensor) : Err (List a) := - let traverseFn ind : Err a := getDimIndex x ind - x.shape.allDimIndices.traverse traverseFn + x.shape.allDimIndices.traverse (getDimIndex x) def toList! (a : Type) [Tensor.Element a] (x : Tensor) : List a := match toList a x with | .error _ => [] diff --git a/TensorLib/Test.lean b/TensorLib/Test.lean new file mode 100644 index 0000000..82c05f7 --- /dev/null +++ b/TensorLib/Test.lean @@ -0,0 +1,57 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin +-/ + +import Init.System.IO +import TensorLib.Tensor + +namespace TensorLib + +section Test + +private def DEBUG := false +private def debugPrint {a : Type} [Repr a] (s : a) : IO Unit := if DEBUG then IO.print (Std.Format.pretty (repr s)) else return () +private def debugPrintln {a : Type} [Repr a] (s : a) : IO Unit := do + debugPrint s + if DEBUG then IO.print "\n" else return () + +-- Caller must remove the temp file +private def saveNumpyArray (expr : String) : IO System.FilePath := do + let (_root_, file) <- IO.FS.createTempFile + let expr := s!"import numpy as np; x = {expr}; np.save('{file}', x)" + let output <- IO.Process.output { cmd := "/usr/bin/env", args := ["python3", "-c", expr].toArray } + let _ <- debugPrintln output.stdout + let _ <- debugPrintln output.stderr + -- `np.save` appends `.npy` to the file + return file.addExtension "npy" + +private def testTensorElementBV (n : Nat) [Tensor.Element (BitVec n)] (dtype : String) : IO Bool := do + let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype}').reshape(5, 4)" + let npy <- Npy.parseFile file + let arr <- IO.ofExcept (Tensor.ofNpy npy) + let _ <- debugPrintln file + let _ <- debugPrintln arr + let _ <- IO.FS.removeFile file + let expected : List (BitVec n) := (List.range 20).map (BitVec.ofNat n) + let mut actual := [] + for i in [0:20] do + match Tensor.Element.getPosition arr i with + | .error msg => IO.throwServerError msg + | .ok v => actual := v :: actual + actual := actual.reverse + let _ <- debugPrintln actual + return expected == actual + +-- Sketchy perhaps, but seems to work for testing +private def ioBool (x : IO Bool) : Bool := match x.run () with +| .error _ _ => false +| .ok b _ => b + +#guard ioBool (testTensorElementBV 16 "uint16") +#guard ! ioBool (testTensorElementBV 32 "uint16") +#guard ioBool (testTensorElementBV 32 "uint32") +#guard ! ioBool (testTensorElementBV 32 "uint64") + +end Test