1use futures::{stream::FuturesOrdered, StreamExt};
4use jsonrpsee::{
5 batch_response_error,
6 core::{server::helpers::prepare_error, JsonRawValue},
7 server::middleware::rpc::RpcServiceT,
8 types::{
9 error::{reject_too_big_request, ErrorCode},
10 ErrorObject, Id, InvalidRequest, Notification, Request,
11 },
12 BatchResponseBuilder, MethodResponse,
13};
14use std::sync::Arc;
15use tokio::sync::OwnedSemaphorePermit;
16use tokio_util::either::Either;
17use tracing::instrument;
18
19type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>;
20
21#[derive(Debug, Clone)]
22pub(crate) struct Batch<S> {
23 data: Vec<u8>,
24 rpc_service: S,
25}
26
27#[instrument(name = "batch", skip(b), level = "TRACE")]
31pub(crate) async fn process_batch_request<S>(
32 b: Batch<S>,
33 max_response_body_size: usize,
34) -> Option<String>
35where
36 S: RpcServiceT<MethodResponse = MethodResponse> + Send,
37{
38 let Batch { data, rpc_service } = b;
39
40 if let Ok(batch) = serde_json::from_slice::<Vec<&JsonRawValue>>(&data) {
41 let mut got_notif = false;
42 let mut batch_response = BatchResponseBuilder::new_with_limit(max_response_body_size);
43
44 let mut pending_calls: FuturesOrdered<_> = batch
45 .into_iter()
46 .filter_map(|v| {
47 if let Ok(req) = serde_json::from_str::<Request<'_>>(v.get()) {
48 Some(Either::Right(rpc_service.call(req)))
49 } else if let Ok(_notif) = serde_json::from_str::<Notif<'_>>(v.get()) {
50 got_notif = true;
52 None
53 } else {
54 let id = match serde_json::from_str::<InvalidRequest<'_>>(v.get()) {
56 Ok(err) => err.id,
57 Err(_) => Id::Null,
58 };
59
60 Some(Either::Left(async {
61 MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest))
62 }))
63 }
64 })
65 .collect();
66
67 while let Some(response) = pending_calls.next().await {
68 if let Err(too_large) = batch_response.append(response) {
69 return Some(too_large.to_json().to_string())
70 }
71 }
72
73 if got_notif && batch_response.is_empty() {
74 None
75 } else {
76 let batch_resp = batch_response.finish();
77 Some(MethodResponse::from_batch(batch_resp).to_json().to_string())
78 }
79 } else {
80 Some(batch_response_error(Id::Null, ErrorObject::from(ErrorCode::ParseError)).to_string())
81 }
82}
83
84pub(crate) async fn process_single_request<S>(
85 data: Vec<u8>,
86 rpc_service: &S,
87) -> Option<MethodResponse>
88where
89 S: RpcServiceT<MethodResponse = MethodResponse> + Send,
90{
91 if let Ok(req) = serde_json::from_slice::<Request<'_>>(&data) {
92 Some(execute_call_with_tracing(req, rpc_service).await)
93 } else if serde_json::from_slice::<Notif<'_>>(&data).is_ok() {
94 None
95 } else {
96 let (id, code) = prepare_error(&data);
97 Some(MethodResponse::error(id, ErrorObject::from(code)))
98 }
99}
100
101#[instrument(name = "method_call", fields(method = req.method.as_ref()), skip(req, rpc_service), level = "TRACE")]
102pub(crate) async fn execute_call_with_tracing<'a, S>(
103 req: Request<'a>,
104 rpc_service: &S,
105) -> MethodResponse
106where
107 S: RpcServiceT<MethodResponse = MethodResponse> + Send,
108{
109 rpc_service.call(req).await
110}
111
112pub(crate) async fn call_with_service<S>(
113 request: String,
114 rpc_service: S,
115 max_response_body_size: usize,
116 max_request_body_size: usize,
117 conn: Arc<OwnedSemaphorePermit>,
118) -> Option<String>
119where
120 S: RpcServiceT<MethodResponse = MethodResponse> + Send,
121{
122 enum Kind {
123 Single,
124 Batch,
125 }
126
127 let request_kind = request
128 .chars()
129 .find_map(|c| match c {
130 '{' => Some(Kind::Single),
131 '[' => Some(Kind::Batch),
132 _ => None,
133 })
134 .unwrap_or(Kind::Single);
135
136 let data = request.into_bytes();
137 if data.len() > max_request_body_size {
138 return Some(
139 batch_response_error(Id::Null, reject_too_big_request(max_request_body_size as u32))
140 .to_string(),
141 )
142 }
143
144 let res = if matches!(request_kind, Kind::Single) {
146 let response = process_single_request(data, &rpc_service).await;
147 match response {
148 Some(response) if response.is_method_call() => Some(response.to_json().to_string()),
149 _ => {
150 None
153 }
154 }
155 } else {
156 process_batch_request(Batch { data, rpc_service }, max_response_body_size).await
157 };
158
159 drop(conn);
160
161 res
162}