From abe35e606c21c8e1f12d4082cd9ed668930fa524 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Mon, 25 Oct 2021 16:20:47 -0400 Subject: [PATCH] fix #41546, make `using` thread-safe (#41602) use more precision when handling loading lock, merge with TOML lock (since we typically are needing both, sometimes in unpredictable orders), and unlock before call most user code Co-authored-by: Jameson Nash --- base/loading.jl | 76 ++++++++++++++++++++++++++++++-------------- base/toml_parser.jl | 2 +- src/dump.c | 5 +-- test/threads_exec.jl | 28 ++++++++++++++++ 4 files changed, 84 insertions(+), 27 deletions(-) diff --git a/base/loading.jl b/base/loading.jl index 1760da1efdbdf3..30f7bd25a1160e 100644 --- a/base/loading.jl +++ b/base/loading.jl @@ -1,6 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license # Base.require is the implementation for the `import` statement +const require_lock = ReentrantLock() # Cross-platform case-sensitive path canonicalization @@ -129,6 +130,7 @@ end const ns_dummy_uuid = UUID("fe0723d6-3a44-4c41-8065-ee0f42c8ceab") function dummy_uuid(project_file::String) + @lock require_lock begin cache = LOADING_CACHE[] if cache !== nothing uuid = get(cache.dummy_uuid, project_file, nothing) @@ -144,6 +146,7 @@ function dummy_uuid(project_file::String) cache.dummy_uuid[project_file] = uuid end return uuid + end end ## package path slugs: turning UUID + SHA1 into a pair of 4-byte "slugs" ## @@ -236,8 +239,7 @@ struct TOMLCache end const TOML_CACHE = TOMLCache(TOML.Parser(), Dict{String, Dict{String, Any}}()) -const TOML_LOCK = ReentrantLock() -parsed_toml(project_file::AbstractString) = parsed_toml(project_file, TOML_CACHE, TOML_LOCK) +parsed_toml(project_file::AbstractString) = parsed_toml(project_file, TOML_CACHE, require_lock) function parsed_toml(project_file::AbstractString, toml_cache::TOMLCache, toml_lock::ReentrantLock) lock(toml_lock) do cache = LOADING_CACHE[] @@ -337,13 +339,15 @@ Use [`dirname`](@ref) to get the directory part and [`basename`](@ref) to get the file name part of the path. """ function pathof(m::Module) - pkgid = get(Base.module_keys, m, nothing) + @lock require_lock begin + pkgid = get(module_keys, m, nothing) pkgid === nothing && return nothing - origin = get(Base.pkgorigins, pkgid, nothing) + origin = get(pkgorigins, pkgid, nothing) origin === nothing && return nothing path = origin.path path === nothing && return nothing return fixup_stdlib_path(path) + end end """ @@ -366,7 +370,7 @@ julia> pkgdir(Foo, "src", "file.jl") The optional argument `paths` requires at least Julia 1.7. """ function pkgdir(m::Module, paths::String...) - rootmodule = Base.moduleroot(m) + rootmodule = moduleroot(m) path = pathof(rootmodule) path === nothing && return nothing return joinpath(dirname(dirname(path)), paths...) @@ -383,6 +387,7 @@ const preferences_names = ("JuliaLocalPreferences.toml", "LocalPreferences.toml" # - `true`: `env` is an implicit environment # - `path`: the path of an explicit project file function env_project_file(env::String)::Union{Bool,String} + @lock require_lock begin cache = LOADING_CACHE[] if cache !== nothing project_file = get(cache.env_project_file, env, nothing) @@ -406,6 +411,7 @@ function env_project_file(env::String)::Union{Bool,String} cache.env_project_file[env] = project_file end return project_file + end end function project_deps_get(env::String, name::String)::Union{Nothing,PkgId} @@ -473,6 +479,7 @@ end # find project file's corresponding manifest file function project_file_manifest_path(project_file::String)::Union{Nothing,String} + @lock require_lock begin cache = LOADING_CACHE[] if cache !== nothing manifest_path = get(cache.project_file_manifest_path, project_file, missing) @@ -501,6 +508,7 @@ function project_file_manifest_path(project_file::String)::Union{Nothing,String} cache.project_file_manifest_path[project_file] = manifest_path end return manifest_path + end end # given a directory (implicit env from LOAD_PATH) and a name, @@ -688,7 +696,7 @@ function implicit_manifest_deps_get(dir::String, where::PkgId, name::String)::Un @assert where.uuid !== nothing project_file = entry_point_and_project_file(dir, where.name)[2] project_file === nothing && return nothing # a project file is mandatory for a package with a uuid - proj = project_file_name_uuid(project_file, where.name, ) + proj = project_file_name_uuid(project_file, where.name) proj == where || return nothing # verify that this is the correct project file # this is the correct project, so stop searching here pkg_uuid = explicit_project_deps_get(project_file, name) @@ -753,19 +761,26 @@ function _include_from_serialized(path::String, depmods::Vector{Any}) if isa(sv, Exception) return sv end - restored = sv[1] - if !isa(restored, Exception) - for M in restored::Vector{Any} - M = M::Module - if isdefined(M, Base.Docs.META) - push!(Base.Docs.modules, M) - end - if parentmodule(M) === M - register_root_module(M) - end + sv = sv::SimpleVector + restored = sv[1]::Vector{Any} + for M in restored + M = M::Module + if isdefined(M, Base.Docs.META) + push!(Base.Docs.modules, M) + end + if parentmodule(M) === M + register_root_module(M) + end + end + inits = sv[2]::Vector{Any} + if !isempty(inits) + unlock(require_lock) # temporarily _unlock_ during these callbacks + try + ccall(:jl_init_restored_modules, Cvoid, (Any,), inits) + finally + lock(require_lock) end end - isassigned(sv, 2) && ccall(:jl_init_restored_modules, Cvoid, (Any,), sv[2]) return restored end @@ -873,7 +888,7 @@ function _require_search_from_serialized(pkg::PkgId, sourcepath::String, depth:: end # to synchronize multiple tasks trying to import/using something -const package_locks = Dict{PkgId,Condition}() +const package_locks = Dict{PkgId,Threads.Condition}() # to notify downstream consumers that a module was successfully loaded # Callbacks take the form (mod::Base.PkgId) -> nothing. @@ -896,7 +911,9 @@ function _include_dependency(mod::Module, _path::AbstractString) path = normpath(joinpath(dirname(prev), _path)) end if _track_dependencies[] + @lock require_lock begin push!(_require_dependencies, (mod, path, mtime(path))) + end end return path, prev end @@ -968,6 +985,7 @@ For more details regarding code loading, see the manual sections on [modules](@r [parallel computing](@ref code-availability). """ function require(into::Module, mod::Symbol) + @lock require_lock begin LOADING_CACHE[] = LoadingCache() try uuidkey = identify_package(into, String(mod)) @@ -1019,6 +1037,7 @@ function require(into::Module, mod::Symbol) finally LOADING_CACHE[] = nothing end + end end mutable struct PkgOrigin @@ -1030,6 +1049,7 @@ PkgOrigin() = PkgOrigin(nothing, nothing) const pkgorigins = Dict{PkgId,PkgOrigin}() function require(uuidkey::PkgId) + @lock require_lock begin if !root_module_exists(uuidkey) cachefile = _require(uuidkey) if cachefile !== nothing @@ -1041,15 +1061,19 @@ function require(uuidkey::PkgId) end end return root_module(uuidkey) + end end const loaded_modules = Dict{PkgId,Module}() const module_keys = IdDict{Module,PkgId}() # the reverse -is_root_module(m::Module) = haskey(module_keys, m) -root_module_key(m::Module) = module_keys[m] +is_root_module(m::Module) = @lock require_lock haskey(module_keys, m) +root_module_key(m::Module) = @lock require_lock module_keys[m] function register_root_module(m::Module) + # n.b. This is called from C after creating a new module in `Base.__toplevel__`, + # instead of adding them to the binding table there. + @lock require_lock begin key = PkgId(m, String(nameof(m))) if haskey(loaded_modules, key) oldm = loaded_modules[key] @@ -1059,6 +1083,7 @@ function register_root_module(m::Module) end loaded_modules[key] = m module_keys[m] = key + end nothing end @@ -1074,12 +1099,13 @@ using Base end # get a top-level Module from the given key -root_module(key::PkgId) = loaded_modules[key] +root_module(key::PkgId) = @lock require_lock loaded_modules[key] root_module(where::Module, name::Symbol) = root_module(identify_package(where, String(name))) +maybe_root_module(key::PkgId) = @lock require_lock get(loaded_modules, key, nothing) -root_module_exists(key::PkgId) = haskey(loaded_modules, key) -loaded_modules_array() = collect(values(loaded_modules)) +root_module_exists(key::PkgId) = @lock require_lock haskey(loaded_modules, key) +loaded_modules_array() = @lock require_lock collect(values(loaded_modules)) function unreference_module(key::PkgId) if haskey(loaded_modules, key) @@ -1098,7 +1124,7 @@ function _require(pkg::PkgId) wait(loading) return end - package_locks[pkg] = Condition() + package_locks[pkg] = Threads.Condition(require_lock) last = toplevel_load[] try @@ -1166,10 +1192,12 @@ function _require(pkg::PkgId) if uuid !== old_uuid ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), __toplevel__, uuid) end + unlock(require_lock) try include(__toplevel__, path) return finally + lock(require_lock) if uuid !== old_uuid ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), __toplevel__, old_uuid) end diff --git a/base/toml_parser.jl b/base/toml_parser.jl index 4b2af426429a09..66db0e56955513 100644 --- a/base/toml_parser.jl +++ b/base/toml_parser.jl @@ -104,7 +104,7 @@ function Parser(str::String; filepath=nothing) IdSet{TOMLDict}(), # defined_tables root, filepath, - isdefined(Base, :loaded_modules) ? get(Base.loaded_modules, DATES_PKGID, nothing) : nothing, + isdefined(Base, :maybe_root_module) ? Base.maybe_root_module(DATES_PKGID) : nothing, ) startup(l) return l diff --git a/src/dump.c b/src/dump.c index 84ab325bd08a12..53b9bc3f0a719a 100644 --- a/src/dump.c +++ b/src/dump.c @@ -2228,8 +2228,6 @@ static jl_array_t *jl_finalize_deserializer(jl_serializer_state *s, arraylist_t JL_DLLEXPORT void jl_init_restored_modules(jl_array_t *init_order) { - if (!init_order) - return; int i, l = jl_array_len(init_order); for (i = 0; i < l; i++) { jl_value_t *mod = jl_array_ptr_ref(init_order, i); @@ -2683,6 +2681,9 @@ static jl_value_t *_jl_restore_incremental(ios_t *f, jl_array_t *mod_array) jl_recache_other(); // make all of the other objects identities correct (needs to be after insert methods) htable_free(&uniquing_table); jl_array_t *init_order = jl_finalize_deserializer(&s, tracee_list); // done with f and s (needs to be after recache) + if (init_order == NULL) + init_order = (jl_array_t*)jl_an_empty_vec_any; + assert(jl_isa((jl_value_t*)init_order, jl_array_any_type)); JL_GC_PUSH4(&init_order, &restored, &external_backedges, &external_edges); jl_gc_enable(en); // subtyping can allocate a lot, not valid before recache-other diff --git a/test/threads_exec.jl b/test/threads_exec.jl index f3d2dc9577c64d..b4c28d20b89cd3 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -912,3 +912,31 @@ end @test reproducible_rand(r, 10) == val end end + +# issue #41546, thread-safe package loading +@testset "package loading" begin + ch = Channel{Bool}(nthreads()) + barrier = Base.Event() + old_act_proj = Base.ACTIVE_PROJECT[] + try + pushfirst!(LOAD_PATH, "@") + Base.ACTIVE_PROJECT[] = joinpath(@__DIR__, "TestPkg") + @sync begin + for _ in 1:nthreads() + Threads.@spawn begin + put!(ch, true) + wait(barrier) + @eval using TestPkg + end + end + for _ in 1:nthreads() + take!(ch) + end + notify(barrier) + end + @test Base.root_module(@__MODULE__, :TestPkg) isa Module + finally + Base.ACTIVE_PROJECT[] = old_act_proj + popfirst!(LOAD_PATH) + end +end