Skip to content

Commit

Permalink
Add Main CLI for quick experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmcl committed Jan 24, 2025
1 parent 50b0c4e commit 3e2cd3f
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 1 deletion.
75 changes: 74 additions & 1 deletion Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,80 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Init.System.IO
import Cli
import TensorLib

def main (_args : List String) : IO UInt32 :=
open Cli
open TensorLib

def format (p : Parsed) : IO UInt32 := do
let shape : Shape := Shape.mk (p.variableArgsAs! Nat).toList
IO.println s!"Got shape {shape}"
let range := Tensor.Element.arange BV16 shape.count
let v := range.reshape! shape
let s := v.format BV16
IO.println s
return 0

def formatCmd := `[Cli|
"format" VIA format;
"Test formatting"

ARGS:
...shape : Nat; "shape to test"
]

def parseNpy (p : Parsed) : IO UInt32 := do
let file := p.positionalArg! "input" |>.as! String
IO.println s!"Parsing {file}..."
let v <- Npy.parseFile file
IO.println (repr v)
if p.hasFlag "write" then do
let new := (System.FilePath.mk file).addExtension "new"
IO.println s!"Writing copy to {new}"
let _ <- v.save! new
-- TensorLib.Npy.save! (arr : Ndarray) (file : System.FilePath) : IO Unit
return 0

def parseNpyCmd := `[Cli|
"parse-npy" VIA parseNpy;
"Parse a .npy file and pretty print the contents"

FLAGS:
write; "Also write the result back to `input`.new to test saving arrays to disk"

ARGS:
input : String; ".npy file to parse"
]

def runTests (_ : Parsed) : IO UInt32 := do
-- Just pytest for now, but add Lean tests here as well
-- pytest will exit nonzero on it's own, so we don't need to check exit code
IO.println "Running PyTest..."
let output <- IO.Process.output { cmd := "pytest" }
IO.println s!"stdout: {output.stdout}"
IO.println s!"stderr: {output.stderr}"
return output.exitCode

def runTestsCmd := `[Cli|
"test" VIA runTests;
"Run tests"
]

def tensorlibCmd : Cmd := `[Cli|
tensorlib NOOP; ["0.0.1"]
"TensorLib is a NumPy-like library for Lean."

SUBCOMMANDS:
formatCmd;
parseNpyCmd;
runTestsCmd
]

def main (args : List String) : IO UInt32 :=
if args.isEmpty then do
IO.println tensorlibCmd.help
return 0
else do
tensorlibCmd.validate args
1 change: 1 addition & 0 deletions TensorLib/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ import TensorLib.Tensor
import TensorLib.Index
import TensorLib.Mgrid
import TensorLib.Ufunc
import TensorLib.Test
57 changes: 57 additions & 0 deletions TensorLib/Test.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Init.System.IO
import TensorLib.Tensor

namespace TensorLib

section Test

private def DEBUG := false
private def debugPrint {a : Type} [Repr a] (s : a) : IO Unit := if DEBUG then IO.print (Std.Format.pretty (repr s)) else return ()
private def debugPrintln {a : Type} [Repr a] (s : a) : IO Unit := do
debugPrint s
if DEBUG then IO.print "\n" else return ()

-- Caller must remove the temp file
private def saveNumpyArray (expr : String) : IO System.FilePath := do
let (_root_, file) <- IO.FS.createTempFile
let expr := s!"import numpy as np; x = {expr}; np.save('{file}', x)"
let output <- IO.Process.output { cmd := "/usr/bin/env", args := ["python3", "-c", expr].toArray }
let _ <- debugPrintln output.stdout
let _ <- debugPrintln output.stderr
-- `np.save` appends `.npy` to the file
return file.addExtension "npy"

private def testTensorElementBV (n : Nat) [Tensor.Element (BitVec n)] (dtype : String) : IO Bool := do
let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype}').reshape(5, 4)"
let npy <- Npy.parseFile file
let arr <- IO.ofExcept (Tensor.ofNpy npy)
let _ <- debugPrintln file
let _ <- debugPrintln arr
let _ <- IO.FS.removeFile file
let expected : List (BitVec n) := (List.range 20).map (BitVec.ofNat n)
let mut actual := []
for i in [0:20] do
match Tensor.Element.getPosition arr i with
| .error msg => IO.throwServerError msg
| .ok v => actual := v :: actual
actual := actual.reverse
let _ <- debugPrintln actual
return expected == actual

-- Sketchy perhaps, but seems to work for testing
private def ioBool (x : IO Bool) : Bool := match x.run () with
| .error _ _ => false
| .ok b _ => b

#guard ioBool (testTensorElementBV 16 "uint16")
#guard ! ioBool (testTensorElementBV 32 "uint16")
#guard ioBool (testTensorElementBV 32 "uint32")
#guard ! ioBool (testTensorElementBV 32 "uint64")

end Test

0 comments on commit 3e2cd3f

Please sign in to comment.