Skip to content

Commit c509a42

Browse files
committed
Remove unsafe unwraps in hostcalls.rs
1 parent 5283e57 commit c509a42

File tree

1 file changed

+106
-67
lines changed

1 file changed

+106
-67
lines changed

src/hostcalls.rs

Lines changed: 106 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -156,32 +156,26 @@ fn proxy_get_header_map_pairs(
156156
}
157157

158158
pub fn get_map(map_type: MapType) -> Result<Vec<(String, String)>, Status> {
159-
unsafe {
160-
let mut return_data: *mut u8 = null_mut();
161-
let mut return_size: usize = 0;
162-
match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) {
163-
Status::Ok => {
164-
if !return_data.is_null() {
165-
let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size);
166-
Ok(utils::deserialize_map(&serialized_map))
167-
} else {
168-
Ok(Vec::new())
169-
}
170-
}
171-
status => panic!("unexpected status: {}", status as u32),
172-
}
173-
}
159+
get_map_impl(map_type, |v| String::from_utf8_lossy(&v).into_owned())
174160
}
175161

176162
pub fn get_map_bytes(map_type: MapType) -> Result<Vec<(String, Bytes)>, Status> {
163+
get_map_impl(map_type, |v| v)
164+
}
165+
166+
#[inline]
167+
fn get_map_impl<F, V>(map_type: MapType, value_mapper: F) -> Result<Vec<(String, V)>, Status>
168+
where
169+
F: Fn(Vec<u8>) -> V,
170+
{
177171
unsafe {
178172
let mut return_data: *mut u8 = null_mut();
179173
let mut return_size: usize = 0;
180174
match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) {
181175
Status::Ok => {
182176
if !return_data.is_null() {
183177
let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size);
184-
Ok(utils::deserialize_map_bytes(&serialized_map))
178+
Ok(utils::deserialize_map_impl(&serialized_map, value_mapper).unwrap())
185179
} else {
186180
Ok(Vec::new())
187181
}
@@ -243,12 +237,12 @@ pub fn get_map_value(map_type: MapType, key: &str) -> Result<Option<String>, Sta
243237
Status::Ok => {
244238
if !return_data.is_null() {
245239
Ok(Some(
246-
String::from_utf8(Vec::from_raw_parts(
240+
String::from_utf8_lossy(&Vec::from_raw_parts(
247241
return_data,
248242
return_size,
249243
return_size,
250244
))
251-
.unwrap(),
245+
.into_owned(),
252246
))
253247
} else {
254248
Ok(Some(String::new()))
@@ -1206,7 +1200,23 @@ mod tests {
12061200

12071201
mod utils {
12081202
use crate::types::Bytes;
1209-
use std::convert::TryFrom;
1203+
use std::convert::TryInto;
1204+
use std::fmt::Display;
1205+
1206+
#[derive(Debug)]
1207+
pub enum Error {
1208+
BufferTooShort,
1209+
BufferOverflow,
1210+
}
1211+
1212+
impl Display for Error {
1213+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1214+
match self {
1215+
Error::BufferTooShort => write!(f, "buffer too short"),
1216+
Error::BufferOverflow => write!(f, "buffer overflow"),
1217+
}
1218+
}
1219+
}
12101220

12111221
pub(super) fn serialize_property_path(path: Vec<&str>) -> Bytes {
12121222
if path.is_empty() {
@@ -1265,49 +1275,64 @@ mod utils {
12651275
bytes
12661276
}
12671277

1268-
pub(super) fn deserialize_map(bytes: &[u8]) -> Vec<(String, String)> {
1278+
#[inline]
1279+
pub(super) fn deserialize_map_impl<F, V>(
1280+
bytes: &[u8],
1281+
value_mapper: F,
1282+
) -> Result<Vec<(String, V)>, Error>
1283+
where
1284+
F: Fn(Vec<u8>) -> V,
1285+
{
12691286
if bytes.is_empty() {
1270-
return Vec::new();
1287+
return Ok(Vec::new());
12711288
}
1272-
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[0..4]).unwrap()) as usize;
1273-
let mut map = Vec::with_capacity(size);
1274-
let mut p = 4 + size * 8;
1275-
for n in 0..size {
1276-
let s = 4 + n * 8;
1277-
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s..s + 4]).unwrap()) as usize;
1278-
let key = bytes[p..p + size].to_vec();
1279-
p += size + 1;
1280-
let size =
1281-
u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8]).unwrap()) as usize;
1282-
let value = bytes[p..p + size].to_vec();
1283-
p += size + 1;
1284-
map.push((
1285-
String::from_utf8(key).unwrap(),
1286-
String::from_utf8(value).unwrap(),
1287-
));
1288-
}
1289-
map
1290-
}
12911289

1292-
pub(super) fn deserialize_map_bytes(bytes: &[u8]) -> Vec<(String, Bytes)> {
1293-
if bytes.is_empty() {
1294-
return Vec::new();
1290+
if bytes.len() < 4 {
1291+
return Err(Error::BufferTooShort);
12951292
}
1296-
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[0..4]).unwrap()) as usize;
1293+
1294+
let size = u32::from_le_bytes(bytes[0..4].try_into().map_err(|_| Error::BufferTooShort)?) as usize;
12971295
let mut map = Vec::with_capacity(size);
1298-
let mut p = 4 + size * 8;
1296+
1297+
// check if header is large enough
1298+
let header_size = 4 + size.checked_mul(8).ok_or(Error::BufferOverflow)?;
1299+
if bytes.len() < header_size {
1300+
return Err(Error::BufferTooShort);
1301+
}
1302+
1303+
let mut p = header_size;
1304+
12991305
for n in 0..size {
13001306
let s = 4 + n * 8;
1301-
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s..s + 4]).unwrap()) as usize;
1302-
let key = bytes[p..p + size].to_vec();
1303-
p += size + 1;
1304-
let size =
1305-
u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8]).unwrap()) as usize;
1306-
let value = bytes[p..p + size].to_vec();
1307-
p += size + 1;
1308-
map.push((String::from_utf8(key).unwrap(), value));
1309-
}
1310-
map
1307+
1308+
// read key size
1309+
let key_size = u32::from_le_bytes(bytes[s..s + 4].try_into().unwrap()) as usize;
1310+
let key_end = p.checked_add(key_size).ok_or(Error::BufferOverflow)?;
1311+
if key_end > bytes.len() {
1312+
return Err(Error::BufferTooShort);
1313+
}
1314+
let key = String::from_utf8_lossy(&bytes[p..key_end].to_vec()).into_owned();
1315+
1316+
p = key_end.checked_add(1).ok_or(Error::BufferOverflow)?;
1317+
1318+
// read value size
1319+
let value_size = u32::from_le_bytes(
1320+
bytes[s + 4..s + 8]
1321+
.try_into()
1322+
.map_err(|_| Error::BufferOverflow)?,
1323+
) as usize;
1324+
let value_end = p.checked_add(value_size).ok_or(Error::BufferOverflow)?;
1325+
if value_end > bytes.len() {
1326+
return Err(Error::BufferTooShort);
1327+
}
1328+
let value = bytes[p..value_end].to_vec();
1329+
1330+
p = value_end.checked_add(1).ok_or(Error::BufferOverflow)?;
1331+
1332+
map.push((key, value_mapper(value)));
1333+
}
1334+
1335+
Ok(map)
13111336
}
13121337

13131338
#[cfg(test)]
@@ -1324,6 +1349,9 @@ mod utils {
13241349
("Powered-By", "proxy-wasm"),
13251350
];
13261351

1352+
static BYTES_MAPPER: fn(Bytes) -> Bytes = |v| v;
1353+
static STRING_MAPPER: fn(Bytes) -> String = |v| String::from_utf8_lossy(&v).into_owned();
1354+
13271355
#[rustfmt::skip]
13281356
pub(in crate::hostcalls) static SERIALIZED_MAP: &[u8] = &[
13291357
// num entries
@@ -1354,6 +1382,14 @@ mod utils {
13541382
112, 114, 111, 120, 121, 45, 119, 97, 115, 109, 0,
13551383
];
13561384

1385+
fn deserialize_map_strings(bytes: &[u8]) -> Vec<(String, String)> {
1386+
deserialize_map_impl(bytes, STRING_MAPPER).expect("deserialize_map failed")
1387+
}
1388+
1389+
fn deserialize_map_bytes(bytes: &[u8]) -> Vec<(String, Bytes)> {
1390+
deserialize_map_impl(bytes, BYTES_MAPPER).expect("deserialize_map failed")
1391+
}
1392+
13571393
#[test]
13581394
fn test_serialize_map_empty() {
13591395
let serialized_map = serialize_map(&[]);
@@ -1368,9 +1404,9 @@ mod utils {
13681404

13691405
#[test]
13701406
fn test_deserialize_map_empty() {
1371-
let map = deserialize_map(&[]);
1407+
let map = deserialize_map_strings(&[]);
13721408
assert_eq!(map, []);
1373-
let map = deserialize_map(&[0, 0, 0, 0]);
1409+
let map = deserialize_map_strings(&[0, 0, 0, 0]);
13741410
assert_eq!(map, []);
13751411
}
13761412

@@ -1397,7 +1433,7 @@ mod utils {
13971433

13981434
#[test]
13991435
fn test_deserialize_map() {
1400-
let map = deserialize_map(SERIALIZED_MAP);
1436+
let map = deserialize_map_strings(SERIALIZED_MAP);
14011437
assert_eq!(map.len(), MAP.len());
14021438
for (got, expected) in map.into_iter().zip(MAP) {
14031439
assert_eq!(got.0, expected.0);
@@ -1417,7 +1453,7 @@ mod utils {
14171453

14181454
#[test]
14191455
fn test_deserialize_map_roundtrip() {
1420-
let map = deserialize_map(SERIALIZED_MAP);
1456+
let map = deserialize_map_strings(SERIALIZED_MAP);
14211457
// TODO(v0.3): fix arguments, so that maps can be reused without conversion.
14221458
let map_refs: Vec<(&str, &str)> =
14231459
map.iter().map(|x| (x.0.as_ref(), x.1.as_ref())).collect();
@@ -1440,21 +1476,24 @@ mod utils {
14401476
// 0x00-0x7f are valid single-byte UTF-8 characters.
14411477
for i in 0..0x7f {
14421478
let serialized_src = [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 99, 0, i, 0];
1443-
let map = deserialize_map(&serialized_src);
1479+
let map = deserialize_map_strings(&serialized_src);
14441480
// TODO(v0.3): fix arguments, so that maps can be reused without conversion.
14451481
let map_refs: Vec<(&str, &str)> =
14461482
map.iter().map(|x| (x.0.as_ref(), x.1.as_ref())).collect();
14471483
let serialized_map = serialize_map(&map_refs);
1448-
assert_eq!(serialized_map, serialized_src);
1484+
assert_eq!(serialized_map, serialized_src, "Failed at i={}", i);
14491485
}
14501486
// 0x80-0xff are invalid single-byte UTF-8 characters.
14511487
for i in 0x80..0xff {
14521488
let serialized_src = [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 99, 0, i, 0];
1453-
std::panic::set_hook(Box::new(|_| {}));
1454-
let result = std::panic::catch_unwind(|| {
1455-
deserialize_map(&serialized_src);
1456-
});
1457-
assert!(result.is_err());
1489+
let map = deserialize_map_strings(&serialized_src);
1490+
1491+
// Invalid UTF-8 bytes should be replaced with the replacement character U+FFFD.
1492+
assert!(
1493+
map[0].1.contains('�'),
1494+
"Expected replacement character for byte 0x{:02x}",
1495+
i
1496+
);
14581497
}
14591498
}
14601499

@@ -1495,7 +1534,7 @@ mod utils {
14951534
fn bench_deserialize_map(b: &mut Bencher) {
14961535
let serialized_map = SERIALIZED_MAP.to_vec();
14971536
b.iter(|| {
1498-
deserialize_map(test::black_box(&serialized_map));
1537+
deserialize_map_strings(test::black_box(&serialized_map));
14991538
});
15001539
}
15011540

0 commit comments

Comments
 (0)