1use anyhow::Result;
5use canon_json::CanonJsonSerialize;
6use schemars::JsonSchema;
7use serde::Serialize;
8use std::borrow::Cow;
9use std::os::fd::{FromRawFd, OwnedFd, RawFd};
10use std::str::FromStr;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::io::{AsyncWriteExt, BufWriter};
14use tokio::net::unix::pipe::Sender;
15use tokio::sync::Mutex;
16
17const REFRESH_HZ: u16 = 5;
19
20const API_VERSION: &str = "0.1.0";
22
23#[derive(
27 Debug, serde::Serialize, serde::Deserialize, Default, Clone, JsonSchema, PartialEq, Eq,
28)]
29#[serde(rename_all = "camelCase")]
30pub struct SubTaskBytes<'t> {
31 #[serde(borrow)]
34 pub subtask: Cow<'t, str>,
35 #[serde(borrow)]
38 pub description: Cow<'t, str>,
39 #[serde(borrow)]
42 pub id: Cow<'t, str>,
43 pub bytes_cached: u64,
45 pub bytes: u64,
47 pub bytes_total: u64,
49}
50
51#[derive(
53 Debug, serde::Serialize, serde::Deserialize, Default, Clone, JsonSchema, PartialEq, Eq,
54)]
55#[serde(rename_all = "camelCase")]
56pub struct SubTaskStep<'t> {
57 #[serde(borrow)]
60 pub subtask: Cow<'t, str>,
61 #[serde(borrow)]
64 pub description: Cow<'t, str>,
65 #[serde(borrow)]
68 pub id: Cow<'t, str>,
69 pub completed: bool,
71}
72
73#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema, PartialEq, Eq)]
75#[serde(
76 tag = "type",
77 rename_all = "PascalCase",
78 rename_all_fields = "camelCase"
79)]
80pub enum Event<'t> {
81 Start {
82 #[serde(borrow)]
84 version: Cow<'t, str>,
85 },
86 ProgressBytes {
88 #[serde(borrow)]
91 task: Cow<'t, str>,
92 #[serde(borrow)]
94 description: Cow<'t, str>,
95 #[serde(borrow)]
99 id: Cow<'t, str>,
100 bytes_cached: u64,
102 bytes: u64,
104 bytes_total: u64,
106 steps_cached: u64,
108 steps: u64,
110 steps_total: u64,
112 subtasks: Vec<SubTaskBytes<'t>>,
114 },
115 ProgressSteps {
117 #[serde(borrow)]
120 task: Cow<'t, str>,
121 #[serde(borrow)]
123 description: Cow<'t, str>,
124 #[serde(borrow)]
128 id: Cow<'t, str>,
129 steps_cached: u64,
131 steps: u64,
133 steps_total: u64,
135 subtasks: Vec<SubTaskStep<'t>>,
137 },
138}
139
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub(crate) struct RawProgressFd(RawFd);
142
143impl FromStr for RawProgressFd {
144 type Err = anyhow::Error;
145
146 fn from_str(s: &str) -> Result<Self> {
147 let fd = s.parse::<u32>()?;
148 if matches!(fd, 0..=2) {
150 anyhow::bail!("Cannot use fd {fd} for progress JSON")
151 }
152 Ok(Self(fd.try_into()?))
153 }
154}
155
156#[derive(Debug)]
157struct ProgressWriterInner {
158 sent_start: bool,
160 last_write: Option<std::time::Instant>,
161 fd: BufWriter<Sender>,
162}
163
164#[derive(Clone, Debug, Default)]
165pub(crate) struct ProgressWriter {
166 inner: Arc<Mutex<Option<ProgressWriterInner>>>,
167}
168
169impl TryFrom<OwnedFd> for ProgressWriter {
170 type Error = anyhow::Error;
171
172 fn try_from(value: OwnedFd) -> Result<Self> {
173 let value = Sender::from_owned_fd(value)?;
174 Ok(Self::from(value))
175 }
176}
177
178impl From<Sender> for ProgressWriter {
179 fn from(value: Sender) -> Self {
180 let inner = ProgressWriterInner {
181 sent_start: false,
182 last_write: None,
183 fd: BufWriter::new(value),
184 };
185 Self {
186 inner: Arc::new(Some(inner).into()),
187 }
188 }
189}
190
191impl TryFrom<RawProgressFd> for ProgressWriter {
192 type Error = anyhow::Error;
193
194 #[allow(unsafe_code)]
195 fn try_from(fd: RawProgressFd) -> Result<Self> {
196 unsafe { OwnedFd::from_raw_fd(fd.0) }.try_into()
197 }
198}
199
200impl ProgressWriter {
201 async fn send_impl_inner<T: Serialize>(inner: &mut ProgressWriterInner, v: T) -> Result<()> {
203 let buf = v.to_canon_json_vec()?;
205 inner.fd.write_all(&buf).await?;
206 inner.fd.write_all(b"\n").await?;
208 inner.fd.flush().await?;
210 Ok(())
211 }
212
213 pub(crate) async fn send_impl<T: Serialize>(&self, v: T, required: bool) -> Result<()> {
215 let mut guard = self.inner.lock().await;
216 let Some(inner) = guard.as_mut() else {
218 return Ok(());
219 };
220
221 if !inner.sent_start {
223 inner.sent_start = true;
224 let start = Event::Start {
225 version: API_VERSION.into(),
226 };
227 Self::send_impl_inner(inner, &start).await?;
228 }
229
230 let now = Instant::now();
233 if !required {
234 const REFRESH_MS: u32 = 1000 / REFRESH_HZ as u32;
235 if let Some(elapsed) = inner.last_write.map(|w| now.duration_since(w)) {
236 if elapsed.as_millis() < REFRESH_MS.into() {
237 return Ok(());
238 }
239 }
240 }
241
242 Self::send_impl_inner(inner, &v).await?;
243 inner.last_write = Some(now);
245 Ok(())
246 }
247
248 pub(crate) async fn send(&self, event: Event<'_>) {
250 if let Err(e) = self.send_impl(event, true).await {
251 eprintln!("Failed to write to jsonl: {e}");
252 let _ = self.inner.lock().await.take();
255 }
256 }
257
258 pub(crate) async fn send_lossy(&self, event: Event<'_>) {
260 if let Err(e) = self.send_impl(event, false).await {
261 eprintln!("Failed to write to jsonl: {e}");
262 let _ = self.inner.lock().await.take();
265 }
266 }
267
268 #[allow(dead_code)]
270 pub(crate) async fn into_inner(self) -> Result<Option<Sender>> {
271 let mut mutex = self.inner.lock().await;
273 if let Some(inner) = mutex.take() {
274 Ok(Some(inner.fd.into_inner()))
275 } else {
276 Ok(None)
277 }
278 }
279}
280
281#[cfg(test)]
282mod test {
283 use tokio::io::{AsyncBufReadExt, BufReader};
284
285 use super::*;
286
287 #[tokio::test]
288 async fn test_jsonl() -> Result<()> {
289 let testvalues = [
290 Event::ProgressSteps {
291 task: "sometask".into(),
292 description: "somedesc".into(),
293 id: "someid".into(),
294 steps_cached: 0,
295 steps: 0,
296 steps_total: 3,
297 subtasks: Vec::new(),
298 },
299 Event::ProgressBytes {
300 task: "sometask".into(),
301 description: "somedesc".into(),
302 id: "someid".into(),
303 bytes_cached: 0,
304 bytes: 11,
305 bytes_total: 42,
306 steps_cached: 0,
307 steps: 0,
308 steps_total: 3,
309 subtasks: Vec::new(),
310 },
311 ];
312 let (send, recv) = tokio::net::unix::pipe::pipe()?;
313 let testvalues_sender = testvalues.iter().cloned();
314 let sender = async move {
315 let w = ProgressWriter::try_from(send)?;
316 for value in testvalues_sender {
317 w.send(value).await;
318 }
319 anyhow::Ok(())
320 };
321 let testvalues = &testvalues;
322 let receiver = async move {
323 let tf = BufReader::new(recv);
324 let mut expected = testvalues.iter();
325 let mut lines = tf.lines();
326 let mut got_first = false;
327 while let Some(line) = lines.next_line().await? {
328 let found: Event = serde_json::from_str(&line)?;
329 let expected_value = if !got_first {
330 got_first = true;
331 &Event::Start {
332 version: API_VERSION.into(),
333 }
334 } else {
335 expected.next().unwrap()
336 };
337 assert_eq!(&found, expected_value);
338 }
339 anyhow::Ok(())
340 };
341 tokio::try_join!(sender, receiver)?;
342 Ok(())
343 }
344}