Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DifferentialEquations.jl <diffeqpy?> PyTorch/torchdyn #67

Open
Zymrael opened this issue Jul 27, 2020 · 5 comments
Open

DifferentialEquations.jl <diffeqpy?> PyTorch/torchdyn #67

Zymrael opened this issue Jul 27, 2020 · 5 comments

Comments

@Zymrael
Copy link

Zymrael commented Jul 27, 2020

We're the devs of the PyTorch-based younger brother of JuliaDiffEq, torchdyn. Since we don't have the bandwidth (nor it is our objective at the moment) to improve / reimplement the differential equation solvers available in our ecosystem, I was toying with the idea of somehow utilizing the fully-featured DifferentialEquations.jl and Julia as an additional backend option for torchdyn models. It is my understanding that diffeqpy wouldn't work out of the box for us, since it'd be interfacing with PyTorch.

It'd be nice to pick your brain around this direction @ChrisRackauckas . Do you think the overheads would be small enough such that a switch to a Julia diffeq solvers backend would benefit the PyTorch userbase?

I'll be glad to keep you up to date with our attempts if you're interested :)

@ChrisRackauckas
Copy link
Member

Hey, that's a pretty cool project. We'd love to help out. I was actually just running the torchsde benchmarks today to see how it turned out and those benchmarks convinced me that we should really make sure to contribute more to the Python community. So I was planning to try and get some things up and running during this year's JuliaCon.

One of the main things we'd like to do is make the installation more automatic. @christopher-dG do you know much about Python build systems? I am wondering if we can somehow get pyjulia vendering Julia itself, kind of like Conda.jl, so diffeqpy could be a full instsallation from pip. I was looking to see if PackageCompiler can do a static compilation of the ODE solvers too, but let's ignore that for now and look at the lower hanging fruit.

@Zymrael do you know much about direct definitions of adjoints for PyTorch? We refactored the Julia side a bit ago in preparation for this combination, so what it looks like is this. solve has a quick step that lowers to solve_up:

https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L98-L103

and then the adjoint is defined as:

https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L262-L266

The internal function _solve_adjoint takes care of the rest of the plumbing in DiffEqSensitivity.jl, doing:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L33-L136

for example as the "standard" adjoint (covering the big 3 we will want: QuadratureAdjoint, InterpolatingAdjoint, and BacksolveAdjoint) with all of the keyword argument and saving handling. So given how that's refactored, I think we could do this in like just 2 function definitions in PyTorch, I just need to figure out that interface.

Now the next difficulty involved is going to be that these adjoints will use Zygote for the vjp calculations by default. The defaults are handled at the top of that file:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L7-L26

Since this is using Zygote or ReverseDiff vjps, this might not directly work at first. So step one I think we'd test this out where the function is eval'd to live in Julia and get that working, and then try to get torch working in the vjps (which has been demonstrated before). That should be all it takes.

Do you think the overheads would be small enough such that a switch to a Julia diffeq solvers backend would benefit the PyTorch userbase?

I think there's a two-pronged approach we can take. For small problems, problems which are non-heterogeneous, like chemical combustion or quantitative systems pharmacology models, torch JIT doesn't seem to do very well. On these problems, what I want to do is hijack the functional form via ModelingToolkit.jl (https://github.com/SciML/ModelingToolkit.jl) to then directly compile the version in Julia with sparsity and all of that jazz (see my coming JuliaCon talk for more details on this system). This is the aspect I've been working on the most in terms of pyjulia performance.

That would cover a lot of scientific modeling, but I don't think that is necessary for your case which is more big heterogeneous matmul neural ODE models. In that case, the overhead should be minimal to non-existent since pyjulia passes to Julia by reference and not by copying, and so as long as the two AD systems connect well for the vjp we should be in the asymtopically large matmul case. For GPUs I think we might need to connect to https://github.com/TuringLang/ThArrays.jl to give it the right overloads in Julia but that shouldn't be fairly difficult either (pinging @KDr2 who may be interested in helping)

@christopher-dG
Copy link

While I can't claim to be an expert, I do know some things. I could certainly look into helping out here.

@KDr2
Copy link

KDr2 commented Jul 28, 2020

It's my pleasure to help too if necessary.

@Zymrael
Copy link
Author

Zymrael commented Jul 28, 2020

Thank you for the detailed response!

A first huge step would be being able to access the .jl solvers for solve steps within our current API. Before redesigning everything to include additional options / kwargs (e.g Callbacks for event handling), we should verify that the two AD systems can connect well.

We also have an adjoint class in torchdyn (https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/sensitivity/adjoint.py). It shouldn't be too big of a problem for us to refactor the API (or provide another option), modifying it to follow DiffEqBase.jl conventions.

@ChrisRackauckas
Copy link
Member

MTK performance discussion can continue in #57 and the vendering of Julia discussion to JuliaPy/pyjulia#118

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants