1use rand::{distr::Alphanumeric, Rng};
8use std::{
9 io::{Error, ErrorKind, Read, Result},
10 os::{
11 fd::{AsFd, AsRawFd, OwnedFd},
12 unix::ffi::OsStrExt,
13 },
14 path::Path,
15};
16
17use rustix::{
18 fs::{readlinkat, renameat, symlinkat, unlinkat, AtFlags},
19 io::{Errno, Result as ErrnoResult},
20};
21use tokio::io::{AsyncRead, AsyncReadExt};
22
23pub(crate) fn proc_self_fd(fd: impl AsFd) -> String {
28 format!("/proc/self/fd/{}", fd.as_fd().as_raw_fd())
29}
30
31pub fn read_exactish(reader: &mut impl Read, buf: &mut [u8]) -> Result<bool> {
51 let buflen = buf.len();
52 let mut todo: &mut [u8] = buf;
53
54 while !todo.is_empty() {
55 match reader.read(todo) {
56 Ok(0) => {
57 return match todo.len() {
58 s if s == buflen => Ok(false), _ => Err(Error::from(ErrorKind::UnexpectedEof)),
60 };
61 }
62 Ok(n) => todo = &mut todo[n..],
63 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
64 Err(e) => return Err(e),
65 }
66 }
67
68 Ok(true)
69}
70
71pub async fn read_exactish_async(
76 reader: &mut (impl AsyncRead + Unpin),
77 buf: &mut [u8],
78) -> Result<bool> {
79 let buflen = buf.len();
80 let mut todo: &mut [u8] = buf;
81
82 while !todo.is_empty() {
83 match reader.read(todo).await {
84 Ok(0) => {
85 return match todo.len() {
86 s if s == buflen => Ok(false), _ => Err(ErrorKind::UnexpectedEof.into()),
88 };
89 }
90 Ok(n) => todo = &mut todo[n..],
91 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
92 Err(e) => return Err(e),
93 }
94 }
95
96 Ok(true)
97}
98
99pub type Sha256Digest = [u8; 32];
101
102pub fn parse_sha256(string: impl AsRef<str>) -> Result<Sha256Digest> {
109 let mut value = [0u8; 32];
110 hex::decode_to_slice(string.as_ref(), &mut value)
111 .map_err(|source| Error::new(ErrorKind::InvalidInput, source))?;
112 Ok(value)
113}
114
115pub(crate) trait ErrnoFilter<T> {
116 fn filter_errno(self, ignored: Errno) -> ErrnoResult<Option<T>>;
117}
118
119impl<T> ErrnoFilter<T> for ErrnoResult<T> {
120 fn filter_errno(self, ignored: Errno) -> ErrnoResult<Option<T>> {
121 match self {
122 Ok(result) => Ok(Some(result)),
123 Err(err) if err == ignored => Ok(None),
124 Err(err) => Err(err),
125 }
126 }
127}
128
129fn generate_tmpname(prefix: &str) -> String {
130 let rand_string: String = rand::rng()
131 .sample_iter(&Alphanumeric)
132 .take(12)
133 .map(char::from)
134 .collect();
135 format!("{prefix}{rand_string}")
136}
137
138pub(crate) fn replace_symlinkat(
139 target: impl AsRef<Path>,
140 dirfd: &OwnedFd,
141 name: impl AsRef<Path>,
142) -> ErrnoResult<()> {
143 let name = name.as_ref();
144 let target = target.as_ref();
145
146 if symlinkat(target, dirfd, name)
148 .filter_errno(Errno::EXIST)?
149 .is_some()
150 {
151 return Ok(());
152 };
153
154 if let Some(current_target) = readlinkat(dirfd, name, []).filter_errno(Errno::NOENT)? {
156 if current_target.into_bytes() == target.as_os_str().as_bytes() {
157 return Ok(());
158 }
159 }
160
161 for _ in 0..16 {
163 let tmp_name = generate_tmpname(".symlink-");
164 if symlinkat(target, dirfd, &tmp_name)
165 .filter_errno(Errno::EXIST)?
166 .is_none()
167 {
168 continue;
170 }
171
172 match renameat(dirfd, &tmp_name, dirfd, name) {
173 Ok(_) => return Ok(()),
174 Err(e) => {
175 let _ = unlinkat(dirfd, tmp_name, AtFlags::empty());
176 return Err(e);
177 }
178 }
179 }
180
181 Err(Errno::EXIST)
182}
183
184#[cfg(test)]
185mod test {
186 use similar_asserts::assert_eq;
187
188 use super::*;
189
190 fn read_exactish_common(read9: fn(&mut &[u8]) -> Result<bool>) {
191 let mut r = b"" as &[u8];
193 assert_eq!(read9(&mut r).unwrap(), false);
194 assert_eq!(read9(&mut r).unwrap(), false); r = b"ninebytes";
198 assert_eq!(read9(&mut r).unwrap(), true);
199 assert_eq!(read9(&mut r).unwrap(), false);
200
201 r = b"twelve bytes";
203 assert_eq!(read9(&mut r).unwrap(), true);
204 assert_eq!(read9(&mut r).unwrap_err().kind(), ErrorKind::UnexpectedEof);
205
206 r = b"eighteen(18) bytes";
208 assert_eq!(read9(&mut r).unwrap(), true);
209 assert_eq!(read9(&mut r).unwrap(), true);
210 assert_eq!(read9(&mut r).unwrap(), false);
211 }
212
213 #[test]
214 fn test_read_exactish() {
215 read_exactish_common(|r| read_exactish(r, &mut [0; 9]));
216 }
217
218 #[test]
219 fn test_read_exactish_broken_reader() {
220 struct BrokenReader;
221 impl Read for BrokenReader {
222 fn read(&mut self, _buffer: &mut [u8]) -> Result<usize> {
223 Err(ErrorKind::NetworkDown.into())
224 }
225 }
226
227 assert_eq!(
229 read_exactish(&mut BrokenReader, &mut [0; 9])
230 .unwrap_err()
231 .kind(),
232 ErrorKind::NetworkDown
233 );
234 }
235
236 #[test]
237 fn test_read_exactish_async() {
238 read_exactish_common(|r| {
239 tokio::runtime::Builder::new_current_thread()
240 .enable_all()
241 .build()
242 .unwrap()
243 .block_on(read_exactish_async(r, &mut [0; 9]))
244 });
245 }
246
247 #[tokio::test]
248 async fn test_read_exactish_broken_reader_async() {
249 let mut reader = tokio_test::io::Builder::new()
251 .read_error(Error::from(ErrorKind::NetworkDown))
252 .build();
253
254 assert_eq!(
255 read_exactish_async(&mut reader, &mut [0; 9])
256 .await
257 .unwrap_err()
258 .kind(),
259 ErrorKind::NetworkDown
260 );
261 }
262
263 #[test]
264 fn test_parse_sha256() {
265 let valid = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff";
266 let valid_caps = "00112233445566778899AABBCCDDEEFF00112233445566778899AABBCCDDEEFf";
267 let valid_weird = "00112233445566778899aABbcCDdeEFf00112233445566778899AaBbCcDdEeFf";
268 assert_eq!(hex::encode(parse_sha256(valid).unwrap()), valid);
269 assert_eq!(hex::encode(parse_sha256(valid_caps).unwrap()), valid);
270 assert_eq!(hex::encode(parse_sha256(valid_weird).unwrap()), valid);
271
272 fn assert_invalid(x: &str) {
273 assert_eq!(parse_sha256(x).unwrap_err().kind(), ErrorKind::InvalidInput);
274 }
275
276 assert_invalid("");
278 assert_invalid("/etc/shadow");
280 assert_invalid("00112233445566778899aabbccddeeff00112233445566778899aabbccddeef");
282 assert_invalid("00112233445566778899aabbccddeeff00112233445566778899aabbccddeefff");
284 assert_invalid("00112233445566778899aabbccddeeff00112233445566778899aabbccddeefg");
286 }
287}