使用rust实现的websocket推送服务关键片段,可配置证书
创建:
2024-06-11 17:28
更新:
2024-06-25 09:34
访问:
437
主词:
ws websocket rust 长链接 推送服务 axum tls
描述:
使用rust实现的websocket推送服务,库有axum,tokio,rustls,tracing。可以设置最多链接个数,可以设置tls连接证书进行安全证书验证,axum设置tls时优雅退出等。
配置文件
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct WsConfig {
pub port: u16, // 端口号
pub tls: Option<TlsConf>, // tls配置
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TlsConf {
pub cert_file: String, // 证书路径
pub key_file: String, // 私钥路径
pub ca_file: Option<String>, // ca证书-未使用
}
服务定义结构体
#[derive(Debug)]
pub struct WsServer<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>
where
FnOnNew: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnTextMsg: Fn(&Arc<Conn>, &str) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnBinMsg: Fn(&Arc<Conn>, &[u8]) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnClose: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
pub cfg: WsConfig, // webSocket配置
pub max_conn: i32, // 实例最多认证完成的连接数-是否断开连接由上层确定
pub conns: Mutex<HashMap<String, Weak<Conn>>>, // 已经连接的连接-退出时关闭所有连接即可-持有者是接收协程
pub on_new: FnOnNew, // 新客户
pub on_text_msg: FnOnTextMsg, // 新消息
pub on_bin_msg: FnOnBinMsg, // 新消息
pub on_close: FnOnClose, // 客户关闭离开
}
关闭所有连接-准备退出
#[allow(dead_code)]
pub async fn close_all(&self) {
let mut conns = self.conns.lock().await;
for key in conns.keys() { // 逐个关闭连接
if let Some(conn) = conns.get(key).unwrap().upgrade() {
if let Some(inner_conn) = conn.conn.upgrade() {
match inner_conn.lock().await.close().await {
Ok(_) => continue,
Err(err) => event!(Level::ERROR, "关闭连接失败 err:{}", err),
}
}
}
}
conns.clear();
}
开启服务
// 开始ws服务并等待
#[allow(dead_code)]
pub async fn ws_begin<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>(svc: Arc<WsServer<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>>, end_signal: impl Future<Output = ()> + Send + 'static) -> Result<(), WsErr>
where
FnOnNew: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnTextMsg: Fn(&Arc<Conn>, &str) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnBinMsg: Fn(&Arc<Conn>, &[u8]) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnClose: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
// 测试
async fn echo_handler() -> String {
return "pong".to_string();
}
let echo = Router::new().route("/ping", get(echo_handler)); // 连通性测试
// 开启ws服务
let ws = Router::new().route("/ws", get(ws_handle::<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>));
let all_router = echo
.merge(ws)
.layer(Extension(svc.clone()))
.layer(axum_util::log::MyMiddlewareLayer) // 自定义日志组件
.layer(axum_util::cors::get()) // cors跨域
.layer(axum_util::trace::layer::TraceMidLayer {}) // 日志组件2
.layer(axum_util::compress::get()) // 压缩
.layer(tower_http::catch_panic::CatchPanicLayer::custom(handle_panic)) // panic处理
.into_make_service_with_connect_info::<SocketAddr>(); // 远程地址
获取
match svc.cfg.tls.clone() {
None => {
// 没有tls配置的情况,即不使用wss安全连接
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", svc.cfg.port)).await?;
axum::serve(listener, all_router).with_graceful_shutdown(end_signal).await?;
return Ok::<(), WsErr>(());
}
Some(tls) => {
// 配置了安全证书的情况下
let handle = axum_server::Handle::new();
let handle_for_end = handle.clone();
tokio::spawn(async move{
end_signal.await;
handle_for_end.graceful_shutdown(Some(Duration::from_secs(7))) // 7秒的时间优雅退出
});
// configure certificate and private key used by https
let config = RustlsConfig::from_pem_file(tls.cert_file, tls.key_file).await?; // 读取证书文件内容
let addr = SocketAddr::from(([0, 0, 0, 0], svc.cfg.port)); // 设置端口号
axum_server::bind_rustls(addr, config).handle(handle).serve(all_router).await?; // 开启服务
return Ok::<(), WsErr>(());
}
}
}
服务新连接
#[instrument(skip_all)]
async fn ws_handle<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>(ws: WebSocketUpgrade, svc: Extension<Arc<WsServer<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>>>) -> impl IntoResponse
where
FnOnNew: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnTextMsg: Fn(&Arc<Conn>, &str) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnBinMsg: Fn(&Arc<Conn>, &[u8]) -> BoxFuture<'static, ()> + Send + Sync + 'static,
FnOnClose: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
let svc = svc.clone();
let reply = ws.on_upgrade(move |websocket| {
// 处理连接
let svc = svc.clone();
let (tx, mut rx) = websocket.split();
let tx = Arc::new(Mutex::new(tx));
// traceIdSpan 设置日志追踪信息 包括应用名称与git提交号
let trace_id = nanoid();
let conn_span = span!(tracing::Level::ERROR, "ws_handle", trace = trace_id, app = get_app_name(), version = get_app_version(), host = host_name());
// 接收 维持 推送 验证token并解码
Box::pin(async move {
// 处理连接
event!(tracing::Level::INFO, "新连接开始");
if svc.conns.lock().await.len() >= svc.max_conn as usize {
event!(tracing::Level::ERROR, "超出最大用户数");
return ();
}
// 处理连接
let session = crate::idutil::nanoid();
let conn: Arc<Conn> = Arc::new(Conn {
session,
conn_at: Local::now(),
conn: Arc::downgrade(&tx),
});
svc.conns.lock().await.insert(conn.session.clone(), Arc::downgrade(&conn));
(svc.on_new)(&conn).await;
loop {
let raw_msg = rx.next().await;
// 是不是已经没有消息了
if raw_msg.is_none() {
event!(tracing::Level::INFO, "取下个信息为None,正常退出接收循环");
break;
}
let raw_msg = raw_msg.unwrap();
if let Err(err) = raw_msg {
event!(tracing::Level::ERROR, "接收ws消息错误 {}", err);
break;
}
let raw_msg = raw_msg.unwrap();
event!(tracing::Level::DEBUG, "获取到新消息 {:?}", raw_msg);
match raw_msg {
Message::Text(msg) => {
event!(tracing::Level::DEBUG, "获取到text消息 {:?}", msg);
(svc.on_text_msg)(&conn, &*msg).await;
continue;
}
Message::Binary(bin) => {
event!(tracing::Level::DEBUG, "获取到binary消息 {:?}", bin);
(svc.on_bin_msg)(&conn, &*bin).await;
continue;
}
Message::Ping(_) => {
use std::borrow::BorrowMut;
match tx.lock().await.borrow_mut().send(Message::Pong(Vec::from("pong"))).await {
Ok(_) => continue,
Err(err) => {
event!(tracing::Level::ERROR, "发送消息错误 err:{:?}", err);
break;
}
}
}
Message::Pong(_) => {}
Message::Close(close) => {
event!(tracing::Level::INFO, "连接已经关闭 {:?}", close);
break;
}
}
}
(svc.on_close)(&conn).await;
svc.conns.lock().await.remove(&*conn.session);
})
.instrument(conn_span)
});
Ok::<http::Response<axum::body::Body>, WsErr>(reply)
}
本篇为原创内容,未经允许,不得转载