Asio C++零基础入门(十四):Asio C++的网络协议实现
网络协议是网络通信的基础,Asio提供了构建各种网络协议的底层框架。本教程将详细介绍如何使用Asio实现常见的网络协议,包括HTTP、WebSocket、自定义TCP协议和UDP协议。HTTP是现代Web应用程序的基础协议。下面我们将学习如何使用Asio实现HTTP服务器和客户端。以下是一个简单的HTTP服务器实现,它能够处理基本的GET请求:三、自定义TCP协议实现除了标准协议外,Asio还非常
网络协议是网络通信的基础,Asio提供了构建各种网络协议的底层框架。本教程将详细介绍如何使用Asio实现常见的网络协议,包括HTTP、WebSocket、自定义TCP协议和UDP协议。
一、HTTP协议实现
HTTP是现代Web应用程序的基础协议。下面我们将学习如何使用Asio实现HTTP服务器和客户端。
1. 基础HTTP服务器
以下是一个简单的HTTP服务器实现,它能够处理基本的GET请求:
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <asio.hpp>
using asio::ip::tcp;
// HTTP连接类
class HttpConnection : public std::enable_shared_from_this<HttpConnection> {
public:
HttpConnection(tcp::socket socket)
: socket_(std::move(socket)) {
}
// 开始处理连接
void start() {
read_request();
}
private:
// 读取HTTP请求
void read_request() {
auto self = shared_from_this();
asio::async_read_until(socket_, buffer_, "\r\n\r\n",
[self](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
std::istream request_stream(&self->buffer_);
std::string request_line;
std::getline(request_stream, request_line);
// 解析请求行
std::string method, path, version;
std::istringstream request_line_stream(request_line);
request_line_stream >> method >> path >> version;
// 读取请求头
std::string header;
while (std::getline(request_stream, header) && header != "\r") {
// 可以在这里解析请求头
}
// 处理请求
self->handle_request(method, path);
}
});
}
// 处理HTTP请求
void handle_request(const std::string& method, const std::string& path) {
// 构建HTTP响应
std::string response;
if (method == "GET") {
if (path == "/" || path == "/index.html") {
response = "HTTP/1.1 200 OK\r\n";
response += "Content-Type: text/html\r\n";
response += "Content-Length: 45\r\n";
response += "\r\n";
response += "<html><body><h1>Hello, World!</h1></body></html>";
} else if (path == "/about") {
response = "HTTP/1.1 200 OK\r\n";
response += "Content-Type: text/html\r\n";
response += "Content-Length: 52\r\n";
response += "\r\n";
response += "<html><body><h1>About Page</h1></body></html>";
} else {
response = "HTTP/1.1 404 Not Found\r\n";
response += "Content-Type: text/html\r\n";
response += "Content-Length: 51\r\n";
response += "\r\n";
response += "<html><body><h1>404 Page Not Found</h1></body></html>";
}
} else {
response = "HTTP/1.1 405 Method Not Allowed\r\n";
response += "Content-Type: text/html\r\n";
response += "Content-Length: 59\r\n";
response += "\r\n";
response += "<html><body><h1>405 Method Not Allowed</h1></body></html>";
}
// 发送HTTP响应
send_response(response);
}
// 发送HTTP响应
void send_response(const std::string& response) {
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(response),
[self](std::error_code ec, std::size_t) {
if (!ec) {
// 响应发送完成后关闭连接
asio::error_code ignored_ec;
self->socket_.shutdown(tcp::socket::shutdown_both, ignored_ec);
}
});
}
tcp::socket socket_;
asio::streambuf buffer_;
};
// HTTP服务器类
class HttpServer {
public:
HttpServer(asio::io_context& io_context, short port)
: acceptor_(io_context, tcp::endpoint(tcp::v4(), port)) {
do_accept();
}
private:
// 接受新连接
void do_accept() {
acceptor_.async_accept(
[this](std::error_code ec, tcp::socket socket) {
if (!ec) {
std::make_shared<HttpConnection>(std::move(socket))->start();
}
do_accept();
});
}
tcp::acceptor acceptor_;
};
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "Usage: http_server <port>" << std::endl;
return 1;
}
asio::io_context io_context;
HttpServer server(io_context, std::atoi(argv[1]));
std::cout << "HTTP server running on port " << argv[1] << std::endl;
io_context.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
三、自定义TCP协议实现
除了标准协议外,Asio还非常适合实现自定义网络协议。下面我们将学习如何设计和实现一个简单的自定义TCP协议。
1. 自定义协议设计
设计一个名为"SimpleMessageProtocol"的自定义协议,具有以下特点:
- 每个消息都有固定的头部,包含消息长度和消息类型
- 支持文本和二进制两种消息类型
- 使用小端字节序进行数据传输
协议格式如下:
+-------------+-------------+------------------+
| 消息长度(4字节) | 消息类型(1字节) | 消息体 |
+-------------+-------------+------------------+
其中:
- 消息长度:不包括自身,仅表示消息类型和消息体的总长度
- 消息类型:0表示文本消息,1表示二进制消息
- 消息体:根据消息类型包含文本或二进制数据
2. 自定义协议消息类
下面是实现自定义协议消息的类:
#include <iostream>
#include <string>
#include <vector>
#include <cstdint>
#include <algorithm>
#include <asio.hpp>
// 自定义协议消息类
class SimpleMessage {
public:
enum MessageType {
TEXT = 0,
BINARY = 1
};
// 构造文本消息
SimpleMessage(const std::string& text)
: type_(TEXT) {
body_.insert(body_.end(), text.begin(), text.end());
}
// 构造二进制消息
SimpleMessage(const std::vector<uint8_t>& binary_data)
: type_(BINARY), body_(binary_data) {
}
// 从缓冲区解析消息(静态工厂方法)
static std::optional<SimpleMessage> from_buffer(const std::vector<uint8_t>& buffer) {
// 检查缓冲区是否至少包含头部
if (buffer.size() < kHeaderSize) {
return std::nullopt;
}
// 解析消息长度(小端字节序)
uint32_t body_size = 0;
for (int i = 0; i < 4; ++i) {
body_size |= static_cast<uint32_t>(buffer[i]) << (i * 8);
}
// 检查完整消息是否可用
if (buffer.size() < kHeaderSize + body_size) {
return std::nullopt;
}
// 解析消息类型
MessageType type = static_cast<MessageType>(buffer[4]);
// 解析消息体
std::vector<uint8_t> body(buffer.begin() + kHeaderSize, buffer.begin() + kHeaderSize + body_size);
if (type == TEXT) {
std::string text(body.begin(), body.end());
return SimpleMessage(text);
} else {
return SimpleMessage(body);
}
}
// 序列化为字节流
std::vector<uint8_t> serialize() const {
std::vector<uint8_t> serialized;
// 计算消息总长度(不包括长度字段本身)
uint32_t total_size = 1 + body_.size(); // 1字节类型 + 消息体长度
// 添加长度字段(小端字节序)
for (int i = 0; i < 4; ++i) {
serialized.push_back(static_cast<uint8_t>((total_size >> (i * 8)) & 0xFF));
}
// 添加消息类型
serialized.push_back(static_cast<uint8_t>(type_));
// 添加消息体
serialized.insert(serialized.end(), body_.begin(), body_.end());
return serialized;
}
// 获取消息类型
MessageType get_type() const {
return type_;
}
// 获取文本内容(如果是文本消息)
std::string get_text() const {
if (type_ == TEXT) {
return std::string(body_.begin(), body_.end());
}
return "";
}
// 获取二进制内容(如果是二进制消息)
std::vector<uint8_t> get_binary() const {
if (type_ == BINARY) {
return body_;
}
return {};
}
// 获取消息总长度(序列化后的长度)
size_t get_serialized_size() const {
return kHeaderSize + body_.size();
}
// 获取头部大小
static constexpr size_t kHeaderSize = 5; // 4字节长度 + 1字节类型
private:
// 私有构造函数,用于from_buffer方法
SimpleMessage(MessageType type, const std::vector<uint8_t>& body)
: type_(type), body_(body) {
}
MessageType type_;
std::vector<uint8_t> body_;
};
3. 自定义协议服务器
下面是实现自定义协议服务器的代码:
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <cstdint>
#include <asio.hpp>
// SimpleMessage类的定义(如前所述)
// ...
// 连接管理类
class Connection : public std::enable_shared_from_this<Connection> {
public:
Connection(asio::io_context& io_context)
: socket_(io_context) {
}
// 获取socket引用
asio::ip::tcp::socket& socket() {
return socket_;
}
// 开始处理连接
void start() {
read_header();
}
// 发送消息
void send_message(const SimpleMessage& message) {
auto self = shared_from_this();
// 序列化消息
std::vector<uint8_t> serialized = message.serialize();
// 将消息添加到发送队列
bool write_in_progress = !write_queue_.empty();
write_queue_.push_back(serialized);
// 如果当前没有写入操作正在进行,则开始写入
if (!write_in_progress) {
do_write();
}
}
// 设置消息接收回调
void set_message_handler(std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> handler) {
message_handler_ = handler;
}
// 设置连接关闭回调
void set_close_handler(std::function<void(std::shared_ptr<Connection>)> handler) {
close_handler_ = handler;
}
// 关闭连接
void close() {
asio::error_code ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ec);
socket_.close(ec);
}
private:
// 读取消息头部
void read_header() {
auto self = shared_from_this();
asio::async_read(socket_, asio::buffer(read_buffer_, SimpleMessage::kHeaderSize),
[self](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
// 尝试从头部解析消息长度
uint32_t body_size = 0;
for (int i = 0; i < 4; ++i) {
body_size |= static_cast<uint32_t>(self->read_buffer_[i]) << (i * 8);
}
// 调整读取缓冲区大小以容纳完整消息
self->read_buffer_.resize(SimpleMessage::kHeaderSize + body_size);
// 读取消息体
self->read_body(body_size);
} else {
// 发生错误,关闭连接
if (self->close_handler_) {
self->close_handler_(self);
}
}
});
}
// 读取消息体
void read_body(uint32_t body_size) {
auto self = shared_from_this();
asio::async_read(socket_,
asio::buffer(read_buffer_.data() + SimpleMessage::kHeaderSize, body_size),
[self](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
// 尝试解析完整消息
auto message = SimpleMessage::from_buffer(self->read_buffer_);
if (message) {
// 调用消息处理回调
if (self->message_handler_) {
self->message_handler_(self, *message);
}
}
// 准备读取下一条消息
self->read_buffer_.resize(SimpleMessage::kHeaderSize);
self->read_header();
} else {
// 发生错误,关闭连接
if (self->close_handler_) {
self->close_handler_(self);
}
}
});
}
// 写入消息
void do_write() {
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(write_queue_.front()),
[self](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
// 从队列中移除已发送的消息
self->write_queue_.pop_front();
// 如果还有消息要发送,则继续
if (!self->write_queue_.empty()) {
self->do_write();
}
} else {
// 发生错误,关闭连接
if (self->close_handler_) {
self->close_handler_(self);
}
}
});
}
asio::ip::tcp::socket socket_;
std::vector<uint8_t> read_buffer_ = std::vector<uint8_t>(SimpleMessage::kHeaderSize);
std::deque<std::vector<uint8_t>> write_queue_;
std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> message_handler_;
std::function<void(std::shared_ptr<Connection>)> close_handler_;
};
// 自定义协议服务器类
class SimpleMessageServer {
public:
SimpleMessageServer(asio::io_context& io_context, short port)
: io_context_(io_context),
acceptor_(io_context, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), port)) {
do_accept();
}
// 设置消息处理回调
void set_message_handler(std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> handler) {
message_handler_ = handler;
}
// 设置连接处理回调
void set_connection_handler(std::function<void(std::shared_ptr<Connection>)> handler) {
connection_handler_ = handler;
}
// 设置断开连接处理回调
void set_disconnection_handler(std::function<void(std::shared_ptr<Connection>)> handler) {
disconnection_handler_ = handler;
}
private:
// 接受新连接
void do_accept() {
auto connection = std::make_shared<Connection>(io_context_);
acceptor_.async_accept(connection->socket(),
[this, connection](std::error_code ec) {
if (!ec) {
// 设置连接的回调函数
connection->set_message_handler(message_handler_);
connection->set_close_handler([this](std::shared_ptr<Connection> conn) {
if (disconnection_handler_) {
disconnection_handler_(conn);
}
// 从连接列表中移除
connections_.erase(std::remove(connections_.begin(), connections_.end(), conn),
connections_.end());
});
// 将连接添加到列表
connections_.push_back(connection);
// 调用连接回调
if (connection_handler_) {
connection_handler_(connection);
}
// 开始处理连接
connection->start();
}
// 继续接受下一个连接
do_accept();
});
}
asio::io_context& io_context_;
asio::ip::tcp::acceptor acceptor_;
std::vector<std::shared_ptr<Connection>> connections_;
std::function<void(std::shared_ptr<Connection>, const SimpleMessage&)> message_handler_;
std::function<void(std::shared_ptr<Connection>)> connection_handler_;
std::function<void(std::shared_ptr<Connection>)> disconnection_handler_;
};
// 服务器示例
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "Usage: simple_message_server <port>" << std::endl;
return 1;
}
asio::io_context io_context;
SimpleMessageServer server(io_context, std::atoi(argv[1]));
// 设置消息处理回调
server.set_message_handler([](std::shared_ptr<Connection> connection, const SimpleMessage& message) {
std::cout << "Received message of type: " << (message.get_type() == SimpleMessage::TEXT ? "TEXT" : "BINARY") << std::endl;
if (message.get_type() == SimpleMessage::TEXT) {
std::cout << "Message content: " << message.get_text() << std::endl;
// 回显文本消息
std::string response_text = "Server received: " + message.get_text();
connection->send_message(SimpleMessage(response_text));
} else {
std::cout << "Binary message size: " << message.get_binary().size() << " bytes" << std::endl;
// 回显二进制消息
std::vector<uint8_t> response_data = message.get_binary();
connection->send_message(SimpleMessage(response_data));
}
});
// 设置连接处理回调
server.set_connection_handler([](std::shared_ptr<Connection> connection) {
std::cout << "New connection established" << std::endl;
// 发送欢迎消息
connection->send_message(SimpleMessage("Welcome to Simple Message Protocol Server!"));
});
// 设置断开连接处理回调
server.set_disconnection_handler([](std::shared_ptr<Connection>) {
std::cout << "Connection closed" << std::endl;
});
std::cout << "Simple Message Protocol Server running on port " << argv[1] << std::endl;
io_context.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
4. 自定义协议客户端
下面是实现自定义协议客户端的代码:
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <cstdint>
#include <asio.hpp>
// SimpleMessage类的定义(如前所述)
// ...
// 自定义协议客户端类
class SimpleMessageClient : public std::enable_shared_from_this<SimpleMessageClient> {
public:
SimpleMessageClient(asio::io_context& io_context)
: io_context_(io_context),
socket_(io_context),
resolver_(io_context) {
}
// 连接到服务器
void connect(const std::string& host, const std::string& port) {
auto self = shared_from_this();
resolver_.async_resolve(host, port,
[self](std::error_code ec, asio::ip::tcp::resolver::results_type results) {
if (!ec) {
asio::async_connect(self->socket_, results,
[self](std::error_code ec, asio::ip::tcp::endpoint) {
if (!ec) {
std::cout << "Connected to server" << std::endl;
// 连接成功,调用回调
if (self->connection_handler_) {
self->connection_handler_();
}
// 开始接收消息
self->read_header();
} else {
std::cerr << "Connection failed: " << ec.message() << std::endl;
if (self->error_handler_) {
self->error_handler_(ec);
}
}
});
} else {
std::cerr << "Resolution failed: " << ec.message() << std::endl;
if (self->error_handler_) {
self->error_handler_(ec);
}
}
});
}
// 发送文本消息
void send_text(const std::string& text) {
auto self = shared_from_this();
io_context_.post([self, text]() {
SimpleMessage message(text);
self->send_message(message);
});
}
// 发送二进制消息
void send_binary(const std::vector<uint8_t>& binary_data) {
auto self = shared_from_this();
io_context_.post([self, binary_data]() {
SimpleMessage message(binary_data);
self->send_message(message);
});
}
// 关闭连接
void close() {
auto self = shared_from_this();
io_context_.post([self]() {
asio::error_code ec;
self->socket_.shutdown(asio::ip::tcp::socket::shutdown_both, ec);
self->socket_.close(ec);
if (self->disconnection_handler_) {
self->disconnection_handler_();
}
});
}
// 设置消息处理回调
void set_message_handler(std::function<void(const SimpleMessage&)> handler) {
message_handler_ = handler;
}
// 设置连接处理回调
void set_connection_handler(std::function<void()> handler) {
connection_handler_ = handler;
}
// 设置断开连接处理回调
void set_disconnection_handler(std::function<void()> handler) {
disconnection_handler_ = handler;
}
// 设置错误处理回调
void set_error_handler(std::function<void(const std::error_code&)> handler) {
error_handler_ = handler;
}
private:
// 发送消息(内部方法)
void send_message(const SimpleMessage& message) {
bool write_in_progress = !write_queue_.empty();
write_queue_.push_back(message.serialize());
if (!write_in_progress) {
do_write();
}
}
// 写入消息
void do_write() {
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(write_queue_.front()),
[self](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
self->write_queue_.pop_front();
if (!self->write_queue_.empty()) {
self->do_write();
}
} else {
std::cerr << "Write error: " << ec.message() << std::endl;
self->close();
if (self->error_handler_) {
self->error_handler_(ec);
}
}
});
}
// 读取消息头部
void read_header() {
auto self = shared_from_this();
asio::async_read(socket_, asio::buffer(read_buffer_, SimpleMessage::kHeaderSize),
[self](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
// 解析消息长度
uint32_t body_size = 0;
for (int i = 0; i < 4; ++i) {
body_size |= static_cast<uint32_t>(self->read_buffer_[i]) << (i * 8);
}
// 调整缓冲区大小
self->read_buffer_.resize(SimpleMessage::kHeaderSize + body_size);
// 读取消息体
self->read_body(body_size);
} else {
std::cerr << "Read header error: " << ec.message() << std::endl;
self->close();
if (self->error_handler_) {
self->error_handler_(ec);
}
}
});
}
// 读取消息体
void read_body(uint32_t body_size) {
auto self = shared_from_this();
asio::async_read(socket_,
asio::buffer(read_buffer_.data() + SimpleMessage::kHeaderSize, body_size),
[self](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
// 解析消息
auto message = SimpleMessage::from_buffer(self->read_buffer_);
if (message && self->message_handler_) {
self->message_handler_(*message);
}
// 准备读取下一条消息
self->read_buffer_.resize(SimpleMessage::kHeaderSize);
self->read_header();
} else {
std::cerr << "Read body error: " << ec.message() << std::endl;
self->close();
if (self->error_handler_) {
self->error_handler_(ec);
}
}
});
}
asio::io_context& io_context_;
asio::ip::tcp::socket socket_;
asio::ip::tcp::resolver resolver_;
std::vector<uint8_t> read_buffer_ = std::vector<uint8_t>(SimpleMessage::kHeaderSize);
std::deque<std::vector<uint8_t>> write_queue_;
std::function<void(const SimpleMessage&)> message_handler_;
std::function<void()> connection_handler_;
std::function<void()> disconnection_handler_;
std::function<void(const std::error_code&)> error_handler_;
};
// 客户端示例
int main() {
try {
asio::io_context io_context;
auto client = std::make_shared<SimpleMessageClient>(io_context);
// 设置回调函数
client->set_message_handler([](const SimpleMessage& message) {
std::cout << "Received message of type: " << (message.get_type() == SimpleMessage::TEXT ? "TEXT" : "BINARY") << std::endl;
if (message.get_type() == SimpleMessage::TEXT) {
std::cout << "Message content: " << message.get_text() << std::endl;
} else {
std::cout << "Binary message size: " << message.get_binary().size() << " bytes" << std::endl;
}
});
client->set_connection_handler([]() {
std::cout << "Connection established successfully" << std::endl;
});
client->set_disconnection_handler([]() {
std::cout << "Disconnected from server" << std::endl;
});
client->set_error_handler([](const std::error_code& ec) {
std::cerr << "Error: " << ec.message() << std::endl;
});
// 连接到服务器
client->connect("localhost", "8080");
// 创建一个工作对象以防止io_context在没有事件时退出
asio::executor_work_guard<asio::io_context::executor_type> work =
asio::make_work_guard(io_context);
// 在单独的线程中运行io_context
std::thread io_thread([&io_context]() {
io_context.run();
});
// 等待连接建立
std::this_thread::sleep_for(std::chrono::seconds(1));
// 发送文本消息
client->send_text("Hello, Simple Message Protocol Server!");
// 发送二进制消息
std::vector<uint8_t> binary_data = {0x01, 0x02, 0x03, 0x04, 0x05};
client->send_binary(binary_data);
// 等待接收响应
std::this_thread::sleep_for(std::chrono::seconds(5));
// 关闭连接
client->close();
// 停止io_context
io_context.stop();
if (io_thread.joinable()) {
io_thread.join();
}
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
四、UDP协议实现
UDP是一种无连接的传输协议,适用于对实时性要求高但对可靠性要求相对较低的场景。下面我们将学习如何使用Asio实现UDP服务器和客户端。
1. 基础UDP回显服务器
以下是一个简单的UDP回显服务器实现:
#include <iostream>
#include <string>
#include <asio.hpp>
using asio::ip::udp;
class UDPServer {
public:
UDPServer(asio::io_context& io_context, short port)
: socket_(io_context, udp::endpoint(udp::v4(), port)) {
do_receive();
}
private:
// 接收数据
void do_receive() {
socket_.async_receive_from(
asio::buffer(data_, max_length), sender_endpoint_,
[this](std::error_code ec, std::size_t bytes_recvd) {
if (!ec && bytes_recvd > 0) {
std::cout << "Received " << bytes_recvd << " bytes from "
<< sender_endpoint_.address().to_string() << ":"
<< sender_endpoint_.port() << std::endl;
// 回显收到的数据
do_send(bytes_recvd);
}
// 继续接收下一个数据包
do_receive();
});
}
// 发送数据
void do_send(std::size_t length) {
socket_.async_send_to(
asio::buffer(data_, length), sender_endpoint_,
[this](std::error_code ec, std::size_t /*bytes_sent*/) {
if (ec) {
std::cerr << "Error sending response: " << ec.message() << std::endl;
}
});
}
udp::socket socket_;
udp::endpoint sender_endpoint_;
enum { max_length = 1024 };
char data_[max_length];
};
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "Usage: udp_server <port>" << std::endl;
return 1;
}
asio::io_context io_context;
UDPServer server(io_context, std::atoi(argv[1]));
std::cout << "UDP echo server running on port " << argv[1] << std::endl;
std::cout << "Press Ctrl+C to exit" << std::endl;
io_context.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
2. 基础UDP客户端
以下是一个简单的UDP客户端实现:
#include <iostream>
#include <string>
#include <asio.hpp>
using asio::ip::udp;
int main(int argc, char* argv[]) {
try {
if (argc != 3) {
std::cerr << "Usage: udp_client <host> <port>" << std::endl;
return 1;
}
asio::io_context io_context;
// 解析服务器地址
udp::resolver resolver(io_context);
udp::endpoint endpoint = *resolver.resolve(udp::v4(), argv[1], argv[2]).begin();
// 创建socket并连接到服务器
udp::socket socket(io_context);
socket.open(udp::v4());
// 发送消息到服务器
std::string message = "Hello, UDP Server!";
std::cout << "Sending: " << message << std::endl;
socket.send_to(asio::buffer(message), endpoint);
// 等待并接收服务器响应
udp::endpoint sender_endpoint;
enum { max_length = 1024 };
char reply[max_length];
std::size_t reply_length = socket.receive_from(
asio::buffer(reply, max_length), sender_endpoint);
std::cout << "Received reply: ";
std::cout.write(reply, reply_length);
std::cout << std::endl;
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
3. 高级UDP协议实现
下面是一个更复杂的UDP协议实现,包含以下特性:
- 消息封装(头部+数据)
- 简单的错误检测
- 超时重传机制
- 异步处理模型
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <map>
#include <cstdint>
#include <chrono>
#include <asio.hpp>
using asio::ip::udp;
using asio::steady_timer;
// 自定义UDP消息类
class UdpMessage {
public:
// 构造函数
UdpMessage(uint16_t message_id, const std::vector<uint8_t>& data)
: message_id_(message_id), data_(data) {
calculate_checksum();
}
// 从原始数据构造
static std::optional<UdpMessage> from_raw(const std::vector<uint8_t>& raw_data) {
// 检查数据长度是否足够
if (raw_data.size() < kHeaderSize) {
return std::nullopt;
}
// 解析头部
uint16_t message_id = (static_cast<uint16_t>(raw_data[0]) << 8) | raw_data[1];
uint16_t data_size = (static_cast<uint16_t>(raw_data[2]) << 8) | raw_data[3];
uint16_t checksum = (static_cast<uint16_t>(raw_data[4]) << 8) | raw_data[5];
// 检查数据长度是否匹配
if (raw_data.size() != kHeaderSize + data_size) {
return std::nullopt;
}
// 提取数据部分
std::vector<uint8_t> data(raw_data.begin() + kHeaderSize, raw_data.end());
// 计算并验证校验和
UdpMessage message(message_id, data);
if (message.get_checksum() != checksum) {
return std::nullopt;
}
return message;
}
// 序列化为原始数据
std::vector<uint8_t> to_raw() const {
std::vector<uint8_t> raw;
// 添加消息ID
raw.push_back(static_cast<uint8_t>((message_id_ >> 8) & 0xFF));
raw.push_back(static_cast<uint8_t>(message_id_ & 0xFF));
// 添加数据大小
raw.push_back(static_cast<uint8_t>((data_.size() >> 8) & 0xFF));
raw.push_back(static_cast<uint8_t>(data_.size() & 0xFF));
// 添加校验和
raw.push_back(static_cast<uint8_t>((checksum_ >> 8) & 0xFF));
raw.push_back(static_cast<uint8_t>(checksum_ & 0xFF));
// 添加数据
raw.insert(raw.end(), data_.begin(), data_.end());
return raw;
}
// 获取消息ID
uint16_t get_message_id() const { return message_id_; }
// 获取数据
const std::vector<uint8_t>& get_data() const { return data_; }
// 获取校验和
uint16_t get_checksum() const { return checksum_; }
// 头部大小
static constexpr size_t kHeaderSize = 6; // ID(2) + Size(2) + Checksum(2)
private:
// 计算校验和
void calculate_checksum() {
checksum_ = 0;
// 添加消息ID到校验和
checksum_ += message_id_;
// 添加数据大小到校验和
checksum_ += static_cast<uint16_t>(data_.size());
// 添加数据到校验和
for (size_t i = 0; i < data_.size(); i += 2) {
uint16_t word = 0;
word |= static_cast<uint16_t>(data_[i]) << 8;
if (i + 1 < data_.size()) {
word |= data_[i + 1];
}
checksum_ += word;
}
// 折叠校验和(处理溢出)
while (checksum_ >> 16) {
checksum_ = (checksum_ & 0xFFFF) + (checksum_ >> 16);
}
// 取反
checksum_ = ~checksum_;
}
uint16_t message_id_;
std::vector<uint8_t> data_;
uint16_t checksum_;
};
// 高级UDP客户端类
class ReliableUdpClient : public std::enable_shared_from_this<ReliableUdpClient> {
public:
ReliableUdpClient(asio::io_context& io_context)
: io_context_(io_context),
socket_(io_context, udp::v4()),
resolver_(io_context),
next_message_id_(1),
timeout_(500) { // 默认超时时间500ms
socket_.set_option(udp::socket::reuse_address(true));
}
// 连接到服务器
void connect(const std::string& host, const std::string& port) {
server_endpoint_ = *resolver_.resolve(udp::v4(), host, port).begin();
// 开始接收数据
do_receive();
std::cout << "Connected to UDP server at " << server_endpoint_.address().to_string()
<< ":" << server_endpoint_.port() << std::endl;
}
// 发送消息(带确认和重传)
void send_message(const std::vector<uint8_t>& data,
std::function<void(bool)> completion_handler = nullptr) {
auto self = shared_from_this();
uint16_t message_id = next_message_id_++;
// 创建消息
UdpMessage message(message_id, data);
std::vector<uint8_t> raw_message = message.to_raw();
// 保存消息和回调
OutgoingMessage outgoing;
outgoing.raw_data = raw_message;
outgoing.attempts = 0;
outgoing.max_attempts = 3;
outgoing.completion_handler = completion_handler;
outgoing_messages_[message_id] = outgoing;
// 发送消息
send_with_retry(message_id);
}
// 设置消息接收回调
void set_message_handler(std::function<void(const UdpMessage&)> handler) {
message_handler_ = handler;
}
// 设置超时时间(毫秒)
void set_timeout(uint32_t milliseconds) {
timeout_ = milliseconds;
}
// 关闭客户端
void close() {
socket_.close();
// 取消所有定时器
for (auto& pair : outgoing_messages_) {
if (pair.second.timer) {
pair.second.timer->cancel();
}
}
}
private:
// 待发送消息结构
struct OutgoingMessage {
std::vector<uint8_t> raw_data;
int attempts;
int max_attempts;
std::shared_ptr<steady_timer> timer;
std::function<void(bool)> completion_handler;
};
// 发送消息并重试
void send_with_retry(uint16_t message_id) {
auto self = shared_from_this();
auto it = outgoing_messages_.find(message_id);
if (it == outgoing_messages_.end()) {
return; // 消息已被处理
}
OutgoingMessage& outgoing = it->second;
// 检查重传次数
if (outgoing.attempts >= outgoing.max_attempts) {
std::cout << "Message " << message_id << " failed after "
<< outgoing.attempts << " attempts" << std::endl;
// 调用完成回调,指示失败
if (outgoing.completion_handler) {
outgoing.completion_handler(false);
}
// 移除消息
outgoing_messages_.erase(it);
return;
}
// 增加尝试次数
outgoing.attempts++;
std::cout << "Sending message " << message_id << " (attempt "
<< outgoing.attempts << ")" << std::endl;
// 发送消息
socket_.async_send_to(
asio::buffer(outgoing.raw_data), server_endpoint_,
[self, message_id](std::error_code ec, std::size_t bytes_sent) {
if (!ec) {
// 设置超时定时器
self->setup_retry_timer(message_id);
} else {
std::cerr << "Error sending message: " << ec.message() << std::endl;
// 立即重试
self->send_with_retry(message_id);
}
});
}
// 设置重传定时器
void setup_retry_timer(uint16_t message_id) {
auto self = shared_from_this();
auto it = outgoing_messages_.find(message_id);
if (it == outgoing_messages_.end()) {
return;
}
// 创建或重置定时器
if (!it->second.timer) {
it->second.timer = std::make_shared<steady_timer>(io_context_);
}
it->second.timer->expires_after(std::chrono::milliseconds(timeout_));
it->second.timer->async_wait(
[self, message_id](std::error_code ec) {
if (!ec) { // 定时器未被取消
self->send_with_retry(message_id);
}
});
}
// 接收数据
void do_receive() {
auto self = shared_from_this();
socket_.async_receive_from(
asio::buffer(recv_buffer_), sender_endpoint_,
[this, self](std::error_code ec, std::size_t bytes_recvd) {
if (!ec && bytes_recvd > 0) {
// 将接收到的数据转换为vector
std::vector<uint8_t> raw_data(recv_buffer_.begin(), recv_buffer_.begin() + bytes_recvd);
// 尝试解析消息
auto message = UdpMessage::from_raw(raw_data);
if (message) {
// 检查是否是确认消息(假设消息ID为0表示确认)
if (message->get_message_id() == 0) {
handle_acknowledgment(*message);
} else {
// 处理普通消息
handle_message(*message);
}
}
}
// 继续接收下一个数据包
do_receive();
});
}
// 处理确认消息
void handle_acknowledgment(const UdpMessage& message) {
// 假设确认消息的数据部分包含被确认的消息ID
if (message.get_data().size() >= 2) {
uint16_t acknowledged_id = (static_cast<uint16_t>(message.get_data()[0]) << 8) | message.get_data()[1];
auto it = outgoing_messages_.find(acknowledged_id);
if (it != outgoing_messages_.end()) {
std::cout << "Message " << acknowledged_id << " acknowledged" << std::endl;
// 取消定时器
if (it->second.timer) {
it->second.timer->cancel();
}
// 调用完成回调,指示成功
if (it->second.completion_handler) {
it->second.completion_handler(true);
}
// 移除消息
outgoing_messages_.erase(it);
}
}
}
// 处理接收到的消息
void handle_message(const UdpMessage& message) {
std::cout << "Received message with ID: " << message.get_message_id()
<< ", data size: " << message.get_data().size() << " bytes" << std::endl;
// 发送确认消息
std::vector<uint8_t> ack_data;
ack_data.push_back(static_cast<uint8_t>((message.get_message_id() >> 8) & 0xFF));
ack_data.push_back(static_cast<uint8_t>(message.get_message_id() & 0xFF));
UdpMessage ack_message(0, ack_data); // ID为0表示确认消息
std::vector<uint8_t> raw_ack = ack_message.to_raw();
socket_.async_send_to(
asio::buffer(raw_ack), server_endpoint_,
[](std::error_code /*ec*/, std::size_t /*bytes_sent*/) {
// 忽略发送确认的错误
});
// 调用消息处理回调
if (message_handler_) {
message_handler_(message);
}
}
asio::io_context& io_context_;
udp::socket socket_;
udp::resolver resolver_;
udp::endpoint server_endpoint_;
udp::endpoint sender_endpoint_;
std::vector<uint8_t> recv_buffer_{4096}; // 接收缓冲区
std::map<uint16_t, OutgoingMessage> outgoing_messages_; // 待发送消息
uint16_t next_message_id_; // 下一个消息ID
uint32_t timeout_; // 超时时间(毫秒)
std::function<void(const UdpMessage&)> message_handler_; // 消息处理回调
};
// 客户端示例
int main() {
try {
asio::io_context io_context;
auto client = std::make_shared<ReliableUdpClient>(io_context);
// 设置消息接收回调
client->set_message_handler([](const UdpMessage& message) {
std::cout << "Received message content: ";
for (uint8_t byte : message.get_data()) {
std::cout << static_cast<char>(byte);
}
std::cout << std::endl;
});
// 设置超时时间(毫秒)
client->set_timeout(1000);
// 连接到服务器
client->connect("localhost", "8080");
// 创建一个工作对象以防止io_context在没有事件时退出
asio::executor_work_guard<asio::io_context::executor_type> work =
asio::make_work_guard(io_context);
// 在单独的线程中运行io_context
std::thread io_thread([&io_context]() {
io_context.run();
});
// 发送一些测试消息
std::string text_message = "Hello, Reliable UDP Server!";
std::vector<uint8_t> message_data(text_message.begin(), text_message.end());
for (int i = 0; i < 3; ++i) {
client->send_message(message_data, [i](bool success) {
std::cout << "Message " << (i+1) << " send "
<< (success ? "succeeded" : "failed") << std::endl;
});
// 间隔发送消息
std::this_thread::sleep_for(std::chrono::seconds(2));
}
// 等待一段时间
std::this_thread::sleep_for(std::chrono::seconds(5));
// 关闭客户端
client->close();
// 停止io_context
io_context.stop();
if (io_thread.joinable()) {
io_thread.join();
}
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
4. 高级UDP服务器实现
下面是与上面客户端对应的高级UDP服务器实现:
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <map>
#include <cstdint>
#include <asio.hpp>
using asio::ip::udp;
// UdpMessage类的定义(如前所述)
// ...
// 高级UDP服务器类
class ReliableUdpServer {
public:
ReliableUdpServer(asio::io_context& io_context, short port)
: io_context_(io_context),
socket_(io_context, udp::endpoint(udp::v4(), port)),
next_message_id_(1) {
socket_.set_option(udp::socket::reuse_address(true));
do_receive();
}
// 设置消息处理回调
void set_message_handler(std::function<void(const udp::endpoint&, const UdpMessage&)> handler) {
message_handler_ = handler;
}
private:
// 接收数据
void do_receive() {
auto self = shared_from_this();
socket_.async_receive_from(
asio::buffer(recv_buffer_), remote_endpoint_,
[this](std::error_code ec, std::size_t bytes_recvd) {
if (!ec && bytes_recvd > 0) {
// 将接收到的数据转换为vector
std::vector<uint8_t> raw_data(recv_buffer_.begin(), recv_buffer_.begin() + bytes_recvd);
// 尝试解析消息
auto message = UdpMessage::from_raw(raw_data);
if (message) {
// 检查是否是确认消息(假设消息ID为0表示确认)
if (message->get_message_id() == 0) {
handle_acknowledgment(remote_endpoint_, *message);
} else {
// 处理普通消息
handle_message(remote_endpoint_, *message);
}
}
}
// 继续接收下一个数据包
do_receive();
});
}
// 处理确认消息
void handle_acknowledgment(const udp::endpoint& sender, const UdpMessage& message) {
// 在此实现中,服务器可能不需要处理确认消息
// 因为服务器发送的消息可能不需要客户端确认
// 但如果需要,可以按照与客户端类似的方式实现
}
// 处理接收到的消息
void handle_message(const udp::endpoint& sender, const UdpMessage& message) {
std::cout << "Received message from " << sender.address().to_string()
<< ":" << sender.port() << " with ID: " << message.get_message_id()
<< ", data size: " << message.get_data().size() << " bytes" << std::endl;
// 发送确认消息
std::vector<uint8_t> ack_data;
ack_data.push_back(static_cast<uint8_t>((message.get_message_id() >> 8) & 0xFF));
ack_data.push_back(static_cast<uint8_t>(message.get_message_id() & 0xFF));
UdpMessage ack_message(0, ack_data); // ID为0表示确认消息
std::vector<uint8_t> raw_ack = ack_message.to_raw();
socket_.async_send_to(
asio::buffer(raw_ack), sender,
[](std::error_code /*ec*/, std::size_t /*bytes_sent*/) {
// 忽略发送确认的错误
});
// 调用消息处理回调
if (message_handler_) {
message_handler_(sender, message);
}
}
// 向客户端发送消息
void send_message(const udp::endpoint& client, const std::vector<uint8_t>& data) {
uint16_t message_id = next_message_id_++;
UdpMessage message(message_id, data);
std::vector<uint8_t> raw_message = message.to_raw();
socket_.async_send_to(
asio::buffer(raw_message), client,
[message_id](std::error_code ec, std::size_t bytes_sent) {
if (!ec) {
std::cout << "Sent message with ID: " << message_id
<< ", size: " << bytes_sent << " bytes" << std::endl;
} else {
std::cerr << "Error sending message: " << ec.message() << std::endl;
}
});
}
asio::io_context& io_context_;
udp::socket socket_;
udp::endpoint remote_endpoint_;
std::vector<uint8_t> recv_buffer_{4096}; // 接收缓冲区
uint16_t next_message_id_; // 下一个消息ID
std::function<void(const udp::endpoint&, const UdpMessage&)> message_handler_; // 消息处理回调
};
// 服务器示例
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "Usage: reliable_udp_server <port>" << std::endl;
return 1;
}
asio::io_context io_context;
ReliableUdpServer server(io_context, std::atoi(argv[1]));
// 设置消息处理回调
server.set_message_handler([&server](const udp::endpoint& client, const UdpMessage& message) {
// 处理接收到的消息
std::cout << "Received message content: ";
for (uint8_t byte : message.get_data()) {
std::cout << static_cast<char>(byte);
}
std::cout << std::endl;
// 回显消息
std::vector<uint8_t> response_data(message.get_data().begin(), message.get_data().end());
server.send_message(client, response_data);
});
std::cout << "Reliable UDP server running on port " << argv[1] << std::endl;
std::cout << "Press Ctrl+C to exit" << std::endl;
io_context.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
五、协议实现最佳实践
在使用Asio实现网络协议时,以下是一些最佳实践:
-
错误处理
- 始终检查异步操作返回的错误码
- 为网络错误提供适当的恢复机制
- 使用异常处理来捕获不可恢复的错误
-
内存管理
- 使用智能指针管理长期存在的对象
- 避免不必要的内存分配和复制
- 使用预分配的缓冲区来减少内存碎片
-
性能优化
- 使用
asio::streambuf来高效处理可变大小的数据 - 对大数据使用零拷贝技术
- 合理设置缓冲区大小以平衡内存使用和性能
- 使用
-
安全性考虑
- 验证所有输入数据的长度和格式
- 实施适当的认证和授权机制
- 保护敏感数据的传输(使用SSL/TLS)
-
可扩展性设计
- 将协议逻辑与业务逻辑分离
- 使用回调或观察者模式处理事件
- 设计模块化的组件以支持未来的扩展
-
测试策略
- 为协议实现单元测试
- 使用模拟对象测试网络交互
- 进行性能测试以确定瓶颈
六、总结
本教程详细介绍了如何使用Asio实现各种网络协议,包括:
- HTTP协议:实现了基本和高级的HTTP服务器和客户端,支持请求路由、静态文件服务等功能
- WebSocket协议:实现了完整的WebSocket握手和消息处理,支持文本和二进制消息
- 自定义TCP协议:设计并实现了一个简单但功能完整的自定义协议,包括消息封装、解析和错误处理
- UDP协议:实现了基础的UDP回显服务器/客户端,以及更高级的带确认和重传机制的可靠UDP实现
通过这些实现,我们学习了Asio的核心概念和技术,包括异步I/O模型、缓冲区管理、错误处理等。这些知识将帮助你构建高性能、可靠的网络应用程序。
在实际项目中,你可以根据具体需求选择或设计合适的协议,并结合Asio提供的强大功能来实现。记住,良好的协议设计是构建可靠网络应用的基础,而Asio则为你提供了实现这些协议的理想框架。
### 2. 高级HTTP服务器功能
下面是一个更高级的HTTP服务器实现,包含以下功能:
- 请求路由
- 查询参数解析
- 表单数据处理
- 文件服务
```cpp
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <map>
#include <fstream>
#include <sstream>
#include <filesystem>
#include <regex>
#include <asio.hpp>
namespace fs = std::filesystem;
using asio::ip::tcp;
// HTTP请求
struct HttpRequest {
std::string method;
std::string path;
std::string version;
std::map<std::string, std::string> headers;
std::map<std::string, std::string> query_params;
std::string body;
};
// HTTP响应
struct HttpResponse {
std::string version = "HTTP/1.1";
int status_code = 200;
std::string status_message = "OK";
std::map<std::string, std::string> headers;
std::string body;
// 转换为字符串
std::string to_string() {
std::stringstream ss;
ss << version << " " << status_code << " " << status_message << "\r\n";
// 添加Content-Length头
headers["Content-Length"] = std::to_string(body.size());
// 添加所有头信息
for (const auto& [key, value] : headers) {
ss << key << ": " << value << "\r\n";
}
ss << "\r\n" << body;
return ss.str();
}
};
// HTTP路由处理器类型
using HttpHandler = std::function<void(const HttpRequest&, HttpResponse&)>;
// HTTP服务器类
class AdvancedHttpServer {
public:
AdvancedHttpServer(asio::io_context& io_context, short port)
: acceptor_(io_context, tcp::endpoint(tcp::v4(), port)) {
do_accept();
}
// 注册GET路由
void get(const std::string& path, const HttpHandler& handler) {
routes_["GET"][path] = handler;
}
// 注册POST路由
void post(const std::string& path, const HttpHandler& handler) {
routes_["POST"][path] = handler;
}
// 设置静态文件目录
void set_static_dir(const std::string& dir) {
static_dir_ = dir;
}
private:
// HTTP连接类
class Connection : public std::enable_shared_from_this<Connection> {
public:
Connection(tcp::socket socket, AdvancedHttpServer& server)
: socket_(std::move(socket)), server_(server) {
}
void start() {
read_request_line();
}
private:
// 读取请求行
void read_request_line() {
auto self = shared_from_this();
asio::async_read_until(socket_, buffer_, "\r\n",
[self](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
std::istream request_stream(&self->buffer_);
std::string request_line;
std::getline(request_stream, request_line);
// 移除回车符
if (!request_line.empty() && request_line.back() == '\r') {
request_line.pop_back();
}
self->parse_request_line(request_line);
self->read_headers();
}
});
}
// 解析请求行
void parse_request_line(const std::string& request_line) {
std::istringstream request_line_stream(request_line);
request_line_stream >> request_.method >> request_.path >> request_.version;
// 解析查询参数
size_t query_pos = request_.path.find('?');
if (query_pos != std::string::npos) {
std::string path_part = request_.path.substr(0, query_pos);
std::string query_part = request_.path.substr(query_pos + 1);
request_.path = path_part;
parse_query_params(query_part);
}
}
// 解析查询参数
void parse_query_params(const std::string& query_part) {
std::stringstream query_stream(query_part);
std::string param;
while (std::getline(query_stream, param, '&')) {
size_t eq_pos = param.find('=');
if (eq_pos != std::string::npos) {
std::string key = param.substr(0, eq_pos);
std::string value = param.substr(eq_pos + 1);
request_.query_params[key] = value;
}
}
}
// 读取请求头
void read_headers() {
auto self = shared_from_this();
asio::async_read_until(socket_, buffer_, "\r\n\r\n",
[self](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
std::istream request_stream(&self->buffer_);
std::string header_line;
while (std::getline(request_stream, header_line) && header_line != "\r") {
// 移除回车符
if (!header_line.empty() && header_line.back() == '\r') {
header_line.pop_back();
}
size_t colon_pos = header_line.find(':');
if (colon_pos != std::string::npos) {
std::string key = header_line.substr(0, colon_pos);
std::string value = header_line.substr(colon_pos + 2); // 跳过冒号和空格
self->request_.headers[key] = value;
}
}
// 检查是否有请求体
auto content_length_it = self->request_.headers.find("Content-Length");
if (content_length_it != self->request_.headers.end()) {
self->read_body(std::stoi(content_length_it->second));
} else {
self->handle_request();
}
}
});
}
// 读取请求体
void read_body(size_t content_length) {
auto self = shared_from_this();
asio::async_read(socket_, buffer_, asio::transfer_exactly(content_length),
[self](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
// 读取请求体
std::istream request_stream(&self->buffer_);
std::vector<char> body_content(content_length);
request_stream.read(body_content.data(), content_length);
self->request_.body = std::string(body_content.data(), content_length);
self->handle_request();
}
});
}
// 处理请求
void handle_request() {
HttpResponse response;
// 检查是否为静态文件请求
if (!server_.static_dir_.empty() && request_.method == "GET") {
fs::path file_path = fs::path(server_.static_dir_) / fs::path(request_.path.substr(1));
if (fs::exists(file_path) && !fs::is_directory(file_path)) {
serve_static_file(file_path.string(), response);
send_response(response);
return;
}
}
// 查找路由处理器
auto method_it = server_.routes_.find(request_.method);
if (method_it != server_.routes_.end()) {
auto path_it = method_it->second.find(request_.path);
if (path_it != method_it->second.end()) {
// 调用路由处理器
path_it->second(request_, response);
send_response(response);
return;
}
}
// 未找到路由,返回404
response.status_code = 404;
response.status_message = "Not Found";
response.body = "<html><body><h1>404 Not Found</h1></body></html>";
response.headers["Content-Type"] = "text/html";
send_response(response);
}
// 提供静态文件
void serve_static_file(const std::string& file_path, HttpResponse& response) {
std::ifstream file(file_path, std::ios::binary);
if (file) {
// 读取文件内容
std::stringstream buffer;
buffer << file.rdbuf();
response.body = buffer.str();
// 设置内容类型
std::string extension = fs::path(file_path).extension().string();
response.headers["Content-Type"] = get_mime_type(extension);
} else {
response.status_code = 404;
response.status_message = "Not Found";
response.body = "<html><body><h1>404 File Not Found</h1></body></html>";
response.headers["Content-Type"] = "text/html";
}
}
// 获取MIME类型
std::string get_mime_type(const std::string& extension) {
if (extension == ".html" || extension == ".htm") return "text/html";
if (extension == ".css") return "text/css";
if (extension == ".js") return "application/javascript";
if (extension == ".json") return "application/json";
if (extension == ".png") return "image/png";
if (extension == ".jpg" || extension == ".jpeg") return "image/jpeg";
if (extension == ".gif") return "image/gif";
if (extension == ".svg") return "image/svg+xml";
if (extension == ".txt") return "text/plain";
return "application/octet-stream";
}
// 发送响应
void send_response(const HttpResponse& response) {
auto self = shared_from_this();
std::string response_str = response.to_string();
asio::async_write(socket_, asio::buffer(response_str),
[self](std::error_code ec, std::size_t) {
if (!ec) {
// 响应发送完成后关闭连接
asio::error_code ignored_ec;
self->socket_.shutdown(tcp::socket::shutdown_both, ignored_ec);
}
});
}
tcp::socket socket_;
asio::streambuf buffer_;
HttpRequest request_;
AdvancedHttpServer& server_;
};
// 接受新连接
void do_accept() {
acceptor_.async_accept(
[this](std::error_code ec, tcp::socket socket) {
if (!ec) {
std::make_shared<Connection>(std::move(socket), *this)->start();
}
do_accept();
});
}
tcp::acceptor acceptor_;
std::map<std::string, std::map<std::string, HttpHandler>> routes_;
std::string static_dir_;
};
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "Usage: advanced_http_server <port>" << std::endl;
return 1;
}
asio::io_context io_context;
AdvancedHttpServer server(io_context, std::atoi(argv[1]));
// 设置静态文件目录
server.set_static_dir("./public");
// 注册GET路由
server.get("/api/hello", [](const HttpRequest& request, HttpResponse& response) {
std::string name = "World";
if (request.query_params.count("name")) {
name = request.query_params["name"];
}
response.body = "{\"message\": \"Hello, " + name + "!\"}";
response.headers["Content-Type"] = "application/json";
});
// 注册POST路由
server.post("/api/data", [](const HttpRequest& request, HttpResponse& response) {
// 假设请求体是JSON格式的数据
response.body = "{\"received\": true, \"data_length\": " +
std::to_string(request.body.size()) + "}";
response.headers["Content-Type"] = "application/json";
});
std::cout << "Advanced HTTP server running on port " << argv[1] << std::endl;
std::cout << "Static files served from ./public" << std::endl;
std::cout << "API endpoints: /api/hello, /api/data" << std::endl;
io_context.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
3. HTTP客户端实现
下面是一个简单的HTTP客户端实现,可以发送GET和POST请求:
#include <iostream>
#include <string>
#include <functional>
#include <asio.hpp>
using asio::ip::tcp;
// HTTP客户端类
class HttpClient {
public:
HttpClient(asio::io_context& io_context)
: resolver_(io_context), socket_(io_context) {
}
// 发送GET请求
void get(const std::string& host, const std::string& port, const std::string& path,
std::function<void(const std::string&)> callback) {
callback_ = callback;
// 解析服务器地址
resolver_.async_resolve(host, port,
[this, host, path](std::error_code ec, tcp::resolver::results_type results) {
if (!ec) {
// 连接到服务器
asio::async_connect(socket_, results,
[this, host, path](std::error_code ec, tcp::endpoint) {
if (!ec) {
// 构建HTTP GET请求
std::string request = "GET " + path + " HTTP/1.1\r\n";
request += "Host: " + host + "\r\n";
request += "Connection: close\r\n";
request += "\r\n";
// 发送请求
asio::async_write(socket_, asio::buffer(request),
[this](std::error_code ec, std::size_t) {
if (!ec) {
// 读取响应
read_response();
}
});
}
});
}
});
}
// 发送POST请求
void post(const std::string& host, const std::string& port, const std::string& path,
const std::string& body, const std::string& content_type,
std::function<void(const std::string&)> callback) {
callback_ = callback;
// 解析服务器地址
resolver_.async_resolve(host, port,
[this, host, path, body, content_type](std::error_code ec, tcp::resolver::results_type results) {
if (!ec) {
// 连接到服务器
asio::async_connect(socket_, results,
[this, host, path, body, content_type](std::error_code ec, tcp::endpoint) {
if (!ec) {
// 构建HTTP POST请求
std::string request = "POST " + path + " HTTP/1.1\r\n";
request += "Host: " + host + "\r\n";
request += "Content-Type: " + content_type + "\r\n";
request += "Content-Length: " + std::to_string(body.size()) + "\r\n";
request += "Connection: close\r\n";
request += "\r\n";
request += body;
// 发送请求
asio::async_write(socket_, asio::buffer(request),
[this](std::error_code ec, std::size_t) {
if (!ec) {
// 读取响应
read_response();
}
});
}
});
}
});
}
private:
// 读取HTTP响应
void read_response() {
auto self(shared_from_this());
asio::async_read(socket_, response_buffer_,
[this](std::error_code ec, std::size_t /*bytes_transferred*/) {
if (!ec) {
// 读取完整响应
std::string response(asio::buffers_begin(response_buffer_.data()),
asio::buffers_end(response_buffer_.data()));
// 调用回调函数处理响应
callback_(response);
}
});
}
tcp::resolver resolver_;
tcp::socket socket_;
asio::streambuf response_buffer_;
std::function<void(const std::string&)> callback_;
};
int main() {
try {
asio::io_context io_context;
auto client = std::make_shared<HttpClient>(io_context);
// 发送GET请求示例
client->get("www.example.com", "80", "/",
[](const std::string& response) {
std::cout << "GET Response: " << std::endl;
std::cout << response.substr(0, 500) << "..." << std::endl; // 只显示前500个字符
});
// 运行事件循环
io_context.run();
// 重新创建io_context以发送第二个请求
asio::io_context io_context2;
auto client2 = std::make_shared<HttpClient>(io_context2);
// 发送POST请求示例
std::string json_body = "{\"name\": \"Asio\", \"type\": \"HTTP Client\"}";
client2->post("httpbin.org", "80", "/post", json_body, "application/json",
[](const std::string& response) {
std::cout << "\nPOST Response: " << std::endl;
std::cout << response.substr(0, 500) << "..." << std::endl; // 只显示前500个字符
});
// 运行事件循环
io_context2.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
二、WebSocket协议实现
WebSocket提供了双向通信的能力,适用于需要实时通信的应用程序。下面我们将学习如何使用Asio实现WebSocket服务器和客户端。
1. WebSocket握手
WebSocket通信的第一步是握手,客户端发送HTTP请求,服务器返回HTTP响应,完成协议升级:
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <string_view>
#include <vector>
#include <algorithm>
#include <cctype>
#include <asio.hpp>
using asio::ip::tcp;
// Base64编码工具函数
std::string base64_encode(const std::string& input) {
const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string encoded;
int val = 0, valb = -6;
for (unsigned char c : input) {
val = (val << 8) + c;
valb += 8;
while (valb >= 0) {
encoded.push_back(base64_chars[(val >> valb) & 0x3F]);
valb -= 6;
}
}
if (valb > -6) encoded.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
while (encoded.size() % 4) encoded.push_back('=');
return encoded;
}
// SHA-1哈希工具函数 (简化版本,实际应用应使用密码学库)
std::string sha1(const std::string& str) {
// 注意:这只是一个占位符实现
// 在实际应用中,应该使用OpenSSL或其他密码学库来计算SHA-1哈希
// 这里返回一个模拟的哈希值
// 实际的WebSocket实现必须正确计算SHA-1哈希
return "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
}
// WebSocket服务器类
class WebSocketServer {
public:
WebSocketServer(asio::io_context& io_context, short port)
: acceptor_(io_context, tcp::endpoint(tcp::v4(), port)) {
do_accept();
}
private:
// WebSocket连接类
class Connection : public std::enable_shared_from_this<Connection> {
public:
Connection(tcp::socket socket)
: socket_(std::move(socket)) {
}
void start() {
perform_handshake();
}
private:
// 执行WebSocket握手
void perform_handshake() {
auto self = shared_from_this();
asio::async_read_until(socket_, buffer_, "\r\n\r\n",
[self](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
std::istream request_stream(&self->buffer_);
std::string request_line;
std::getline(request_stream, request_line);
// 解析请求头
std::map<std::string, std::string> headers;
std::string header_line;
while (std::getline(request_stream, header_line) && header_line != "\r") {
if (!header_line.empty() && header_line.back() == '\r') {
header_line.pop_back();
}
size_t colon_pos = header_line.find(':');
if (colon_pos != std::string::npos) {
std::string key = header_line.substr(0, colon_pos);
// 去除键中的空格并转为小写
key.erase(std::remove_if(key.begin(), key.end(), ::isspace), key.end());
std::transform(key.begin(), key.end(), key.begin(), ::tolower);
std::string value = header_line.substr(colon_pos + 1);
// 去除值前面的空格
size_t start_pos = value.find_first_not_of(" ");
if (start_pos != std::string::npos) {
value = value.substr(start_pos);
}
headers[key] = value;
}
}
// 验证WebSocket握手请求
if (self->validate_handshake(headers)) {
// 发送WebSocket握手响应
self->send_handshake_response(headers);
// 握手完成后,开始处理WebSocket消息
self->read_frame();
} else {
std::cerr << "Invalid WebSocket handshake request" << std::endl;
self->socket_.close();
}
}
});
}
// 验证WebSocket握手请求
bool validate_handshake(const std::map<std::string, std::string>& headers) {
// 检查必要的头信息
if (headers.find("upgrade") == headers.end() ||
headers.find("connection") == headers.end() ||
headers.find("sec-websocket-key") == headers.end() ||
headers.find("sec-websocket-version") == headers.end()) {
return false;
}
// 检查Upgrade头的值是否为websocket
std::string upgrade = headers.at("upgrade");
std::transform(upgrade.begin(), upgrade.end(), upgrade.begin(), ::tolower);
if (upgrade != "websocket") {
return false;
}
// 检查Connection头是否包含Upgrade
std::string connection = headers.at("connection");
std::transform(connection.begin(), connection.end(), connection.begin(), ::tolower);
if (connection.find("upgrade") == std::string::npos) {
return false;
}
// 检查WebSocket版本
if (headers.at("sec-websocket-version") != "13") {
return false;
}
return true;
}
// 发送WebSocket握手响应
void send_handshake_response(const std::map<std::string, std::string>& headers) {
// 获取Sec-WebSocket-Key
const std::string& key = headers.at("sec-websocket-key");
// 计算Sec-WebSocket-Accept
// 实际实现中应该使用正确的SHA-1哈希计算
std::string accept_key = base64_encode(sha1(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
// 构建HTTP响应
std::string response = "HTTP/1.1 101 Switching Protocols\r\n";
response += "Upgrade: websocket\r\n";
response += "Connection: Upgrade\r\n";
response += "Sec-WebSocket-Accept: " + accept_key + "\r\n";
response += "\r\n";
// 发送响应
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(response),
[self](std::error_code ec, std::size_t) {
if (ec) {
std::cerr << "Failed to send handshake response: " << ec.message() << std::endl;
self->socket_.close();
}
});
}
// 读取WebSocket帧
void read_frame() {
auto self = shared_from_this();
asio::async_read(socket_, asio::buffer(&frame_header_, 2),
[self](std::error_code ec, std::size_t) {
if (!ec) {
// 解析帧头部
bool fin = (frame_header_ & 0x80) != 0;
uint8_t opcode = frame_header_ & 0x0F;
bool mask = (frame_header_[1] & 0x80) != 0;
uint64_t payload_length = frame_header_[1] & 0x7F;
// 处理不同的操作码
if (opcode == 0x08) { // 连接关闭帧
self->handle_close_frame();
} else if (opcode == 0x09) { // PING帧
self->handle_ping_frame();
} else if (opcode == 0x0A) { // PONG帧
self->handle_pong_frame();
} else if (opcode == 0x01 || opcode == 0x02) { // 文本或二进制帧
self->read_payload_length(payload_length, opcode, fin, mask);
} else {
std::cerr << "Unknown WebSocket opcode: " << static_cast<int>(opcode) << std::endl;
self->socket_.close();
}
}
});
}
// 读取有效载荷长度
void read_payload_length(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
auto self = shared_from_this();
if (payload_length == 126) {
// 16位长度
asio::async_read(socket_, asio::buffer(length_bytes_, 2),
[self, opcode, fin, mask](std::error_code ec, std::size_t) {
if (!ec) {
uint64_t extended_length = (static_cast<uint64_t>(self->length_bytes_[0]) << 8) |
static_cast<uint64_t>(self->length_bytes_[1]);
self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
}
});
} else if (payload_length == 127) {
// 64位长度
asio::async_read(socket_, asio::buffer(length_bytes_, 8),
[self, opcode, fin, mask](std::error_code ec, std::size_t) {
if (!ec) {
uint64_t extended_length = 0;
for (int i = 0; i < 8; ++i) {
extended_length = (extended_length << 8) | self->length_bytes_[i];
}
self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
}
});
} else {
// 7位长度
read_masking_key_and_payload(payload_length, opcode, fin, mask);
}
}
// 读取掩码键和有效载荷
void read_masking_key_and_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
auto self = shared_from_this();
if (mask) {
// 读取掩码键
asio::async_read(socket_, asio::buffer(masking_key_, 4),
[self, payload_length, opcode, fin, mask](std::error_code ec, std::size_t) {
if (!ec) {
self->read_payload(payload_length, opcode, fin, mask);
}
});
} else {
// 没有掩码键
read_payload(payload_length, opcode, fin, mask);
}
}
// 读取有效载荷
void read_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
auto self = shared_from_this();
payload_.resize(payload_length);
asio::async_read(socket_, asio::buffer(payload_),
[self, opcode, fin, mask](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
// 处理掩码
if (mask) {
for (size_t i = 0; i < bytes_transferred; ++i) {
self->payload_[i] ^= self->masking_key_[i % 4];
}
}
// 处理消息
self->handle_message(opcode, fin);
// 继续读取下一帧
self->read_frame();
}
});
}
// 处理WebSocket消息
void handle_message(uint8_t opcode, bool fin) {
if (opcode == 0x01) { // 文本消息
std::string message(payload_.begin(), payload_.end());
std::cout << "Received text message: " << message << std::endl;
// 回显消息
send_text_message(message);
} else if (opcode == 0x02) { // 二进制消息
std::cout << "Received binary message of size: " << payload_.size() << " bytes" << std::endl;
// 回显二进制消息
send_binary_message(payload_);
}
}
// 发送文本消息
void send_text_message(const std::string& message) {
// 构建WebSocket帧
std::vector<uint8_t> frame;
// 帧头部
frame.push_back(0x81); // FIN + TEXT opcode
// 有效载荷长度
size_t message_size = message.size();
if (message_size <= 125) {
frame.push_back(static_cast<uint8_t>(message_size));
} else if (message_size <= 65535) {
frame.push_back(126);
frame.push_back(static_cast<uint8_t>((message_size >> 8) & 0xFF));
frame.push_back(static_cast<uint8_t>(message_size & 0xFF));
} else {
frame.push_back(127);
for (int i = 7; i >= 0; --i) {
frame.push_back(static_cast<uint8_t>((message_size >> (i * 8)) & 0xFF));
}
}
// 有效载荷
frame.insert(frame.end(), message.begin(), message.end());
// 发送帧
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(frame),
[self](std::error_code ec, std::size_t) {
if (ec) {
std::cerr << "Failed to send message: " << ec.message() << std::endl;
self->socket_.close();
}
});
}
// 发送二进制消息
void send_binary_message(const std::vector<uint8_t>& data) {
// 构建WebSocket帧
std::vector<uint8_t> frame;
// 帧头部
frame.push_back(0x82); // FIN + BINARY opcode
// 有效载荷长度
size_t data_size = data.size();
if (data_size <= 125) {
frame.push_back(static_cast<uint8_t>(data_size));
} else if (data_size <= 65535) {
frame.push_back(126);
frame.push_back(static_cast<uint8_t>((data_size >> 8) & 0xFF));
frame.push_back(static_cast<uint8_t>(data_size & 0xFF));
} else {
frame.push_back(127);
for (int i = 7; i >= 0; --i) {
frame.push_back(static_cast<uint8_t>((data_size >> (i * 8)) & 0xFF));
}
}
// 有效载荷
frame.insert(frame.end(), data.begin(), data.end());
// 发送帧
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(frame),
[self](std::error_code ec, std::size_t) {
if (ec) {
std::cerr << "Failed to send binary message: " << ec.message() << std::endl;
self->socket_.close();
}
});
}
// 处理关闭帧
void handle_close_frame() {
std::cout << "Received close frame" << std::endl;
socket_.close();
}
// 处理PING帧
void handle_ping_frame() {
std::cout << "Received ping frame" << std::endl;
// 读取PING帧的有效载荷
// 为简化起见,这里不读取有效载荷
// 发送PONG帧
send_pong_frame();
}
// 处理PONG帧
void handle_pong_frame() {
std::cout << "Received pong frame" << std::endl;
// 为简化起见,这里不做任何处理
read_frame();
}
// 发送PONG帧
void send_pong_frame() {
std::vector<uint8_t> pong_frame = {0x8A, 0x00}; // FIN + PONG opcode, no payload
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(pong_frame),
[self](std::error_code ec, std::size_t) {
if (!ec) {
self->read_frame();
}
});
}
tcp::socket socket_;
asio::streambuf buffer_;
// WebSocket帧相关字段
uint8_t frame_header_[2];
uint8_t length_bytes_[8];
uint8_t masking_key_[4];
std::vector<uint8_t> payload_;
};
// 接受新连接
void do_accept() {
acceptor_.async_accept(
[this](std::error_code ec, tcp::socket socket) {
if (!ec) {
std::make_shared<Connection>(std::move(socket))->start();
}
do_accept();
});
}
tcp::acceptor acceptor_;
};
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "Usage: websocket_server <port>" << std::endl;
return 1;
}
asio::io_context io_context;
WebSocketServer server(io_context, std::atoi(argv[1]));
std::cout << "WebSocket server running on port " << argv[1] << std::endl;
std::cout << "Waiting for connections..." << std::endl;
io_context.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
2. 完整的WebSocket客户端
下面是一个完整的WebSocket客户端实现:
#include <iostream>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include <algorithm>
#include <cctype>
#include <asio.hpp>
using asio::ip::tcp;
// WebSocket客户端类
class WebSocketClient : public std::enable_shared_from_this<WebSocketClient> {
public:
using MessageHandler = std::function<void(const std::string&)>;
using BinaryMessageHandler = std::function<void(const std::vector<uint8_t>&)>;
using ConnectionHandler = std::function<void()>;
using DisconnectionHandler = std::function<void()>;
WebSocketClient(asio::io_context& io_context)
: resolver_(io_context), socket_(io_context) {
}
// 连接到WebSocket服务器
void connect(const std::string& host, const std::string& port, const std::string& path) {
host_ = host;
port_ = port;
path_ = path;
// 解析服务器地址
resolver_.async_resolve(host, port,
[this](std::error_code ec, tcp::resolver::results_type results) {
if (!ec) {
// 连接到服务器
asio::async_connect(socket_, results,
[this](std::error_code ec, tcp::endpoint) {
if (!ec) {
std::cout << "Connected to " << host_ << ":" << port_ << std::endl;
// 执行WebSocket握手
perform_handshake();
} else {
std::cerr << "Connection failed: " << ec.message() << std::endl;
if (disconnection_handler_) {
disconnection_handler_();
}
}
});
} else {
std::cerr << "Resolution failed: " << ec.message() << std::endl;
if (disconnection_handler_) {
disconnection_handler_();
}
}
});
}
// 发送文本消息
void send_text(const std::string& message) {
if (!is_connected_) {
std::cerr << "Not connected to WebSocket server" << std::endl;
return;
}
// 构建WebSocket帧
std::vector<uint8_t> frame;
// 帧头部
frame.push_back(0x81); // FIN + TEXT opcode
// 有效载荷长度
size_t message_size = message.size();
if (message_size <= 125) {
frame.push_back(static_cast<uint8_t>(message_size));
} else if (message_size <= 65535) {
frame.push_back(126);
frame.push_back(static_cast<uint8_t>((message_size >> 8) & 0xFF));
frame.push_back(static_cast<uint8_t>(message_size & 0xFF));
} else {
frame.push_back(127);
for (int i = 7; i >= 0; --i) {
frame.push_back(static_cast<uint8_t>((message_size >> (i * 8)) & 0xFF));
}
}
// 有效载荷
frame.insert(frame.end(), message.begin(), message.end());
// 发送帧
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(frame),
[self](std::error_code ec, std::size_t) {
if (ec) {
std::cerr << "Failed to send message: " << ec.message() << std::endl;
self->close();
}
});
}
// 发送二进制消息
void send_binary(const std::vector<uint8_t>& data) {
if (!is_connected_) {
std::cerr << "Not connected to WebSocket server" << std::endl;
return;
}
// 构建WebSocket帧
std::vector<uint8_t> frame;
// 帧头部
frame.push_back(0x82); // FIN + BINARY opcode
// 有效载荷长度
size_t data_size = data.size();
if (data_size <= 125) {
frame.push_back(static_cast<uint8_t>(data_size));
} else if (data_size <= 65535) {
frame.push_back(126);
frame.push_back(static_cast<uint8_t>((data_size >> 8) & 0xFF));
frame.push_back(static_cast<uint8_t>(data_size & 0xFF));
} else {
frame.push_back(127);
for (int i = 7; i >= 0; --i) {
frame.push_back(static_cast<uint8_t>((data_size >> (i * 8)) & 0xFF));
}
}
// 有效载荷
frame.insert(frame.end(), data.begin(), data.end());
// 发送帧
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(frame),
[self](std::error_code ec, std::size_t) {
if (ec) {
std::cerr << "Failed to send binary message: " << ec.message() << std::endl;
self->close();
}
});
}
// 关闭连接
void close() {
if (is_connected_) {
is_connected_ = false;
// 发送关闭帧
std::vector<uint8_t> close_frame = {0x88, 0x00}; // FIN + CLOSE opcode, no payload
asio::error_code ignored_ec;
asio::write(socket_, asio::buffer(close_frame), ignored_ec);
socket_.shutdown(tcp::socket::shutdown_both, ignored_ec);
socket_.close(ignored_ec);
std::cout << "WebSocket connection closed" << std::endl;
if (disconnection_handler_) {
disconnection_handler_();
}
}
}
// 设置消息处理回调
void set_message_handler(const MessageHandler& handler) {
message_handler_ = handler;
}
// 设置二进制消息处理回调
void set_binary_message_handler(const BinaryMessageHandler& handler) {
binary_message_handler_ = handler;
}
// 设置连接成功回调
void set_connection_handler(const ConnectionHandler& handler) {
connection_handler_ = handler;
}
// 设置断开连接回调
void set_disconnection_handler(const DisconnectionHandler& handler) {
disconnection_handler_ = handler;
}
// 发送PING
void ping() {
if (is_connected_) {
std::vector<uint8_t> ping_frame = {0x89, 0x00}; // FIN + PING opcode, no payload
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(ping_frame),
[self](std::error_code ec, std::size_t) {
if (ec) {
std::cerr << "Failed to send ping: " << ec.message() << std::endl;
self->close();
}
});
}
}
private:
// 执行WebSocket握手
void perform_handshake() {
// 生成随机的Sec-WebSocket-Key
std::string sec_websocket_key = generate_websocket_key();
// 构建HTTP请求
std::string request = "GET " + path_ + " HTTP/1.1\r\n";
request += "Host: " + host_ + ":" + port_ + "\r\n";
request += "Upgrade: websocket\r\n";
request += "Connection: Upgrade\r\n";
request += "Sec-WebSocket-Key: " + sec_websocket_key + "\r\n";
request += "Sec-WebSocket-Version: 13\r\n";
request += "\r\n";
// 发送握手请求
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(request),
[self](std::error_code ec, std::size_t) {
if (!ec) {
// 读取握手响应
self->read_handshake_response();
} else {
std::cerr << "Failed to send handshake request: " << ec.message() << std::endl;
self->close();
}
});
}
// 生成WebSocket Key
std::string generate_websocket_key() {
// 生成16字节的随机数据
std::vector<uint8_t> random_bytes(16);
std::generate(random_bytes.begin(), random_bytes.end(), []() {
return static_cast<uint8_t>(rand() % 256);
});
// Base64编码(简化实现,实际应用应使用标准库或第三方库)
const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string encoded;
int val = 0, valb = -6;
for (unsigned char c : random_bytes) {
val = (val << 8) + c;
valb += 8;
while (valb >= 0) {
encoded.push_back(base64_chars[(val >> valb) & 0x3F]);
valb -= 6;
}
}
if (valb > -6) encoded.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
while (encoded.size() % 4) encoded.push_back('=');
return encoded;
}
// 读取握手响应
void read_handshake_response() {
auto self = shared_from_this();
asio::async_read_until(socket_, response_buffer_, "\r\n\r\n",
[self](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
std::istream response_stream(&self->response_buffer_);
std::string status_line;
std::getline(response_stream, status_line);
// 检查状态码
std::istringstream status_stream(status_line);
std::string version, status_code_str;
status_stream >> version >> status_code_str;
int status_code = std::stoi(status_code_str);
if (status_code == 101) {
// 解析响应头
std::map<std::string, std::string> headers;
std::string header_line;
while (std::getline(response_stream, header_line) && header_line != "\r") {
if (!header_line.empty() && header_line.back() == '\r') {
header_line.pop_back();
}
size_t colon_pos = header_line.find(':');
if (colon_pos != std::string::npos) {
std::string key = header_line.substr(0, colon_pos);
std::string value = header_line.substr(colon_pos + 2); // 跳过冒号和空格
// 转换键为小写以进行不区分大小写的比较
std::string lower_key = key;
std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower);
headers[lower_key] = value;
}
}
// 验证Upgrade和Connection头
bool upgrade_valid = (headers.find("upgrade") != headers.end() &&
headers["upgrade"] == "websocket");
bool connection_valid = (headers.find("connection") != headers.end() &&
headers["connection"] == "Upgrade");
if (upgrade_valid && connection_valid) {
// 握手成功
self->is_connected_ = true;
std::cout << "WebSocket handshake successful" << std::endl;
// 调用连接成功回调
if (self->connection_handler_) {
self->connection_handler_();
}
// 开始读取WebSocket帧
self->read_frame();
} else {
std::cerr << "Invalid WebSocket handshake response" << std::endl;
self->close();
}
} else {
std::cerr << "WebSocket handshake failed with status code: " << status_code << std::endl;
self->close();
}
} else {
std::cerr << "Failed to read handshake response: " << ec.message() << std::endl;
self->close();
}
});
}
// 读取WebSocket帧
void read_frame() {
if (!is_connected_) {
return;
}
auto self = shared_from_this();
asio::async_read(socket_, asio::buffer(&frame_header_, 2),
[self](std::error_code ec, std::size_t) {
if (!ec) {
// 解析帧头部
bool fin = (self->frame_header_[0] & 0x80) != 0;
uint8_t opcode = self->frame_header_[0] & 0x0F;
bool mask = (self->frame_header_[1] & 0x80) != 0;
uint64_t payload_length = self->frame_header_[1] & 0x7F;
// 处理不同的操作码
if (opcode == 0x08) { // 连接关闭帧
self->handle_close_frame();
} else if (opcode == 0x09) { // PING帧
self->handle_ping_frame();
} else if (opcode == 0x0A) { // PONG帧
self->handle_pong_frame();
} else if (opcode == 0x01 || opcode == 0x02) { // 文本或二进制帧
self->read_payload_length(payload_length, opcode, fin, mask);
} else if (opcode == 0x00) { // 延续帧
// 简化实现,不处理延续帧
std::cerr << "Continuation frames not supported" << std::endl;
self->close();
} else {
std::cerr << "Unknown WebSocket opcode: " << static_cast<int>(opcode) << std::endl;
self->close();
}
} else {
std::cerr << "Failed to read frame header: " << ec.message() << std::endl;
self->close();
}
});
}
// 读取有效载荷长度
void read_payload_length(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
auto self = shared_from_this();
if (payload_length == 126) {
// 16位长度
asio::async_read(socket_, asio::buffer(length_bytes_, 2),
[self, opcode, fin, mask](std::error_code ec, std::size_t) {
if (!ec) {
uint64_t extended_length = (static_cast<uint64_t>(self->length_bytes_[0]) << 8) |
static_cast<uint64_t>(self->length_bytes_[1]);
self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
} else {
std::cerr << "Failed to read payload length: " << ec.message() << std::endl;
self->close();
}
});
} else if (payload_length == 127) {
// 64位长度
asio::async_read(socket_, asio::buffer(length_bytes_, 8),
[self, opcode, fin, mask](std::error_code ec, std::size_t) {
if (!ec) {
uint64_t extended_length = 0;
for (int i = 0; i < 8; ++i) {
extended_length = (extended_length << 8) | self->length_bytes_[i];
}
self->read_masking_key_and_payload(extended_length, opcode, fin, mask);
} else {
std::cerr << "Failed to read payload length: " << ec.message() << std::endl;
self->close();
}
});
} else {
// 7位长度
read_masking_key_and_payload(payload_length, opcode, fin, mask);
}
}
// 读取掩码键和有效载荷
void read_masking_key_and_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
auto self = shared_from_this();
if (mask) {
// 读取掩码键
asio::async_read(socket_, asio::buffer(masking_key_, 4),
[self, payload_length, opcode, fin, mask](std::error_code ec, std::size_t) {
if (!ec) {
self->read_payload(payload_length, opcode, fin, mask);
} else {
std::cerr << "Failed to read masking key: " << ec.message() << std::endl;
self->close();
}
});
} else {
// 没有掩码键
read_payload(payload_length, opcode, fin, mask);
}
}
// 读取有效载荷
void read_payload(uint64_t payload_length, uint8_t opcode, bool fin, bool mask) {
auto self = shared_from_this();
payload_.resize(payload_length);
asio::async_read(socket_, asio::buffer(payload_),
[self, opcode, fin, mask](std::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
// 处理掩码
if (mask) {
for (size_t i = 0; i < bytes_transferred; ++i) {
self->payload_[i] ^= self->masking_key_[i % 4];
}
}
// 处理消息
self->handle_message(opcode, fin);
// 继续读取下一帧
self->read_frame();
} else {
std::cerr << "Failed to read payload: " << ec.message() << std::endl;
self->close();
}
});
}
// 处理WebSocket消息
void handle_message(uint8_t opcode, bool fin) {
if (opcode == 0x01) { // 文本消息
std::string message(payload_.begin(), payload_.end());
std::cout << "Received text message: " << message << std::endl;
// 调用消息处理回调
if (message_handler_) {
message_handler_(message);
}
} else if (opcode == 0x02) { // 二进制消息
std::cout << "Received binary message of size: " << payload_.size() << " bytes" << std::endl;
// 调用二进制消息处理回调
if (binary_message_handler_) {
binary_message_handler_(payload_);
}
}
}
// 处理关闭帧
void handle_close_frame() {
std::cout << "Received close frame" << std::endl;
close();
}
// 处理PING帧
void handle_ping_frame() {
std::cout << "Received ping frame" << std::endl;
// 读取PING帧的有效载荷
// 为简化起见,这里不读取有效载荷
// 发送PONG帧
send_pong_frame();
}
// 处理PONG帧
void handle_pong_frame() {
std::cout << "Received pong frame" << std::endl;
// 为简化起见,这里不做任何处理
read_frame();
}
// 发送PONG帧
void send_pong_frame() {
std::vector<uint8_t> pong_frame = {0x8A, 0x00}; // FIN + PONG opcode, no payload
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(pong_frame),
[self](std::error_code ec, std::size_t) {
if (!ec) {
self->read_frame();
} else {
std::cerr << "Failed to send pong: " << ec.message() << std::endl;
self->close();
}
});
}
tcp::resolver resolver_;
tcp::socket socket_;
asio::streambuf response_buffer_;
bool is_connected_ = false;
std::string host_;
std::string port_;
std::string path_;
// WebSocket帧相关字段
uint8_t frame_header_[2];
uint8_t length_bytes_[8];
uint8_t masking_key_[4];
std::vector<uint8_t> payload_;
// 回调函数
MessageHandler message_handler_;
BinaryMessageHandler binary_message_handler_;
ConnectionHandler connection_handler_;
DisconnectionHandler disconnection_handler_;
};
// WebSocket客户端示例
int main() {
try {
asio::io_context io_context;
auto client = std::make_shared<WebSocketClient>(io_context);
// 设置回调函数
client->set_message_handler([](const std::string& message) {
std::cout << "Message received: " << message << std::endl;
});
client->set_binary_message_handler([](const std::vector<uint8_t>& data) {
std::cout << "Binary data received, size: " << data.size() << " bytes" << std::endl;
});
client->set_connection_handler([]() {
std::cout << "Connection established successfully" << std::endl;
});
client->set_disconnection_handler([]() {
std::cout << "Disconnected from server" << std::endl;
});
// 连接到WebSocket服务器
// 注意:这里使用的是示例地址,请替换为实际的WebSocket服务器地址
client->connect("localhost", "8080", "/ws");
// 创建一个工作对象以防止io_context在没有事件时退出
asio::executor_work_guard<asio::io_context::executor_type> work =
asio::make_work_guard(io_context);
// 在单独的线程中运行io_context
std::thread io_thread([&io_context]() {
io_context.run();
});
// 等待连接建立
std::this_thread::sleep_for(std::chrono::seconds(1));
// 发送一条文本消息
client->send_text("Hello, WebSocket Server!");
// 发送一条二进制消息
std::vector<uint8_t> binary_data = {0x01, 0x02, 0x03, 0x04, 0x05};
client->send_binary(binary_data);
// 发送PING
client->ping();
// 等待一段时间,以便接收响应
std::this_thread::sleep_for(std::chrono::seconds(5));
// 关闭连接
client->close();
// 停止io_context
io_context.stop();
if (io_thread.joinable()) {
io_thread.join();
}
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
更多推荐
所有评论(0)