Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Super-minimal POC for bounded translation validation #7169

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ option(EMSCRIPTEN_ENABLE_PTHREADS "Enable pthreads in emscripten build" OFF)
# This is useful for debugging, performance analysis, and other testing.
option(EMSCRIPTEN_ENABLE_SINGLE_FILE "Enable SINGLE_FILE mode in emscripten build" ON)

option(BUILD_WASM_VALIDATE_REFINEMENT "Build the wasm-validate-refinement tool" OFF)

# For git users, attempt to generate a more useful version string
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/.git)
find_package(Git QUIET REQUIRED)
Expand Down Expand Up @@ -460,6 +462,7 @@ else()
message(STATUS "Building libbinaryen as shared library.")
add_library(binaryen SHARED ${binaryen_SOURCES} ${binaryen_objs})
endif()

target_link_libraries(binaryen ${CMAKE_THREAD_LIBS_INIT})
if(INSTALL_LIBS OR NOT BUILD_STATIC_LIB)
install(TARGETS binaryen
Expand Down
35 changes: 28 additions & 7 deletions src/passes/GlobalStructInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ struct GlobalStructInference : public Pass {
// refined, which could change the struct.get's type.
refinalize = true;
}
// No need to worry about atomic gets here. We will still read from
// the same memory location as before and preserve all side effects
// (including synchronization) that were previously present. The
// memory location is immutable anyway, so there cannot be any writes
// to synchronize with in the first place.
curr->ref = builder.makeSequence(
builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)),
builder.makeGlobalGet(global, globalType));
Expand Down Expand Up @@ -457,10 +462,18 @@ struct GlobalStructInference : public Pass {
// the early return above) so that only leaves 1 and 2.
if (values.size() == 1) {
// The case of 1 value is simple: trap if the ref is null, and
// otherwise return the value.
replaceCurrent(builder.makeSequence(
builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)),
getReadValue(values[0])));
// otherwise return the value. We must also fence if the get was
// seqcst. No additional work is necessary for a acquire get because
// there cannot have been any writes to this immutable field that it
// would synchronize with.
Expression* replacement =
builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref));
if (curr->order == MemoryOrder::SeqCst) {
replacement =
builder.blockify(replacement, builder.makeAtomicFence());
}
replaceCurrent(
builder.blockify(replacement, getReadValue(values[0])));
return;
}
assert(values.size() == 2);
Expand All @@ -486,11 +499,19 @@ struct GlobalStructInference : public Pass {
// of their execution matters (they may note globals for un-nesting).
auto* left = getReadValue(values[0]);
auto* right = getReadValue(values[1]);
// Note that we must trap on null, so add a ref.as_non_null here.
// Note that we must trap on null, so add a ref.as_non_null here. We
// must also add a fence if this get is seqcst. As before, no extra work
// is necessary for an acquire get because there cannot be a write is
// synchronizes with.
Expression* getGlobal =
builder.makeGlobalGet(checkGlobal, wasm.getGlobal(checkGlobal)->type);
if (curr->order == MemoryOrder::SeqCst) {
getGlobal =
builder.makeSequence(builder.makeAtomicFence(), getGlobal);
}
replaceCurrent(builder.makeSelect(
builder.makeRefEq(builder.makeRefAs(RefAsNonNull, curr->ref),
builder.makeGlobalGet(
checkGlobal, wasm.getGlobal(checkGlobal)->type)),
getGlobal),
left,
right));
}
Expand Down
9 changes: 9 additions & 0 deletions src/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,14 @@ if(NOT BUILD_EMSCRIPTEN_TOOLS_ONLY)
binaryen_add_executable(wasm-fuzz-types "${fuzzing_SOURCES};wasm-fuzz-types.cpp")
binaryen_add_executable(wasm-fuzz-lattices "${fuzzing_SOURCES};wasm-fuzz-lattices.cpp")
endif()
if(BUILD_WASM_VALIDATE_REFINEMENT)
binaryen_add_executable(wasm-validate-refinement wasm-validate-refinement.cpp)
add_library(z3 SHARED IMPORTED)
if(NOT LIBZ3_LOCATION)
find_library(LIBZ3_LOCATION z3)
endif()
set_property(TARGET z3 PROPERTY IMPORTED_LOCATION ${LIBZ3_LOCATION})
target_link_libraries(wasm-validate-refinement z3)
endif()

add_subdirectory(wasm-split)
185 changes: 185 additions & 0 deletions src/tools/wasm-validate-refinement.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "support/command-line.h"
#include "wasm-io.h"

#include <iostream>
#include <z3++.h>

using namespace wasm;

struct ToSMT : UnifiedExpressionVisitor<ToSMT, z3::expr> {
z3::context& ctx;
Function* func;
std::vector<z3::expr> params;

ToSMT(z3::context& ctx, Function* func) : ctx(ctx), func(func) {
initParams(func);
}

void initParams(Function* func) {
for (Index i = 0; i < func->getNumParams(); ++i) {
auto type = func->getLocalType(i);
auto name = func->getLocalNameOrGeneric(i).str.data();
if (type.isBasic()) {
switch (type.getBasic()) {
case Type::none:
case Type::unreachable:
case Type::f32:
case Type::f64:
break;
case Type::i32:
params.push_back(ctx.bv_const(name, 32));
continue;
case Type::i64:
params.push_back(ctx.bv_const(name, 64));
continue;
case Type::v128:
params.push_back(ctx.bv_const(name, 128));
continue;
}
}
WASM_UNREACHABLE("unimplemented param type");
}
}

z3::expr visitExpression(Expression* curr) {
WASM_UNREACHABLE("unimplemented expression");
}

z3::expr visitLocalGet(LocalGet* curr) {
assert(curr->index < func->getNumParams() && "TODO");
return params[curr->index];
}

z3::expr visitConst(Const* curr) {
assert(curr->type.isBasic());
switch (curr->type.getBasic()) {
case Type::none:
case Type::unreachable:
break;
case Type::f32:
case Type::f64:
WASM_UNREACHABLE("TODO: fp const");
case Type::i32:
return ctx.bv_val(curr->value.geti32(), 32);
case Type::i64:
return ctx.bv_val(curr->value.geti64(), 64);
case Type::v128:
WASM_UNREACHABLE("TODO: v128.const");
}
WASM_UNREACHABLE("unexpected type");
}

z3::expr visitBinary(Binary* curr) {
auto lhs = visit(curr->left);
auto rhs = visit(curr->right);
switch (curr->op) {
case MulInt32:
return lhs * rhs;
case ShlInt32:
return z3::shl(lhs, rhs);
default:
break;
}
WASM_UNREACHABLE("unimplemented binary op");
}
};

z3::expr funcToSMT(z3::context& ctx, Function* func) {
return ToSMT(ctx, func).visit(func->body);
}

z3::expr refinedBy(const z3::expr& src, const z3::expr& tgt) {
// TODO: Something more complicated!
return tgt == src;
}

void prove(const z3::expr& conjecture) {
z3::context& ctx = conjecture.ctx();
z3::solver solver(ctx);
solver.add(!conjecture);
std::cout << "Proving conjecture:\n" << conjecture << "\n";
if (solver.check() == z3::unsat) {
std::cout << "proved!\n";
} else {
std::cout << "counterexample:\n" << solver.get_model() << "\n";
}
}

void checkRefinement(Function* src, Function* tgt) {
z3::context ctx;
auto srcSMT = funcToSMT(ctx, src);
auto tgtSMT = funcToSMT(ctx, tgt);
prove(refinedBy(srcSMT, tgtSMT));
}

struct ValidateRefinementOptions : Options {
std::string source;
std::string target;
ValidateRefinementOptions(const std::string& command, const std::string& desc)
: Options(command, desc) {
add("--source",
"-s",
"The original module",
"",
Arguments::One,
[&](Options*, const std::string& val) { source = val; });
add("--target",
"-t",
"The transformed module",
"",
Arguments::One,
[&](Options*, const std::string& val) { target = val; });
}
};

int main(int argc, const char* argv[]) {
ValidateRefinementOptions options(
"wasm-validate-refinement",
"Bounded translation validation for WebAssembly");

options.parse(argc, argv);

if (options.source.empty()) {
std::cerr << "Source module must be provided (--source)\n";
return 1;
}

if (options.target.empty()) {
std::cerr << "Target module must be provided (--target)\n";
return 1;
}

Module src, tgt;

ModuleReader().read(options.source, src);
ModuleReader().read(options.target, tgt);

// TODO: Verify that src and tgt have matching global structures, including
// function signatures.

for (size_t i = 0; i < src.functions.size(); ++i) {
if (src.functions[i]->imported()) {
continue;
}

assert(i < tgt.functions.size() && !tgt.functions[i]->imported());

checkRefinement(src.functions[i].get(), tgt.functions[i].get());
}
}
Loading
Loading