1 | // SPDX-License-Identifier: GPL-3.0-or-later |
2 | |
3 | #include "librpc/rpc_client.h" |
4 | |
5 | #include "librpc/internal.h" |
6 | #include "librpc/rpc.h" |
7 | |
8 | #include <libipc/ipc.h> |
9 | #include <mos/types.h> |
10 | #include <pb_decode.h> |
11 | #include <pb_encode.h> |
12 | #include <stdarg.h> |
13 | |
14 | #if defined(__MOS_KERNEL__) |
15 | #include <mos/lib/sync/mutex.h> |
16 | #include <mos_stdio.h> |
17 | #include <mos_stdlib.h> |
18 | #include <mos_string.h> |
19 | #else |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | #include <string.h> |
23 | #endif |
24 | |
25 | #ifdef __MOS_KERNEL__ |
26 | #include "mos/ipc/ipc_io.h" |
27 | |
28 | #include <mos/platform/platform.h> |
29 | #include <mos/syscall/decl.h> |
30 | #define syscall_ipc_connect(n, s) ipc_connect(n, s) |
31 | #define syscall_io_close(fd) io_unref(fd) |
32 | #else |
33 | #include <mos/syscall/usermode.h> |
34 | #endif |
35 | |
36 | #if !defined(__MOS_KERNEL__) |
37 | // fixup for hosted libc |
38 | #include <pthread.h> |
39 | typedef pthread_mutex_t mutex_t; |
40 | #define memzero(ptr, size) memset(ptr, 0, size) |
41 | #define mutex_acquire(mutex) pthread_mutex_lock(mutex) |
42 | #define mutex_release(mutex) pthread_mutex_unlock(mutex) |
43 | #define mos_warn(...) fprintf(stderr, __VA_ARGS__) |
44 | #define MOS_LIB_UNREACHABLE() __builtin_unreachable() |
45 | #endif |
46 | |
47 | #define RPC_CLIENT_SMH_SIZE MOS_PAGE_SIZE |
48 | |
49 | typedef struct rpc_server_stub |
50 | { |
51 | const char *server_name; |
52 | ipcfd_t fd; |
53 | mutex_t mutex; // only one call at a time |
54 | atomic_t callid; |
55 | } rpc_server_stub_t; |
56 | |
57 | typedef struct rpc_call |
58 | { |
59 | rpc_server_stub_t *server; |
60 | rpc_request_t *request; |
61 | size_t size; |
62 | mutex_t mutex; |
63 | } rpc_call_t; |
64 | |
65 | rpc_server_stub_t *rpc_client_create(const char *server_name) |
66 | { |
67 | rpc_server_stub_t *client = malloc(sizeof(rpc_server_stub_t)); |
68 | memzero(s: client, n: sizeof(rpc_server_stub_t)); |
69 | client->server_name = server_name; |
70 | client->fd = syscall_ipc_connect(server_name, RPC_CLIENT_SMH_SIZE); |
71 | |
72 | if (IS_ERR_VALUE(client->fd)) |
73 | { |
74 | free(client); |
75 | return NULL; |
76 | } |
77 | |
78 | return client; |
79 | } |
80 | |
81 | void rpc_client_destroy(rpc_server_stub_t *server) |
82 | { |
83 | mutex_acquire(mutex: &server->mutex); |
84 | syscall_io_close(server->fd); |
85 | free(server); |
86 | } |
87 | |
88 | rpc_call_t *rpc_call_create(rpc_server_stub_t *server, u32 function_id) |
89 | { |
90 | rpc_call_t *call = malloc(sizeof(rpc_call_t)); |
91 | memzero(s: call, n: sizeof(rpc_call_t)); |
92 | |
93 | call->request = malloc(sizeof(rpc_request_t)); |
94 | memzero(s: call->request, n: sizeof(rpc_request_t)); |
95 | call->request->magic = RPC_REQUEST_MAGIC; |
96 | call->request->function_id = function_id; |
97 | call->size = sizeof(rpc_request_t); |
98 | call->server = server; |
99 | |
100 | return call; |
101 | } |
102 | |
103 | void rpc_call_destroy(rpc_call_t *call) |
104 | { |
105 | mutex_acquire(mutex: &call->mutex); |
106 | free(call->request); |
107 | free(call); |
108 | } |
109 | |
110 | void rpc_call_arg(rpc_call_t *call, rpc_argtype_t argtype, const void *data, size_t size) |
111 | { |
112 | mutex_acquire(mutex: &call->mutex); |
113 | call->request = realloc(call->request, call->size + sizeof(rpc_arg_t) + size); |
114 | call->request->args_count += 1; |
115 | |
116 | rpc_arg_t *arg = (rpc_arg_t *) &call->request->args_array[call->size - sizeof(rpc_request_t)]; |
117 | arg->size = size; |
118 | arg->argtype = argtype; |
119 | arg->magic = RPC_ARG_MAGIC; |
120 | memcpy(dest: arg->data, src: data, n: size); |
121 | |
122 | call->size += sizeof(rpc_arg_t) + size; |
123 | mutex_release(mutex: &call->mutex); |
124 | } |
125 | |
126 | #define RPC_CALL_ARG_IMPL(type, TYPE) \ |
127 | void rpc_call_arg_##type(rpc_call_t *call, type arg) \ |
128 | { \ |
129 | rpc_call_arg(call, RPC_ARGTYPE_##TYPE, &arg, sizeof(arg)); \ |
130 | } |
131 | |
132 | RPC_CALL_ARG_IMPL(u8, UINT8) |
133 | RPC_CALL_ARG_IMPL(u32, UINT32) |
134 | RPC_CALL_ARG_IMPL(u64, UINT64) |
135 | RPC_CALL_ARG_IMPL(s8, INT8) |
136 | RPC_CALL_ARG_IMPL(s32, INT32) |
137 | RPC_CALL_ARG_IMPL(s64, INT64) |
138 | |
139 | void rpc_call_arg_string(rpc_call_t *call, const char *arg) |
140 | { |
141 | rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator |
142 | } |
143 | |
144 | rpc_result_code_t rpc_call_exec(rpc_call_t *call, void **result_data, size_t *data_size) |
145 | { |
146 | if (result_data && data_size) |
147 | { |
148 | *data_size = 0; |
149 | *result_data = NULL; |
150 | } |
151 | |
152 | mutex_acquire(mutex: &call->mutex); |
153 | mutex_acquire(mutex: &call->server->mutex); |
154 | call->request->call_id = ++call->server->callid; |
155 | |
156 | bool written = ipc_write_as_msg(fd: call->server->fd, data: (char *) call->request, size: call->size); |
157 | if (!written) |
158 | { |
159 | mutex_release(mutex: &call->server->mutex); |
160 | mutex_release(mutex: &call->mutex); |
161 | return RPC_RESULT_CLIENT_WRITE_FAILED; |
162 | } |
163 | |
164 | ipc_msg_t *msg = ipc_read_msg(fd: call->server->fd); |
165 | if (!msg) |
166 | { |
167 | mutex_release(mutex: &call->server->mutex); |
168 | mutex_release(mutex: &call->mutex); |
169 | return RPC_RESULT_CLIENT_READ_FAILED; |
170 | } |
171 | |
172 | if (msg->size < sizeof(rpc_response_t)) |
173 | { |
174 | mutex_release(mutex: &call->server->mutex); |
175 | mutex_release(mutex: &call->mutex); |
176 | return RPC_RESULT_CLIENT_READ_FAILED; |
177 | } |
178 | |
179 | rpc_response_t *response = (rpc_response_t *) msg->data; |
180 | if (response->magic != RPC_RESPONSE_MAGIC) |
181 | { |
182 | mutex_release(mutex: &call->server->mutex); |
183 | mutex_release(mutex: &call->mutex); |
184 | return RPC_RESULT_CLIENT_READ_FAILED; |
185 | } |
186 | |
187 | if (response->call_id != call->request->call_id) |
188 | { |
189 | mutex_release(mutex: &call->server->mutex); |
190 | mutex_release(mutex: &call->mutex); |
191 | return RPC_RESULT_CALLID_MISMATCH; |
192 | } |
193 | |
194 | if (response->result_code != RPC_RESULT_OK) |
195 | { |
196 | mutex_release(mutex: &call->server->mutex); |
197 | mutex_release(mutex: &call->mutex); |
198 | return response->result_code; |
199 | } |
200 | |
201 | if (msg->size < sizeof(rpc_response_t)) |
202 | { |
203 | mutex_release(mutex: &call->server->mutex); |
204 | mutex_release(mutex: &call->mutex); |
205 | return RPC_RESULT_CLIENT_READ_FAILED; |
206 | } |
207 | |
208 | if (result_data && data_size && response->data_size) |
209 | { |
210 | *data_size = response->data_size; |
211 | *result_data = malloc(response->data_size); |
212 | memcpy(dest: *result_data, src: response->data, n: response->data_size); |
213 | } |
214 | |
215 | rpc_result_code_t result = response->result_code; |
216 | ipc_msg_destroy(buffer: msg); |
217 | mutex_release(mutex: &call->server->mutex); |
218 | mutex_release(mutex: &call->mutex); |
219 | return result; |
220 | } |
221 | |
222 | rpc_result_code_t rpc_simple_call(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, ...) |
223 | { |
224 | va_list args; |
225 | va_start(args, argspec); |
226 | rpc_result_code_t code = rpc_simple_callv(stub, funcid, result, argspec, args); |
227 | va_end(args); |
228 | return code; |
229 | } |
230 | |
231 | rpc_result_code_t rpc_simple_callv(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, va_list args) |
232 | { |
233 | if (unlikely(!argspec)) |
234 | { |
235 | mos_warn("argspec is NULL" ); |
236 | return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
237 | } |
238 | |
239 | rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid); |
240 | |
241 | if (*argspec == 'v') |
242 | { |
243 | if (*++argspec != '\0') |
244 | { |
245 | mos_warn("argspec is not empty after 'v' (void) (argspec='%s')" , argspec); |
246 | rpc_call_destroy(call); |
247 | return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
248 | } |
249 | goto exec; |
250 | } |
251 | |
252 | for (const char *c = argspec; *c != '\0'; c++) |
253 | { |
254 | switch (*c) |
255 | { |
256 | case 'c': |
257 | { |
258 | u8 arg = va_arg(args, int); |
259 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT8, data: &arg, size: sizeof(arg)); |
260 | break; |
261 | } |
262 | case 'i': |
263 | { |
264 | u32 arg = va_arg(args, int); |
265 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT32, data: &arg, size: sizeof(arg)); |
266 | break; |
267 | } |
268 | case 'l': |
269 | { |
270 | u64 arg = va_arg(args, long long); |
271 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT64, data: &arg, size: sizeof(arg)); |
272 | break; |
273 | } |
274 | case 'f': |
275 | { |
276 | MOS_LIB_UNREACHABLE(); // TODO: implement |
277 | // double arg = va_arg(args, double); |
278 | // rpc_call_arg(call, &arg, sizeof(arg)); |
279 | break; |
280 | } |
281 | case 's': |
282 | { |
283 | const char *arg = va_arg(args, const char *); |
284 | rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator |
285 | break; |
286 | } |
287 | default: mos_warn("rpc_call: invalid argspec '%c'" , *c); return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
288 | } |
289 | } |
290 | |
291 | exec: |
292 | rpc_call_exec(call, result_data: &result->data, data_size: &result->size); |
293 | rpc_call_destroy(call); |
294 | |
295 | return RPC_RESULT_OK; |
296 | } |
297 | |
298 | rpc_result_code_t rpc_do_pb_call(rpc_server_stub_t *stub, u32 funcid, const pb_msgdesc_t *reqm, const void *req, const pb_msgdesc_t *respm, void *resp) |
299 | { |
300 | size_t bufsize; |
301 | pb_get_encoded_size(size: &bufsize, fields: reqm, src_struct: req); |
302 | pb_byte_t buf[bufsize]; |
303 | pb_ostream_t wstream = pb_ostream_from_buffer(buf, bufsize); |
304 | if (!pb_encode(stream: &wstream, fields: reqm, src_struct: req)) |
305 | return RPC_RESULT_CLIENT_WRITE_FAILED; |
306 | |
307 | rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid); |
308 | rpc_call_arg(call, argtype: RPC_ARGTYPE_BUFFER, data: buf, size: wstream.bytes_written); |
309 | |
310 | void *result = NULL; |
311 | size_t result_size = 0; |
312 | rpc_result_code_t result_code = rpc_call_exec(call, result_data: &result, data_size: &result_size); |
313 | rpc_call_destroy(call); |
314 | |
315 | if (!respm || !resp) |
316 | return result_code; // no response expected |
317 | |
318 | if (result_code != RPC_RESULT_OK) |
319 | return result_code; |
320 | |
321 | pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) result, msglen: result_size); |
322 | if (!pb_decode(stream: &stream, fields: respm, dest_struct: resp)) |
323 | return RPC_RESULT_CLIENT_READ_FAILED; |
324 | |
325 | free(result); |
326 | return RPC_RESULT_OK; |
327 | } |
328 | |