Skip to content

Commit

Permalink
add OmniCounter example
Browse files Browse the repository at this point in the history
  • Loading branch information
e00dan committed Jan 9, 2024
1 parent d3aa8dd commit 1839dd1
Show file tree
Hide file tree
Showing 5 changed files with 428 additions and 13 deletions.
277 changes: 277 additions & 0 deletions src/OmniCounter.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.19;

import {OptionsBuilder} from "@layerzerolabs/lz-evm-oapp-v2/contracts/oapp/libs/OptionsBuilder.sol";
import {
MessagingParams,
MessagingReceipt
} from "@layerzerolabs/lz-evm-protocol-v2/contracts/interfaces/ILayerZeroEndpointV2.sol";
import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol";
import {OAppUpgradeable, MessagingFee, Origin} from "@zodomo/oapp-upgradeable/OAppUpgradeable.sol";
import {ILayerZeroComposer} from "@layerzerolabs/lz-evm-protocol-v2/contracts/interfaces/ILayerZeroComposer.sol";

library MsgCodec {
uint8 internal constant VANILLA_TYPE = 1;
uint8 internal constant COMPOSED_TYPE = 2;
uint8 internal constant ABA_TYPE = 3;
uint8 internal constant COMPOSED_ABA_TYPE = 4;

uint8 internal constant MSG_TYPE_OFFSET = 0;
uint8 internal constant SRC_EID_OFFSET = 1;
uint8 internal constant VALUE_OFFSET = 5;

function encode(uint8 _type, uint32 _srcEid) internal pure returns (bytes memory) {
return abi.encodePacked(_type, _srcEid);
}

function encode(uint8 _type, uint32 _srcEid, uint256 _value) internal pure returns (bytes memory) {
return abi.encodePacked(_type, _srcEid, _value);
}

function msgType(bytes calldata _message) internal pure returns (uint8) {
return uint8(bytes1(_message[MSG_TYPE_OFFSET:SRC_EID_OFFSET]));
}

function srcEid(bytes calldata _message) internal pure returns (uint32) {
return uint32(bytes4(_message[SRC_EID_OFFSET:VALUE_OFFSET]));
}

function value(bytes calldata _message) internal pure returns (uint256) {
return uint256(bytes32(_message[VALUE_OFFSET:]));
}
}

contract OmniCounter is ILayerZeroComposer, OAppUpgradeable, UUPSUpgradeable {
using MsgCodec for bytes;
using OptionsBuilder for bytes;

uint256 public count;
uint256 public composedCount;

address public admin;
uint32 public eid;

mapping(uint32 srcEid => mapping(bytes32 sender => uint64 nonce)) private maxReceivedNonce;
bool private orderedNonce;

// for global assertions
mapping(uint32 srcEid => uint256 count) public inboundCount;
mapping(uint32 dstEid => uint256 count) public outboundCount;

constructor() {
_disableInitializers();
}

/**
* @dev Initialize the OApp with the provided endpoint and owner.
* @param _endpoint The address of the LOCAL LayerZero endpoint.
* @param _owner The address of the owner of the OApp.
*/
function initialize(address _endpoint, address _owner) public initializer {
_initializeOApp(_endpoint, _owner);
}

modifier onlyAdmin() {
require(msg.sender == admin, "only admin");
_;
}

// -------------------------------
// Only Admin
function setAdmin(address _admin) external onlyAdmin {
admin = _admin;
}

function withdraw(address payable _to, uint256 _amount) external onlyAdmin {
(bool success,) = _to.call{value: _amount}("");
require(success, "OmniCounter: withdraw failed");
}

// -------------------------------
// Send
function increment(uint32 _eid, uint8 _type, bytes calldata _options) external payable {
// bytes memory options = combineOptions(_eid, _type, _options);
_lzSend(_eid, MsgCodec.encode(_type, eid), _options, MessagingFee(msg.value, 0), payable(msg.sender));
_incrementOutbound(_eid);
}

// this is a broken function to skip incrementing outbound count
// so that preCrime will fail
function brokenIncrement(uint32 _eid, uint8 _type, bytes calldata _options) external payable onlyAdmin {
// bytes memory options = combineOptions(_eid, _type, _options);
_lzSend(_eid, MsgCodec.encode(_type, eid), _options, MessagingFee(msg.value, 0), payable(msg.sender));
}

function batchIncrement(uint32[] calldata _eids, uint8[] calldata _types, bytes[] calldata _options)
external
payable
{
require(_eids.length == _options.length && _eids.length == _types.length, "OmniCounter: length mismatch");

MessagingReceipt memory receipt;
uint256 providedFee = msg.value;
for (uint256 i = 0; i < _eids.length; i++) {
address refundAddress = i == _eids.length - 1 ? msg.sender : address(this);
uint32 dstEid = _eids[i];
uint8 msgType = _types[i];
// bytes memory options = combineOptions(dstEid, msgType, _options[i]);
receipt = _lzSend(
dstEid, MsgCodec.encode(msgType, eid), _options[i], MessagingFee(providedFee, 0), payable(refundAddress)
);
_incrementOutbound(dstEid);
providedFee -= receipt.fee.nativeFee;
}
}

// -------------------------------
// View
function quote(uint32 _eid, uint8 _type, bytes calldata _options)
public
view
returns (uint256 nativeFee, uint256 lzTokenFee)
{
// bytes memory options = combineOptions(_eid, _type, _options);
MessagingFee memory fee = _quote(_eid, MsgCodec.encode(_type, eid), _options, false);
return (fee.nativeFee, fee.lzTokenFee);
}

// -------------------------------
function _lzReceive(
Origin calldata _origin,
bytes32 _guid,
bytes calldata _message,
address, /*_executor*/
bytes calldata /*_extraData*/
) internal override {
_acceptNonce(_origin.srcEid, _origin.sender, _origin.nonce);
uint8 messageType = _message.msgType();

if (messageType == MsgCodec.VANILLA_TYPE) {
count++;

//////////////////////////////// IMPORTANT //////////////////////////////////
/// if you request for msg.value in the options, you should also encode it
/// into your message and check the value received at destination (example below).
/// if not, the executor could potentially provide less msg.value than you requested
/// leading to unintended behavior. Another option is to assert the executor to be
/// one that you trust.
/////////////////////////////////////////////////////////////////////////////
require(msg.value >= _message.value(), "OmniCounter: insufficient value");

_incrementInbound(_origin.srcEid);
} else if (messageType == MsgCodec.COMPOSED_TYPE || messageType == MsgCodec.COMPOSED_ABA_TYPE) {
count++;
_incrementInbound(_origin.srcEid);
endpoint.sendCompose(address(this), _guid, 0, _message);
} else if (messageType == MsgCodec.ABA_TYPE) {
count++;
_incrementInbound(_origin.srcEid);

// send back to the sender
_incrementOutbound(_origin.srcEid);
bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 10);
_lzSend(
_origin.srcEid,
MsgCodec.encode(MsgCodec.VANILLA_TYPE, eid, 10),
options,
MessagingFee(msg.value, 0),
payable(address(this))
);
} else {
revert("invalid message type");
}
}

function _incrementInbound(uint32 _srcEid) internal {
inboundCount[_srcEid]++;
}

function _incrementOutbound(uint32 _dstEid) internal {
outboundCount[_dstEid]++;
}

function lzCompose(address _oApp, bytes32, /*_guid*/ bytes calldata _message, address, bytes calldata)
external
payable
override
{
require(_oApp == address(this), "!oApp");
require(msg.sender == address(endpoint), "!endpoint");

uint8 msgType = _message.msgType();
if (msgType == MsgCodec.COMPOSED_TYPE) {
composedCount += 1;
} else if (msgType == MsgCodec.COMPOSED_ABA_TYPE) {
composedCount += 1;

uint32 srcEid = _message.srcEid();
_incrementOutbound(srcEid);
bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 0);
_lzSend(
srcEid,
MsgCodec.encode(MsgCodec.VANILLA_TYPE, eid),
options,
MessagingFee(msg.value, 0),
payable(address(this))
);
} else {
revert("invalid message type");
}
}

// -------------------------------
// Ordered OApp
// this demonstrates how to build an app that requires execution nonce ordering
// normally an app should decide ordered or not on contract construction
// this is just a demo
function setOrderedNonce(bool _orderedNonce) external onlyOwner {
orderedNonce = _orderedNonce;
}

function _acceptNonce(uint32 _srcEid, bytes32 _sender, uint64 _nonce) internal virtual {
uint64 currentNonce = maxReceivedNonce[_srcEid][_sender];
if (orderedNonce) {
require(_nonce == currentNonce + 1, "OApp: invalid nonce");
}
// update the max nonce anyway. once the ordered mode is turned on, missing early nonces will be rejected
if (_nonce > currentNonce) {
maxReceivedNonce[_srcEid][_sender] = _nonce;
}
}

function nextNonce(uint32 _srcEid, bytes32 _sender) public view virtual override returns (uint64) {
if (orderedNonce) {
return maxReceivedNonce[_srcEid][_sender] + 1;
} else {
return 0; // path nonce starts from 1. if 0 it means that there is no specific nonce enforcement
}
}

// TODO should override oApp version with added ordered nonce increment
// a governance function to skip nonce
function skipInboundNonce(uint32 _srcEid, bytes32 _sender, uint64 _nonce) public virtual onlyOwner {
endpoint.skip(address(this), _srcEid, _sender, _nonce);
if (orderedNonce) {
maxReceivedNonce[_srcEid][_sender]++;
}
}

// @dev Batch send requires overriding this function from OAppSender because the msg.value contains multiple fees
function _payNative(uint256 _nativeFee) internal virtual override returns (uint256 nativeFee) {
if (msg.value < _nativeFee) revert NotEnoughNative(msg.value);
return _nativeFee;
}

// be able to receive ether
receive() external payable virtual {}

fallback() external payable {}

/* ========== UUPS ========== */
//solhint-disable-next-line no-empty-blocks
function _authorizeUpgrade(address) internal override onlyOwner {}

function getImplementation() external view returns (address) {
return _getImplementation();
}
}
13 changes: 12 additions & 1 deletion test/Counter.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ contract CounterTest is ProxyTestHelper {

setUpEndpoints(2, LibraryType.UltraLightNode);

(address[] memory uas,) = setupOAppsProxies(1, 2);
(address[] memory uas,) = setupOAppsProxies(type(Counter).creationCode, 1, 2);
aCounter = Counter(payable(uas[0]));
bCounter = Counter(payable(uas[1]));
}
Expand Down Expand Up @@ -59,4 +59,15 @@ contract CounterTest is ProxyTestHelper {

assertEq(aCounter.count(), counterBefore + 1, "increment assertion failure");
}

// required for test helper to know how to initialize the OApp
function _deployOAppProxy(address _endpoint, address _owner, address implementationAddress)
internal
override
returns (address proxyAddress)
{
UUPSProxy proxy =
new UUPSProxy(implementationAddress, abi.encodeWithSelector(Counter.initialize.selector, _endpoint, _owner));
proxyAddress = address(proxy);
}
}
13 changes: 12 additions & 1 deletion test/CounterUpgradeability.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ contract CounterUpgradeabilityTest is ProxyTestHelper {

setUpEndpoints(2, LibraryType.UltraLightNode);

(address[] memory uas, address implementationAddress) = setupOAppsProxies(1, 2);
(address[] memory uas, address implementationAddress) = setupOAppsProxies(type(Counter).creationCode, 1, 2);

counterImplementation = Counter(implementationAddress);

Expand Down Expand Up @@ -96,4 +96,15 @@ contract CounterUpgradeabilityTest is ProxyTestHelper {
bCounter.increment{value: nativeFee}(aEid, options);
verifyPackets(aEid, addressToBytes32(address(counter)));
}

// required for test helper to know how to initialize the OApp
function _deployOAppProxy(address _endpoint, address _owner, address implementationAddress)
internal
override
returns (address proxyAddress)
{
UUPSProxy proxy =
new UUPSProxy(implementationAddress, abi.encodeWithSelector(Counter.initialize.selector, _endpoint, _owner));
proxyAddress = address(proxy);
}
}
Loading

0 comments on commit 1839dd1

Please sign in to comment.