网络协议是网络通信的基础,Asio提供了构建各种网络协议的底层框架。本教程将详细介绍如何使用Asio实现常见的网络协议,包括HTTP、WebSocket、自定义TCP协议和UDP协议。

一、HTTP协议实现

HTTP是现代Web应用程序的基础协议。下面我们将学习如何使用Asio实现HTTP服务器和客户端。

1. 基础HTTP服务器

以下是一个简单的HTTP服务器实现,它能够处理基本的GET请求:

#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <asio.hpp>

using asio::ip::tcp;

// HTTP连接类
class HttpConnection : public std::enable_shared_from_this<HttpConnection> {
public:
    HttpConnection(tcp::socket socket)
        : socket_(std::move(socket)) {
    }
    
    // 开始处理连接
    void start() {
        read_request();
    }
    
private:
    // 读取HTTP请求
    void read_request() {
        auto self = shared_from_this();
        asio::async_read_until(socket_, buffer_, "\r\n\r\n",
            [self](std::error_code ec, std::size_t bytes_transferred) {
                if (!ec) {
                    std::istream request_stream(&self->buffer_);
                    std::string request_line;
                    std::getline(request_stream, request_line);
                    
                    // 解析请求行
                    std::string method, path, version;
                    std::istringstream request_line_stream(request_line);
                    request_line_stream >> method >> path >> version;
                    
                    // 读取请求头
                    std::string header;
                    while (std::getline(request_stream, header) && header != "\r") {
                        // 可以在这里解析请求头
                    }
                    
                    // 处理请求
                    self->handle_request(method, path);
                }
            });
    }
    
    // 处理HTTP请求
    void handle_request(const std::string& method, const std::string& path) {
        // 构建HTTP响应
        std::string response;
        
        if (method == "GET") {
            if (path == "/" || path == "/index.html") {
                response = "HTTP/1.1 200 OK\r\n";
                response += "Content-Type: text/html\r\n";
                response += "Content-Length: 45\r\n";
                response += "\r\n";
                response += "<html><body><h1>Hello, World!</h1></body></html>";
            } else if (path == "/about") {
                response = "HTTP/1.1 200 OK\r\n";
                response += "Content-Type: text/html\r\n";
                response += "Content-Length: 52\r\n";
                response += "\r\n";
                response += "<html><body><h1>About Page</h1></body></html>";
            } else {
                response = "HTTP/1.1 404 Not Found\r\n";
                response += "Content-Type: text/html\r\n";
                response += "Content-Length: 51\r\n";
                response += "\r\n";
                response += "<html><body><h1>404 Page Not Found</h1></body></html>";
            }
        } else {
            response = "HTTP/1.1 405 Method Not Allowed\r\n";
            response += "Content-Type: text/html\r\n";
            response += "Content-Length: 59\r\n";
            response += "\r\n";
            response += "<html><body><h1>405 Method Not Allowed</h1></body></html>";
        }
        
        // 发送HTTP响应
        send_response(response);
    }
    
    // 发送HTTP响应
    void send_response(const std::string& response) {
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(response),
            [self](std::error_code ec, std::size_t) {
                if (!ec) {
                    // 响应发送完成后关闭连接
                    asio::error_code ignored_ec;
                    self->socket_.shutdown(tcp::socket::shutdown_both, ignored_ec);
                }
            });
    }
    
    tcp::socket socket_;
    asio::streambuf buffer_;
};

// HTTP服务器类
class HttpServer {
public:
    HttpServer(asio::io_context& io_context, short port)
        : acceptor_(io_context, tcp::endpoint(tcp::v4(), port)) {
        do_accept();
    }
    
private:
    // 接受新连接
    void do_accept() {
        acceptor_.async_accept(
            [this](std::error_code ec, tcp::socket socket) {
                if (!ec) {
                    std::make_shared<HttpConnection>(std::move(socket))->start();
                }
                do_accept();
            });
    }
    
    tcp::acceptor acceptor_;
};

int main(int argc, char* argv[]) {
    try {
        if (argc != 2) {
            std::cerr << "Usage: http_server <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        HttpServer server(io_context, std::atoi(argv[1]));
        
        std::cout << "HTTP server running on port " << argv[1] << std::endl;
        io_context.run();
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

三、自定义TCP协议实现

除了标准协议外,Asio还非常适合实现自定义网络协议。下面我们将学习如何设计和实现一个简单的自定义TCP协议。

1. 自定义协议设计

设计一个名为"SimpleMessageProtocol"的自定义协议,具有以下特点:

  • 每个消息都有固定的头部,包含消息长度和消息类型
  • 支持文本和二进制两种消息类型
  • 使用小端字节序进行数据传输

协议格式如下:

+-------------+-------------+------------------+
| 消息长度(4字节) | 消息类型(1字节) |      消息体      |
+-------------+-------------+------------------+

其中:

  • 消息长度:不包括自身,仅表示消息类型和消息体的总长度
  • 消息类型:0表示文本消息,1表示二进制消息
  • 消息体:根据消息类型包含文本或二进制数据

2. 自定义协议消息类

下面是实现自定义协议消息的类:

#include <iostream>
#include <string>
#include <vector>
#include <cstdint>
#include <algorithm>
#include <asio.hpp>

// 自定义协议消息类
class SimpleMessage {
public:
    enum MessageType {
        TEXT = 0,
        BINARY = 1
    };
    
    // 构造文本消息
    SimpleMessage(const std::string& text) 
        : type_(TEXT) {
        body_.insert(body_.end(), text.begin(), text.end());
    }
    
    // 构造二进制消息
    SimpleMessage(const std::vector<uint8_t>& binary_data) 
        : type_(BINARY), body_(binary_data) {
    }
    
    // 从缓冲区解析消息(静态工厂方法)
    static std::optional<SimpleMessage> from_buffer(const std::vector<uint8_t>& buffer) {
        // 检查缓冲区是否至少包含头部
        if (buffer.size() < kHeaderSize) {
            return std::nullopt;
        }
        
        // 解析消息长度(小端字节序)
        uint32_t body_size = 0;
        for (int i = 0; i < 4; ++i) {
            body_size |= static_cast<uint32_t>(buffer[i]) << (i * 8);
        }
        
        // 检查完整消息是否可用
        if (buffer.size() < kHeaderSize + body_size) {
            return std::nullopt;
        }
        
        // 解析消息类型
        MessageType type = static_cast<MessageType>(buffer[4]);
        
        // 解析消息体
        std::vector<uint8_t> body(buffer.begin() + kHeaderSize, buffer.begin() + kHeaderSize + body_size);
        
        if (type == TEXT) {
            std::string text(body.begin(), body.end());
            return SimpleMessage(text);
        } else {
            return SimpleMessage(body);
        }
    }
    
    // 序列化为字节流
    std::vector<uint8_t> serialize() const {
        std::vector<uint8_t> serialized;
        
        // 计算消息总长度(不包括长度字段本身)
        uint32_t total_size = 1 + body_.size(); // 1字节类型 + 消息体长度
        
        // 添加长度字段(小端字节序)
        for (int i = 0; i < 4; ++i) {
            serialized.push_back(static_cast<uint8_t>((total_size >> (i * 8)) & 0xFF));
        }
        
        // 添加消息类型
        serialized.push_back(static_cast<uint8_t>(type_));
        
        // 添加消息体
        serialized.insert(serialized.end(), body_.begin(), body_.end());
        
        return serialized;
    }
    
    // 获取消息类型
    MessageType get_type() const {
        return type_;
    }
    
    // 获取文本内容(如果是文本消息)
    std::string get_text() const {
        if (type_ == TEXT) {
            return std::string(body_.begin(), body_.end());
        }
        return "";
    }
    
    // 获取二进制内容(如果是二进制消息)
    std::vector<uint8_t> get_binary() const {
        if (type_ == BINARY) {
            return body_;
        }
        return {};
    }
    
    // 获取消息总长度(序列化后的长度)
    size_t get_serialized_size() const {
        return kHeaderSize + body_.size();
    }
    
    // 获取头部大小
    static constexpr size_t kHeaderSize = 5; // 4字节长度 + 1字节类型
    
private:
    // 私有构造函数,用于from_buffer方法
    SimpleMessage(MessageType type, const std::vector<uint8_t>& body)
        : type_(type), body_(body) {
    }
    
    MessageType type_;
    std::vector<uint8_t> body_;
};

3. 自定义协议服务器

下面是实现自定义协议服务器的代码:

#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <cstdint>
#include <asio.hpp>

// SimpleMessage类的定义(如前所述)
// ...

// 连接管理类
class Connection : public std::enable_shared_from_this<Connection> {
public:
    Connection(asio::io_context& io_context)
        : socket_(io_context) {
    }
    
    // 获取socket引用
    asio::ip::tcp::socket& socket() {
        return socket_;
    }
    
    // 开始处理连接
    void start() {
        read_header();
    }
    
    // 发送消息
    void send_message(const SimpleMessage& message) {
        auto self = shared_from_this();
        
        // 序列化消息
        std::vector<uint8_t> serialized = message.serialize();
        
        // 将消息添加到发送队列
        bool write_in_progress = !write_queue_.empty();
        write_queue_.push_back(serialized);
        
        // 如果当前没有写入操作正在进行,则开始写入
        if (!write_in_progress) {
            do_write();
        }
    }
    
    // 设置消息接收回调
    void set_message_handler(std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> handler) {
        message_handler_ = handler;
    }
    
    // 设置连接关闭回调
    void set_close_handler(std::function<void(std::shared_ptr<Connection>)> handler) {
        close_handler_ = handler;
    }
    
    // 关闭连接
    void close() {
        asio::error_code ec;
        socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ec);
        socket_.close(ec);
    }
    
private:
    // 读取消息头部
    void read_header() {
        auto self = shared_from_this();
        asio::async_read(socket_, asio::buffer(read_buffer_, SimpleMessage::kHeaderSize),
            [self](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    // 尝试从头部解析消息长度
                    uint32_t body_size = 0;
                    for (int i = 0; i < 4; ++i) {
                        body_size |= static_cast<uint32_t>(self->read_buffer_[i]) << (i * 8);
                    }
                    
                    // 调整读取缓冲区大小以容纳完整消息
                    self->read_buffer_.resize(SimpleMessage::kHeaderSize + body_size);
                    
                    // 读取消息体
                    self->read_body(body_size);
                } else {
                    // 发生错误,关闭连接
                    if (self->close_handler_) {
                        self->close_handler_(self);
                    }
                }
            });
    }
    
    // 读取消息体
    void read_body(uint32_t body_size) {
        auto self = shared_from_this();
        asio::async_read(socket_, 
            asio::buffer(read_buffer_.data() + SimpleMessage::kHeaderSize, body_size),
            [self](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    // 尝试解析完整消息
                    auto message = SimpleMessage::from_buffer(self->read_buffer_);
                    if (message) {
                        // 调用消息处理回调
                        if (self->message_handler_) {
                            self->message_handler_(self, *message);
                        }
                    }
                    
                    // 准备读取下一条消息
                    self->read_buffer_.resize(SimpleMessage::kHeaderSize);
                    self->read_header();
                } else {
                    // 发生错误,关闭连接
                    if (self->close_handler_) {
                        self->close_handler_(self);
                    }
                }
            });
    }
    
    // 写入消息
    void do_write() {
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(write_queue_.front()),
            [self](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    // 从队列中移除已发送的消息
                    self->write_queue_.pop_front();
                    
                    // 如果还有消息要发送,则继续
                    if (!self->write_queue_.empty()) {
                        self->do_write();
                    }
                } else {
                    // 发生错误,关闭连接
                    if (self->close_handler_) {
                        self->close_handler_(self);
                    }
                }
            });
    }
    
    asio::ip::tcp::socket socket_;
    std::vector<uint8_t> read_buffer_ = std::vector<uint8_t>(SimpleMessage::kHeaderSize);
    std::deque<std::vector<uint8_t>> write_queue_;
    
    std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> message_handler_;
    std::function<void(std::shared_ptr<Connection>)> close_handler_;
};

// 自定义协议服务器类
class SimpleMessageServer {
public:
    SimpleMessageServer(asio::io_context& io_context, short port)
        : io_context_(io_context),
          acceptor_(io_context, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), port)) {
        do_accept();
    }
    
    // 设置消息处理回调
    void set_message_handler(std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> handler) {
        message_handler_ = handler;
    }
    
    // 设置连接处理回调
    void set_connection_handler(std::function<void(std::shared_ptr<Connection>)> handler) {
        connection_handler_ = handler;
    }
    
    // 设置断开连接处理回调
    void set_disconnection_handler(std::function<void(std::shared_ptr<Connection>)> handler) {
        disconnection_handler_ = handler;
    }
    
private:
    // 接受新连接
    void do_accept() {
        auto connection = std::make_shared<Connection>(io_context_);
        
        acceptor_.async_accept(connection->socket(),
            [this, connection](std::error_code ec) {
                if (!ec) {
                    // 设置连接的回调函数
                    connection->set_message_handler(message_handler_);
                    connection->set_close_handler([this](std::shared_ptr<Connection> conn) {
                        if (disconnection_handler_) {
                            disconnection_handler_(conn);
                        }
                        // 从连接列表中移除
                        connections_.erase(std::remove(connections_.begin(), connections_.end(), conn),
                                         connections_.end());
                    });
                    
                    // 将连接添加到列表
                    connections_.push_back(connection);
                    
                    // 调用连接回调
                    if (connection_handler_) {
                        connection_handler_(connection);
                    }
                    
                    // 开始处理连接
                    connection->start();
                }
                
                // 继续接受下一个连接
                do_accept();
            });
    }
    
    asio::io_context& io_context_;
    asio::ip::tcp::acceptor acceptor_;
    std::vector<std::shared_ptr<Connection>> connections_;
    
    std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> message_handler_;
    std::function<void(std::shared_ptr<Connection>)> connection_handler_;
    std::function<void(std::shared_ptr<Connection>)> disconnection_handler_;
};

// 服务器示例
int main(int argc, char* argv[]) {
    try {
        if (argc != 2) {
            std::cerr << "Usage: simple_message_server <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        SimpleMessageServer server(io_context, std::atoi(argv[1]));
        
        // 设置消息处理回调
        server.set_message_handler([](std::shared_ptr<Connection> connection, const SimpleMessage& message) {
            std::cout << "Received message of type: " << (message.get_type() == SimpleMessage::TEXT ? "TEXT" : "BINARY") << std::endl;
            
            if (message.get_type() == SimpleMessage::TEXT) {
                std::cout << "Message content: " << message.get_text() << std::endl;
                
                // 回显文本消息
                std::string response_text = "Server received: " + message.get_text();
                connection->send_message(SimpleMessage(response_text));
            } else {
                std::cout << "Binary message size: " << message.get_binary().size() << " bytes" << std::endl;
                
                // 回显二进制消息
                std::vector<uint8_t> response_data = message.get_binary();
                connection->send_message(SimpleMessage(response_data));
            }
        });
        
        // 设置连接处理回调
        server.set_connection_handler([](std::shared_ptr<Connection> connection) {
            std::cout << "New connection established" << std::endl;
            
            // 发送欢迎消息
            connection->send_message(SimpleMessage("Welcome to Simple Message Protocol Server!"));
        });
        
        // 设置断开连接处理回调
        server.set_disconnection_handler([](std::shared_ptr<Connection>) {
            std::cout << "Connection closed" << std::endl;
        });
        
        std::cout << "Simple Message Protocol Server running on port " << argv[1] << std::endl;
        io_context.run();
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

4. 自定义协议客户端

下面是实现自定义协议客户端的代码:

#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <cstdint>
#include <asio.hpp>

// SimpleMessage类的定义(如前所述)
// ...

// 自定义协议客户端类
class SimpleMessageClient : public std::enable_shared_from_this<SimpleMessageClient> {
public:
    SimpleMessageClient(asio::io_context& io_context)
        : io_context_(io_context),
          socket_(io_context),
          resolver_(io_context) {
    }
    
    // 连接到服务器
    void connect(const std::string& host, const std::string& port) {
        auto self = shared_from_this();
        resolver_.async_resolve(host, port,
            [self](std::error_code ec, asio::ip::tcp::resolver::results_type results) {
                if (!ec) {
                    asio::async_connect(self->socket_, results,
                        [self](std::error_code ec, asio::ip::tcp::endpoint) {
                            if (!ec) {
                                std::cout << "Connected to server" << std::endl;
                                
                                // 连接成功,调用回调
                                if (self->connection_handler_) {
                                    self->connection_handler_();
                                }
                                
                                // 开始接收消息
                                self->read_header();
                            } else {
                                std::cerr << "Connection failed: " << ec.message() << std::endl;
                                if (self->error_handler_) {
                                    self->error_handler_(ec);
                                }
                            }
                        });
                } else {
                    std::cerr << "Resolution failed: " << ec.message() << std::endl;
                    if (self->error_handler_) {
                        self->error_handler_(ec);
                    }
                }
            });
    }
    
    // 发送文本消息
    void send_text(const std::string& text) {
        auto self = shared_from_this();
        io_context_.post([self, text]() {
            SimpleMessage message(text);
            self->send_message(message);
        });
    }
    
    // 发送二进制消息
    void send_binary(const std::vector<uint8_t>& binary_data) {
        auto self = shared_from_this();
        io_context_.post([self, binary_data]() {
            SimpleMessage message(binary_data);
            self->send_message(message);
        });
    }
    
    // 关闭连接
    void close() {
        auto self = shared_from_this();
        io_context_.post([self]() {
            asio::error_code ec;
            self->socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ec);
            self->socket_.close(ec);
            
            if (self->disconnection_handler_) {
                self->disconnection_handler_();
            }
        });
    }
    
    // 设置消息处理回调
    void set_message_handler(std::function<void(const SimpleMessage&)> handler) {
        message_handler_ = handler;
    }
    
    // 设置连接处理回调
    void set_connection_handler(std::function<void()> handler) {
        connection_handler_ = handler;
    }
    
    // 设置断开连接处理回调
    void set_disconnection_handler(std::function<void()> handler) {
        disconnection_handler_ = handler;
    }
    
    // 设置错误处理回调
    void set_error_handler(std::function<void(const std::error_code&)> handler) {
        error_handler_ = handler;
    }
    
private:
    // 发送消息(内部方法)
    void send_message(const SimpleMessage& message) {
        bool write_in_progress = !write_queue_.empty();
        write_queue_.push_back(message.serialize());
        
        if (!write_in_progress) {
            do_write();
        }
    }
    
    // 写入消息
    void do_write() {
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(write_queue_.front()),
            [self](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    self->write_queue_.pop_front();
                    
                    if (!self->write_queue_.empty()) {
                        self->do_write();
                    }
                } else {
                    std::cerr << "Write error: " << ec.message() << std::endl;
                    self->close();
                    
                    if (self->error_handler_) {
                        self->error_handler_(ec);
                    }
                }
            });
    }
    
    // 读取消息头部
    void read_header() {
        auto self = shared_from_this();
        asio::async_read(socket_, asio::buffer(read_buffer_, SimpleMessage::kHeaderSize),
            [self](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    // 解析消息长度
                    uint32_t body_size = 0;
                    for (int i = 0; i < 4; ++i) {
                        body_size |= static_cast<uint32_t>(self->read_buffer_[i]) << (i * 8);
                    }
                    
                    // 调整缓冲区大小
                    self->read_buffer_.resize(SimpleMessage::kHeaderSize + body_size);
                    
                    // 读取消息体
                    self->read_body(body_size);
                } else {
                    std::cerr << "Read header error: " << ec.message() << std::endl;
                    self->close();
                    
                    if (self->error_handler_) {
                        self->error_handler_(ec);
                    }
                }
            });
    }
    
    // 读取消息体
    void read_body(uint32_t body_size) {
        auto self = shared_from_this();
        asio::async_read(socket_, 
            asio::buffer(read_buffer_.data() + SimpleMessage::kHeaderSize, body_size),
            [self](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    // 解析消息
                    auto message = SimpleMessage::from_buffer(self->read_buffer_);
                    if (message && self->message_handler_) {
                        self->message_handler_(*message);
                    }
                    
                    // 准备读取下一条消息
                    self->read_buffer_.resize(SimpleMessage::kHeaderSize);
                    self->read_header();
                } else {
                    std::cerr << "Read body error: " << ec.message() << std::endl;
                    self->close();
                    
                    if (self->error_handler_) {
                        self->error_handler_(ec);
                    }
                }
            });
    }
    
    asio::io_context& io_context_;
    asio::ip::tcp::socket socket_;
    asio::ip::tcp::resolver resolver_;
    
    std::vector<uint8_t> read_buffer_ = std::vector<uint8_t>(SimpleMessage::kHeaderSize);
    std::deque<std::vector<uint8_t>> write_queue_;
    
    std::function<void(const SimpleMessage&)> message_handler_;
    std::function<void()> connection_handler_;
    std::function<void()> disconnection_handler_;
    std::function<void(const std::error_code&)> error_handler_;
};

// 客户端示例
int main() {
    try {
        asio::io_context io_context;
        auto client = std::make_shared<SimpleMessageClient>(io_context);
        
        // 设置回调函数
        client->set_message_handler([](const SimpleMessage& message) {
            std::cout << "Received message of type: " << (message.get_type() == SimpleMessage::TEXT ? "TEXT" : "BINARY") << std::endl;
            
            if (message.get_type() == SimpleMessage::TEXT) {
                std::cout << "Message content: " << message.get_text() << std::endl;
            } else {
                std::cout << "Binary message size: " << message.get_binary().size() << " bytes" << std::endl;
            }
        });
        
        client->set_connection_handler([]() {
            std::cout << "Connection established successfully" << std::endl;
        });
        
        client->set_disconnection_handler([]() {
            std::cout << "Disconnected from server" << std::endl;
        });
        
        client->set_error_handler([](const std::error_code& ec) {
            std::cerr << "Error: " << ec.message() << std::endl;
        });
        
        // 连接到服务器
        client->connect("localhost", "8080");
        
        // 创建一个工作对象以防止io_context在没有事件时退出
        asio::executor_work_guard<asio::io_context::executor_type> work = 
            asio::make_work_guard(io_context);
        
        // 在单独的线程中运行io_context
        std::thread io_thread([&io_context]() {
            io_context.run();
        });
        
        // 等待连接建立
        std::this_thread::sleep_for(std::chrono::seconds(1));
        
        // 发送文本消息
        client->send_text("Hello, Simple Message Protocol Server!");
        
        // 发送二进制消息
        std::vector<uint8_t> binary_data = {0x01, 0x02, 0x03, 0x04, 0x05};
        client->send_binary(binary_data);
        
        // 等待接收响应
        std::this_thread::sleep_for(std::chrono::seconds(5));
        
        // 关闭连接
        client->close();
        
        // 停止io_context
        io_context.stop();
        if (io_thread.joinable()) {
            io_thread.join();
        }
        
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

四、UDP协议实现

UDP是一种无连接的传输协议,适用于对实时性要求高但对可靠性要求相对较低的场景。下面我们将学习如何使用Asio实现UDP服务器和客户端。

1. 基础UDP回显服务器

以下是一个简单的UDP回显服务器实现:

#include <iostream>
#include <string>
#include <asio.hpp>

using asio::ip::udp;

class UDPServer {
public:
    UDPServer(asio::io_context& io_context, short port)
        : socket_(io_context, udp::endpoint(udp::v4(), port)) {
        do_receive();
    }
    
private:
    // 接收数据
    void do_receive() {
        socket_.async_receive_from(
            asio::buffer(data_, max_length), sender_endpoint_,
            [this](std::error_code ec, std::size_t bytes_recvd) {
                if (!ec && bytes_recvd > 0) {
                    std::cout << "Received " << bytes_recvd << " bytes from " 
                              << sender_endpoint_.address().to_string() << ":" 
                              << sender_endpoint_.port() << std::endl;
                    
                    // 回显收到的数据
                    do_send(bytes_recvd);
                }
                
                // 继续接收下一个数据包
                do_receive();
            });
    }
    
    // 发送数据
    void do_send(std::size_t length) {
        socket_.async_send_to(
            asio::buffer(data_, length), sender_endpoint_,
            [this](std::error_code ec, std::size_t /*bytes_sent*/) {
                if (ec) {
                    std::cerr << "Error sending response: " << ec.message() << std::endl;
                }
            });
    }
    
    udp::socket socket_;
    udp::endpoint sender_endpoint_;
    enum { max_length = 1024 };
    char data_[max_length];
};

int main(int argc, char* argv[]) {
    try {
        if (argc != 2) {
            std::cerr << "Usage: udp_server <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        UDPServer server(io_context, std::atoi(argv[1]));
        
        std::cout << "UDP echo server running on port " << argv[1] << std::endl;
        std::cout << "Press Ctrl+C to exit" << std::endl;
        
        io_context.run();
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

2. 基础UDP客户端

以下是一个简单的UDP客户端实现:

#include <iostream>
#include <string>
#include <asio.hpp>

using asio::ip::udp;

int main(int argc, char* argv[]) {
    try {
        if (argc != 3) {
            std::cerr << "Usage: udp_client <host> <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        
        // 解析服务器地址
        udp::resolver resolver(io_context);
        udp::endpoint endpoint = *resolver.resolve(udp::v4(), argv[1], argv[2]).begin();
        
        // 创建socket并连接到服务器
        udp::socket socket(io_context);
        socket.open(udp::v4());
        
        // 发送消息到服务器
        std::string message = "Hello, UDP Server!";
        std::cout << "Sending: " << message << std::endl;
        socket.send_to(asio::buffer(message), endpoint);
        
        // 等待并接收服务器响应
        udp::endpoint sender_endpoint;
        enum { max_length = 1024 };
        char reply[max_length];
        std::size_t reply_length = socket.receive_from(
            asio::buffer(reply, max_length), sender_endpoint);
        
        std::cout << "Received reply: ";
        std::cout.write(reply, reply_length);
        std::cout << std::endl;
        
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

3. 高级UDP协议实现

下面是一个更复杂的UDP协议实现,包含以下特性:

  • 消息封装(头部+数据)
  • 简单的错误检测
  • 超时重传机制
  • 异步处理模型
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <map>
#include <cstdint>
#include <chrono>
#include <asio.hpp>

using asio::ip::udp;
using asio::steady_timer;

// 自定义UDP消息类
class UdpMessage {
public:
    // 构造函数
    UdpMessage(uint16_t message_id, const std::vector<uint8_t>& data)
        : message_id_(message_id), data_(data) {
        calculate_checksum();
    }
    
    // 从原始数据构造
    static std::optional<UdpMessage> from_raw(const std::vector<uint8_t>& raw_data) {
        // 检查数据长度是否足够
        if (raw_data.size() < kHeaderSize) {
            return std::nullopt;
        }
        
        // 解析头部
        uint16_t message_id = (static_cast<uint16_t>(raw_data[0]) << 8) | raw_data[1];
        uint16_t data_size = (static_cast<uint16_t>(raw_data[2]) << 8) | raw_data[3];
        uint16_t checksum = (static_cast<uint16_t>(raw_data[4]) << 8) | raw_data[5];
        
        // 检查数据长度是否匹配
        if (raw_data.size() != kHeaderSize + data_size) {
            return std::nullopt;
        }
        
        // 提取数据部分
        std::vector<uint8_t> data(raw_data.begin() + kHeaderSize, raw_data.end());
        
        // 计算并验证校验和
        UdpMessage message(message_id, data);
        if (message.get_checksum() != checksum) {
            return std::nullopt;
        }
        
        return message;
    }
    
    // 序列化为原始数据
    std::vector<uint8_t> to_raw() const {
        std::vector<uint8_t> raw;
        
        // 添加消息ID
        raw.push_back(static_cast<uint8_t>((message_id_ >> 8) & 0xFF));
        raw.push_back(static_cast<uint8_t>(message_id_ & 0xFF));
        
        // 添加数据大小
        raw.push_back(static_cast<uint8_t>((data_.size() >> 8) & 0xFF));
        raw.push_back(static_cast<uint8_t>(data_.size() & 0xFF));
        
        // 添加校验和
        raw.push_back(static_cast<uint8_t>((checksum_ >> 8) & 0xFF));
        raw.push_back(static_cast<uint8_t>(checksum_ & 0xFF));
        
        // 添加数据
        raw.insert(raw.end(), data_.begin(), data_.end());
        
        return raw;
    }
    
    // 获取消息ID
    uint16_t get_message_id() const { return message_id_; }
    
    // 获取数据
    const std::vector<uint8_t>& get_data() const { return data_; }
    
    // 获取校验和
    uint16_t get_checksum() const { return checksum_; }
    
    // 头部大小
    static constexpr size_t kHeaderSize = 6; // ID(2) + Size(2) + Checksum(2)
    
private:
    // 计算校验和
    void calculate_checksum() {
        checksum_ = 0;
        
        // 添加消息ID到校验和
        checksum_ += message_id_;
        
        // 添加数据大小到校验和
        checksum_ += static_cast<uint16_t>(data_.size());
        
        // 添加数据到校验和
        for (size_t i = 0; i < data_.size(); i += 2) {
            uint16_t word = 0;
            word |= static_cast<uint16_t>(data_[i]) << 8;
            if (i + 1 < data_.size()) {
                word |= data_[i + 1];
            }
            checksum_ += word;
        }
        
        // 折叠校验和(处理溢出)
        while (checksum_ >> 16) {
            checksum_ = (checksum_ & 0xFFFF) + (checksum_ >> 16);
        }
        
        // 取反
        checksum_ = ~checksum_;
    }
    
    uint16_t message_id_;
    std::vector<uint8_t> data_;
    uint16_t checksum_;
};

// 高级UDP客户端类
class ReliableUdpClient : public std::enable_shared_from_this<ReliableUdpClient> {
public:
    ReliableUdpClient(asio::io_context& io_context)
        : io_context_(io_context),
          socket_(io_context, udp::v4()),
          resolver_(io_context),
          next_message_id_(1),
          timeout_(500) { // 默认超时时间500ms
        socket_.set_option(udp::socket::reuse_address(true));
    }
    
    // 连接到服务器
    void connect(const std::string& host, const std::string& port) {
        server_endpoint_ = *resolver_.resolve(udp::v4(), host, port).begin();
        
        // 开始接收数据
        do_receive();
        
        std::cout << "Connected to UDP server at " << server_endpoint_.address().to_string() 
                  << ":" << server_endpoint_.port() << std::endl;
    }
    
    // 发送消息(带确认和重传)
    void send_message(const std::vector<uint8_t>& data, 
                     std::function<void(bool)> completion_handler = nullptr) {
        auto self = shared_from_this();
        uint16_t message_id = next_message_id_++;
        
        // 创建消息
        UdpMessage message(message_id, data);
        std::vector<uint8_t> raw_message = message.to_raw();
        
        // 保存消息和回调
        OutgoingMessage outgoing;
        outgoing.raw_data = raw_message;
        outgoing.attempts = 0;
        outgoing.max_attempts = 3;
        outgoing.completion_handler = completion_handler;
        
        outgoing_messages_[message_id] = outgoing;
        
        // 发送消息
        send_with_retry(message_id);
    }
    
    // 设置消息接收回调
    void set_message_handler(std::function<void(const UdpMessage&)> handler) {
        message_handler_ = handler;
    }
    
    // 设置超时时间(毫秒)
    void set_timeout(uint32_t milliseconds) {
        timeout_ = milliseconds;
    }
    
    // 关闭客户端
    void close() {
        socket_.close();
        
        // 取消所有定时器
        for (auto& pair : outgoing_messages_) {
            if (pair.second.timer) {
                pair.second.timer->cancel();
            }
        }
    }
    
private:
    // 待发送消息结构
    struct OutgoingMessage {
        std::vector<uint8_t> raw_data;
        int attempts;
        int max_attempts;
        std::shared_ptr<steady_timer> timer;
        std::function<void(bool)> completion_handler;
    };
    
    // 发送消息并重试
    void send_with_retry(uint16_t message_id) {
        auto self = shared_from_this();
        auto it = outgoing_messages_.find(message_id);
        
        if (it == outgoing_messages_.end()) {
            return; // 消息已被处理
        }
        
        OutgoingMessage& outgoing = it->second;
        
        // 检查重传次数
        if (outgoing.attempts >= outgoing.max_attempts) {
            std::cout << "Message " << message_id << " failed after " 
                      << outgoing.attempts << " attempts" << std::endl;
            
            // 调用完成回调,指示失败
            if (outgoing.completion_handler) {
                outgoing.completion_handler(false);
            }
            
            // 移除消息
            outgoing_messages_.erase(it);
            return;
        }
        
        // 增加尝试次数
        outgoing.attempts++;
        
        std::cout << "Sending message " << message_id << " (attempt " 
                  << outgoing.attempts << ")" << std::endl;
        
        // 发送消息
        socket_.async_send_to(
            asio::buffer(outgoing.raw_data), server_endpoint_,
            [self, message_id](std::error_code ec, std::size_t bytes_sent) {
                if (!ec) {
                    // 设置超时定时器
                    self->setup_retry_timer(message_id);
                } else {
                    std::cerr << "Error sending message: " << ec.message() << std::endl;
                    
                    // 立即重试
                    self->send_with_retry(message_id);
                }
            });
    }
    
    // 设置重传定时器
    void setup_retry_timer(uint16_t message_id) {
        auto self = shared_from_this();
        auto it = outgoing_messages_.find(message_id);
        
        if (it == outgoing_messages_.end()) {
            return;
        }
        
        // 创建或重置定时器
        if (!it->second.timer) {
            it->second.timer = std::make_shared<steady_timer>(io_context_);
        }
        
        it->second.timer->expires_after(std::chrono::milliseconds(timeout_));
        it->second.timer->async_wait(
            [self, message_id](std::error_code ec) {
                if (!ec) { // 定时器未被取消
                    self->send_with_retry(message_id);
                }
            });
    }
    
    // 接收数据
    void do_receive() {
        auto self = shared_from_this();
        socket_.async_receive_from(
            asio::buffer(recv_buffer_), sender_endpoint_,
            [this, self](std::error_code ec, std::size_t bytes_recvd) {
                if (!ec && bytes_recvd > 0) {
                    // 将接收到的数据转换为vector
                    std::vector<uint8_t> raw_data(recv_buffer_.begin(), recv_buffer_.begin() + bytes_recvd);
                    
                    // 尝试解析消息
                    auto message = UdpMessage::from_raw(raw_data);
                    if (message) {
                        // 检查是否是确认消息(假设消息ID为0表示确认)
                        if (message->get_message_id() == 0) {
                            handle_acknowledgment(*message);
                        } else {
                            // 处理普通消息
                            handle_message(*message);
                        }
                    }
                }
                
                // 继续接收下一个数据包
                do_receive();
            });
    }
    
    // 处理确认消息
    void handle_acknowledgment(const UdpMessage& message) {
        // 假设确认消息的数据部分包含被确认的消息ID
        if (message.get_data().size() >= 2) {
            uint16_t acknowledged_id = (static_cast<uint16_t>(message.get_data()[0]) << 8) | message.get_data()[1];
            
            auto it = outgoing_messages_.find(acknowledged_id);
            if (it != outgoing_messages_.end()) {
                std::cout << "Message " << acknowledged_id << " acknowledged" << std::endl;
                
                // 取消定时器
                if (it->second.timer) {
                    it->second.timer->cancel();
                }
                
                // 调用完成回调,指示成功
                if (it->second.completion_handler) {
                    it->second.completion_handler(true);
                }
                
                // 移除消息
                outgoing_messages_.erase(it);
            }
        }
    }
    
    // 处理接收到的消息
    void handle_message(const UdpMessage& message) {
        std::cout << "Received message with ID: " << message.get_message_id() 
                  << ", data size: " << message.get_data().size() << " bytes" << std::endl;
        
        // 发送确认消息
        std::vector<uint8_t> ack_data;
        ack_data.push_back(static_cast<uint8_t>((message.get_message_id() >> 8) & 0xFF));
        ack_data.push_back(static_cast<uint8_t>(message.get_message_id() & 0xFF));
        
        UdpMessage ack_message(0, ack_data); // ID为0表示确认消息
        std::vector<uint8_t> raw_ack = ack_message.to_raw();
        
        socket_.async_send_to(
            asio::buffer(raw_ack), server_endpoint_,
            [](std::error_code /*ec*/, std::size_t /*bytes_sent*/) {
                // 忽略发送确认的错误
            });
        
        // 调用消息处理回调
        if (message_handler_) {
            message_handler_(message);
        }
    }
    
    asio::io_context& io_context_;
    udp::socket socket_;
    udp::resolver resolver_;
    udp::endpoint server_endpoint_;
    udp::endpoint sender_endpoint_;
    
    std::vector<uint8_t> recv_buffer_{4096}; // 接收缓冲区
    std::map<uint16_t, OutgoingMessage> outgoing_messages_; // 待发送消息
    uint16_t next_message_id_; // 下一个消息ID
    uint32_t timeout_; // 超时时间(毫秒)
    
    std::function<void(const UdpMessage&)> message_handler_; // 消息处理回调
};

// 客户端示例
int main() {
    try {
        asio::io_context io_context;
        auto client = std::make_shared<ReliableUdpClient>(io_context);
        
        // 设置消息接收回调
        client->set_message_handler([](const UdpMessage& message) {
            std::cout << "Received message content: ";
            for (uint8_t byte : message.get_data()) {
                std::cout << static_cast<char>(byte);
            }
            std::cout << std::endl;
        });
        
        // 设置超时时间(毫秒)
        client->set_timeout(1000);
        
        // 连接到服务器
        client->connect("localhost", "8080");
        
        // 创建一个工作对象以防止io_context在没有事件时退出
        asio::executor_work_guard<asio::io_context::executor_type> work = 
            asio::make_work_guard(io_context);
        
        // 在单独的线程中运行io_context
        std::thread io_thread([&io_context]() {
            io_context.run();
        });
        
        // 发送一些测试消息
        std::string text_message = "Hello, Reliable UDP Server!";
        std::vector<uint8_t> message_data(text_message.begin(), text_message.end());
        
        for (int i = 0; i < 3; ++i) {
            client->send_message(message_data, [i](bool success) {
                std::cout << "Message " << (i+1) << " send " 
                          << (success ? "succeeded" : "failed") << std::endl;
            });
            
            // 间隔发送消息
            std::this_thread::sleep_for(std::chrono::seconds(2));
        }
        
        // 等待一段时间
        std::this_thread::sleep_for(std::chrono::seconds(5));
        
        // 关闭客户端
        client->close();
        
        // 停止io_context
        io_context.stop();
        if (io_thread.joinable()) {
            io_thread.join();
        }
        
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

4. 高级UDP服务器实现

下面是与上面客户端对应的高级UDP服务器实现:

#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <map>
#include <cstdint>
#include <asio.hpp>

using asio::ip::udp;

// UdpMessage类的定义(如前所述)
// ...

// 高级UDP服务器类
class ReliableUdpServer {
public:
    ReliableUdpServer(asio::io_context& io_context, short port)
        : io_context_(io_context),
          socket_(io_context, udp::endpoint(udp::v4(), port)),
          next_message_id_(1) {
        socket_.set_option(udp::socket::reuse_address(true));
        do_receive();
    }
    
    // 设置消息处理回调
    void set_message_handler(std::function<void(const udp::endpoint&, const UdpMessage&)> handler) {
        message_handler_ = handler;
    }
    
private:
    // 接收数据
    void do_receive() {
        auto self = shared_from_this();
        socket_.async_receive_from(
            asio::buffer(recv_buffer_), remote_endpoint_,
            [this](std::error_code ec, std::size_t bytes_recvd) {
                if (!ec && bytes_recvd > 0) {
                    // 将接收到的数据转换为vector
                    std::vector<uint8_t> raw_data(recv_buffer_.begin(), recv_buffer_.begin() + bytes_recvd);
                    
                    // 尝试解析消息
                    auto message = UdpMessage::from_raw(raw_data);
                    if (message) {
                        // 检查是否是确认消息(假设消息ID为0表示确认)
                        if (message->get_message_id() == 0) {
                            handle_acknowledgment(remote_endpoint_, *message);
                        } else {
                            // 处理普通消息
                            handle_message(remote_endpoint_, *message);
                        }
                    }
                }
                
                // 继续接收下一个数据包
                do_receive();
            });
    }
    
    // 处理确认消息
    void handle_acknowledgment(const udp::endpoint& sender, const UdpMessage& message) {
        // 在此实现中,服务器可能不需要处理确认消息
        // 因为服务器发送的消息可能不需要客户端确认
        // 但如果需要,可以按照与客户端类似的方式实现
    }
    
    // 处理接收到的消息
    void handle_message(const udp::endpoint& sender, const UdpMessage& message) {
        std::cout << "Received message from " << sender.address().to_string() 
                  << ":" << sender.port() << " with ID: " << message.get_message_id() 
                  << ", data size: " << message.get_data().size() << " bytes" << std::endl;
        
        // 发送确认消息
        std::vector<uint8_t> ack_data;
        ack_data.push_back(static_cast<uint8_t>((message.get_message_id() >> 8) & 0xFF));
        ack_data.push_back(static_cast<uint8_t>(message.get_message_id() & 0xFF));
        
        UdpMessage ack_message(0, ack_data); // ID为0表示确认消息
        std::vector<uint8_t> raw_ack = ack_message.to_raw();
        
        socket_.async_send_to(
            asio::buffer(raw_ack), sender,
            [](std::error_code /*ec*/, std::size_t /*bytes_sent*/) {
                // 忽略发送确认的错误
            });
        
        // 调用消息处理回调
        if (message_handler_) {
            message_handler_(sender, message);
        }
    }
    
    // 向客户端发送消息
    void send_message(const udp::endpoint& client, const std::vector<uint8_t>& data) {
        uint16_t message_id = next_message_id_++;
        UdpMessage message(message_id, data);
        std::vector<uint8_t> raw_message = message.to_raw();
        
        socket_.async_send_to(
            asio::buffer(raw_message), client,
            [message_id](std::error_code ec, std::size_t bytes_sent) {
                if (!ec) {
                    std::cout << "Sent message with ID: " << message_id 
                              << ", size: " << bytes_sent << " bytes" << std::endl;
                } else {
                    std::cerr << "Error sending message: " << ec.message() << std::endl;
                }
            });
    }
    
    asio::io_context& io_context_;
    udp::socket socket_;
    udp::endpoint remote_endpoint_;
    
    std::vector<uint8_t> recv_buffer_{4096}; // 接收缓冲区
    uint16_t next_message_id_; // 下一个消息ID
    
    std::function<void(const udp::endpoint&, const UdpMessage&)> message_handler_; // 消息处理回调
};

// 服务器示例
int main(int argc, char* argv[]) {
    try {
        if (argc != 2) {
            std::cerr << "Usage: reliable_udp_server <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        ReliableUdpServer server(io_context, std::atoi(argv[1]));
        
        // 设置消息处理回调
        server.set_message_handler([&server](const udp::endpoint& client, const UdpMessage& message) {
            // 处理接收到的消息
            std::cout << "Received message content: ";
            for (uint8_t byte : message.get_data()) {
                std::cout << static_cast<char>(byte);
            }
            std::cout << std::endl;
            
            // 回显消息
            std::vector<uint8_t> response_data(message.get_data().begin(), message.get_data().end());
            server.send_message(client, response_data);
        });
        
        std::cout << "Reliable UDP server running on port " << argv[1] << std::endl;
        std::cout << "Press Ctrl+C to exit" << std::endl;
        
        io_context.run();
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

五、协议实现最佳实践

在使用Asio实现网络协议时,以下是一些最佳实践:

  1. 错误处理

    • 始终检查异步操作返回的错误码
    • 为网络错误提供适当的恢复机制
    • 使用异常处理来捕获不可恢复的错误
  2. 内存管理

    • 使用智能指针管理长期存在的对象
    • 避免不必要的内存分配和复制
    • 使用预分配的缓冲区来减少内存碎片
  3. 性能优化

    • 使用asio::streambuf来高效处理可变大小的数据
    • 对大数据使用零拷贝技术
    • 合理设置缓冲区大小以平衡内存使用和性能
  4. 安全性考虑

    • 验证所有输入数据的长度和格式
    • 实施适当的认证和授权机制
    • 保护敏感数据的传输(使用SSL/TLS)
  5. 可扩展性设计

    • 将协议逻辑与业务逻辑分离
    • 使用回调或观察者模式处理事件
    • 设计模块化的组件以支持未来的扩展
  6. 测试策略

    • 为协议实现单元测试
    • 使用模拟对象测试网络交互
    • 进行性能测试以确定瓶颈

六、总结

本教程详细介绍了如何使用Asio实现各种网络协议,包括:

  • HTTP协议:实现了基本和高级的HTTP服务器和客户端,支持请求路由、静态文件服务等功能
  • WebSocket协议:实现了完整的WebSocket握手和消息处理,支持文本和二进制消息
  • 自定义TCP协议:设计并实现了一个简单但功能完整的自定义协议,包括消息封装、解析和错误处理
  • UDP协议:实现了基础的UDP回显服务器/客户端,以及更高级的带确认和重传机制的可靠UDP实现

通过这些实现,我们学习了Asio的核心概念和技术,包括异步I/O模型、缓冲区管理、错误处理等。这些知识将帮助你构建高性能、可靠的网络应用程序。

在实际项目中,你可以根据具体需求选择或设计合适的协议,并结合Asio提供的强大功能来实现。记住,良好的协议设计是构建可靠网络应用的基础,而Asio则为你提供了实现这些协议的理想框架。


### 2. 高级HTTP服务器功能

下面是一个更高级的HTTP服务器实现,包含以下功能:
- 请求路由
- 查询参数解析
- 表单数据处理
- 文件服务

```cpp
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <map>
#include <fstream>
#include <sstream>
#include <filesystem>
#include <regex>
#include <asio.hpp>

namespace fs = std::filesystem;
using asio::ip::tcp;

// HTTP请求
struct HttpRequest {
    std::string method;
    std::string path;
    std::string version;
    std::map<std::string, std::string> headers;
    std::map<std::string, std::string> query_params;
    std::string body;
};

// HTTP响应
struct HttpResponse {
    std::string version = "HTTP/1.1";
    int status_code = 200;
    std::string status_message = "OK";
    std::map<std::string, std::string> headers;
    std::string body;
    
    // 转换为字符串
    std::string to_string() {
        std::stringstream ss;
        ss << version << " " << status_code << " " << status_message << "\r\n";
        
        // 添加Content-Length头
        headers["Content-Length"] = std::to_string(body.size());
        
        // 添加所有头信息
        for (const auto& [key, value] : headers) {
            ss << key << ": " << value << "\r\n";
        }
        
        ss << "\r\n" << body;
        return ss.str();
    }
};

// HTTP路由处理器类型
using HttpHandler = std::function<void(const HttpRequest&, HttpResponse&)>;

// HTTP服务器类
class AdvancedHttpServer {
public:
    AdvancedHttpServer(asio::io_context& io_context, short port)
        : acceptor_(io_context, tcp::endpoint(tcp::v4(), port)) {
        do_accept();
    }
    
    // 注册GET路由
    void get(const std::string& path, const HttpHandler& handler) {
        routes_["GET"][path] = handler;
    }
    
    // 注册POST路由
    void post(const std::string& path, const HttpHandler& handler) {
        routes_["POST"][path] = handler;
    }
    
    // 设置静态文件目录
    void set_static_dir(const std::string& dir) {
        static_dir_ = dir;
    }
    
private:
    // HTTP连接类
    class Connection : public std::enable_shared_from_this<Connection> {
    public:
        Connection(tcp::socket socket, AdvancedHttpServer& server)
            : socket_(std::move(socket)), server_(server) {
        }
        
        void start() {
            read_request_line();
        }
        
    private:
        // 读取请求行
        void read_request_line() {
            auto self = shared_from_this();
            asio::async_read_until(socket_, buffer_, "\r\n",
                [self](std::error_code ec, std::size_t bytes_transferred) {
                    if (!ec) {
                        std::istream request_stream(&self->buffer_);
                        std::string request_line;
                        std::getline(request_stream, request_line);
                        
                        // 移除回车符
                        if (!request_line.empty() && request_line.back() == '\r') {
                            request_line.pop_back();
                        }
                        
                        self->parse_request_line(request_line);
                        self->read_headers();
                    }
                });
        }
        
        // 解析请求行
        void parse_request_line(const std::string& request_line) {
            std::istringstream request_line_stream(request_line);
            request_line_stream >> request_.method >> request_.path >> request_.version;
            
            // 解析查询参数
            size_t query_pos = request_.path.find('?');
            if (query_pos != std::string::npos) {
                std::string path_part = request_.path.substr(0, query_pos);
                std::string query_part = request_.path.substr(query_pos + 1);
                request_.path = path_part;
                parse_query_params(query_part);
            }
        }
        
        // 解析查询参数
        void parse_query_params(const std::string& query_part) {
            std::stringstream query_stream(query_part);
            std::string param;
            
            while (std::getline(query_stream, param, '&')) {
                size_t eq_pos = param.find('=');
                if (eq_pos != std::string::npos) {
                    std::string key = param.substr(0, eq_pos);
                    std::string value = param.substr(eq_pos + 1);
                    request_.query_params[key] = value;
                }
            }
        }
        
        // 读取请求头
        void read_headers() {
            auto self = shared_from_this();
            asio::async_read_until(socket_, buffer_, "\r\n\r\n",
                [self](std::error_code ec, std::size_t bytes_transferred) {
                    if (!ec) {
                        std::istream request_stream(&self->buffer_);
                        std::string header_line;
                        
                        while (std::getline(request_stream, header_line) && header_line != "\r") {
                            // 移除回车符
                            if (!header_line.empty() && header_line.back() == '\r') {
                                header_line.pop_back();
                            }
                            
                            size_t colon_pos = header_line.find(':');
                            if (colon_pos != std::string::npos) {
                                std::string key = header_line.substr(0, colon_pos);
                                std::string value = header_line.substr(colon_pos + 2); // 跳过冒号和空格
                                self->request_.headers[key] = value;
                            }
                        }
                        
                        // 检查是否有请求体
                        auto content_length_it = self->request_.headers.find("Content-Length");
                        if (content_length_it != self->request_.headers.end()) {
                            self->read_body(std::stoi(content_length_it->second));
                        } else {
                            self->handle_request();
                        }
                    }
                });
        }
        
        // 读取请求体
        void read_body(size_t content_length) {
            auto self = shared_from_this();
            asio::async_read(socket_, buffer_, asio::transfer_exactly(content_length),
                [self](std::error_code ec, std::size_t bytes_transferred) {
                    if (!ec) {
                        // 读取请求体
                        std::istream request_stream(&self->buffer_);
                        std::vector<char> body_content(content_length);
                        request_stream.read(body_content.data(), content_length);
                        self->request_.body = std::string(body_content.data(), content_length);
                        
                        self->handle_request();
                    }
                });
        }
        
        // 处理请求
        void handle_request() {
            HttpResponse response;
            
            // 检查是否为静态文件请求
            if (!server_.static_dir_.empty() && request_.method == "GET") {
                fs::path file_path = fs::path(server_.static_dir_) / fs::path(request_.path.substr(1));
                if (fs::exists(file_path) && !fs::is_directory(file_path)) {
                    serve_static_file(file_path.string(), response);
                    send_response(response);
                    return;
                }
            }
            
            // 查找路由处理器
            auto method_it = server_.routes_.find(request_.method);
            if (method_it != server_.routes_.end()) {
                auto path_it = method_it->second.find(request_.path);
                if (path_it != method_it->second.end()) {
                    // 调用路由处理器
                    path_it->second(request_, response);
                    send_response(response);
                    return;
                }
            }
            
            // 未找到路由,返回404
            response.status_code = 404;
            response.status_message = "Not Found";
            response.body = "<html><body><h1>404 Not Found</h1></body></html>";
            response.headers["Content-Type"] = "text/html";
            send_response(response);
        }
        
        // 提供静态文件
        void serve_static_file(const std::string& file_path, HttpResponse& response) {
            std::ifstream file(file_path, std::ios::binary);
            if (file) {
                // 读取文件内容
                std::stringstream buffer;
                buffer << file.rdbuf();
                response.body = buffer.str();
                
                // 设置内容类型
                std::string extension = fs::path(file_path).extension().string();
                response.headers["Content-Type"] = get_mime_type(extension);
            } else {
                response.status_code = 404;
                response.status_message = "Not Found";
                response.body = "<html><body><h1>404 File Not Found</h1></body></html>";
                response.headers["Content-Type"] = "text/html";
            }
        }
        
        // 获取MIME类型
        std::string get_mime_type(const std::string& extension) {
            if (extension == ".html" || extension == ".htm") return "text/html";
            if (extension == ".css") return "text/css";
            if (extension == ".js") return "application/javascript";
            if (extension == ".json") return "application/json";
            if (extension == ".png") return "image/png";
            if (extension == ".jpg" || extension == ".jpeg") return "image/jpeg";
            if (extension == ".gif") return "image/gif";
            if (extension == ".svg") return "image/svg+xml";
            if (extension == ".txt") return "text/plain";
            return "application/octet-stream";
        }
        
        // 发送响应
        void send_response(const HttpResponse& response) {
            auto self = shared_from_this();
            std::string response_str = response.to_string();
            
            asio::async_write(socket_, asio::buffer(response_str),
                [self](std::error_code ec, std::size_t) {
                    if (!ec) {
                        // 响应发送完成后关闭连接
                        asio::error_code ignored_ec;
                        self->socket_.shutdown(tcp::socket::shutdown_both, ignored_ec);
                    }
                });
        }
        
        tcp::socket socket_;
        asio::streambuf buffer_;
        HttpRequest request_;
        AdvancedHttpServer& server_;
    };
    
    // 接受新连接
    void do_accept() {
        acceptor_.async_accept(
            [this](std::error_code ec, tcp::socket socket) {
                if (!ec) {
                    std::make_shared<Connection>(std::move(socket), *this)->start();
                }
                do_accept();
            });
    }
    
    tcp::acceptor acceptor_;
    std::map<std::string, std::map<std::string, HttpHandler>> routes_;
    std::string static_dir_;
};

int main(int argc, char* argv[]) {
    try {
        if (argc != 2) {
            std::cerr << "Usage: advanced_http_server <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        AdvancedHttpServer server(io_context, std::atoi(argv[1]));
        
        // 设置静态文件目录
        server.set_static_dir("./public");
        
        // 注册GET路由
        server.get("/api/hello", [](const HttpRequest& request, HttpResponse& response) {
            std::string name = "World";
            if (request.query_params.count("name")) {
                name = request.query_params["name"];
            }
            response.body = "{\"message\": \"Hello, " + name + "!\"}";
            response.headers["Content-Type"] = "application/json";
        });
        
        // 注册POST路由
        server.post("/api/data", [](const HttpRequest& request, HttpResponse& response) {
            // 假设请求体是JSON格式的数据
            response.body = "{\"received\": true, \"data_length\": " + 
                           std::to_string(request.body.size()) + "}";
            response.headers["Content-Type"] = "application/json";
        });
        
        std::cout << "Advanced HTTP server running on port " << argv[1] << std::endl;
        std::cout << "Static files served from ./public" << std::endl;
        std::cout << "API endpoints: /api/hello, /api/data" << std::endl;
        io_context.run();
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

3. HTTP客户端实现

下面是一个简单的HTTP客户端实现,可以发送GET和POST请求:

#include <iostream>
#include <string>
#include <functional>
#include <asio.hpp>

using asio::ip::tcp;

// HTTP客户端类
class HttpClient {
public:
    HttpClient(asio::io_context& io_context)
        : resolver_(io_context), socket_(io_context) {
    }
    
    // 发送GET请求
    void get(const std::string& host, const std::string& port, const std::string& path,
             std::function<void(const std::string&)> callback) {
        callback_ = callback;
        
        // 解析服务器地址
        resolver_.async_resolve(host, port, 
            [this, host, path](std::error_code ec, tcp::resolver::results_type results) {
                if (!ec) {
                    // 连接到服务器
                    asio::async_connect(socket_, results, 
                        [this, host, path](std::error_code ec, tcp::endpoint) {
                            if (!ec) {
                                // 构建HTTP GET请求
                                std::string request = "GET " + path + " HTTP/1.1\r\n";
                                request += "Host: " + host + "\r\n";
                                request += "Connection: close\r\n";
                                request += "\r\n";
                                
                                // 发送请求
                                asio::async_write(socket_, asio::buffer(request),
                                    [this](std::error_code ec, std::size_t) {
                                        if (!ec) {
                                            // 读取响应
                                            read_response();
                                        }
                                    });
                            }
                        });
                }
            });
    }
    
    // 发送POST请求
    void post(const std::string& host, const std::string& port, const std::string& path, 
              const std::string& body, const std::string& content_type,
              std::function<void(const std::string&)> callback) {
        callback_ = callback;
        
        // 解析服务器地址
        resolver_.async_resolve(host, port, 
            [this, host, path, body, content_type](std::error_code ec, tcp::resolver::results_type results) {
                if (!ec) {
                    // 连接到服务器
                    asio::async_connect(socket_, results, 
                        [this, host, path, body, content_type](std::error_code ec, tcp::endpoint) {
                            if (!ec) {
                                // 构建HTTP POST请求
                                std::string request = "POST " + path + " HTTP/1.1\r\n";
                                request += "Host: " + host + "\r\n";
                                request += "Content-Type: " + content_type + "\r\n";
                                request += "Content-Length: " + std::to_string(body.size()) + "\r\n";
                                request += "Connection: close\r\n";
                                request += "\r\n";
                                request += body;
                                
                                // 发送请求
                                asio::async_write(socket_, asio::buffer(request),
                                    [this](std::error_code ec, std::size_t) {
                                        if (!ec) {
                                            // 读取响应
                                            read_response();
                                        }
                                    });
                            }
                        });
                }
            });
    }
    
private:
    // 读取HTTP响应
    void read_response() {
        auto self(shared_from_this());
        asio::async_read(socket_, response_buffer_, 
            [this](std::error_code ec, std::size_t /*bytes_transferred*/) {
                if (!ec) {
                    // 读取完整响应
                    std::string response(asio::buffers_begin(response_buffer_.data()),
                                         asio::buffers_end(response_buffer_.data()));
                    
                    // 调用回调函数处理响应
                    callback_(response);
                }
            });
    }
    
    tcp::resolver resolver_;
    tcp::socket socket_;
    asio::streambuf response_buffer_;
    std::function<void(const std::string&)> callback_;
};

int main() {
    try {
        asio::io_context io_context;
        auto client = std::make_shared<HttpClient>(io_context);
        
        // 发送GET请求示例
        client->get("www.example.com", "80", "/", 
            [](const std::string& response) {
                std::cout << "GET Response: " << std::endl;
                std::cout << response.substr(0, 500) << "..." << std::endl; // 只显示前500个字符
            });
        
        // 运行事件循环
        io_context.run();
        
        // 重新创建io_context以发送第二个请求
        asio::io_context io_context2;
        auto client2 = std::make_shared<HttpClient>(io_context2);
        
        // 发送POST请求示例
        std::string json_body = "{\"name\": \"Asio\", \"type\": \"HTTP Client\"}";
        client2->post("httpbin.org", "80", "/post", json_body, "application/json",
            [](const std::string& response) {
                std::cout << "\nPOST Response: " << std::endl;
                std::cout << response.substr(0, 500) << "..." << std::endl; // 只显示前500个字符
            });
        
        // 运行事件循环
        io_context2.run();
        
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

二、WebSocket协议实现

WebSocket提供了双向通信的能力,适用于需要实时通信的应用程序。下面我们将学习如何使用Asio实现WebSocket服务器和客户端。

1. WebSocket握手

WebSocket通信的第一步是握手,客户端发送HTTP请求,服务器返回HTTP响应,完成协议升级:

#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <string_view>
#include <vector>
#include <algorithm>
#include <cctype>
#include <asio.hpp>

using asio::ip::tcp;

// Base64编码工具函数
std::string base64_encode(const std::string& input) {
    const std::string base64_chars = 
        "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
        "abcdefghijklmnopqrstuvwxyz"
        "0123456789+/";
    
    std::string encoded;
    int val = 0, valb = -6;
    
    for (unsigned char c : input) {
        val = (val << 8) + c;
        valb += 8;
        while (valb >= 0) {
            encoded.push_back(base64_chars[(val >> valb) & 0x3F]);
            valb -= 6;
        }
    }
    
    if (valb > -6) encoded.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
    while (encoded.size() % 4) encoded.push_back('=');
    
    return encoded;
}

// SHA-1哈希工具函数 (简化版本,实际应用应使用密码学库)
std::string sha1(const std::string& str) {
    // 注意:这只是一个占位符实现
    // 在实际应用中,应该使用OpenSSL或其他密码学库来计算SHA-1哈希
    
    // 这里返回一个模拟的哈希值
    // 实际的WebSocket实现必须正确计算SHA-1哈希
    return "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
}

// WebSocket服务器类
class WebSocketServer {
public:
    WebSocketServer(asio::io_context& io_context, short port)
        : acceptor_(io_context, tcp::endpoint(tcp::v4(), port)) {
        do_accept();
    }
    
private:
    // WebSocket连接类
    class Connection : public std::enable_shared_from_this<Connection> {
    public:
        Connection(tcp::socket socket)
            : socket_(std::move(socket)) {
        }
        
        void start() {
            perform_handshake();
        }
        
    private:
        // 执行WebSocket握手
        void perform_handshake() {
            auto self = shared_from_this();
            asio::async_read_until(socket_, buffer_, "\r\n\r\n",
                [self](std::error_code ec, std::size_t bytes_transferred) {
                    if (!ec) {
                        std::istream request_stream(&self->buffer_);
                        std::string request_line;
                        std::getline(request_stream, request_line);
                        
                        // 解析请求头
                        std::map<std::string, std::string> headers;
                        std::string header_line;
                        while (std::getline(request_stream, header_line) && header_line != "\r") {
                            if (!header_line.empty() && header_line.back() == '\r') {
                                header_line.pop_back();
                            }
                            
                            size_t colon_pos = header_line.find(':');
                            if (colon_pos != std::string::npos) {
                                std::string key = header_line.substr(0, colon_pos);
                                // 去除键中的空格并转为小写
                                key.erase(std::remove_if(key.begin(), key.end(), ::isspace), key.end());
                                std::transform(key.begin(), key.end(), key.begin(), ::tolower);
                                
                                std::string value = header_line.substr(colon_pos + 1);
                                // 去除值前面的空格
                                size_t start_pos = value.find_first_not_of(" ");
                                if (start_pos != std::string::npos) {
                                    value = value.substr(start_pos);
                                }
                                
                                headers[key] = value;
                            }
                        }
                        
                        // 验证WebSocket握手请求
                        if (self->validate_handshake(headers)) {
                            // 发送WebSocket握手响应
                            self->send_handshake_response(headers);
                            // 握手完成后,开始处理WebSocket消息
                            self->read_frame();
                        } else {
                            std::cerr << "Invalid WebSocket handshake request" << std::endl;
                            self->socket_.close();
                        }
                    }
                });
        }
        
        // 验证WebSocket握手请求
        bool validate_handshake(const std::map<std::string, std::string>& headers) {
            // 检查必要的头信息
            if (headers.find("upgrade") == headers.end() || 
                headers.find("connection") == headers.end() ||
                headers.find("sec-websocket-key") == headers.end() ||
                headers.find("sec-websocket-version") == headers.end()) {
                return false;
            }
            
            // 检查Upgrade头的值是否为websocket
            std::string upgrade = headers.at("upgrade");
            std::transform(upgrade.begin(), upgrade.end(), upgrade.begin(), ::tolower);
            if (upgrade != "websocket") {
                return false;
            }
            
            // 检查Connection头是否包含Upgrade
            std::string connection = headers.at("connection");
            std::transform(connection.begin(), connection.end(), connection.begin(), ::tolower);
            if (connection.find("upgrade") == std::string::npos) {
                return false;
            }
            
            // 检查WebSocket版本
            if (headers.at("sec-websocket-version") != "13") {
                return false;
            }
            
            return true;
        }
        
        // 发送WebSocket握手响应
        void send_handshake_response(const std::map<std::string, std::string>& headers) {
            // 获取Sec-WebSocket-Key
            const std::string& key = headers.at("sec-websocket-key");
            
            // 计算Sec-WebSocket-Accept
            // 实际实现中应该使用正确的SHA-1哈希计算
            std::string accept_key = base64_encode(sha1(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
            
            // 构建HTTP响应
            std::string response = "HTTP/1.1 101 Switching Protocols\r\n";
            response += "Upgrade: websocket\r\n";
            response += "Connection: Upgrade\r\n";
            response += "Sec-WebSocket-Accept: " + accept_key + "\r\n";
            response += "\r\n";
            
            // 发送响应
            auto self = shared_from_this();
            asio::async_write(socket_, asio::buffer(response),
                [self](std::error_code ec, std::size_t) {
                    if (ec) {
                        std::cerr << "Failed to send handshake response: " << ec.message() << std::endl;
                        self->socket_.close();
                    }
                });
        }
        
        // 读取WebSocket帧
        void read_frame() {
            auto self = shared_from_this();
            asio::async_read(socket_, asio::buffer(&frame_header_, 2),
                [self](std::error_code ec, std::size_t) {
                    if (!ec) {
                        // 解析帧头部
                        bool fin = (frame_header_ & 0x80) != 0;
                        uint8_t opcode = frame_header_ & 0x0F;
                        bool mask = (frame_header_[1] & 0x80) != 0;
                        uint64_t payload_length = frame_header_[1] & 0x7F;
                        
                        // 处理不同的操作码
                        if (opcode == 0x08) { // 连接关闭帧
                            self->handle_close_frame();
                        } else if (opcode == 0x09) { // PING帧
                            self->handle_ping_frame();
                        } else if (opcode == 0x0A) { // PONG帧
                            self->handle_pong_frame();
                        } else if (opcode == 0x01 || opcode == 0x02) { // 文本或二进制帧
                            self->read_payload_length(payload_length, opcode, fin, mask);
                        } else {
                            std::cerr << "Unknown WebSocket opcode: " << static_cast<int>(opcode) << std::endl;
                            self->socket_.close();
                        }
                    }
                });
        }
        
        // 读取有效载荷长度
        void read_payload_length(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
            auto self = shared_from_this();
            
            if (payload_length == 126) {
                // 16位长度
                asio::async_read(socket_, asio::buffer(length_bytes_, 2),
                    [self, opcode, fin, mask](std::error_code ec, std::size_t) {
                        if (!ec) {
                            uint64_t extended_length = (static_cast<uint64_t>(self->length_bytes_[0]) << 8) | 
                                                      static_cast<uint64_t>(self->length_bytes_[1]);
                            self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
                        }
                    });
            } else if (payload_length == 127) {
                // 64位长度
                asio::async_read(socket_, asio::buffer(length_bytes_, 8),
                    [self, opcode, fin, mask](std::error_code ec, std::size_t) {
                        if (!ec) {
                            uint64_t extended_length = 0;
                            for (int i = 0; i < 8; ++i) {
                                extended_length = (extended_length << 8) | self->length_bytes_[i];
                            }
                            self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
                        }
                    });
            } else {
                // 7位长度
                read_masking_key_and_payload(payload_length, opcode, fin, mask);
            }
        }
        
        // 读取掩码键和有效载荷
        void read_masking_key_and_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
            auto self = shared_from_this();
            
            if (mask) {
                // 读取掩码键
                asio::async_read(socket_, asio::buffer(masking_key_, 4),
                    [self, payload_length, opcode, fin, mask](std::error_code ec, std::size_t) {
                        if (!ec) {
                            self->read_payload(payload_length, opcode, fin, mask);
                        }
                    });
            } else {
                // 没有掩码键
                read_payload(payload_length, opcode, fin, mask);
            }
        }
        
        // 读取有效载荷
        void read_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
            auto self = shared_from_this();
            payload_.resize(payload_length);
            
            asio::async_read(socket_, asio::buffer(payload_),
                [self, opcode, fin, mask](std::error_code ec, std::size_t bytes_transferred) {
                    if (!ec) {
                        // 处理掩码
                        if (mask) {
                            for (size_t i = 0; i < bytes_transferred; ++i) {
                                self->payload_[i] ^= self->masking_key_[i % 4];
                            }
                        }
                        
                        // 处理消息
                        self->handle_message(opcode, fin);
                        
                        // 继续读取下一帧
                        self->read_frame();
                    }
                });
        }
        
        // 处理WebSocket消息
        void handle_message(uint8_t opcode, bool fin) {
            if (opcode == 0x01) { // 文本消息
                std::string message(payload_.begin(), payload_.end());
                std::cout << "Received text message: " << message << std::endl;
                
                // 回显消息
                send_text_message(message);
            } else if (opcode == 0x02) { // 二进制消息
                std::cout << "Received binary message of size: " << payload_.size() << " bytes" << std::endl;
                
                // 回显二进制消息
                send_binary_message(payload_);
            }
        }
        
        // 发送文本消息
        void send_text_message(const std::string& message) {
            // 构建WebSocket帧
            std::vector<uint8_t> frame;
            
            // 帧头部
            frame.push_back(0x81); // FIN + TEXT opcode
            
            // 有效载荷长度
            size_t message_size = message.size();
            if (message_size <= 125) {
                frame.push_back(static_cast<uint8_t>(message_size));
            } else if (message_size <= 65535) {
                frame.push_back(126);
                frame.push_back(static_cast<uint8_t>((message_size >> 8) & 0xFF));
                frame.push_back(static_cast<uint8_t>(message_size & 0xFF));
            } else {
                frame.push_back(127);
                for (int i = 7; i >= 0; --i) {
                    frame.push_back(static_cast<uint8_t>((message_size >> (i * 8)) & 0xFF));
                }
            }
            
            // 有效载荷
            frame.insert(frame.end(), message.begin(), message.end());
            
            // 发送帧
            auto self = shared_from_this();
            asio::async_write(socket_, asio::buffer(frame),
                [self](std::error_code ec, std::size_t) {
                    if (ec) {
                        std::cerr << "Failed to send message: " << ec.message() << std::endl;
                        self->socket_.close();
                    }
                });
        }
        
        // 发送二进制消息
        void send_binary_message(const std::vector<uint8_t>& data) {
            // 构建WebSocket帧
            std::vector<uint8_t> frame;
            
            // 帧头部
            frame.push_back(0x82); // FIN + BINARY opcode
            
            // 有效载荷长度
            size_t data_size = data.size();
            if (data_size <= 125) {
                frame.push_back(static_cast<uint8_t>(data_size));
            } else if (data_size <= 65535) {
                frame.push_back(126);
                frame.push_back(static_cast<uint8_t>((data_size >> 8) & 0xFF));
                frame.push_back(static_cast<uint8_t>(data_size & 0xFF));
            } else {
                frame.push_back(127);
                for (int i = 7; i >= 0; --i) {
                    frame.push_back(static_cast<uint8_t>((data_size >> (i * 8)) & 0xFF));
                }
            }
            
            // 有效载荷
            frame.insert(frame.end(), data.begin(), data.end());
            
            // 发送帧
            auto self = shared_from_this();
            asio::async_write(socket_, asio::buffer(frame),
                [self](std::error_code ec, std::size_t) {
                    if (ec) {
                        std::cerr << "Failed to send binary message: " << ec.message() << std::endl;
                        self->socket_.close();
                    }
                });
        }
        
        // 处理关闭帧
        void handle_close_frame() {
            std::cout << "Received close frame" << std::endl;
            socket_.close();
        }
        
        // 处理PING帧
        void handle_ping_frame() {
            std::cout << "Received ping frame" << std::endl;
            // 读取PING帧的有效载荷
            // 为简化起见,这里不读取有效载荷
            
            // 发送PONG帧
            send_pong_frame();
        }
        
        // 处理PONG帧
        void handle_pong_frame() {
            std::cout << "Received pong frame" << std::endl;
            // 为简化起见,这里不做任何处理
            read_frame();
        }
        
        // 发送PONG帧
        void send_pong_frame() {
            std::vector<uint8_t> pong_frame = {0x8A, 0x00}; // FIN + PONG opcode, no payload
            
            auto self = shared_from_this();
            asio::async_write(socket_, asio::buffer(pong_frame),
                [self](std::error_code ec, std::size_t) {
                    if (!ec) {
                        self->read_frame();
                    }
                });
        }
        
        tcp::socket socket_;
        asio::streambuf buffer_;
        
        // WebSocket帧相关字段
        uint8_t frame_header_[2];
        uint8_t length_bytes_[8];
        uint8_t masking_key_[4];
        std::vector<uint8_t> payload_;
    };
    
    // 接受新连接
    void do_accept() {
        acceptor_.async_accept(
            [this](std::error_code ec, tcp::socket socket) {
                if (!ec) {
                    std::make_shared<Connection>(std::move(socket))->start();
                }
                do_accept();
            });
    }
    
    tcp::acceptor acceptor_;
};

int main(int argc, char* argv[]) {
    try {
        if (argc != 2) {
            std::cerr << "Usage: websocket_server <port>" << std::endl;
            return 1;
        }
        
        asio::io_context io_context;
        WebSocketServer server(io_context, std::atoi(argv[1]));
        
        std::cout << "WebSocket server running on port " << argv[1] << std::endl;
        std::cout << "Waiting for connections..." << std::endl;
        io_context.run();
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}

2. 完整的WebSocket客户端

下面是一个完整的WebSocket客户端实现:

#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <algorithm>
#include <cctype>
#include <asio.hpp>

using asio::ip::tcp;

// WebSocket客户端类
class WebSocketClient : public std::enable_shared_from_this<WebSocketClient> {
public:
    using MessageHandler = std::function<void(const std::string&)>;
    using BinaryMessageHandler = std::function<void(const std::vector<uint8_t>&)>;
    using ConnectionHandler = std::function<void()>;
    using DisconnectionHandler = std::function<void()>;
    
    WebSocketClient(asio::io_context& io_context)
        : resolver_(io_context), socket_(io_context) {
    }
    
    // 连接到WebSocket服务器
    void connect(const std::string& host, const std::string& port, const std::string& path) {
        host_ = host;
        port_ = port;
        path_ = path;
        
        // 解析服务器地址
        resolver_.async_resolve(host, port, 
            [this](std::error_code ec, tcp::resolver::results_type results) {
                if (!ec) {
                    // 连接到服务器
                    asio::async_connect(socket_, results, 
                        [this](std::error_code ec, tcp::endpoint) {
                            if (!ec) {
                                std::cout << "Connected to " << host_ << ":" << port_ << std::endl;
                                // 执行WebSocket握手
                                perform_handshake();
                            } else {
                                std::cerr << "Connection failed: " << ec.message() << std::endl;
                                if (disconnection_handler_) {
                                    disconnection_handler_();
                                }
                            }
                        });
                } else {
                    std::cerr << "Resolution failed: " << ec.message() << std::endl;
                    if (disconnection_handler_) {
                        disconnection_handler_();
                    }
                }
            });
    }
    
    // 发送文本消息
    void send_text(const std::string& message) {
        if (!is_connected_) {
            std::cerr << "Not connected to WebSocket server" << std::endl;
            return;
        }
        
        // 构建WebSocket帧
        std::vector<uint8_t> frame;
        
        // 帧头部
        frame.push_back(0x81); // FIN + TEXT opcode
        
        // 有效载荷长度
        size_t message_size = message.size();
        if (message_size <= 125) {
            frame.push_back(static_cast<uint8_t>(message_size));
        } else if (message_size <= 65535) {
            frame.push_back(126);
            frame.push_back(static_cast<uint8_t>((message_size >> 8) & 0xFF));
            frame.push_back(static_cast<uint8_t>(message_size & 0xFF));
        } else {
            frame.push_back(127);
            for (int i = 7; i >= 0; --i) {
                frame.push_back(static_cast<uint8_t>((message_size >> (i * 8)) & 0xFF));
            }
        }
        
        // 有效载荷
        frame.insert(frame.end(), message.begin(), message.end());
        
        // 发送帧
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(frame),
            [self](std::error_code ec, std::size_t) {
                if (ec) {
                    std::cerr << "Failed to send message: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    // 发送二进制消息
    void send_binary(const std::vector<uint8_t>& data) {
        if (!is_connected_) {
            std::cerr << "Not connected to WebSocket server" << std::endl;
            return;
        }
        
        // 构建WebSocket帧
        std::vector<uint8_t> frame;
        
        // 帧头部
        frame.push_back(0x82); // FIN + BINARY opcode
        
        // 有效载荷长度
        size_t data_size = data.size();
        if (data_size <= 125) {
            frame.push_back(static_cast<uint8_t>(data_size));
        } else if (data_size <= 65535) {
            frame.push_back(126);
            frame.push_back(static_cast<uint8_t>((data_size >> 8) & 0xFF));
            frame.push_back(static_cast<uint8_t>(data_size & 0xFF));
        } else {
            frame.push_back(127);
            for (int i = 7; i >= 0; --i) {
                frame.push_back(static_cast<uint8_t>((data_size >> (i * 8)) & 0xFF));
            }
        }
        
        // 有效载荷
        frame.insert(frame.end(), data.begin(), data.end());
        
        // 发送帧
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(frame),
            [self](std::error_code ec, std::size_t) {
                if (ec) {
                    std::cerr << "Failed to send binary message: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    // 关闭连接
    void close() {
        if (is_connected_) {
            is_connected_ = false;
            
            // 发送关闭帧
            std::vector<uint8_t> close_frame = {0x88, 0x00}; // FIN + CLOSE opcode, no payload
            asio::error_code ignored_ec;
            asio::write(socket_, asio::buffer(close_frame), ignored_ec);
            
            socket_.shutdown(tcp::socket::shutdown_both, ignored_ec);
            socket_.close(ignored_ec);
            
            std::cout << "WebSocket connection closed" << std::endl;
            
            if (disconnection_handler_) {
                disconnection_handler_();
            }
        }
    }
    
    // 设置消息处理回调
    void set_message_handler(const MessageHandler& handler) {
        message_handler_ = handler;
    }
    
    // 设置二进制消息处理回调
    void set_binary_message_handler(const BinaryMessageHandler& handler) {
        binary_message_handler_ = handler;
    }
    
    // 设置连接成功回调
    void set_connection_handler(const ConnectionHandler& handler) {
        connection_handler_ = handler;
    }
    
    // 设置断开连接回调
    void set_disconnection_handler(const DisconnectionHandler& handler) {
        disconnection_handler_ = handler;
    }
    
    // 发送PING
    void ping() {
        if (is_connected_) {
            std::vector<uint8_t> ping_frame = {0x89, 0x00}; // FIN + PING opcode, no payload
            
            auto self = shared_from_this();
            asio::async_write(socket_, asio::buffer(ping_frame),
                [self](std::error_code ec, std::size_t) {
                    if (ec) {
                        std::cerr << "Failed to send ping: " << ec.message() << std::endl;
                        self->close();
                    }
                });
        }
    }
    
private:
    // 执行WebSocket握手
    void perform_handshake() {
        // 生成随机的Sec-WebSocket-Key
        std::string sec_websocket_key = generate_websocket_key();
        
        // 构建HTTP请求
        std::string request = "GET " + path_ + " HTTP/1.1\r\n";
        request += "Host: " + host_ + ":" + port_ + "\r\n";
        request += "Upgrade: websocket\r\n";
        request += "Connection: Upgrade\r\n";
        request += "Sec-WebSocket-Key: " + sec_websocket_key + "\r\n";
        request += "Sec-WebSocket-Version: 13\r\n";
        request += "\r\n";
        
        // 发送握手请求
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(request),
            [self](std::error_code ec, std::size_t) {
                if (!ec) {
                    // 读取握手响应
                    self->read_handshake_response();
                } else {
                    std::cerr << "Failed to send handshake request: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    // 生成WebSocket Key
    std::string generate_websocket_key() {
        // 生成16字节的随机数据
        std::vector<uint8_t> random_bytes(16);
        std::generate(random_bytes.begin(), random_bytes.end(), []() {
            return static_cast<uint8_t>(rand() % 256);
        });
        
        // Base64编码(简化实现,实际应用应使用标准库或第三方库)
        const std::string base64_chars = 
            "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
            "abcdefghijklmnopqrstuvwxyz"
            "0123456789+/";
        
        std::string encoded;
        int val = 0, valb = -6;
        
        for (unsigned char c : random_bytes) {
            val = (val << 8) + c;
            valb += 8;
            while (valb >= 0) {
                encoded.push_back(base64_chars[(val >> valb) & 0x3F]);
                valb -= 6;
            }
        }
        
        if (valb > -6) encoded.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
        while (encoded.size() % 4) encoded.push_back('=');
        
        return encoded;
    }
    
    // 读取握手响应
    void read_handshake_response() {
        auto self = shared_from_this();
        asio::async_read_until(socket_, response_buffer_, "\r\n\r\n",
            [self](std::error_code ec, std::size_t bytes_transferred) {
                if (!ec) {
                    std::istream response_stream(&self->response_buffer_);
                    std::string status_line;
                    std::getline(response_stream, status_line);
                    
                    // 检查状态码
                    std::istringstream status_stream(status_line);
                    std::string version, status_code_str;
                    status_stream >> version >> status_code_str;
                    
                    int status_code = std::stoi(status_code_str);
                    if (status_code == 101) {
                        // 解析响应头
                        std::map<std::string, std::string> headers;
                        std::string header_line;
                        while (std::getline(response_stream, header_line) && header_line != "\r") {
                            if (!header_line.empty() && header_line.back() == '\r') {
                                header_line.pop_back();
                            }
                            
                            size_t colon_pos = header_line.find(':');
                            if (colon_pos != std::string::npos) {
                                std::string key = header_line.substr(0, colon_pos);
                                std::string value = header_line.substr(colon_pos + 2); // 跳过冒号和空格
                                
                                // 转换键为小写以进行不区分大小写的比较
                                std::string lower_key = key;
                                std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower);
                                headers[lower_key] = value;
                            }
                        }
                        
                        // 验证Upgrade和Connection头
                        bool upgrade_valid = (headers.find("upgrade") != headers.end() && 
                                             headers["upgrade"] == "websocket");
                        bool connection_valid = (headers.find("connection") != headers.end() && 
                                                headers["connection"] == "Upgrade");
                        
                        if (upgrade_valid && connection_valid) {
                            // 握手成功
                            self->is_connected_ = true;
                            std::cout << "WebSocket handshake successful" << std::endl;
                            
                            // 调用连接成功回调
                            if (self->connection_handler_) {
                                self->connection_handler_();
                            }
                            
                            // 开始读取WebSocket帧
                            self->read_frame();
                        } else {
                            std::cerr << "Invalid WebSocket handshake response" << std::endl;
                            self->close();
                        }
                    } else {
                        std::cerr << "WebSocket handshake failed with status code: " << status_code << std::endl;
                        self->close();
                    }
                } else {
                    std::cerr << "Failed to read handshake response: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    // 读取WebSocket帧
    void read_frame() {
        if (!is_connected_) {
            return;
        }
        
        auto self = shared_from_this();
        asio::async_read(socket_, asio::buffer(&frame_header_, 2),
            [self](std::error_code ec, std::size_t) {
                if (!ec) {
                    // 解析帧头部
                    bool fin = (self->frame_header_[0] & 0x80) != 0;
                    uint8_t opcode = self->frame_header_[0] & 0x0F;
                    bool mask = (self->frame_header_[1] & 0x80) != 0;
                    uint64_t payload_length = self->frame_header_[1] & 0x7F;
                    
                    // 处理不同的操作码
                    if (opcode == 0x08) { // 连接关闭帧
                        self->handle_close_frame();
                    } else if (opcode == 0x09) { // PING帧
                        self->handle_ping_frame();
                    } else if (opcode == 0x0A) { // PONG帧
                        self->handle_pong_frame();
                    } else if (opcode == 0x01 || opcode == 0x02) { // 文本或二进制帧
                        self->read_payload_length(payload_length, opcode, fin, mask);
                    } else if (opcode == 0x00) { // 延续帧
                        // 简化实现,不处理延续帧
                        std::cerr << "Continuation frames not supported" << std::endl;
                        self->close();
                    } else {
                        std::cerr << "Unknown WebSocket opcode: " << static_cast<int>(opcode) << std::endl;
                        self->close();
                    }
                } else {
                    std::cerr << "Failed to read frame header: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    // 读取有效载荷长度
    void read_payload_length(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
        auto self = shared_from_this();
        
        if (payload_length == 126) {
            // 16位长度
            asio::async_read(socket_, asio::buffer(length_bytes_, 2),
                [self, opcode, fin, mask](std::error_code ec, std::size_t) {
                    if (!ec) {
                        uint64_t extended_length = (static_cast<uint64_t>(self->length_bytes_[0]) << 8) | 
                                                  static_cast<uint64_t>(self->length_bytes_[1]);
                        self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
                    } else {
                        std::cerr << "Failed to read payload length: " << ec.message() << std::endl;
                        self->close();
                    }
                });
        } else if (payload_length == 127) {
            // 64位长度
            asio::async_read(socket_, asio::buffer(length_bytes_, 8),
                [self, opcode, fin, mask](std::error_code ec, std::size_t) {
                    if (!ec) {
                        uint64_t extended_length = 0;
                        for (int i = 0; i < 8; ++i) {
                            extended_length = (extended_length << 8) | self->length_bytes_[i];
                        }
                        self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
                    } else {
                        std::cerr << "Failed to read payload length: " << ec.message() << std::endl;
                        self->close();
                    }
                });
        } else {
            // 7位长度
            read_masking_key_and_payload(payload_length, opcode, fin, mask);
        }
    }
    
    // 读取掩码键和有效载荷
    void read_masking_key_and_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
        auto self = shared_from_this();
        
        if (mask) {
            // 读取掩码键
            asio::async_read(socket_, asio::buffer(masking_key_, 4),
                [self, payload_length, opcode, fin, mask](std::error_code ec, std::size_t) {
                    if (!ec) {
                        self->read_payload(payload_length, opcode, fin, mask);
                    } else {
                        std::cerr << "Failed to read masking key: " << ec.message() << std::endl;
                        self->close();
                    }
                });
        } else {
            // 没有掩码键
            read_payload(payload_length, opcode, fin, mask);
        }
    }
    
    // 读取有效载荷
    void read_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
        auto self = shared_from_this();
        payload_.resize(payload_length);
        
        asio::async_read(socket_, asio::buffer(payload_),
            [self, opcode, fin, mask](std::error_code ec, std::size_t bytes_transferred) {
                if (!ec) {
                    // 处理掩码
                    if (mask) {
                        for (size_t i = 0; i < bytes_transferred; ++i) {
                            self->payload_[i] ^= self->masking_key_[i % 4];
                        }
                    }
                    
                    // 处理消息
                    self->handle_message(opcode, fin);
                    
                    // 继续读取下一帧
                    self->read_frame();
                } else {
                    std::cerr << "Failed to read payload: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    // 处理WebSocket消息
    void handle_message(uint8_t opcode, bool fin) {
        if (opcode == 0x01) { // 文本消息
            std::string message(payload_.begin(), payload_.end());
            std::cout << "Received text message: " << message << std::endl;
            
            // 调用消息处理回调
            if (message_handler_) {
                message_handler_(message);
            }
        } else if (opcode == 0x02) { // 二进制消息
            std::cout << "Received binary message of size: " << payload_.size() << " bytes" << std::endl;
            
            // 调用二进制消息处理回调
            if (binary_message_handler_) {
                binary_message_handler_(payload_);
            }
        }
    }
    
    // 处理关闭帧
    void handle_close_frame() {
        std::cout << "Received close frame" << std::endl;
        close();
    }
    
    // 处理PING帧
    void handle_ping_frame() {
        std::cout << "Received ping frame" << std::endl;
        // 读取PING帧的有效载荷
        // 为简化起见,这里不读取有效载荷
        
        // 发送PONG帧
        send_pong_frame();
    }
    
    // 处理PONG帧
    void handle_pong_frame() {
        std::cout << "Received pong frame" << std::endl;
        // 为简化起见,这里不做任何处理
        read_frame();
    }
    
    // 发送PONG帧
    void send_pong_frame() {
        std::vector<uint8_t> pong_frame = {0x8A, 0x00}; // FIN + PONG opcode, no payload
        
        auto self = shared_from_this();
        asio::async_write(socket_, asio::buffer(pong_frame),
            [self](std::error_code ec, std::size_t) {
                if (!ec) {
                    self->read_frame();
                } else {
                    std::cerr << "Failed to send pong: " << ec.message() << std::endl;
                    self->close();
                }
            });
    }
    
    tcp::resolver resolver_;
    tcp::socket socket_;
    asio::streambuf response_buffer_;
    bool is_connected_ = false;
    
    std::string host_;
    std::string port_;
    std::string path_;
    
    // WebSocket帧相关字段
    uint8_t frame_header_[2];
    uint8_t length_bytes_[8];
    uint8_t masking_key_[4];
    std::vector<uint8_t> payload_;
    
    // 回调函数
    MessageHandler message_handler_;
    BinaryMessageHandler binary_message_handler_;
    ConnectionHandler connection_handler_;
    DisconnectionHandler disconnection_handler_;
};

// WebSocket客户端示例
int main() {
    try {
        asio::io_context io_context;
        auto client = std::make_shared<WebSocketClient>(io_context);
        
        // 设置回调函数
        client->set_message_handler([](const std::string& message) {
            std::cout << "Message received: " << message << std::endl;
        });
        
        client->set_binary_message_handler([](const std::vector<uint8_t>& data) {
            std::cout << "Binary data received, size: " << data.size() << " bytes" << std::endl;
        });
        
        client->set_connection_handler([]() {
            std::cout << "Connection established successfully" << std::endl;
        });
        
        client->set_disconnection_handler([]() {
            std::cout << "Disconnected from server" << std::endl;
        });
        
        // 连接到WebSocket服务器
        // 注意:这里使用的是示例地址,请替换为实际的WebSocket服务器地址
        client->connect("localhost", "8080", "/ws");
        
        // 创建一个工作对象以防止io_context在没有事件时退出
        asio::executor_work_guard<asio::io_context::executor_type> work = 
            asio::make_work_guard(io_context);
        
        // 在单独的线程中运行io_context
        std::thread io_thread([&io_context]() {
            io_context.run();
        });
        
        // 等待连接建立
        std::this_thread::sleep_for(std::chrono::seconds(1));
        
        // 发送一条文本消息
        client->send_text("Hello, WebSocket Server!");
        
        // 发送一条二进制消息
        std::vector<uint8_t> binary_data = {0x01, 0x02, 0x03, 0x04, 0x05};
        client->send_binary(binary_data);
        
        // 发送PING
        client->ping();
        
        // 等待一段时间,以便接收响应
        std::this_thread::sleep_for(std::chrono::seconds(5));
        
        // 关闭连接
        client->close();
        
        // 停止io_context
        io_context.stop();
        if (io_thread.joinable()) {
            io_thread.join();
        }
        
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }
    
    return 0;
}
Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐