-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement truncation_error keyword arg for truncate! #99
base: main
Are you sure you want to change the base?
implement truncation_error keyword arg for truncate! #99
Conversation
One thing I noticed while playing around with this new feature is that sometimes the truncation error exceeds the cutoff. But I thought it should be that The following MWE illustrates this. I'd be curious what you think. If you want I can open a separate issue about this if/once this has been merged. d = 2
N = 7
n = d^N
sinds = siteinds(d, N)
x = collect(LinRange(0, 1, n))
y = @. 2 * x + 3 + sin(2π * x)
χ = 8
truncation_error = Ref{Float64}()
cutoff_ = 1E-4
Y = MPS(y, sinds, maxdim=χ, cutoff=0.0)
truncation_error[] = 0.0
truncate!(Y, maxdim=χ, cutoff=cutoff_, truncation_error=truncation_error)
truncation_error[]
truncation_error[] > cutoff_ # returns true!
cutoff_ = 1E-5
Y = MPS(y, sinds, maxdim=χ, cutoff=0.0)
truncation_error[] = 0.0
truncate!(Y, maxdim=χ, cutoff=cutoff_, truncation_error=truncation_error)
truncation_error[]
truncation_error[] > cutoff_ # returns false, as expected |
The cutoff refers to the truncation of each SVD performed, if multiple SVDs are performed the total error could add up to a value larger than the cutoff. It could make sense to store a truncation error for each bond of the MPS. That's one reason why I'm a bit hesitant about this PR, since I'd prefer to think about this more generally in terms of what kinds of other information we might want to output and how it should get output. |
Ah, duh. Yeah so I will push an update here shortly and make it store per bond truncation error which is the more sensible thing to do. Thanks for the feedback. |
This commit refactors the original implemenation. `truncate!` now expects the user to pass a pointer to a vector of floats with as many elements as there are bonds in the MPS. It will then store the truncation error of each bond in the vector. The corresponding test was updated. The package was re-tested with the same results described in the PR (75165 passing and 33 broken for total of 75198 tests). The docstring of `truncate!` was updated to reflect the new behavior.
Any other comments or changes you'd like me to make? Or to discuss? I don't mind if a different system is opted for but I do think it is useful to be able to inspect the truncation error. |
test/base/test_mps.jl
Outdated
nbonds = N - 1 | ||
M = basicRandomMPS(10; dim=10) | ||
truncation_errors = Ref{Vector{Float64}}() | ||
truncation_errors[] = fill(-1.0, nbonds) # set to something other than zero for test. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to design the code so that users don't have to initialize the list ahead of time. We could add a line:
if isnothing(truncation_errors!)
truncation_errors![] = Vector{real(scalartype(M))}(undef, length(M) - 1)
end
at the top of truncate!(...)
(note the change to truncation_errors!
based on my comment above).
Change new kwarg name to `truncation_errors!` from `truncation_errors` Remove usage of `enumerate` Update docstring Co-authored-by: Matt Fishman <[email protected]>
I updated the PR. Could you please take a look and let me know what you think? |
Any update here? If I need to rewrite I don't mind. Just looking for some feedback. Or if there are no plans to expose the truncation information then I can just close this PR. I still think it would be useful but I can also get it via the local branch I've modified for this PR. |
Sorry for the slow response. I want to make sure we use a design that is as "future proof" as possible, by which I mean it:
I would say getting all of these properties is almost a "code research" question, from the perspective that we are still evolving in the way we are thinking about how to do this, and also I haven't seen many approaches I am happy with that ticks all of these boxes. So, given that broader context and design considerations, here is my latest proposal for how to design this feature: function truncate!(
::Algorithm"frobenius",
M::AbstractMPS;
site_range=1:length(M),
(callback!)=Returns(nothing),
kwargs...,
)
N = length(M)
# Left-orthogonalize all tensors to make
# truncations controlled
orthogonalize!(M, last(site_range))
# Perform truncations in a right-to-left sweep
for j in reverse((first(site_range) + 1):last(site_range))
rinds = uniqueinds(M[j], M[j - 1])
ltags = tags(commonind(M[j], M[j - 1]))
U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...)
M[j] = U
M[j - 1] *= (S * V)
setrightlim!(M, j)
callback!(; link=(j => j - 1), truncation_error=spec.truncerr)
end
return M
end Then, a user can save the truncation error from each link/bond like this: using ITensorMPS: maxlinkdim, random_mps, siteinds, truncate!
s = siteinds("S=1/2", 10)
ψ = random_mps(s; linkdims=6)
@show maxlinkdim(ψ)
truncation_errors = Dict{Pair{Int,Int},Float64}()
function callback!(; link, truncation_error, kwargs...)
truncation_errors[link] = truncation_error
end
ψ′ = truncate!(copy(ψ); maxdim=3, callback!)
@show maxlinkdim(ψ′) which when run outputs something like this: julia> truncation_errors
Dict{Pair{Int64, Int64}, Float64} with 9 entries:
5=>4 => 0.104644
8=>7 => 0.122511
2=>1 => 0.0
9=>8 => 0.153664
10=>9 => 0.0
7=>6 => 0.096218
6=>5 => 0.0617276
4=>3 => 0.0620809
3=>2 => 0.0378573 You can see this design would allow for a lot of customizability, since in the future we could decide to pass more data from within |
That makes good sense and I agree those points are all important and balancing them is basically a research question as you say. I like this pattern though and I will update the PR later today to reflect this (including updated tests). Thank you! |
I've updated the PR. Please let me know if you have any comments or questions. |
truncation_errors[bond_no] = truncation_error | ||
return nothing | ||
end | ||
truncate!(ψ, maxdim=5, cutoff=1E-7, (callback!)=callback!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
truncate!(ψ, maxdim=5, cutoff=1E-7, (callback!)=callback!) | |
truncate!(ψ, maxdim=5, cutoff=1E-7, callback!) |
@testset "truncate! with callback!" begin | ||
nsites = 10 | ||
nbonds = nsites - 1 | ||
mps_ = basicRandomMPS(nsites; dim=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mps_ = basicRandomMPS(nsites; dim=10) | |
mps_ = random_mps(nsites; dim=10) |
Was there a reason to use basicRandomMPS
as opposed to random_mps
? I don't remember what basicRandomMPS
was defined for in the first place.
Co-authored-by: Matt Fishman <[email protected]>
The purpose of this commit is to allow the user to access the truncation error that is internally calculated in a call to
truncate!
. This PR implements #96.I have implemented this by allowing the user to pass a Ref object to the call to
truncate!
. The result of each SVD performed during a call totruncate!
is then accumulated in that Ref.I also added a new test for this functionality. I then re-ran all tests. There were 75165 passing and 33 broken for a total of 75198 tests. I also updated the docstring for
truncate!
to reflect this new keyword argument.Here is an example of the new functionality.
A few differences with what was suggested in #96:
!
in my keyword argument. Was that just for stylistic reasons or is there some other convention for doing that (I'm aware of the convention in function names but not in variable names)?