Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CLI and some light testing #24

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,80 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Init.System.IO
import Cli
import TensorLib

def main (_args : List String) : IO UInt32 :=
open Cli
open TensorLib

def format (p : Parsed) : IO UInt32 := do
let shape : Shape := Shape.mk (p.variableArgsAs! Nat).toList
IO.println s!"Got shape {shape}"
let range := Tensor.Element.arange BV16 shape.count
let v := range.reshape! shape
let s := v.format BV16
IO.println s
return 0

def formatCmd := `[Cli|
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat. We should use it in NKL as well.

"format" VIA format;
"Test formatting"

ARGS:
...shape : Nat; "shape to test"
]

def parseNpy (p : Parsed) : IO UInt32 := do
let file := p.positionalArg! "input" |>.as! String
IO.println s!"Parsing {file}..."
let v <- Npy.parseFile file
IO.println (repr v)
if p.hasFlag "write" then do
let new := (System.FilePath.mk file).addExtension "new"
IO.println s!"Writing copy to {new}"
let _ <- v.save! new
-- TensorLib.Npy.save! (arr : Ndarray) (file : System.FilePath) : IO Unit
return 0

def parseNpyCmd := `[Cli|
"parse-npy" VIA parseNpy;
"Parse a .npy file and pretty print the contents"

FLAGS:
write; "Also write the result back to `input`.new to test saving arrays to disk"

ARGS:
input : String; ".npy file to parse"
]

def runTests (_ : Parsed) : IO UInt32 := do
-- Just pytest for now, but add Lean tests here as well
-- pytest will exit nonzero on it's own, so we don't need to check exit code
IO.println "Running PyTest..."
let output <- IO.Process.output { cmd := "pytest" }
IO.println s!"stdout: {output.stdout}"
IO.println s!"stderr: {output.stderr}"
return output.exitCode

def runTestsCmd := `[Cli|
"test" VIA runTests;
"Run tests"
]

def tensorlibCmd : Cmd := `[Cli|
tensorlib NOOP; ["0.0.1"]
"TensorLib is a NumPy-like library for Lean."

SUBCOMMANDS:
formatCmd;
parseNpyCmd;
runTestsCmd
]

def main (args : List String) : IO UInt32 :=
if args.isEmpty then do
IO.println tensorlibCmd.help
return 0
else do
tensorlibCmd.validate args
1 change: 1 addition & 0 deletions TensorLib/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ import TensorLib.Tensor
import TensorLib.Index
import TensorLib.Mgrid
import TensorLib.Ufunc
import TensorLib.Test
57 changes: 57 additions & 0 deletions TensorLib/Test.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Init.System.IO
import TensorLib.Tensor

namespace TensorLib

section Test

private def DEBUG := false
private def debugPrint {a : Type} [Repr a] (s : a) : IO Unit := if DEBUG then IO.print (Std.Format.pretty (repr s)) else return ()
private def debugPrintln {a : Type} [Repr a] (s : a) : IO Unit := do
debugPrint s
if DEBUG then IO.print "\n" else return ()

-- Caller must remove the temp file
private def saveNumpyArray (expr : String) : IO System.FilePath := do
let (_root_, file) <- IO.FS.createTempFile
let expr := s!"import numpy as np; x = {expr}; np.save('{file}', x)"
let output <- IO.Process.output { cmd := "/usr/bin/env", args := ["python3", "-c", expr].toArray }
let _ <- debugPrintln output.stdout
let _ <- debugPrintln output.stderr
-- `np.save` appends `.npy` to the file
return file.addExtension "npy"

private def testTensorElementBV (n : Nat) [Tensor.Element (BitVec n)] (dtype : String) : IO Bool := do
let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype}').reshape(5, 4)"
let npy <- Npy.parseFile file
let arr <- IO.ofExcept (Tensor.ofNpy npy)
let _ <- debugPrintln file
let _ <- debugPrintln arr
let _ <- IO.FS.removeFile file
let expected : List (BitVec n) := (List.range 20).map (BitVec.ofNat n)
let mut actual := []
for i in [0:20] do
match Tensor.Element.getPosition arr i with
| .error msg => IO.throwServerError msg
| .ok v => actual := v :: actual
actual := actual.reverse
let _ <- debugPrintln actual
return expected == actual

-- Sketchy perhaps, but seems to work for testing
private def ioBool (x : IO Bool) : Bool := match x.run () with
| .error _ _ => false
| .ok b _ => b

#guard ioBool (testTensorElementBV 16 "uint16")
#guard ! ioBool (testTensorElementBV 32 "uint16")
#guard ioBool (testTensorElementBV 32 "uint32")
#guard ! ioBool (testTensorElementBV 32 "uint64")

end Test
Loading