pxt-common-packages/libs/wifi---esp32/socket.cpp

291 строка
7.3 KiB
C++

#include "wifi.h"
#include "esp_tls.h"
#include "esp_crt_bundle.h"
#define TAG "ssl"
#define LOG(...) ESP_LOGI(TAG, __VA_ARGS__)
#define MAX_SOCKET 16
namespace _wifi {
static esp_tls_cfg_t tls_cfg;
struct socket_t {
esp_tls_t *ssl;
int bytesAvailable;
struct read_args *readers;
};
static worker_t worker;
static socket_t *sockets[MAX_SOCKET];
static void process_reader(struct read_args *args);
static void update_bytes_avail(socket_t *sock) {
if (!sock->ssl)
return;
int ret = esp_tls_conn_read(sock->ssl, NULL, 0);
int bytes = esp_tls_get_bytes_avail(sock->ssl);
int emitEvent = 0;
if (bytes == 0 && ret != 0 && ret != ESP_TLS_ERR_SSL_WANT_READ) {
esp_tls_conn_destroy(sock->ssl);
sock->ssl = NULL;
emitEvent = 1;
}
// DMESG("updbyt: %d r=%d", bytes, ret);
if (bytes > 0) {
// if we went from 0 to more, raise event
if (sock->bytesAvailable == 0)
emitEvent = 1;
sock->bytesAvailable = bytes;
} else
sock->bytesAvailable = 0;
if (emitEvent)
for (int i = 0; i < MAX_SOCKET; ++i) {
if (sockets[i] == sock) {
raiseEvent(eventID(), 1000 + i);
break;
}
}
}
static void flush_ssl(void *) {
for (int i = 0; i < MAX_SOCKET; ++i) {
auto s = sockets[i];
if (!s || !s->ssl)
continue;
update_bytes_avail(s);
while (s->bytesAvailable && s->readers) {
process_reader(s->readers);
}
}
}
static void socket_init() {
if (worker)
return;
worker = worker_alloc("ssl", 10 * 1024);
worker_set_idle(worker, flush_ssl, NULL);
tls_cfg.crt_bundle_attach = esp_crt_bundle_attach;
tls_cfg.non_block = true;
tls_cfg.timeout_ms = 30000;
}
/** Allocate new socket. */
//%
int socketAlloc() {
socket_init();
for (int i = 1; i < MAX_SOCKET; ++i) {
if (!sockets[i]) {
sockets[i] = new socket_t;
memset(sockets[i], 0, sizeof(*sockets[i]));
return i;
}
}
return -1;
}
#define GET_SOCK() \
if (fd <= 0 || fd >= MAX_SOCKET) \
return -10; \
auto sock = sockets[fd]; \
if (!sock) \
return -11;
#define GET_SOCK_SSL() \
GET_SOCK(); \
if (!sock->ssl) \
return -12;
struct conn_args {
socket_t *sock;
FiberContext *ctx;
const char *host;
int port;
};
PXT_DEF_STRING(sOOM, "ssl: Out of memory")
PXT_DEF_STRING(sHandshake, "ssl: Handshake failed")
PXT_DEF_STRING(sError, "ssl: error")
static void check_error(socket_t *sock, int r) {
if (r < 0) {
int err_code, flags;
esp_err_t err =
esp_tls_get_and_clear_last_error(sock->ssl->error_handle, &err_code, &flags);
DMESG("ESP TLS error: err=%x (%d/%d) res=%d", err, err_code, flags, r);
vm_stack_trace();
#if 0
// can't really throw from here
if (err == ESP_ERR_MBEDTLS_SSL_SETUP_FAILED) {
pxt::throwValue((TValue)sOOM);
} else if (err == ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED) {
pxt::throwValue((TValue)sHandshake);
} else {
pxt::throwValue((TValue)sError);
}
#endif
}
}
static void worker_conn(conn_args *args) {
int r = esp_tls_conn_new_sync(args->host, strlen(args->host), args->port, &tls_cfg,
args->sock->ssl);
check_error(args->sock, r);
resumeFiber(args->ctx, fromInt(r));
delete args;
}
/** Connect with TLS */
//% promise
int socketConnectTLS(int fd, String host, int port) {
memInfo();
GET_SOCK();
if (sock->ssl)
return -2;
if (port <= 0 || port > 0xffff)
return -3;
sock->ssl = esp_tls_init();
auto args = new conn_args;
args->host = host->getUTF8Data();
args->port = port;
args->sock = sock;
args->ctx = suspendFiber();
worker_run(worker, (TaskFunction_t)worker_conn, args);
return 0; // ignored
}
struct write_args {
socket_t *sock;
FiberContext *ctx;
Buffer data;
};
static void worker_write(write_args *args) {
int r = esp_tls_conn_write(args->sock->ssl, args->data->data, args->data->length);
check_error(args->sock, r);
resumeFiber(args->ctx, fromInt(r));
delete args;
}
/** Write to socket */
//% promise
int socketWrite(int fd, Buffer data) {
GET_SOCK_SSL();
auto args = new write_args;
args->data = data;
args->sock = sock;
args->ctx = suspendFiber();
worker_run(worker, (TaskFunction_t)worker_write, args);
return 0; // ignored
}
struct read_args {
socket_t *sock;
FiberContext *ctx;
int size;
void *buf;
struct read_args *next;
};
static Buffer mk_read_buffer(read_args *args) {
auto res = mkBuffer(args->buf, args->size);
free(args->buf);
delete args;
return res;
}
static void process_reader(read_args *args) {
auto sock = args->sock;
sock->readers = args->next;
int num = args->size;
if (num > sock->bytesAvailable)
num = sock->bytesAvailable;
args->buf = malloc(num);
int r = esp_tls_conn_read(sock->ssl, args->buf, num);
if (r < 0) {
free(args->buf);
resumeFiber(args->ctx, fromInt(r));
delete args;
} else {
args->size = r;
resumeFiberWithFn(args->ctx, (fiber_resume_t)mk_read_buffer, args);
}
update_bytes_avail(sock);
}
static void worker_read(read_args *args) {
auto sock = args->sock;
auto rd = sock->readers;
while (rd && rd->next)
rd = rd->next;
if (rd) {
rd->next = args;
return; // there are other readers in front of us; don't do anything yet
}
sock->readers = args;
update_bytes_avail(sock);
if (sock->bytesAvailable)
process_reader(args);
}
/** Read from a socket; the return type is really number|Buffer */
//% promise
int socketRead(int fd, int size) {
GET_SOCK_SSL();
if (size < 0)
return -20;
if (size == 0)
return 0;
auto args = new read_args;
args->size = size;
args->sock = sock;
args->ctx = suspendFiber();
args->next = NULL;
worker_run(worker, (TaskFunction_t)worker_read, args);
return 0; // ignored
}
/** See how many bytes are available for reading */
//%
int socketBytesAvailable(int fd) {
GET_SOCK_SSL();
return sock->bytesAvailable;
}
static void worker_close(socket_t *sock) {
if (sock->ssl) {
esp_tls_conn_destroy(sock->ssl);
sock->ssl = NULL;
}
}
/** Close the socket if open */
//%
int socketClose(int fd) {
GET_SOCK();
sockets[fd] = NULL;
worker_run(worker, (TaskFunction_t)worker_close, sock);
// wait for the actual close - we only really have memory for one open SSL socket...
while (sock->ssl)
vTaskDelay(5);
free(sock);
return 0;
}
} // namespace _wifi