Skip to content

Commit

Permalink
refactor(protocol): refactor SignalService to remove duplicate code (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dantaik authored Apr 5, 2024
1 parent 8747d61 commit 5a148ef
Showing 1 changed file with 102 additions and 117 deletions.
219 changes: 102 additions & 117 deletions packages/protocol/contracts/signal/SignalService.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ contract SignalService is EssentialContract, ISignalService {

uint256[48] private __gap;

struct CacheAction {
bytes32 rootHash;
bytes32 signalRoot;
uint64 chainId;
uint64 blockId;
bool isFullProof;
bool isLastHop;
CacheOption option;
}

error SS_EMPTY_PROOF();
error SS_INVALID_HOPS_WITH_LOOP();
error SS_INVALID_SENDER();
Expand Down Expand Up @@ -91,56 +101,10 @@ contract SignalService is EssentialContract, ISignalService {
validSender(_app)
nonZeroValue(_signal)
{
HopProof[] memory hopProofs = abi.decode(_proof, (HopProof[]));
if (hopProofs.length == 0) revert SS_EMPTY_PROOF();
uint256 lenLessOne;
unchecked {
lenLessOne = hopProofs.length - 1;
}
uint64[] memory trace = new uint64[](lenLessOne);

uint64 chainId = _chainId;
address app = _app;
bytes32 signal = _signal;
bytes32 value = _signal;
address signalService = resolve(chainId, "signal_service", false);

HopProof memory hop;
for (uint256 i; i < hopProofs.length; ++i) {
hop = hopProofs[i];

for (uint256 j; j < i; ++j) {
if (trace[j] == hop.chainId) revert SS_INVALID_HOPS_WITH_LOOP();
}

bytes32 signalRoot = _verifyHopProof(chainId, app, signal, value, hop, signalService);
bool isLastHop = i == lenLessOne;

if (isLastHop) {
if (hop.chainId != block.chainid) revert SS_INVALID_LAST_HOP_CHAINID();
signalService = address(this);
} else {
trace[i] = hop.chainId;

if (hop.chainId == 0 || hop.chainId == block.chainid) {
revert SS_INVALID_MID_HOP_CHAINID();
}
signalService = resolve(hop.chainId, "signal_service", false);
}

bool isFullProof = hop.accountProof.length != 0;

_cacheChainData(hop, chainId, hop.blockId, signalRoot, isFullProof, isLastHop);

bytes32 kind = isFullProof ? LibSignals.STATE_ROOT : LibSignals.SIGNAL_ROOT;
signal = signalForChainData(chainId, kind, hop.blockId);
value = hop.rootHash;
chainId = hop.chainId;
app = signalService;
}
CacheAction[] memory actions = _verifySignalReceived(_chainId, _app, _signal, _proof);

if (value == 0 || value != _loadSignalValue(address(this), signal)) {
revert SS_SIGNAL_NOT_FOUND();
for (uint256 i; i < actions.length; ++i) {
_cache(actions[i]);
}
}

Expand All @@ -157,56 +121,7 @@ contract SignalService is EssentialContract, ISignalService {
validSender(_app)
nonZeroValue(_signal)
{
HopProof[] memory hopProofs = abi.decode(_proof, (HopProof[]));
if (hopProofs.length == 0) revert SS_EMPTY_PROOF();

uint256 lenLessOne;
unchecked {
lenLessOne = hopProofs.length - 1;
}
uint64[] memory trace = new uint64[](lenLessOne);

uint64 chainId = _chainId;
address app = _app;
bytes32 signal = _signal;
bytes32 value = _signal;
address signalService = resolve(chainId, "signal_service", false);

HopProof memory hop;

for (uint256 i; i < hopProofs.length; ++i) {
hop = hopProofs[i];

for (uint256 j; j < i; ++j) {
if (trace[j] == hop.chainId) revert SS_INVALID_HOPS_WITH_LOOP();
}

_verifyHopProof(chainId, app, signal, value, hop, signalService);

if (i == lenLessOne) {
if (hop.chainId != block.chainid) revert SS_INVALID_LAST_HOP_CHAINID();
signalService = address(this);
} else {
trace[i] = hop.chainId;

if (hop.chainId == 0 || hop.chainId == block.chainid) {
revert SS_INVALID_MID_HOP_CHAINID();
}
signalService = resolve(hop.chainId, "signal_service", false);
}

bool isFullProof = hop.accountProof.length != 0;

bytes32 kind = isFullProof ? LibSignals.STATE_ROOT : LibSignals.SIGNAL_ROOT;
signal = signalForChainData(chainId, kind, hop.blockId);
value = hop.rootHash;
chainId = hop.chainId;
app = signalService;
}

if (value == 0 || value != _loadSignalValue(address(this), signal)) {
revert SS_SIGNAL_NOT_FOUND();
}
_verifySignalReceived(_chainId, _app, _signal, _proof);
}

/// @inheritdoc ISignalService
Expand Down Expand Up @@ -345,30 +260,25 @@ contract SignalService is EssentialContract, ISignalService {
emit SignalSent(_app, _signal, slot_, _value);
}

function _cacheChainData(
HopProof memory _hop,
uint64 _chainId,
uint64 _blockId,
bytes32 _signalRoot,
bool _isFullProof,
bool _isLastHop
)
private
{
function _cache(CacheAction memory _action) private {
// cache state root
bool cacheStateRoot = _hop.cacheOption == CacheOption.CACHE_BOTH
|| _hop.cacheOption == CacheOption.CACHE_STATE_ROOT;
bool cacheStateRoot = _action.option == CacheOption.CACHE_BOTH
|| _action.option == CacheOption.CACHE_STATE_ROOT;

if (cacheStateRoot && _isFullProof && !_isLastHop) {
_syncChainData(_chainId, LibSignals.STATE_ROOT, _blockId, _hop.rootHash);
if (cacheStateRoot && _action.isFullProof && !_action.isLastHop) {
_syncChainData(
_action.chainId, LibSignals.STATE_ROOT, _action.blockId, _action.rootHash
);
}

// cache signal root
bool cacheSignalRoot = _hop.cacheOption == CacheOption.CACHE_BOTH
|| _hop.cacheOption == CacheOption.CACHE_SIGNAL_ROOT;
bool cacheSignalRoot = _action.option == CacheOption.CACHE_BOTH
|| _action.option == CacheOption.CACHE_SIGNAL_ROOT;

if (cacheSignalRoot && (_isFullProof || !_isLastHop)) {
_syncChainData(_chainId, LibSignals.SIGNAL_ROOT, _blockId, _signalRoot);
if (cacheSignalRoot && (_action.isFullProof || !_action.isLastHop)) {
_syncChainData(
_action.chainId, LibSignals.SIGNAL_ROOT, _action.blockId, _action.signalRoot
);
}
}

Expand All @@ -387,4 +297,79 @@ contract SignalService is EssentialContract, ISignalService {
value_ := sload(slot)
}
}

function _verifySignalReceived(
uint64 _chainId,
address _app,
bytes32 _signal,
bytes calldata _proof
)
private
view
validSender(_app)
nonZeroValue(_signal)
returns (CacheAction[] memory actions)
{
HopProof[] memory hopProofs = abi.decode(_proof, (HopProof[]));
if (hopProofs.length == 0) revert SS_EMPTY_PROOF();

uint64[] memory trace = new uint64[](hopProofs.length - 1);
actions = new CacheAction[](hopProofs.length);

uint64 chainId = _chainId;
address app = _app;
bytes32 signal = _signal;
bytes32 value = _signal;
address signalService = resolve(chainId, "signal_service", false);

HopProof memory hop;
bytes32 signalRoot;
bool isFullProof;
bool isLastHop;

for (uint256 i; i < hopProofs.length; ++i) {
hop = hopProofs[i];

for (uint256 j; j < i; ++j) {
if (trace[j] == hop.chainId) revert SS_INVALID_HOPS_WITH_LOOP();
}

signalRoot = _verifyHopProof(chainId, app, signal, value, hop, signalService);
isLastHop = i == trace.length;
if (isLastHop) {
if (hop.chainId != block.chainid) revert SS_INVALID_LAST_HOP_CHAINID();
signalService = address(this);
} else {
trace[i] = hop.chainId;

if (hop.chainId == 0 || hop.chainId == block.chainid) {
revert SS_INVALID_MID_HOP_CHAINID();
}
signalService = resolve(hop.chainId, "signal_service", false);
}

isFullProof = hop.accountProof.length != 0;

actions[i] = CacheAction(
hop.rootHash,
signalRoot,
chainId,
hop.blockId,
isFullProof,
isLastHop,
hop.cacheOption
);

signal = signalForChainData(
chainId, isFullProof ? LibSignals.STATE_ROOT : LibSignals.SIGNAL_ROOT, hop.blockId
);
value = hop.rootHash;
chainId = hop.chainId;
app = signalService;
}

if (value == 0 || value != _loadSignalValue(address(this), signal)) {
revert SS_SIGNAL_NOT_FOUND();
}
}
}

0 comments on commit 5a148ef

Please sign in to comment.