// SPDX-License-Identifier: MIT pragma solidity 0.8.17; import { IDiamondCut } from "../Interfaces/IDiamondCut.sol"; import { LibUtil } from "../Libraries/LibUtil.sol"; import { OnlyContractOwner } from "../Errors/GenericErrors.sol"; /// Implementation of EIP-2535 Diamond Standard /// https://eips.ethereum.org/EIPS/eip-2535 library LibDiamond { bytes32 internal constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage"); // Diamond specific errors error IncorrectFacetCutAction(); error NoSelectorsInFace(); error FunctionAlreadyExists(); error FacetAddressIsZero(); error FacetAddressIsNotZero(); error FacetContainsNoCode(); error FunctionDoesNotExist(); error FunctionIsImmutable(); error InitZeroButCalldataNotEmpty(); error CalldataEmptyButInitNotZero(); error InitReverted(); // ---------------- struct FacetAddressAndPosition { address facetAddress; uint96 functionSelectorPosition; // position in facetFunctionSelectors.functionSelectors array } struct FacetFunctionSelectors { bytes4[] functionSelectors; uint256 facetAddressPosition; // position of facetAddress in facetAddresses array } struct DiamondStorage { // maps function selector to the facet address and // the position of the selector in the facetFunctionSelectors.selectors array mapping(bytes4 => FacetAddressAndPosition) selectorToFacetAndPosition; // maps facet addresses to function selectors mapping(address => FacetFunctionSelectors) facetFunctionSelectors; // facet addresses address[] facetAddresses; // Used to query if a contract implements an interface. // Used to implement ERC-165. mapping(bytes4 => bool) supportedInterfaces; // owner of the contract address contractOwner; } function diamondStorage() internal pure returns (DiamondStorage storage ds) { bytes32 position = DIAMOND_STORAGE_POSITION; // solhint-disable-next-line no-inline-assembly assembly { ds.slot := position } } event OwnershipTransferred( address indexed previousOwner, address indexed newOwner ); function setContractOwner(address _newOwner) internal { DiamondStorage storage ds = diamondStorage(); address previousOwner = ds.contractOwner; ds.contractOwner = _newOwner; emit OwnershipTransferred(previousOwner, _newOwner); } function contractOwner() internal view returns (address contractOwner_) { contractOwner_ = diamondStorage().contractOwner; } function enforceIsContractOwner() internal view { if (msg.sender != diamondStorage().contractOwner) revert OnlyContractOwner(); } event DiamondCut( IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata ); // Internal function version of diamondCut function diamondCut( IDiamondCut.FacetCut[] memory _diamondCut, address _init, bytes memory _calldata ) internal { for (uint256 facetIndex; facetIndex < _diamondCut.length; ) { IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action; if (action == IDiamondCut.FacetCutAction.Add) { addFunctions( _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors ); } else if (action == IDiamondCut.FacetCutAction.Replace) { replaceFunctions( _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors ); } else if (action == IDiamondCut.FacetCutAction.Remove) { removeFunctions( _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors ); } else { revert IncorrectFacetCutAction(); } unchecked { ++facetIndex; } } emit DiamondCut(_diamondCut, _init, _calldata); initializeDiamondCut(_init, _calldata); } function addFunctions( address _facetAddress, bytes4[] memory _functionSelectors ) internal { if (_functionSelectors.length == 0) { revert NoSelectorsInFace(); } DiamondStorage storage ds = diamondStorage(); if (LibUtil.isZeroAddress(_facetAddress)) { revert FacetAddressIsZero(); } uint96 selectorPosition = uint96( ds.facetFunctionSelectors[_facetAddress].functionSelectors.length ); // add new facet address if it does not exist if (selectorPosition == 0) { addFacet(ds, _facetAddress); } for ( uint256 selectorIndex; selectorIndex < _functionSelectors.length; ) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds .selectorToFacetAndPosition[selector] .facetAddress; if (!LibUtil.isZeroAddress(oldFacetAddress)) { revert FunctionAlreadyExists(); } addFunction(ds, selector, selectorPosition, _facetAddress); unchecked { ++selectorPosition; ++selectorIndex; } } } function replaceFunctions( address _facetAddress, bytes4[] memory _functionSelectors ) internal { if (_functionSelectors.length == 0) { revert NoSelectorsInFace(); } DiamondStorage storage ds = diamondStorage(); if (LibUtil.isZeroAddress(_facetAddress)) { revert FacetAddressIsZero(); } uint96 selectorPosition = uint96( ds.facetFunctionSelectors[_facetAddress].functionSelectors.length ); // add new facet address if it does not exist if (selectorPosition == 0) { addFacet(ds, _facetAddress); } for ( uint256 selectorIndex; selectorIndex < _functionSelectors.length; ) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds .selectorToFacetAndPosition[selector] .facetAddress; if (oldFacetAddress == _facetAddress) { revert FunctionAlreadyExists(); } removeFunction(ds, oldFacetAddress, selector); addFunction(ds, selector, selectorPosition, _facetAddress); unchecked { ++selectorPosition; ++selectorIndex; } } } function removeFunctions( address _facetAddress, bytes4[] memory _functionSelectors ) internal { if (_functionSelectors.length == 0) { revert NoSelectorsInFace(); } DiamondStorage storage ds = diamondStorage(); // if function does not exist then do nothing and return if (!LibUtil.isZeroAddress(_facetAddress)) { revert FacetAddressIsNotZero(); } for ( uint256 selectorIndex; selectorIndex < _functionSelectors.length; ) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds .selectorToFacetAndPosition[selector] .facetAddress; removeFunction(ds, oldFacetAddress, selector); unchecked { ++selectorIndex; } } } function addFacet( DiamondStorage storage ds, address _facetAddress ) internal { enforceHasContractCode(_facetAddress); ds.facetFunctionSelectors[_facetAddress].facetAddressPosition = ds .facetAddresses .length; ds.facetAddresses.push(_facetAddress); } function addFunction( DiamondStorage storage ds, bytes4 _selector, uint96 _selectorPosition, address _facetAddress ) internal { ds .selectorToFacetAndPosition[_selector] .functionSelectorPosition = _selectorPosition; ds.facetFunctionSelectors[_facetAddress].functionSelectors.push( _selector ); ds.selectorToFacetAndPosition[_selector].facetAddress = _facetAddress; } function removeFunction( DiamondStorage storage ds, address _facetAddress, bytes4 _selector ) internal { if (LibUtil.isZeroAddress(_facetAddress)) { revert FunctionDoesNotExist(); } // an immutable function is a function defined directly in a diamond if (_facetAddress == address(this)) { revert FunctionIsImmutable(); } // replace selector with last selector, then delete last selector uint256 selectorPosition = ds .selectorToFacetAndPosition[_selector] .functionSelectorPosition; uint256 lastSelectorPosition = ds .facetFunctionSelectors[_facetAddress] .functionSelectors .length - 1; // if not the same then replace _selector with lastSelector if (selectorPosition != lastSelectorPosition) { bytes4 lastSelector = ds .facetFunctionSelectors[_facetAddress] .functionSelectors[lastSelectorPosition]; ds.facetFunctionSelectors[_facetAddress].functionSelectors[ selectorPosition ] = lastSelector; ds .selectorToFacetAndPosition[lastSelector] .functionSelectorPosition = uint96(selectorPosition); } // delete the last selector ds.facetFunctionSelectors[_facetAddress].functionSelectors.pop(); delete ds.selectorToFacetAndPosition[_selector]; // if no more selectors for facet address then delete the facet address if (lastSelectorPosition == 0) { // replace facet address with last facet address and delete last facet address uint256 lastFacetAddressPosition = ds.facetAddresses.length - 1; uint256 facetAddressPosition = ds .facetFunctionSelectors[_facetAddress] .facetAddressPosition; if (facetAddressPosition != lastFacetAddressPosition) { address lastFacetAddress = ds.facetAddresses[ lastFacetAddressPosition ]; ds.facetAddresses[facetAddressPosition] = lastFacetAddress; ds .facetFunctionSelectors[lastFacetAddress] .facetAddressPosition = facetAddressPosition; } ds.facetAddresses.pop(); delete ds .facetFunctionSelectors[_facetAddress] .facetAddressPosition; } } function initializeDiamondCut( address _init, bytes memory _calldata ) internal { if (LibUtil.isZeroAddress(_init)) { if (_calldata.length != 0) { revert InitZeroButCalldataNotEmpty(); } } else { if (_calldata.length == 0) { revert CalldataEmptyButInitNotZero(); } if (_init != address(this)) { enforceHasContractCode(_init); } // solhint-disable-next-line avoid-low-level-calls (bool success, bytes memory error) = _init.delegatecall(_calldata); if (!success) { if (error.length > 0) { // bubble up the error revert(string(error)); } else { revert InitReverted(); } } } } function enforceHasContractCode(address _contract) internal view { uint256 contractSize; // solhint-disable-next-line no-inline-assembly assembly { contractSize := extcodesize(_contract) } if (contractSize == 0) { revert FacetContainsNoCode(); } } }