1 | // SPDX-License-Identifier: GPL-3.0-or-later |
2 | |
3 | #include "librpc/rpc_client.h" |
4 | |
5 | #include "librpc/internal.h" |
6 | #include "librpc/rpc.h" |
7 | |
8 | #include <atomic> |
9 | #include <libipc/ipc.h> |
10 | #include <mos/types.h> |
11 | #include <pb_decode.h> |
12 | #include <pb_encode.h> |
13 | #include <stdarg.h> |
14 | |
15 | #if defined(__MOS_KERNEL__) |
16 | #include <mos/lib/sync/mutex.hpp> |
17 | #include <mos_stdio.hpp> |
18 | #include <mos_stdlib.hpp> |
19 | #include <mos_string.hpp> |
20 | #else |
21 | #include <stdio.h> |
22 | #include <stdlib.h> |
23 | #include <string.h> |
24 | #endif |
25 | |
26 | #ifdef __MOS_KERNEL__ |
27 | #include "mos/assert.hpp" |
28 | #include "mos/ipc/ipc_io.hpp" |
29 | |
30 | #include <mos/platform/platform.hpp> |
31 | #include <mos/syscall/decl.h> |
32 | #define syscall_ipc_connect(n, s) ipc_connect(n, s).get() |
33 | #define syscall_io_close(fd) io_unref(fd) |
34 | #else |
35 | #include <mos/syscall/usermode.h> |
36 | #endif |
37 | |
38 | #if !defined(__MOS_KERNEL__) |
39 | // fixup for hosted libc |
40 | #include <pthread.h> |
41 | typedef pthread_mutex_t mutex_t; |
42 | #define mutex_acquire(mutex) pthread_mutex_lock(mutex) |
43 | #define mutex_release(mutex) pthread_mutex_unlock(mutex) |
44 | #define mos_warn(...) fprintf(stderr, __VA_ARGS__) |
45 | #define MOS_LIB_UNREACHABLE() __builtin_unreachable() |
46 | #endif |
47 | |
48 | #define RPC_CLIENT_SMH_SIZE MOS_PAGE_SIZE |
49 | |
50 | typedef struct rpc_server_stub |
51 | { |
52 | const char *server_name; |
53 | ipcfd_t fd; |
54 | mutex_t mutex; // only one call at a time |
55 | std::atomic_size_t callid; |
56 | } rpc_server_stub_t; |
57 | |
58 | typedef struct rpc_call |
59 | { |
60 | rpc_server_stub_t *server; |
61 | rpc_request_t *request; |
62 | size_t size; |
63 | mutex_t mutex; |
64 | } rpc_call_t; |
65 | |
66 | rpc_server_stub_t *rpc_client_create(const char *server_name) |
67 | { |
68 | rpc_server_stub_t *client = (rpc_server_stub_t *) calloc(nmemb: 1, size: sizeof(rpc_server_stub_t)); |
69 | client->server_name = server_name; |
70 | client->fd = syscall_ipc_connect(server_name, RPC_CLIENT_SMH_SIZE); |
71 | |
72 | if (IS_ERR_VALUE(client->fd)) |
73 | { |
74 | free(ptr: client); |
75 | return NULL; |
76 | } |
77 | |
78 | return client; |
79 | } |
80 | |
81 | void rpc_client_destroy(rpc_server_stub_t *server) |
82 | { |
83 | mutex_acquire(mutex: &server->mutex); |
84 | syscall_io_close(server->fd); |
85 | free(ptr: server); |
86 | } |
87 | |
88 | rpc_call_t *rpc_call_create(rpc_server_stub_t *server, u32 function_id) |
89 | { |
90 | rpc_call_t *call = (rpc_call_t *) calloc(nmemb: 1, size: sizeof(rpc_call_t)); |
91 | call->request = (rpc_request_t *) calloc(nmemb: 1, size: sizeof(rpc_request_t)); |
92 | call->request->magic = RPC_REQUEST_MAGIC; |
93 | call->request->function_id = function_id; |
94 | call->size = sizeof(rpc_request_t); |
95 | call->server = server; |
96 | |
97 | return call; |
98 | } |
99 | |
100 | void rpc_call_destroy(rpc_call_t *call) |
101 | { |
102 | mutex_acquire(mutex: &call->mutex); |
103 | free(ptr: call->request); |
104 | free(ptr: call); |
105 | } |
106 | |
107 | void rpc_call_arg(rpc_call_t *call, rpc_argtype_t argtype, const void *data, size_t size) |
108 | { |
109 | mutex_acquire(mutex: &call->mutex); |
110 | call->request = (rpc_request_t *) realloc(ptr: call->request, size: call->size + sizeof(rpc_arg_t) + size); |
111 | call->request->args_count += 1; |
112 | |
113 | rpc_arg_t *arg = (rpc_arg_t *) &call->request->args_array[call->size - sizeof(rpc_request_t)]; |
114 | arg->size = size; |
115 | arg->argtype = argtype; |
116 | arg->magic = RPC_ARG_MAGIC; |
117 | memcpy(dest: arg->data, src: data, n: size); |
118 | |
119 | call->size += sizeof(rpc_arg_t) + size; |
120 | mutex_release(mutex: &call->mutex); |
121 | } |
122 | |
123 | #define RPC_CALL_ARG_IMPL(type, TYPE) \ |
124 | void rpc_call_arg_##type(rpc_call_t *call, type arg) \ |
125 | { \ |
126 | rpc_call_arg(call, RPC_ARGTYPE_##TYPE, &arg, sizeof(arg)); \ |
127 | } |
128 | |
129 | RPC_CALL_ARG_IMPL(u8, UINT8) |
130 | RPC_CALL_ARG_IMPL(u32, UINT32) |
131 | RPC_CALL_ARG_IMPL(u64, UINT64) |
132 | RPC_CALL_ARG_IMPL(s8, INT8) |
133 | RPC_CALL_ARG_IMPL(s32, INT32) |
134 | RPC_CALL_ARG_IMPL(s64, INT64) |
135 | |
136 | void rpc_call_arg_string(rpc_call_t *call, const char *arg) |
137 | { |
138 | rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator |
139 | } |
140 | |
141 | rpc_result_code_t rpc_call_exec(rpc_call_t *call, void **result_data, size_t *data_size) |
142 | { |
143 | if (result_data && data_size) |
144 | { |
145 | *data_size = 0; |
146 | *result_data = NULL; |
147 | } |
148 | |
149 | mutex_acquire(mutex: &call->mutex); |
150 | mutex_acquire(mutex: &call->server->mutex); |
151 | call->request->call_id = ++call->server->callid; |
152 | |
153 | bool written = ipc_write_as_msg(fd: call->server->fd, data: (char *) call->request, size: call->size); |
154 | if (!written) |
155 | { |
156 | mutex_release(mutex: &call->server->mutex); |
157 | mutex_release(mutex: &call->mutex); |
158 | return RPC_RESULT_CLIENT_WRITE_FAILED; |
159 | } |
160 | |
161 | ipc_msg_t *msg = ipc_read_msg(fd: call->server->fd); |
162 | if (!msg) |
163 | { |
164 | mutex_release(mutex: &call->server->mutex); |
165 | mutex_release(mutex: &call->mutex); |
166 | return RPC_RESULT_CLIENT_READ_FAILED; |
167 | } |
168 | |
169 | if (msg->size < sizeof(rpc_response_t)) |
170 | { |
171 | mutex_release(mutex: &call->server->mutex); |
172 | mutex_release(mutex: &call->mutex); |
173 | return RPC_RESULT_CLIENT_READ_FAILED; |
174 | } |
175 | |
176 | rpc_response_t *response = (rpc_response_t *) msg->data; |
177 | if (response->magic != RPC_RESPONSE_MAGIC) |
178 | { |
179 | mutex_release(mutex: &call->server->mutex); |
180 | mutex_release(mutex: &call->mutex); |
181 | return RPC_RESULT_CLIENT_READ_FAILED; |
182 | } |
183 | |
184 | if (response->call_id != call->request->call_id) |
185 | { |
186 | mutex_release(mutex: &call->server->mutex); |
187 | mutex_release(mutex: &call->mutex); |
188 | return RPC_RESULT_CALLID_MISMATCH; |
189 | } |
190 | |
191 | if (response->result_code != RPC_RESULT_OK) |
192 | { |
193 | mutex_release(mutex: &call->server->mutex); |
194 | mutex_release(mutex: &call->mutex); |
195 | return response->result_code; |
196 | } |
197 | |
198 | if (msg->size < sizeof(rpc_response_t)) |
199 | { |
200 | mutex_release(mutex: &call->server->mutex); |
201 | mutex_release(mutex: &call->mutex); |
202 | return RPC_RESULT_CLIENT_READ_FAILED; |
203 | } |
204 | |
205 | if (result_data && data_size && response->data_size) |
206 | { |
207 | *data_size = response->data_size; |
208 | *result_data = malloc(size: response->data_size); |
209 | memcpy(dest: *result_data, src: response->data, n: response->data_size); |
210 | } |
211 | |
212 | rpc_result_code_t result = response->result_code; |
213 | ipc_msg_destroy(buffer: msg); |
214 | mutex_release(mutex: &call->server->mutex); |
215 | mutex_release(mutex: &call->mutex); |
216 | return result; |
217 | } |
218 | |
219 | rpc_result_code_t rpc_simple_call(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, ...) |
220 | { |
221 | va_list args; |
222 | va_start(args, argspec); |
223 | rpc_result_code_t code = rpc_simple_callv(stub, funcid, result, argspec, args); |
224 | va_end(args); |
225 | return code; |
226 | } |
227 | |
228 | rpc_result_code_t rpc_simple_callv(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, va_list args) |
229 | { |
230 | if (unlikely(!argspec)) |
231 | { |
232 | mos_warn("argspec is NULL" ); |
233 | return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
234 | } |
235 | |
236 | rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid); |
237 | |
238 | if (*argspec == 'v') |
239 | { |
240 | if (*++argspec != '\0') |
241 | { |
242 | mos_warn("argspec is not empty after 'v' (void) (argspec='%s')" , argspec); |
243 | rpc_call_destroy(call); |
244 | return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
245 | } |
246 | goto exec; |
247 | } |
248 | |
249 | for (const char *c = argspec; *c != '\0'; c++) |
250 | { |
251 | switch (*c) |
252 | { |
253 | case 'c': |
254 | { |
255 | u8 arg = va_arg(args, int); |
256 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT8, data: &arg, size: sizeof(arg)); |
257 | break; |
258 | } |
259 | case 'i': |
260 | { |
261 | u32 arg = va_arg(args, int); |
262 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT32, data: &arg, size: sizeof(arg)); |
263 | break; |
264 | } |
265 | case 'l': |
266 | { |
267 | u64 arg = va_arg(args, long long); |
268 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT64, data: &arg, size: sizeof(arg)); |
269 | break; |
270 | } |
271 | case 'f': |
272 | { |
273 | MOS_LIB_UNREACHABLE(); // TODO: implement |
274 | // double arg = va_arg(args, double); |
275 | // rpc_call_arg(call, &arg, sizeof(arg)); |
276 | break; |
277 | } |
278 | case 's': |
279 | { |
280 | const char *arg = va_arg(args, const char *); |
281 | rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator |
282 | break; |
283 | } |
284 | default: mos_warn("rpc_call: invalid argspec '%c'" , *c); return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
285 | } |
286 | } |
287 | |
288 | exec: |
289 | rpc_call_exec(call, result_data: &result->data, data_size: &result->size); |
290 | rpc_call_destroy(call); |
291 | |
292 | return RPC_RESULT_OK; |
293 | } |
294 | |
295 | rpc_result_code_t rpc_do_pb_call(rpc_server_stub_t *stub, u32 funcid, const pb_msgdesc_t *reqm, const void *req, const pb_msgdesc_t *respm, void *resp) |
296 | { |
297 | size_t bufsize; |
298 | pb_get_encoded_size(size: &bufsize, fields: reqm, src_struct: req); |
299 | pb_byte_t buf[bufsize]; |
300 | pb_ostream_t wstream = pb_ostream_from_buffer(buf, bufsize); |
301 | if (!pb_encode(stream: &wstream, fields: reqm, src_struct: req)) |
302 | return RPC_RESULT_CLIENT_WRITE_FAILED; |
303 | |
304 | rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid); |
305 | rpc_call_arg(call, argtype: RPC_ARGTYPE_BUFFER, data: buf, size: wstream.bytes_written); |
306 | |
307 | void *result = NULL; |
308 | size_t result_size = 0; |
309 | rpc_result_code_t result_code = rpc_call_exec(call, result_data: &result, data_size: &result_size); |
310 | rpc_call_destroy(call); |
311 | |
312 | if (!respm || !resp) |
313 | return result_code; // no response expected |
314 | |
315 | if (result_code != RPC_RESULT_OK) |
316 | return result_code; |
317 | |
318 | pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) result, msglen: result_size); |
319 | if (!pb_decode(stream: &stream, fields: respm, dest_struct: resp)) |
320 | return RPC_RESULT_CLIENT_READ_FAILED; |
321 | |
322 | free(ptr: result); |
323 | return RPC_RESULT_OK; |
324 | } |
325 | |