Skip to content

Commit

Permalink
expression: fix overflow panic in conv (#16970) (#16980)
Browse files Browse the repository at this point in the history
close #16969

fix overflow panic in `conv`

Signed-off-by: gengliqi <gengliqiii@gmail.com>

Co-authored-by: gengliqi <gengliqiii@gmail.com>
Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed May 10, 2024
1 parent a0b1254 commit 353aa68
Showing 1 changed file with 58 additions and 16 deletions.
74 changes: 58 additions & 16 deletions components/tidb_query_expr/src/impl_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,20 @@ pub fn conv(n: BytesRef, from_base: &Int, to_base: &Int) -> Result<Option<Bytes>
let s = s.trim();
let from_base = IntWithSign::from_int(*from_base);
let to_base = IntWithSign::from_int(*to_base);
Ok(if is_valid_base(from_base) && is_valid_base(to_base) {
if is_valid_base(from_base) && is_valid_base(to_base) {
if let Some((num_str, is_neg)) = extract_num_str(s, from_base) {
let num = extract_num(num_str.as_ref(), is_neg, from_base);
Some(num.format_to_base(to_base).into_bytes())
match extract_num(num_str.as_ref(), is_neg, from_base) {
Some(num) => Ok(Some(num.format_to_base(to_base).into_bytes())),
None => {
Err(Error::overflow("BIGINT UNSIGNED", format!("conv({})", num_str)).into())
}
}
} else {
Some(b"0".to_vec())
Ok(Some(b"0".to_vec()))
}
} else {
None
})
Ok(None)
}
}

#[inline]
Expand Down Expand Up @@ -566,7 +570,9 @@ impl IntWithSign {
// Shrink num to fit the boundary of i64.
fn shrink_from_signed_uint(num: u64, is_neg: bool) -> IntWithSign {
let value = if is_neg {
num.min(-Int::min_value() as u64)
// Avoid int64 overflow error.
// -int64_min = int64_max + 1
num.min(Int::max_value() as u64 + 1)
} else {
num.min(Int::max_value() as u64)
};
Expand Down Expand Up @@ -594,7 +600,8 @@ impl IntWithSign {
let IntWithSign(value, is_neg) = self;
let IntWithSign(to_base, should_ignore_sign) = to_base;
let mut real_val = value as i64;
if is_neg && !should_ignore_sign {
// real_val > 0 is to avoid overflow issue when value is -int64_min.
if is_neg && !should_ignore_sign && real_val > 0 {
real_val = -real_val;
}
let mut ret = IntWithSign::format_radix(real_val as u64, to_base as u32);
Expand Down Expand Up @@ -629,14 +636,17 @@ fn extract_num_str(s: &str, from_base: IntWithSign) -> Option<(String, bool)> {
}
}

fn extract_num(num_s: &str, is_neg: bool, from_base: IntWithSign) -> IntWithSign {
fn extract_num(num_s: &str, is_neg: bool, from_base: IntWithSign) -> Option<IntWithSign> {
let IntWithSign(from_base, signed) = from_base;
let value = u64::from_str_radix(num_s, from_base as u32).unwrap();
if signed {
let value = match u64::from_str_radix(num_s, from_base as u32) {
Ok(v) => v,
Err(_) => return None,
};
Some(if signed {
IntWithSign::shrink_from_signed_uint(value, is_neg)
} else {
IntWithSign::from_signed_uint(value, is_neg)
}
})
}

// Returns (isize, is_positive): convert an i64 to usize, and whether the input
Expand Down Expand Up @@ -1605,6 +1615,18 @@ mod tests {
("+", 10, 8, "0"),
("-", 10, 8, "0"),
("", 2, 16, "0"),
(
"18446744073709551615",
10,
2,
"1111111111111111111111111111111111111111111111111111111111111111",
),
(
"-18446744073709551615",
-10,
2,
"1000000000000000000000000000000000000000000000000000000000000000",
),
];
for (n, f, t, e) in tests {
let n = Some(n.as_bytes().to_vec());
Expand All @@ -1621,17 +1643,37 @@ mod tests {
}

let invalid_tests = vec![
(None, Some(10), Some(10), None),
(Some(b"a6a".to_vec()), Some(1), Some(8), None),
(None, Some(10), Some(10)),
(Some(b"111".to_vec()), None, Some(7)),
(Some(b"112".to_vec()), Some(10), None),
(None, None, None),
(Some(b"222".to_vec()), Some(2), Some(100)),
(Some(b"333".to_vec()), Some(37), Some(2)),
(Some(b"a6a".to_vec()), Some(1), Some(8)),
];
for (n, f, t, e) in invalid_tests {
for (n, f, t) in invalid_tests {
let got = RpnFnScalarEvaluator::new()
.push_param(n)
.push_param(f)
.push_param(t)
.evaluate::<Bytes>(ScalarFuncSig::Conv)
.unwrap();
assert_eq!(got, e);
assert_eq!(got, None);
}

let error_tests = vec![
("18446744073709551616", Some(10), Some(10)),
("100000000000000000001", Some(10), Some(8)),
("-18446744073709551616", Some(-10), Some(4)),
];
for (n, f, t) in error_tests {
let n = Some(n.as_bytes().to_vec());
let got = RpnFnScalarEvaluator::new()
.push_param(n)
.push_param(f)
.push_param(t)
.evaluate::<Bytes>(ScalarFuncSig::Conv);
got.unwrap_err();
}
}

Expand Down

0 comments on commit 353aa68

Please sign in to comment.