// SPDX-License-Identifier: MIT pragma solidity >=0.8.19; // Common.sol // // Common mathematical functions needed by both SD59x18 and UD60x18. Note that these global functions do not // always operate with SD59x18 and UD60x18 numbers. /*////////////////////////////////////////////////////////////////////////// CUSTOM ERRORS //////////////////////////////////////////////////////////////////////////*/ /// @notice Thrown when the resultant value in {mulDiv} overflows uint256. error PRBMath_MulDiv_Overflow(uint256 x, uint256 y, uint256 denominator); /// @notice Thrown when the resultant value in {mulDiv18} overflows uint256. error PRBMath_MulDiv18_Overflow(uint256 x, uint256 y); /// @notice Thrown when one of the inputs passed to {mulDivSigned} is `type(int256).min`. error PRBMath_MulDivSigned_InputTooSmall(); /// @notice Thrown when the resultant value in {mulDivSigned} overflows int256. error PRBMath_MulDivSigned_Overflow(int256 x, int256 y); /*////////////////////////////////////////////////////////////////////////// CONSTANTS //////////////////////////////////////////////////////////////////////////*/ /// @dev The maximum value a uint128 number can have. uint128 constant MAX_UINT128 = type(uint128).max; /// @dev The maximum value a uint40 number can have. uint40 constant MAX_UINT40 = type(uint40).max; /// @dev The unit number, which the decimal precision of the fixed-point types. uint256 constant UNIT = 1e18; /// @dev The unit number inverted mod 2^256. uint256 constant UNIT_INVERSE = 78156646155174841979727994598816262306175212592076161876661_508869554232690281; /// @dev The the largest power of two that divides the decimal value of `UNIT`. The logarithm of this value is the least significant /// bit in the binary representation of `UNIT`. uint256 constant UNIT_LPOTD = 262144; /*////////////////////////////////////////////////////////////////////////// FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ /// @notice Calculates the binary exponent of x using the binary fraction method. /// @dev Has to use 192.64-bit fixed-point numbers. See https://ethereum.stackexchange.com/a/96594/24693. /// @param x The exponent as an unsigned 192.64-bit fixed-point number. /// @return result The result as an unsigned 60.18-decimal fixed-point number. /// @custom:smtchecker abstract-function-nondet function exp2(uint256 x) pure returns (uint256 result) { unchecked { // Start from 0.5 in the 192.64-bit fixed-point format. result = 0x800000000000000000000000000000000000000000000000; // The following logic multiplies the result by $\sqrt{2^{-i}}$ when the bit at position i is 1. Key points: // // 1. Intermediate results will not overflow, as the starting point is 2^191 and all magic factors are under 2^65. // 2. The rationale for organizing the if statements into groups of 8 is gas savings. If the result of performing // a bitwise AND operation between x and any value in the array [0x80; 0x40; 0x20; 0x10; 0x08; 0x04; 0x02; 0x01] is 1, // we know that `x & 0xFF` is also 1. if (x & 0xFF00000000000000 > 0) { if (x & 0x8000000000000000 > 0) { result = (result * 0x16A09E667F3BCC909) >> 64; } if (x & 0x4000000000000000 > 0) { result = (result * 0x1306FE0A31B7152DF) >> 64; } if (x & 0x2000000000000000 > 0) { result = (result * 0x1172B83C7D517ADCE) >> 64; } if (x & 0x1000000000000000 > 0) { result = (result * 0x10B5586CF9890F62A) >> 64; } if (x & 0x800000000000000 > 0) { result = (result * 0x1059B0D31585743AE) >> 64; } if (x & 0x400000000000000 > 0) { result = (result * 0x102C9A3E778060EE7) >> 64; } if (x & 0x200000000000000 > 0) { result = (result * 0x10163DA9FB33356D8) >> 64; } if (x & 0x100000000000000 > 0) { result = (result * 0x100B1AFA5ABCBED61) >> 64; } } if (x & 0xFF000000000000 > 0) { if (x & 0x80000000000000 > 0) { result = (result * 0x10058C86DA1C09EA2) >> 64; } if (x & 0x40000000000000 > 0) { result = (result * 0x1002C605E2E8CEC50) >> 64; } if (x & 0x20000000000000 > 0) { result = (result * 0x100162F3904051FA1) >> 64; } if (x & 0x10000000000000 > 0) { result = (result * 0x1000B175EFFDC76BA) >> 64; } if (x & 0x8000000000000 > 0) { result = (result * 0x100058BA01FB9F96D) >> 64; } if (x & 0x4000000000000 > 0) { result = (result * 0x10002C5CC37DA9492) >> 64; } if (x & 0x2000000000000 > 0) { result = (result * 0x1000162E525EE0547) >> 64; } if (x & 0x1000000000000 > 0) { result = (result * 0x10000B17255775C04) >> 64; } } if (x & 0xFF0000000000 > 0) { if (x & 0x800000000000 > 0) { result = (result * 0x1000058B91B5BC9AE) >> 64; } if (x & 0x400000000000 > 0) { result = (result * 0x100002C5C89D5EC6D) >> 64; } if (x & 0x200000000000 > 0) { result = (result * 0x10000162E43F4F831) >> 64; } if (x & 0x100000000000 > 0) { result = (result * 0x100000B1721BCFC9A) >> 64; } if (x & 0x80000000000 > 0) { result = (result * 0x10000058B90CF1E6E) >> 64; } if (x & 0x40000000000 > 0) { result = (result * 0x1000002C5C863B73F) >> 64; } if (x & 0x20000000000 > 0) { result = (result * 0x100000162E430E5A2) >> 64; } if (x & 0x10000000000 > 0) { result = (result * 0x1000000B172183551) >> 64; } } if (x & 0xFF00000000 > 0) { if (x & 0x8000000000 > 0) { result = (result * 0x100000058B90C0B49) >> 64; } if (x & 0x4000000000 > 0) { result = (result * 0x10000002C5C8601CC) >> 64; } if (x & 0x2000000000 > 0) { result = (result * 0x1000000162E42FFF0) >> 64; } if (x & 0x1000000000 > 0) { result = (result * 0x10000000B17217FBB) >> 64; } if (x & 0x800000000 > 0) { result = (result * 0x1000000058B90BFCE) >> 64; } if (x & 0x400000000 > 0) { result = (result * 0x100000002C5C85FE3) >> 64; } if (x & 0x200000000 > 0) { result = (result * 0x10000000162E42FF1) >> 64; } if (x & 0x100000000 > 0) { result = (result * 0x100000000B17217F8) >> 64; } } if (x & 0xFF000000 > 0) { if (x & 0x80000000 > 0) { result = (result * 0x10000000058B90BFC) >> 64; } if (x & 0x40000000 > 0) { result = (result * 0x1000000002C5C85FE) >> 64; } if (x & 0x20000000 > 0) { result = (result * 0x100000000162E42FF) >> 64; } if (x & 0x10000000 > 0) { result = (result * 0x1000000000B17217F) >> 64; } if (x & 0x8000000 > 0) { result = (result * 0x100000000058B90C0) >> 64; } if (x & 0x4000000 > 0) { result = (result * 0x10000000002C5C860) >> 64; } if (x & 0x2000000 > 0) { result = (result * 0x1000000000162E430) >> 64; } if (x & 0x1000000 > 0) { result = (result * 0x10000000000B17218) >> 64; } } if (x & 0xFF0000 > 0) { if (x & 0x800000 > 0) { result = (result * 0x1000000000058B90C) >> 64; } if (x & 0x400000 > 0) { result = (result * 0x100000000002C5C86) >> 64; } if (x & 0x200000 > 0) { result = (result * 0x10000000000162E43) >> 64; } if (x & 0x100000 > 0) { result = (result * 0x100000000000B1721) >> 64; } if (x & 0x80000 > 0) { result = (result * 0x10000000000058B91) >> 64; } if (x & 0x40000 > 0) { result = (result * 0x1000000000002C5C8) >> 64; } if (x & 0x20000 > 0) { result = (result * 0x100000000000162E4) >> 64; } if (x & 0x10000 > 0) { result = (result * 0x1000000000000B172) >> 64; } } if (x & 0xFF00 > 0) { if (x & 0x8000 > 0) { result = (result * 0x100000000000058B9) >> 64; } if (x & 0x4000 > 0) { result = (result * 0x10000000000002C5D) >> 64; } if (x & 0x2000 > 0) { result = (result * 0x1000000000000162E) >> 64; } if (x & 0x1000 > 0) { result = (result * 0x10000000000000B17) >> 64; } if (x & 0x800 > 0) { result = (result * 0x1000000000000058C) >> 64; } if (x & 0x400 > 0) { result = (result * 0x100000000000002C6) >> 64; } if (x & 0x200 > 0) { result = (result * 0x10000000000000163) >> 64; } if (x & 0x100 > 0) { result = (result * 0x100000000000000B1) >> 64; } } if (x & 0xFF > 0) { if (x & 0x80 > 0) { result = (result * 0x10000000000000059) >> 64; } if (x & 0x40 > 0) { result = (result * 0x1000000000000002C) >> 64; } if (x & 0x20 > 0) { result = (result * 0x10000000000000016) >> 64; } if (x & 0x10 > 0) { result = (result * 0x1000000000000000B) >> 64; } if (x & 0x8 > 0) { result = (result * 0x10000000000000006) >> 64; } if (x & 0x4 > 0) { result = (result * 0x10000000000000003) >> 64; } if (x & 0x2 > 0) { result = (result * 0x10000000000000001) >> 64; } if (x & 0x1 > 0) { result = (result * 0x10000000000000001) >> 64; } } // In the code snippet below, two operations are executed simultaneously: // // 1. The result is multiplied by $(2^n + 1)$, where $2^n$ represents the integer part, and the additional 1 // accounts for the initial guess of 0.5. This is achieved by subtracting from 191 instead of 192. // 2. The result is then converted to an unsigned 60.18-decimal fixed-point format. // // The underlying logic is based on the relationship $2^{191-ip} = 2^{ip} / 2^{191}$, where $ip$ denotes the, // integer part, $2^n$. result *= UNIT; result >>= (191 - (x >> 64)); } } /// @notice Finds the zero-based index of the first 1 in the binary representation of x. /// /// @dev See the note on "msb" in this Wikipedia article: https://en.wikipedia.org/wiki/Find_first_set /// /// Each step in this implementation is equivalent to this high-level code: /// /// ```solidity /// if (x >= 2 ** 128) { /// x >>= 128; /// result += 128; /// } /// ``` /// /// Where 128 is replaced with each respective power of two factor. See the full high-level implementation here: /// https://gist.github.com/PaulRBerg/f932f8693f2733e30c4d479e8e980948 /// /// The Yul instructions used below are: /// /// - "gt" is "greater than" /// - "or" is the OR bitwise operator /// - "shl" is "shift left" /// - "shr" is "shift right" /// /// @param x The uint256 number for which to find the index of the most significant bit. /// @return result The index of the most significant bit as a uint256. /// @custom:smtchecker abstract-function-nondet function msb(uint256 x) pure returns (uint256 result) { // 2^128 assembly ("memory-safe") { let factor := shl(7, gt(x, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF)) x := shr(factor, x) result := or(result, factor) } // 2^64 assembly ("memory-safe") { let factor := shl(6, gt(x, 0xFFFFFFFFFFFFFFFF)) x := shr(factor, x) result := or(result, factor) } // 2^32 assembly ("memory-safe") { let factor := shl(5, gt(x, 0xFFFFFFFF)) x := shr(factor, x) result := or(result, factor) } // 2^16 assembly ("memory-safe") { let factor := shl(4, gt(x, 0xFFFF)) x := shr(factor, x) result := or(result, factor) } // 2^8 assembly ("memory-safe") { let factor := shl(3, gt(x, 0xFF)) x := shr(factor, x) result := or(result, factor) } // 2^4 assembly ("memory-safe") { let factor := shl(2, gt(x, 0xF)) x := shr(factor, x) result := or(result, factor) } // 2^2 assembly ("memory-safe") { let factor := shl(1, gt(x, 0x3)) x := shr(factor, x) result := or(result, factor) } // 2^1 // No need to shift x any more. assembly ("memory-safe") { let factor := gt(x, 0x1) result := or(result, factor) } } /// @notice Calculates x*y÷denominator with 512-bit precision. /// /// @dev Credits to Remco Bloemen under MIT license https://xn--2-umb.com/21/muldiv. /// /// Notes: /// - The result is rounded toward zero. /// /// Requirements: /// - The denominator must not be zero. /// - The result must fit in uint256. /// /// @param x The multiplicand as a uint256. /// @param y The multiplier as a uint256. /// @param denominator The divisor as a uint256. /// @return result The result as a uint256. /// @custom:smtchecker abstract-function-nondet function mulDiv(uint256 x, uint256 y, uint256 denominator) pure returns (uint256 result) { // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use // use the Chinese Remainder Theorem to reconstruct the 512-bit result. The result is stored in two 256 // variables such that product = prod1 * 2^256 + prod0. uint256 prod0; // Least significant 256 bits of the product uint256 prod1; // Most significant 256 bits of the product assembly ("memory-safe") { let mm := mulmod(x, y, not(0)) prod0 := mul(x, y) prod1 := sub(sub(mm, prod0), lt(mm, prod0)) } // Handle non-overflow cases, 256 by 256 division. if (prod1 == 0) { unchecked { return prod0 / denominator; } } // Make sure the result is less than 2^256. Also prevents denominator == 0. if (prod1 >= denominator) { revert PRBMath_MulDiv_Overflow(x, y, denominator); } //////////////////////////////////////////////////////////////////////////// // 512 by 256 division //////////////////////////////////////////////////////////////////////////// // Make division exact by subtracting the remainder from [prod1 prod0]. uint256 remainder; assembly ("memory-safe") { // Compute remainder using the mulmod Yul instruction. remainder := mulmod(x, y, denominator) // Subtract 256 bit number from 512-bit number. prod1 := sub(prod1, gt(remainder, prod0)) prod0 := sub(prod0, remainder) } unchecked { // Calculate the largest power of two divisor of the denominator using the unary operator ~. This operation cannot overflow // because the denominator cannot be zero at this point in the function execution. The result is always >= 1. // For more detail, see https://cs.stackexchange.com/q/138556/92363. uint256 lpotdod = denominator & (~denominator + 1); uint256 flippedLpotdod; assembly ("memory-safe") { // Factor powers of two out of denominator. denominator := div(denominator, lpotdod) // Divide [prod1 prod0] by lpotdod. prod0 := div(prod0, lpotdod) // Get the flipped value `2^256 / lpotdod`. If the `lpotdod` is zero, the flipped value is one. // `sub(0, lpotdod)` produces the two's complement version of `lpotdod`, which is equivalent to flipping all the bits. // However, `div` interprets this value as an unsigned value: https://ethereum.stackexchange.com/q/147168/24693 flippedLpotdod := add(div(sub(0, lpotdod), lpotdod), 1) } // Shift in bits from prod1 into prod0. prod0 |= prod1 * flippedLpotdod; // Invert denominator mod 2^256. Now that denominator is an odd number, it has an inverse modulo 2^256 such // that denominator * inv = 1 mod 2^256. Compute the inverse by starting with a seed that is correct for // four bits. That is, denominator * inv = 1 mod 2^4. uint256 inverse = (3 * denominator) ^ 2; // Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also works // in modular arithmetic, doubling the correct bits in each step. inverse *= 2 - denominator * inverse; // inverse mod 2^8 inverse *= 2 - denominator * inverse; // inverse mod 2^16 inverse *= 2 - denominator * inverse; // inverse mod 2^32 inverse *= 2 - denominator * inverse; // inverse mod 2^64 inverse *= 2 - denominator * inverse; // inverse mod 2^128 inverse *= 2 - denominator * inverse; // inverse mod 2^256 // Because the division is now exact we can divide by multiplying with the modular inverse of denominator. // This will give us the correct result modulo 2^256. Since the preconditions guarantee that the outcome is // less than 2^256, this is the final result. We don't need to compute the high bits of the result and prod1 // is no longer required. result = prod0 * inverse; } } /// @notice Calculates x*y÷1e18 with 512-bit precision. /// /// @dev A variant of {mulDiv} with constant folding, i.e. in which the denominator is hard coded to 1e18. /// /// Notes: /// - The body is purposely left uncommented; to understand how this works, see the documentation in {mulDiv}. /// - The result is rounded toward zero. /// - We take as an axiom that the result cannot be `MAX_UINT256` when x and y solve the following system of equations: /// /// $$ /// \begin{cases} /// x * y = MAX\_UINT256 * UNIT \\ /// (x * y) \% UNIT \geq \frac{UNIT}{2} /// \end{cases} /// $$ /// /// Requirements: /// - Refer to the requirements in {mulDiv}. /// - The result must fit in uint256. /// /// @param x The multiplicand as an unsigned 60.18-decimal fixed-point number. /// @param y The multiplier as an unsigned 60.18-decimal fixed-point number. /// @return result The result as an unsigned 60.18-decimal fixed-point number. /// @custom:smtchecker abstract-function-nondet function mulDiv18(uint256 x, uint256 y) pure returns (uint256 result) { uint256 prod0; uint256 prod1; assembly ("memory-safe") { let mm := mulmod(x, y, not(0)) prod0 := mul(x, y) prod1 := sub(sub(mm, prod0), lt(mm, prod0)) } if (prod1 == 0) { unchecked { return prod0 / UNIT; } } if (prod1 >= UNIT) { revert PRBMath_MulDiv18_Overflow(x, y); } uint256 remainder; assembly ("memory-safe") { remainder := mulmod(x, y, UNIT) result := mul( or( div(sub(prod0, remainder), UNIT_LPOTD), mul(sub(prod1, gt(remainder, prod0)), add(div(sub(0, UNIT_LPOTD), UNIT_LPOTD), 1)) ), UNIT_INVERSE ) } } /// @notice Calculates x*y÷denominator with 512-bit precision. /// /// @dev This is an extension of {mulDiv} for signed numbers, which works by computing the signs and the absolute values separately. /// /// Notes: /// - The result is rounded toward zero. /// /// Requirements: /// - Refer to the requirements in {mulDiv}. /// - None of the inputs can be `type(int256).min`. /// - The result must fit in int256. /// /// @param x The multiplicand as an int256. /// @param y The multiplier as an int256. /// @param denominator The divisor as an int256. /// @return result The result as an int256. /// @custom:smtchecker abstract-function-nondet function mulDivSigned(int256 x, int256 y, int256 denominator) pure returns (int256 result) { if (x == type(int256).min || y == type(int256).min || denominator == type(int256).min) { revert PRBMath_MulDivSigned_InputTooSmall(); } // Get hold of the absolute values of x, y and the denominator. uint256 xAbs; uint256 yAbs; uint256 dAbs; unchecked { xAbs = x < 0 ? uint256(-x) : uint256(x); yAbs = y < 0 ? uint256(-y) : uint256(y); dAbs = denominator < 0 ? uint256(-denominator) : uint256(denominator); } // Compute the absolute value of x*y÷denominator. The result must fit in int256. uint256 resultAbs = mulDiv(xAbs, yAbs, dAbs); if (resultAbs > uint256(type(int256).max)) { revert PRBMath_MulDivSigned_Overflow(x, y); } // Get the signs of x, y and the denominator. uint256 sx; uint256 sy; uint256 sd; assembly ("memory-safe") { // "sgt" is the "signed greater than" assembly instruction and "sub(0,1)" is -1 in two's complement. sx := sgt(x, sub(0, 1)) sy := sgt(y, sub(0, 1)) sd := sgt(denominator, sub(0, 1)) } // XOR over sx, sy and sd. What this does is to check whether there are 1 or 3 negative signs in the inputs. // If there are, the result should be negative. Otherwise, it should be positive. unchecked { result = sx ^ sy ^ sd == 0 ? -int256(resultAbs) : int256(resultAbs); } } /// @notice Calculates the square root of x using the Babylonian method. /// /// @dev See https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method. /// /// Notes: /// - If x is not a perfect square, the result is rounded down. /// - Credits to OpenZeppelin for the explanations in comments below. /// /// @param x The uint256 number for which to calculate the square root. /// @return result The result as a uint256. /// @custom:smtchecker abstract-function-nondet function sqrt(uint256 x) pure returns (uint256 result) { if (x == 0) { return 0; } // For our first guess, we calculate the biggest power of 2 which is smaller than the square root of x. // // We know that the "msb" (most significant bit) of x is a power of 2 such that we have: // // $$ // msb(x) <= x <= 2*msb(x)$ // $$ // // We write $msb(x)$ as $2^k$, and we get: // // $$ // k = log_2(x) // $$ // // Thus, we can write the initial inequality as: // // $$ // 2^{log_2(x)} <= x <= 2*2^{log_2(x)+1} \\ // sqrt(2^k) <= sqrt(x) < sqrt(2^{k+1}) \\ // 2^{k/2} <= sqrt(x) < 2^{(k+1)/2} <= 2^{(k/2)+1} // $$ // // Consequently, $2^{log_2(x) /2} is a good first approximation of sqrt(x) with at least one correct bit. uint256 xAux = uint256(x); result = 1; if (xAux >= 2 ** 128) { xAux >>= 128; result <<= 64; } if (xAux >= 2 ** 64) { xAux >>= 64; result <<= 32; } if (xAux >= 2 ** 32) { xAux >>= 32; result <<= 16; } if (xAux >= 2 ** 16) { xAux >>= 16; result <<= 8; } if (xAux >= 2 ** 8) { xAux >>= 8; result <<= 4; } if (xAux >= 2 ** 4) { xAux >>= 4; result <<= 2; } if (xAux >= 2 ** 2) { result <<= 1; } // At this point, `result` is an estimation with at least one bit of precision. We know the true value has at // most 128 bits, since it is the square root of a uint256. Newton's method converges quadratically (precision // doubles at every iteration). We thus need at most 7 iteration to turn our partial result with one bit of // precision into the expected uint128 result. unchecked { result = (result + x / result) >> 1; result = (result + x / result) >> 1; result = (result + x / result) >> 1; result = (result + x / result) >> 1; result = (result + x / result) >> 1; result = (result + x / result) >> 1; result = (result + x / result) >> 1; // If x is not a perfect square, round the result toward zero. uint256 roundedResult = x / result; if (result >= roundedResult) { result = roundedResult; } } }