Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix overflow panic in conv (#16970) #16980

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 58 additions & 16 deletions components/tidb_query_expr/src/impl_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,20 @@ pub fn conv(n: BytesRef, from_base: &Int, to_base: &Int) -> Result<Option<Bytes>
let s = s.trim();
let from_base = IntWithSign::from_int(*from_base);
let to_base = IntWithSign::from_int(*to_base);
Ok(if is_valid_base(from_base) && is_valid_base(to_base) {
if is_valid_base(from_base) && is_valid_base(to_base) {
if let Some((num_str, is_neg)) = extract_num_str(s, from_base) {
let num = extract_num(num_str.as_ref(), is_neg, from_base);
Some(num.format_to_base(to_base).into_bytes())
match extract_num(num_str.as_ref(), is_neg, from_base) {
Some(num) => Ok(Some(num.format_to_base(to_base).into_bytes())),
None => {
Err(Error::overflow("BIGINT UNSIGNED", format!("conv({})", num_str)).into())
}
}
} else {
Some(b"0".to_vec())
Ok(Some(b"0".to_vec()))
}
} else {
None
})
Ok(None)
}
}

#[inline]
Expand Down Expand Up @@ -566,7 +570,9 @@ impl IntWithSign {
// Shrink num to fit the boundary of i64.
fn shrink_from_signed_uint(num: u64, is_neg: bool) -> IntWithSign {
let value = if is_neg {
num.min(-Int::min_value() as u64)
// Avoid int64 overflow error.
// -int64_min = int64_max + 1
num.min(Int::max_value() as u64 + 1)
} else {
num.min(Int::max_value() as u64)
};
Expand Down Expand Up @@ -594,7 +600,8 @@ impl IntWithSign {
let IntWithSign(value, is_neg) = self;
let IntWithSign(to_base, should_ignore_sign) = to_base;
let mut real_val = value as i64;
if is_neg && !should_ignore_sign {
// real_val > 0 is to avoid overflow issue when value is -int64_min.
if is_neg && !should_ignore_sign && real_val > 0 {
real_val = -real_val;
}
let mut ret = IntWithSign::format_radix(real_val as u64, to_base as u32);
Expand Down Expand Up @@ -629,14 +636,17 @@ fn extract_num_str(s: &str, from_base: IntWithSign) -> Option<(String, bool)> {
}
}

fn extract_num(num_s: &str, is_neg: bool, from_base: IntWithSign) -> IntWithSign {
fn extract_num(num_s: &str, is_neg: bool, from_base: IntWithSign) -> Option<IntWithSign> {
let IntWithSign(from_base, signed) = from_base;
let value = u64::from_str_radix(num_s, from_base as u32).unwrap();
if signed {
let value = match u64::from_str_radix(num_s, from_base as u32) {
Ok(v) => v,
Err(_) => return None,
};
Some(if signed {
IntWithSign::shrink_from_signed_uint(value, is_neg)
} else {
IntWithSign::from_signed_uint(value, is_neg)
}
})
}

// Returns (isize, is_positive): convert an i64 to usize, and whether the input
Expand Down Expand Up @@ -1605,6 +1615,18 @@ mod tests {
("+", 10, 8, "0"),
("-", 10, 8, "0"),
("", 2, 16, "0"),
(
"18446744073709551615",
10,
2,
"1111111111111111111111111111111111111111111111111111111111111111",
),
(
"-18446744073709551615",
-10,
2,
"1000000000000000000000000000000000000000000000000000000000000000",
),
];
for (n, f, t, e) in tests {
let n = Some(n.as_bytes().to_vec());
Expand All @@ -1621,17 +1643,37 @@ mod tests {
}

let invalid_tests = vec![
(None, Some(10), Some(10), None),
(Some(b"a6a".to_vec()), Some(1), Some(8), None),
(None, Some(10), Some(10)),
(Some(b"111".to_vec()), None, Some(7)),
(Some(b"112".to_vec()), Some(10), None),
(None, None, None),
(Some(b"222".to_vec()), Some(2), Some(100)),
(Some(b"333".to_vec()), Some(37), Some(2)),
(Some(b"a6a".to_vec()), Some(1), Some(8)),
];
for (n, f, t, e) in invalid_tests {
for (n, f, t) in invalid_tests {
let got = RpnFnScalarEvaluator::new()
.push_param(n)
.push_param(f)
.push_param(t)
.evaluate::<Bytes>(ScalarFuncSig::Conv)
.unwrap();
assert_eq!(got, e);
assert_eq!(got, None);
}

let error_tests = vec![
("18446744073709551616", Some(10), Some(10)),
("100000000000000000001", Some(10), Some(8)),
("-18446744073709551616", Some(-10), Some(4)),
];
for (n, f, t) in error_tests {
let n = Some(n.as_bytes().to_vec());
let got = RpnFnScalarEvaluator::new()
.push_param(n)
.push_param(f)
.push_param(t)
.evaluate::<Bytes>(ScalarFuncSig::Conv);
got.unwrap_err();
}
}

Expand Down