Skip to content

Commit

Permalink
Expand JET test (#782)
Browse files Browse the repository at this point in the history
* Fix JET test

* Add output if a typed varinfo isn't inferred
  • Loading branch information
penelopeysm authored Jan 27, 2025
1 parent 727da63 commit 7a140bc
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions test/ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Use debug logging below.
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
# They should all result in typed.
@test varinfo isa DynamicPPL.TypedVarInfo
# But let's also make sure that they're not lying.
# Check that the inferred varinfo is indeed suitable for evaluation and sampling
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo
)
Expand All @@ -76,6 +74,21 @@
model, varinfo, DynamicPPL.SamplingContext()
)
JET.test_call(f_sample, argtypes_sample)
# For our demo models, they should all result in typed.
is_typed = varinfo isa DynamicPPL.TypedVarInfo
@test is_typed
# If the test failed, check why it didn't infer a typed varinfo
if !is_typed
typed_vi = VarInfo(model)
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, typed_vi
)
JET.test_call(f_eval, argtypes_eval)
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, typed_vi, DynamicPPL.SamplingContext()
)
JET.test_call(f_sample, argtypes_sample)
end
end
end
end

0 comments on commit 7a140bc

Please sign in to comment.