欢迎光临.

使用rust实现的websocket推送服务关键片段,可配置证书

创建:
2024-06-11 17:28
更新:
2024-06-25 09:34
访问:
294
主词:
ws websocket rust 长链接 推送服务 axum tls
描述:
使用rust实现的websocket推送服务,库有axum,tokio,rustls,tracing。可以设置最多链接个数,可以设置tls连接证书进行安全证书验证,axum设置tls时优雅退出等。

配置文件


use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct WsConfig {
    pub port: u16,            // 端口号
    pub tls: Option<TlsConf>, // tls配置
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TlsConf {
    pub cert_file: String,       // 证书路径 
    pub key_file: String,        // 私钥路径
    pub ca_file: Option<String>, // ca证书-未使用
}

服务定义结构体

#[derive(Debug)]
pub struct WsServer<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>
where
    FnOnNew: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnTextMsg: Fn(&Arc<Conn>, &str) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnBinMsg: Fn(&Arc<Conn>, &[u8]) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnClose: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
    pub cfg: WsConfig,                             // webSocket配置
    pub max_conn: i32,                             // 实例最多认证完成的连接数-是否断开连接由上层确定
    pub conns: Mutex<HashMap<String, Weak<Conn>>>, // 已经连接的连接-退出时关闭所有连接即可-持有者是接收协程
    pub on_new: FnOnNew,                           // 新客户
    pub on_text_msg: FnOnTextMsg,                  // 新消息
    pub on_bin_msg: FnOnBinMsg,                    // 新消息
    pub on_close: FnOnClose,                       // 客户关闭离开
}

关闭所有连接-准备退出

    #[allow(dead_code)]
    pub async fn close_all(&self) {
        let mut conns = self.conns.lock().await;
        for key in conns.keys() { // 逐个关闭连接
            if let Some(conn) = conns.get(key).unwrap().upgrade() {
                if let Some(inner_conn) = conn.conn.upgrade() {
                    match inner_conn.lock().await.close().await {
                        Ok(_) => continue,
                        Err(err) => event!(Level::ERROR, "关闭连接失败 err:{}", err),
                    }
                }
            }
        }
        conns.clear();
    }

开启服务

// 开始ws服务并等待
#[allow(dead_code)]
pub async fn ws_begin<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>(svc: Arc<WsServer<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>>, end_signal: impl Future<Output = ()> + Send + 'static) -> Result<(), WsErr>
where
    FnOnNew: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnTextMsg: Fn(&Arc<Conn>, &str) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnBinMsg: Fn(&Arc<Conn>, &[u8]) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnClose: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
    // 测试
    async fn echo_handler() -> String {
        return "pong".to_string();
    }
    let echo = Router::new().route("/ping", get(echo_handler)); // 连通性测试


    // 开启ws服务
    let ws = Router::new().route("/ws", get(ws_handle::<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>));
    let all_router = echo
        .merge(ws)
        .layer(Extension(svc.clone()))
        .layer(axum_util::log::MyMiddlewareLayer)                              // 自定义日志组件
        .layer(axum_util::cors::get())                                         // cors跨域
        .layer(axum_util::trace::layer::TraceMidLayer {})                      // 日志组件2
        .layer(axum_util::compress::get())                                     // 压缩
        .layer(tower_http::catch_panic::CatchPanicLayer::custom(handle_panic)) // panic处理
        .into_make_service_with_connect_info::<SocketAddr>();                  // 远程地址
获取

    match svc.cfg.tls.clone() {
        None => {
            // 没有tls配置的情况,即不使用wss安全连接
            let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", svc.cfg.port)).await?;
            axum::serve(listener, all_router).with_graceful_shutdown(end_signal).await?;
            return Ok::<(), WsErr>(());
        }
        Some(tls) => {
            // 配置了安全证书的情况下
            let handle = axum_server::Handle::new();
            let handle_for_end = handle.clone();
            tokio::spawn(async move{
                end_signal.await;
                handle_for_end.graceful_shutdown(Some(Duration::from_secs(7))) // 7秒的时间优雅退出
            });
            // configure certificate and private key used by https
            let config = RustlsConfig::from_pem_file(tls.cert_file, tls.key_file).await?; // 读取证书文件内容
            let addr = SocketAddr::from(([0, 0, 0, 0], svc.cfg.port));                    // 设置端口号
            axum_server::bind_rustls(addr, config).handle(handle).serve(all_router).await?; // 开启服务
            return Ok::<(), WsErr>(());
        }
    }
}

服务新连接


#[instrument(skip_all)]
async fn ws_handle<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>(ws: WebSocketUpgrade, svc: Extension<Arc<WsServer<FnOnNew, FnOnTextMsg, FnOnBinMsg, FnOnClose>>>) -> impl IntoResponse
where
    FnOnNew: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnTextMsg: Fn(&Arc<Conn>, &str) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnBinMsg: Fn(&Arc<Conn>, &[u8]) -> BoxFuture<'static, ()> + Send + Sync + 'static,
    FnOnClose: Fn(&Arc<Conn>) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
    let svc = svc.clone();
    let reply = ws.on_upgrade(move |websocket| {
        // 处理连接
        let svc = svc.clone();
        let (tx, mut rx) = websocket.split();
        let tx = Arc::new(Mutex::new(tx));
        // traceIdSpan 设置日志追踪信息 包括应用名称与git提交号
        let trace_id = nanoid();
        let conn_span = span!(tracing::Level::ERROR, "ws_handle", trace = trace_id, app = get_app_name(), version = get_app_version(), host = host_name());
        // 接收 维持 推送 验证token并解码
        Box::pin(async move {
            // 处理连接
            event!(tracing::Level::INFO, "新连接开始");
            if svc.conns.lock().await.len() >= svc.max_conn as usize {
                event!(tracing::Level::ERROR, "超出最大用户数");
                return ();
            }
            // 处理连接
            let session = crate::idutil::nanoid();
            let conn: Arc<Conn> = Arc::new(Conn {
                session,
                conn_at: Local::now(),
                conn: Arc::downgrade(&tx),
            });
            svc.conns.lock().await.insert(conn.session.clone(), Arc::downgrade(&conn));
            (svc.on_new)(&conn).await;
            loop {
                let raw_msg = rx.next().await;
                // 是不是已经没有消息了
                if raw_msg.is_none() {
                    event!(tracing::Level::INFO, "取下个信息为None,正常退出接收循环");
                    break;
                }
                let raw_msg = raw_msg.unwrap();
                if let Err(err) = raw_msg {
                    event!(tracing::Level::ERROR, "接收ws消息错误 {}", err);
                    break;
                }
                let raw_msg = raw_msg.unwrap();
                event!(tracing::Level::DEBUG, "获取到新消息 {:?}", raw_msg);

                match raw_msg {
                    Message::Text(msg) => {
                        event!(tracing::Level::DEBUG, "获取到text消息 {:?}", msg);
                        (svc.on_text_msg)(&conn, &*msg).await;
                        continue;
                    }
                    Message::Binary(bin) => {
                        event!(tracing::Level::DEBUG, "获取到binary消息 {:?}", bin);
                        (svc.on_bin_msg)(&conn, &*bin).await;
                        continue;
                    }
                    Message::Ping(_) => {
                        use std::borrow::BorrowMut;
                        match tx.lock().await.borrow_mut().send(Message::Pong(Vec::from("pong"))).await {
                            Ok(_) => continue,
                            Err(err) => {
                                event!(tracing::Level::ERROR, "发送消息错误 err:{:?}", err);
                                break;
                            }
                        }
                    }
                    Message::Pong(_) => {}
                    Message::Close(close) => {
                        event!(tracing::Level::INFO, "连接已经关闭 {:?}", close);
                        break;
                    }
                }
            }
            (svc.on_close)(&conn).await;
            svc.conns.lock().await.remove(&*conn.session);
        })
        .instrument(conn_span)
    });
    Ok::<http::Response<axum::body::Body>, WsErr>(reply)
}

 

本篇为原创内容,未经允许,不得转载

繁星树影 @2024
皖ICP备20003857号-2
皖公网安备34132202000234号
14553