Commit 2f80be17 authored by Peter Kasting's avatar Peter Kasting Committed by Chromium LUCI CQ

Avoid UB in checked math implementation.

Various functions were unconditionally casting non-representable values.

Bug: 1124595
Change-Id: I811f44164a7c9f5802287e1d826cf403ff2e08e6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2585669
Commit-Queue: Tom Sepez <tsepez@chromium.org>
Reviewed-by: default avatarTom Sepez <tsepez@chromium.org>
Auto-Submit: Peter Kasting <pkasting@chromium.org>
Cr-Commit-Position: refs/heads/master@{#836217}
parent b1053f97
...@@ -32,12 +32,12 @@ constexpr bool CheckedAddImpl(T x, T y, T* result) { ...@@ -32,12 +32,12 @@ constexpr bool CheckedAddImpl(T x, T y, T* result) {
const UnsignedDst uresult = static_cast<UnsignedDst>(ux + uy); const UnsignedDst uresult = static_cast<UnsignedDst>(ux + uy);
// Addition is valid if the sign of (x + y) is equal to either that of x or // Addition is valid if the sign of (x + y) is equal to either that of x or
// that of y. // that of y.
const bool is_valid = if (std::is_signed<T>::value
std::is_signed<T>::value ? static_cast<SignedDst>((uresult ^ ux) & (uresult ^ uy)) < 0
? static_cast<SignedDst>((uresult ^ ux) & (uresult ^ uy)) >= 0 : uresult < uy) // Unsigned is either valid or underflow.
: uresult >= uy; // Unsigned is either valid or underflow. return false;
*result = static_cast<T>(uresult); *result = static_cast<T>(uresult);
return is_valid; return true;
} }
template <typename T, typename U, class Enable = void> template <typename T, typename U, class Enable = void>
...@@ -77,9 +77,10 @@ struct CheckedAddOp<T, ...@@ -77,9 +77,10 @@ struct CheckedAddOp<T,
is_valid = CheckedAddImpl(static_cast<Promotion>(x), is_valid = CheckedAddImpl(static_cast<Promotion>(x),
static_cast<Promotion>(y), &presult); static_cast<Promotion>(y), &presult);
} }
is_valid &= IsValueInRangeForNumericType<V>(presult); if (!is_valid || !IsValueInRangeForNumericType<V>(presult))
return false;
*result = static_cast<V>(presult); *result = static_cast<V>(presult);
return is_valid; return true;
} }
}; };
...@@ -95,12 +96,12 @@ constexpr bool CheckedSubImpl(T x, T y, T* result) { ...@@ -95,12 +96,12 @@ constexpr bool CheckedSubImpl(T x, T y, T* result) {
const UnsignedDst uresult = static_cast<UnsignedDst>(ux - uy); const UnsignedDst uresult = static_cast<UnsignedDst>(ux - uy);
// Subtraction is valid if either x and y have same sign, or (x-y) and x have // Subtraction is valid if either x and y have same sign, or (x-y) and x have
// the same sign. // the same sign.
const bool is_valid = if (std::is_signed<T>::value
std::is_signed<T>::value ? static_cast<SignedDst>((uresult ^ ux) & (ux ^ uy)) < 0
? static_cast<SignedDst>((uresult ^ ux) & (ux ^ uy)) >= 0 : x < y)
: x >= y; return false;
*result = static_cast<T>(uresult); *result = static_cast<T>(uresult);
return is_valid; return true;
} }
template <typename T, typename U, class Enable = void> template <typename T, typename U, class Enable = void>
...@@ -140,9 +141,10 @@ struct CheckedSubOp<T, ...@@ -140,9 +141,10 @@ struct CheckedSubOp<T,
is_valid = CheckedSubImpl(static_cast<Promotion>(x), is_valid = CheckedSubImpl(static_cast<Promotion>(x),
static_cast<Promotion>(y), &presult); static_cast<Promotion>(y), &presult);
} }
is_valid &= IsValueInRangeForNumericType<V>(presult); if (!is_valid || !IsValueInRangeForNumericType<V>(presult))
return false;
*result = static_cast<V>(presult); *result = static_cast<V>(presult);
return is_valid; return true;
} }
}; };
...@@ -161,11 +163,11 @@ constexpr bool CheckedMulImpl(T x, T y, T* result) { ...@@ -161,11 +163,11 @@ constexpr bool CheckedMulImpl(T x, T y, T* result) {
// We have a fast out for unsigned identity or zero on the second operand. // We have a fast out for unsigned identity or zero on the second operand.
// After that it's an unsigned overflow check on the absolute value, with // After that it's an unsigned overflow check on the absolute value, with
// a +1 bound for a negative result. // a +1 bound for a negative result.
const bool is_valid = if (uy > UnsignedDst(!std::is_signed<T>::value || is_negative) &&
uy <= UnsignedDst(!std::is_signed<T>::value || is_negative) || ux > (std::numeric_limits<T>::max() + UnsignedDst(is_negative)) / uy)
ux <= (std::numeric_limits<T>::max() + UnsignedDst(is_negative)) / uy; return false;
*result = is_negative ? 0 - uresult : uresult; *result = is_negative ? 0 - uresult : uresult;
return is_valid; return true;
} }
template <typename T, typename U, class Enable = void> template <typename T, typename U, class Enable = void>
...@@ -202,9 +204,10 @@ struct CheckedMulOp<T, ...@@ -202,9 +204,10 @@ struct CheckedMulOp<T,
is_valid = CheckedMulImpl(static_cast<Promotion>(x), is_valid = CheckedMulImpl(static_cast<Promotion>(x),
static_cast<Promotion>(y), &presult); static_cast<Promotion>(y), &presult);
} }
is_valid &= IsValueInRangeForNumericType<V>(presult); if (!is_valid || !IsValueInRangeForNumericType<V>(presult))
return false;
*result = static_cast<V>(presult); *result = static_cast<V>(presult);
return is_valid; return true;
} }
}; };
...@@ -244,9 +247,10 @@ struct CheckedDivOp<T, ...@@ -244,9 +247,10 @@ struct CheckedDivOp<T,
} }
const Promotion presult = Promotion(x) / Promotion(y); const Promotion presult = Promotion(x) / Promotion(y);
const bool is_valid = IsValueInRangeForNumericType<V>(presult); if (!IsValueInRangeForNumericType<V>(presult))
return false;
*result = static_cast<V>(presult); *result = static_cast<V>(presult);
return is_valid; return true;
} }
}; };
...@@ -277,9 +281,10 @@ struct CheckedModOp<T, ...@@ -277,9 +281,10 @@ struct CheckedModOp<T,
const Promotion presult = const Promotion presult =
static_cast<Promotion>(x) % static_cast<Promotion>(y); static_cast<Promotion>(x) % static_cast<Promotion>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(presult); if (!IsValueInRangeForNumericType<V>(presult))
return false;
*result = static_cast<Promotion>(presult); *result = static_cast<Promotion>(presult);
return is_valid; return true;
} }
}; };
...@@ -337,9 +342,10 @@ struct CheckedRshOp<T, ...@@ -337,9 +342,10 @@ struct CheckedRshOp<T,
} }
const T tmp = x >> shift; const T tmp = x >> shift;
const bool is_valid = IsValueInRangeForNumericType<V>(tmp); if (!IsValueInRangeForNumericType<V>(tmp))
return false;
*result = static_cast<V>(tmp); *result = static_cast<V>(tmp);
return is_valid; return true;
} }
}; };
...@@ -358,9 +364,10 @@ struct CheckedAndOp<T, ...@@ -358,9 +364,10 @@ struct CheckedAndOp<T,
static constexpr bool Do(T x, U y, V* result) { static constexpr bool Do(T x, U y, V* result) {
const result_type tmp = const result_type tmp =
static_cast<result_type>(x) & static_cast<result_type>(y); static_cast<result_type>(x) & static_cast<result_type>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(tmp); if (!IsValueInRangeForNumericType<V>(tmp))
return false;
*result = static_cast<V>(tmp); *result = static_cast<V>(tmp);
return is_valid; return true;
} }
}; };
...@@ -379,9 +386,10 @@ struct CheckedOrOp<T, ...@@ -379,9 +386,10 @@ struct CheckedOrOp<T,
static constexpr bool Do(T x, U y, V* result) { static constexpr bool Do(T x, U y, V* result) {
const result_type tmp = const result_type tmp =
static_cast<result_type>(x) | static_cast<result_type>(y); static_cast<result_type>(x) | static_cast<result_type>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(tmp); if (!IsValueInRangeForNumericType<V>(tmp))
return false;
*result = static_cast<V>(tmp); *result = static_cast<V>(tmp);
return is_valid; return true;
} }
}; };
...@@ -400,9 +408,10 @@ struct CheckedXorOp<T, ...@@ -400,9 +408,10 @@ struct CheckedXorOp<T,
static constexpr bool Do(T x, U y, V* result) { static constexpr bool Do(T x, U y, V* result) {
const result_type tmp = const result_type tmp =
static_cast<result_type>(x) ^ static_cast<result_type>(y); static_cast<result_type>(x) ^ static_cast<result_type>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(tmp); if (!IsValueInRangeForNumericType<V>(tmp))
return false;
*result = static_cast<V>(tmp); *result = static_cast<V>(tmp);
return is_valid; return true;
} }
}; };
...@@ -423,9 +432,10 @@ struct CheckedMaxOp< ...@@ -423,9 +432,10 @@ struct CheckedMaxOp<
const result_type tmp = IsGreater<T, U>::Test(x, y) const result_type tmp = IsGreater<T, U>::Test(x, y)
? static_cast<result_type>(x) ? static_cast<result_type>(x)
: static_cast<result_type>(y); : static_cast<result_type>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(tmp); if (!IsValueInRangeForNumericType<V>(tmp))
return false;
*result = static_cast<V>(tmp); *result = static_cast<V>(tmp);
return is_valid; return true;
} }
}; };
...@@ -446,9 +456,10 @@ struct CheckedMinOp< ...@@ -446,9 +456,10 @@ struct CheckedMinOp<
const result_type tmp = IsLess<T, U>::Test(x, y) const result_type tmp = IsLess<T, U>::Test(x, y)
? static_cast<result_type>(x) ? static_cast<result_type>(x)
: static_cast<result_type>(y); : static_cast<result_type>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(tmp); if (!IsValueInRangeForNumericType<V>(tmp))
return false;
*result = static_cast<V>(tmp); *result = static_cast<V>(tmp);
return is_valid; return true;
} }
}; };
...@@ -465,9 +476,10 @@ struct CheckedMinOp< ...@@ -465,9 +476,10 @@ struct CheckedMinOp<
static constexpr bool Do(T x, U y, V* result) { \ static constexpr bool Do(T x, U y, V* result) { \
using Promotion = typename MaxExponentPromotion<T, U>::type; \ using Promotion = typename MaxExponentPromotion<T, U>::type; \
const Promotion presult = x OP y; \ const Promotion presult = x OP y; \
const bool is_valid = IsValueInRangeForNumericType<V>(presult); \ if (!IsValueInRangeForNumericType<V>(presult)) \
return false; \
*result = static_cast<V>(presult); \ *result = static_cast<V>(presult); \
return is_valid; \ return true; \
} \ } \
}; };
......
...@@ -36,9 +36,10 @@ struct CheckedMulFastAsmOp { ...@@ -36,9 +36,10 @@ struct CheckedMulFastAsmOp {
Promotion presult; Promotion presult;
presult = static_cast<Promotion>(x) * static_cast<Promotion>(y); presult = static_cast<Promotion>(x) * static_cast<Promotion>(y);
const bool is_valid = IsValueInRangeForNumericType<V>(presult); if (!IsValueInRangeForNumericType<V>(presult))
return false;
*result = static_cast<V>(presult); *result = static_cast<V>(presult);
return is_valid; return true;
} }
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment