// SPDX-License-Identifier: MIT pragma solidity ^0.8.7; import { IDiamondCut } from "../Interfaces/IDiamondCut.sol"; library LibDiamond { bytes32 internal constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage"); struct FacetAddressAndPosition { address facetAddress; uint96 functionSelectorPosition; // position in facetFunctionSelectors.functionSelectors array } struct FacetFunctionSelectors { bytes4[] functionSelectors; uint256 facetAddressPosition; // position of facetAddress in facetAddresses array } struct DiamondStorage { // maps function selector to the facet address and // the position of the selector in the facetFunctionSelectors.selectors array mapping(bytes4 => FacetAddressAndPosition) selectorToFacetAndPosition; // maps facet addresses to function selectors mapping(address => FacetFunctionSelectors) facetFunctionSelectors; // facet addresses address[] facetAddresses; // Used to query if a contract implements an interface. // Used to implement ERC-165. mapping(bytes4 => bool) supportedInterfaces; // owner of the contract address contractOwner; } function diamondStorage() internal pure returns (DiamondStorage storage ds) { bytes32 position = DIAMOND_STORAGE_POSITION; // solhint-disable-next-line no-inline-assembly assembly { ds.slot := position } } event OwnershipTransferred(address indexed previousOwner, address indexed newOwner); function setContractOwner(address _newOwner) internal { DiamondStorage storage ds = diamondStorage(); address previousOwner = ds.contractOwner; ds.contractOwner = _newOwner; emit OwnershipTransferred(previousOwner, _newOwner); } function contractOwner() internal view returns (address contractOwner_) { contractOwner_ = diamondStorage().contractOwner; } function enforceIsContractOwner() internal view { require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner"); } event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata); // Internal function version of diamondCut function diamondCut( IDiamondCut.FacetCut[] memory _diamondCut, address _init, bytes memory _calldata ) internal { for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) { IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action; if (action == IDiamondCut.FacetCutAction.Add) { addFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors); } else if (action == IDiamondCut.FacetCutAction.Replace) { replaceFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors); } else if (action == IDiamondCut.FacetCutAction.Remove) { removeFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors); } else { revert("LibDiamondCut: Incorrect FacetCutAction"); } } emit DiamondCut(_diamondCut, _init, _calldata); initializeDiamondCut(_init, _calldata); } function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal { require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); DiamondStorage storage ds = diamondStorage(); require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)"); uint96 selectorPosition = uint96(ds.facetFunctionSelectors[_facetAddress].functionSelectors.length); // add new facet address if it does not exist if (selectorPosition == 0) { addFacet(ds, _facetAddress); } for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress; require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists"); addFunction(ds, selector, selectorPosition, _facetAddress); selectorPosition++; } } function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal { require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); DiamondStorage storage ds = diamondStorage(); require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)"); uint96 selectorPosition = uint96(ds.facetFunctionSelectors[_facetAddress].functionSelectors.length); // add new facet address if it does not exist if (selectorPosition == 0) { addFacet(ds, _facetAddress); } for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress; require(oldFacetAddress != _facetAddress, "LibDiamondCut: Can't replace function with same function"); removeFunction(ds, oldFacetAddress, selector); addFunction(ds, selector, selectorPosition, _facetAddress); selectorPosition++; } } function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal { require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); DiamondStorage storage ds = diamondStorage(); // if function does not exist then do nothing and return require(_facetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)"); for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress; removeFunction(ds, oldFacetAddress, selector); } } function addFacet(DiamondStorage storage ds, address _facetAddress) internal { enforceHasContractCode(_facetAddress, "LibDiamondCut: New facet has no code"); ds.facetFunctionSelectors[_facetAddress].facetAddressPosition = ds.facetAddresses.length; ds.facetAddresses.push(_facetAddress); } function addFunction( DiamondStorage storage ds, bytes4 _selector, uint96 _selectorPosition, address _facetAddress ) internal { ds.selectorToFacetAndPosition[_selector].functionSelectorPosition = _selectorPosition; ds.facetFunctionSelectors[_facetAddress].functionSelectors.push(_selector); ds.selectorToFacetAndPosition[_selector].facetAddress = _facetAddress; } function removeFunction( DiamondStorage storage ds, address _facetAddress, bytes4 _selector ) internal { require(_facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist"); // an immutable function is a function defined directly in a diamond require(_facetAddress != address(this), "LibDiamondCut: Can't remove immutable function"); // replace selector with last selector, then delete last selector uint256 selectorPosition = ds.selectorToFacetAndPosition[_selector].functionSelectorPosition; uint256 lastSelectorPosition = ds.facetFunctionSelectors[_facetAddress].functionSelectors.length - 1; // if not the same then replace _selector with lastSelector if (selectorPosition != lastSelectorPosition) { bytes4 lastSelector = ds.facetFunctionSelectors[_facetAddress].functionSelectors[lastSelectorPosition]; ds.facetFunctionSelectors[_facetAddress].functionSelectors[selectorPosition] = lastSelector; ds.selectorToFacetAndPosition[lastSelector].functionSelectorPosition = uint96(selectorPosition); } // delete the last selector ds.facetFunctionSelectors[_facetAddress].functionSelectors.pop(); delete ds.selectorToFacetAndPosition[_selector]; // if no more selectors for facet address then delete the facet address if (lastSelectorPosition == 0) { // replace facet address with last facet address and delete last facet address uint256 lastFacetAddressPosition = ds.facetAddresses.length - 1; uint256 facetAddressPosition = ds.facetFunctionSelectors[_facetAddress].facetAddressPosition; if (facetAddressPosition != lastFacetAddressPosition) { address lastFacetAddress = ds.facetAddresses[lastFacetAddressPosition]; ds.facetAddresses[facetAddressPosition] = lastFacetAddress; ds.facetFunctionSelectors[lastFacetAddress].facetAddressPosition = facetAddressPosition; } ds.facetAddresses.pop(); delete ds.facetFunctionSelectors[_facetAddress].facetAddressPosition; } } function initializeDiamondCut(address _init, bytes memory _calldata) internal { if (_init == address(0)) { require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty"); } else { require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)"); if (_init != address(this)) { enforceHasContractCode(_init, "LibDiamondCut: _init address has no code"); } // solhint-disable-next-line avoid-low-level-calls (bool success, bytes memory error) = _init.delegatecall(_calldata); if (!success) { if (error.length > 0) { // bubble up the error revert(string(error)); } else { revert("LibDiamondCut: _init function reverted"); } } } } function enforceHasContractCode(address _contract, string memory _errorMessage) internal view { uint256 contractSize; // solhint-disable-next-line no-inline-assembly assembly { contractSize := extcodesize(_contract) } require(contractSize > 0, _errorMessage); } }