diff --git a/contracts/extensions/FeeTaker.sol b/contracts/extensions/FeeTaker.sol index 69c240c0..350ecdc7 100644 --- a/contracts/extensions/FeeTaker.sol +++ b/contracts/extensions/FeeTaker.sol @@ -12,10 +12,11 @@ import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; import { IAmountGetter } from "../interfaces/IAmountGetter.sol"; import { IOrderMixin } from "../interfaces/IOrderMixin.sol"; import { IPostInteraction } from "../interfaces/IPostInteraction.sol"; +import { PostInteractionController } from "../helpers/PostInteractionController.sol"; import { MakerTraits, MakerTraitsLib } from "../libraries/MakerTraitsLib.sol"; /// @title Helper contract that adds feature of collecting fee in takerAsset -contract FeeTaker is IPostInteraction, IAmountGetter, Ownable { +contract FeeTaker is IPostInteraction, IAmountGetter, PostInteractionController, Ownable { using AddressLib for Address; using SafeERC20 for IERC20; using UniERC20 for IERC20; @@ -76,18 +77,30 @@ contract FeeTaker is IPostInteraction, IAmountGetter, Ownable { ) external view returns (uint256 calculatedMakingAmount) { unchecked { (uint256 integratorFee, uint256 resolverFee, bytes calldata tail) = _parseFeeData(extraData, taker); - if (tail.length > 20) { - calculatedMakingAmount = IAmountGetter(address(bytes20(tail))).getMakingAmount( - order, extension, orderHash, taker, takingAmount, remainingMakingAmount, tail[20:] - ); - } else { - calculatedMakingAmount = order.makingAmount; - } + calculatedMakingAmount = this.getCustomMakingAmount(order, extension, orderHash, taker, takingAmount, remainingMakingAmount, tail); calculatedMakingAmount = Math.mulDiv(calculatedMakingAmount, _FEE_BASE, _FEE_BASE + integratorFee + resolverFee, Math.Rounding.Floor); return Math.mulDiv(calculatedMakingAmount, takingAmount, order.takingAmount, Math.Rounding.Floor); } } + function getCustomMakingAmount( + IOrderMixin.Order calldata order, + bytes calldata extension, + bytes32 orderHash, + address taker, + uint256 takingAmount, + uint256 remainingMakingAmount, + bytes calldata tail + ) external view virtual returns (uint256) { + if (tail.length > 20) { + return IAmountGetter(address(bytes20(tail))).getMakingAmount( + order, extension, orderHash, taker, takingAmount, remainingMakingAmount, tail[20:] + ); + } else { + return order.makingAmount; + } + } + /** * @dev Calculate takingAmount with fee. * `extraData` consists of: @@ -107,39 +120,64 @@ contract FeeTaker is IPostInteraction, IAmountGetter, Ownable { ) external view returns (uint256 calculatedTakingAmount) { unchecked { (uint256 integratorFee, uint256 resolverFee, bytes calldata tail) = _parseFeeData(extraData, taker); - if (tail.length > 20) { - calculatedTakingAmount = IAmountGetter(address(bytes20(tail))).getTakingAmount( - order, extension, orderHash, taker, makingAmount, remainingMakingAmount, tail[20:] - ); - } else { - calculatedTakingAmount = order.takingAmount; - } + calculatedTakingAmount = this.getCustomTakingAmount(order, extension, orderHash, taker, makingAmount, remainingMakingAmount, tail); calculatedTakingAmount = Math.mulDiv(calculatedTakingAmount, _FEE_BASE + integratorFee + resolverFee, _FEE_BASE, Math.Rounding.Ceil); return Math.mulDiv(calculatedTakingAmount, makingAmount, order.makingAmount, Math.Rounding.Ceil); } } + function getCustomTakingAmount( + IOrderMixin.Order calldata order, + bytes calldata extension, + bytes32 orderHash, + address taker, + uint256 makingAmount, + uint256 remainingMakingAmount, + bytes calldata tail + ) external view virtual returns (uint256) { + if (tail.length > 20) { + return IAmountGetter(address(bytes20(tail))).getTakingAmount( + order, extension, orderHash, taker, makingAmount, remainingMakingAmount, tail[20:] + ); + } else { + return order.takingAmount; + } + } + + function postInteraction( + IOrderMixin.Order calldata order, + bytes calldata extension, + bytes32 orderHash, + address taker, + uint256 makingAmount, + uint256 takingAmount, + uint256 remainingMakingAmount, + bytes calldata extraData + ) external onlyLimitOrderProtocol { + _postInteraction(order, extension, orderHash, taker, makingAmount, takingAmount, remainingMakingAmount, extraData); + } + /** * @notice See {IPostInteraction-postInteraction}. * @dev Takes the fee in taking tokens and transfers the rest to the maker. * `extraData` consists of: * 2 bytes — integrator fee percentage (in 1e5) * 2 bytes — resolver fee percentage (in 1e5) - * 1 byte - taker whitelist size + * 1 byte - bitmask ABBBBBBB, where A is the receiver flag and B represents the taker whitelist size * (bytes10)[N] — taker whitelist * 20 bytes — fee recipient * 20 bytes — receiver of taking tokens (optional, if not set, maker is used) */ - function postInteraction( + function _postInteraction( IOrderMixin.Order calldata order, - bytes calldata /* extension */, - bytes32 /* orderHash */, + bytes calldata extension, + bytes32 orderHash, address taker, - uint256 /* makingAmount */, + uint256 makingAmount, uint256 takingAmount, - uint256 /* remainingMakingAmount */, + uint256 remainingMakingAmount, bytes calldata extraData - ) external onlyLimitOrderProtocol { + ) internal virtual override { unchecked { (uint256 integratorFee, uint256 resolverFee, bytes calldata tail) = _parseFeeData(extraData, taker); address feeRecipient = address(bytes20(tail)); @@ -149,8 +187,9 @@ contract FeeTaker is IPostInteraction, IAmountGetter, Ownable { uint256 fee = Math.mulDiv(takingAmount, integratorFee, denominator) + Math.mulDiv(takingAmount, resolverFee, denominator); address receiver = order.maker.get(); - if (tail.length > 0) { + if (uint8(extraData[4]) & 0x80 > 0) { // is set receiver of taking tokens receiver = address(bytes20(tail)); + tail = tail[20:]; } if (order.takerAsset.get() == address(_WETH) && order.makerTraits.unwrapWeth()) { @@ -164,6 +203,7 @@ contract FeeTaker is IPostInteraction, IAmountGetter, Ownable { } IERC20(order.takerAsset.get()).safeTransfer(receiver, takingAmount - fee); } + super._postInteraction(order, extension, orderHash, taker, makingAmount, takingAmount, remainingMakingAmount, tail); } } @@ -208,11 +248,11 @@ contract FeeTaker is IPostInteraction, IAmountGetter, Ownable { } } - function _parseFeeData(bytes calldata extraData, address taker) private pure returns (uint256 integratorFee, uint256 resolverFee, bytes calldata tail) { + function _parseFeeData(bytes calldata extraData, address taker) internal virtual view returns (uint256 integratorFee, uint256 resolverFee, bytes calldata tail) { unchecked { integratorFee = uint256(uint16(bytes2(extraData))); resolverFee = uint256(uint16(bytes2(extraData[2:]))); - uint256 whitelistEnd = 5 + 10 * uint256(uint8(extraData[4])); + uint256 whitelistEnd = 5 + 10 * uint256(uint8(extraData[4] & 0x7F)); // & 0x7F - remove receiver of taking tokens flag bytes calldata whitelist = extraData[5:whitelistEnd]; if (!_isWhitelisted(whitelist, taker)) { resolverFee *= 2; diff --git a/contracts/helpers/PostInteractionController.sol b/contracts/helpers/PostInteractionController.sol new file mode 100644 index 00000000..f4044832 --- /dev/null +++ b/contracts/helpers/PostInteractionController.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT + +pragma solidity 0.8.23; + +import { IOrderMixin } from "../interfaces/IOrderMixin.sol"; + +/** + * @title PostInteraction Controller Contract + * @notice Base contract that facilitates inheritance for two other contracts, allowing control to be transferred between them via `super`. + * It enables contracts such as `FeeTaker` to delegate control to the next extension in the contract chain, but can be adapted for any extension as needed. + */ +abstract contract PostInteractionController { + function _postInteraction( + IOrderMixin.Order calldata order, + bytes calldata extension, + bytes32 orderHash, + address taker, + uint256 makingAmount, + uint256 takingAmount, + uint256 remainingMakingAmount, + bytes calldata extraData + ) internal virtual {} // solhint-disable-line no-empty-blocks +} diff --git a/test/FeeTaker.js b/test/FeeTaker.js index cdf5c942..1fc9e40c 100644 --- a/test/FeeTaker.js +++ b/test/FeeTaker.js @@ -91,7 +91,7 @@ describe('FeeTaker', function () { { postInteraction: ethers.solidityPacked( ['address', 'uint16', 'uint16', 'bytes1', 'address', 'address'], - [await feeTaker.getAddress(), fee, fee, '0x00', feeRecipient, makerReceiver], + [await feeTaker.getAddress(), fee, fee, '0x80', feeRecipient, makerReceiver], ), }, ); @@ -237,7 +237,7 @@ describe('FeeTaker', function () { { postInteraction: ethers.solidityPacked( ['address', 'uint16', 'uint16', 'bytes1', 'address', 'address'], - [await feeTaker.getAddress(), fee, 0, '0x00', feeRecipient, makerReceiver], + [await feeTaker.getAddress(), fee, 0, '0x80', feeRecipient, makerReceiver], ), makingAmountData: ethers.solidityPacked( ['address', 'uint16', 'uint16', 'bytes1'], @@ -329,7 +329,7 @@ describe('FeeTaker', function () { { postInteraction: ethers.solidityPacked( ['address', 'uint16', 'uint16', 'bytes1', 'address', 'address'], - [await feeTaker.getAddress(), fee, 0, '0x00', feeRecipient, makerReceiver], + [await feeTaker.getAddress(), fee, 0, '0x80', feeRecipient, makerReceiver], ), makingAmountData: ethers.solidityPacked( ['address', 'uint16', 'uint16', 'bytes1'],