diff --git a/src/brpc/channel.cpp b/src/brpc/channel.cpp index 0252e97d74..0bb29fce17 100644 --- a/src/brpc/channel.cpp +++ b/src/brpc/channel.cpp @@ -77,6 +77,8 @@ ChannelSSLOptions* ChannelOptions::mutable_ssl_options() { static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { if (opt.auth == NULL && !opt.has_ssl_options() && + opt.client_host.empty() && + opt.device_name.empty() && opt.connection_group.empty() && opt.hc_option.health_check_path.empty()) { // Returning zeroized result by default is more intuitive for users. @@ -94,6 +96,14 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { buf.append("|conng="); buf.append(opt.connection_group); } + if (!opt.client_host.empty()) { + buf.append("|clih="); + buf.append(opt.client_host); + } + if (!opt.device_name.empty()) { + buf.append("|devn="); + buf.append(opt.device_name); + } if (opt.auth) { buf.append("|auth="); buf.append((char*)&opt.auth, sizeof(opt.auth)); @@ -362,14 +372,27 @@ int Channel::InitSingle(const butil::EndPoint& server_addr_and_port, LOG(ERROR) << "Invalid port=" << port; return -1; } + butil::EndPoint client_endpoint; + if (!_options.client_host.empty() && + butil::str2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0 && + butil::hostname2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0) { + LOG(ERROR) << "Invalid client host=`" << _options.client_host << '\''; + return -1; + } _server_address = server_addr_and_port; const ChannelSignature sig = ComputeChannelSignature(_options); std::shared_ptr ssl_ctx; if (CreateSocketSSLContext(_options, &ssl_ctx) != 0) { return -1; } + SocketOptions opt; + opt.local_side = client_endpoint; + opt.initial_ssl_ctx = ssl_ctx; + opt.use_rdma = _options.use_rdma; + opt.hc_option = _options.hc_option; + opt.device_name = _options.device_name; if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig), - &_server_id, ssl_ctx, _options.use_rdma, _options.hc_option) != 0) { + &_server_id, opt) != 0) { LOG(ERROR) << "Fail to insert into SocketMap"; return -1; } @@ -397,6 +420,13 @@ int Channel::Init(const char* ns_url, _options.mutable_ssl_options()->sni_name = _service_name; } } + butil::EndPoint client_endpoint; + if (!_options.client_host.empty() && + butil::str2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0 && + butil::hostname2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0) { + LOG(ERROR) << "Invalid client host=`" << _options.client_host << '\''; + return -1; + } std::unique_ptr lb(new (std::nothrow) LoadBalancerWithNaming); if (NULL == lb) { @@ -406,10 +436,13 @@ int Channel::Init(const char* ns_url, GetNamingServiceThreadOptions ns_opt; ns_opt.succeed_without_server = _options.succeed_without_server; ns_opt.log_succeed_without_server = _options.log_succeed_without_server; - ns_opt.use_rdma = _options.use_rdma; + ns_opt.socket_option.use_rdma = _options.use_rdma; ns_opt.channel_signature = ComputeChannelSignature(_options); - ns_opt.hc_option = _options.hc_option; - if (CreateSocketSSLContext(_options, &ns_opt.ssl_ctx) != 0) { + ns_opt.socket_option.hc_option = _options.hc_option; + ns_opt.socket_option.local_side = client_endpoint; + ns_opt.socket_option.device_name = _options.device_name; + if (CreateSocketSSLContext(_options, + &ns_opt.socket_option.initial_ssl_ctx) != 0) { return -1; } if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) { diff --git a/src/brpc/channel.h b/src/brpc/channel.h index c970209b3a..0f349ac6fe 100644 --- a/src/brpc/channel.h +++ b/src/brpc/channel.h @@ -148,6 +148,16 @@ struct ChannelOptions { // Its priority is higher than FLAGS_health_check_path and FLAGS_health_check_timeout_ms. // When it is not set, FLAGS_health_check_path and FLAGS_health_check_timeout_ms will take effect. HealthCheckOption hc_option; + + // IP address or host name of the client. + // if the client_host is "", the client IP address is determined by the OS. + // Default: "" + std::string client_host; + + // The device name of the client's network adapter. + // if the device_name is "", the flow control is determined by the OS. + // Default: "" + std::string device_name; private: // SSLOptions is large and not often used, allocate it on heap to // prevent ChannelOptions from being bloated in most cases. diff --git a/src/brpc/details/naming_service_thread.cpp b/src/brpc/details/naming_service_thread.cpp index 341ca35b09..f882b2255d 100644 --- a/src/brpc/details/naming_service_thread.cpp +++ b/src/brpc/details/naming_service_thread.cpp @@ -125,8 +125,8 @@ void NamingServiceThread::Actions::ResetServers( // Socket. SocketMapKey may be passed through AddWatcher. Make sure // to pick those Sockets with the right settings during OnAddedServers const SocketMapKey key(_added[i], _owner->_options.channel_signature); - CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx, - _owner->_options.use_rdma, _owner->_options.hc_option)); + CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, + _owner->_options.socket_option)); _added_sockets.push_back(tagged_id); } diff --git a/src/brpc/details/naming_service_thread.h b/src/brpc/details/naming_service_thread.h index 1745e5f267..9acb8f2931 100644 --- a/src/brpc/details/naming_service_thread.h +++ b/src/brpc/details/naming_service_thread.h @@ -44,15 +44,14 @@ class NamingServiceWatcher { struct GetNamingServiceThreadOptions { GetNamingServiceThreadOptions() : succeed_without_server(false) - , log_succeed_without_server(true) - , use_rdma(false) {} + , log_succeed_without_server(true) { + socket_option.use_rdma = false; +} bool succeed_without_server; bool log_succeed_without_server; - bool use_rdma; - HealthCheckOption hc_option; ChannelSignature channel_signature; - std::shared_ptr ssl_ctx; + SocketOptions socket_option; }; // A dedicated thread to map a name to ServerIds diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 9490650b78..e431aceff9 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -728,7 +728,8 @@ int Socket::OnCreated(const SocketOptions& options) { _keytable_pool = options.keytable_pool; _tos = 0; _remote_side = options.remote_side; - _local_side = butil::EndPoint(); + _local_side = options.local_side; + _device_name = options.device_name; _on_edge_triggered_events = options.on_edge_triggered_events; _user = options.user; _conn = options.conn; @@ -1296,7 +1297,25 @@ int Socket::Connect(const timespec* abstime, CHECK_EQ(0, butil::make_close_on_exec(sockfd)); // We need to do async connect (to manage the timeout by ourselves). CHECK_EQ(0, butil::make_non_blocking(sockfd)); - + if (!_device_name.empty()) { + if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, + _device_name.c_str(), _device_name.size()) < 0) { + PLOG(ERROR) << "Fail to set SO_BINDTODEVICE of fd=" << sockfd + << " to device_name=" << _device_name; + return -1; + } + } + if (local_side().ip != butil::IP_ANY) { + struct sockaddr_storage cli_addr; + if (butil::endpoint2sockaddr(local_side(), &cli_addr, &addr_size) != 0) { + PLOG(ERROR) << "Fail to get client sockaddr"; + return -1; + } + if (::bind(sockfd, (struct sockaddr*)&cli_addr, addr_size) != 0) { + PLOG(ERROR) << "Fail to bind client socket, errno=" << strerror(errno); + return -1; + } + } const int rc = ::connect( sockfd, (struct sockaddr*)&serv_addr, addr_size); if (rc != 0 && errno != EINPROGRESS) { @@ -2811,6 +2830,7 @@ int Socket::GetPooledSocket(SocketUniquePtr* pooled_socket) { if (socket_pool == NULL) { SocketOptions opt; opt.remote_side = remote_side(); + opt.local_side = butil::EndPoint(local_side().ip, 0); opt.user = user(); opt.on_edge_triggered_events = _on_edge_triggered_events; opt.initial_ssl_ctx = _ssl_ctx; @@ -2912,6 +2932,7 @@ int Socket::GetShortSocket(SocketUniquePtr* short_socket) { SocketId id; SocketOptions opt; opt.remote_side = remote_side(); + opt.local_side = butil::EndPoint(local_side().ip, 0); opt.user = user(); opt.on_edge_triggered_events = _on_edge_triggered_events; opt.initial_ssl_ctx = _ssl_ctx; diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 03ad43f867..a3e2323056 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -250,6 +250,8 @@ struct SocketOptions { // user->BeforeRecycle() before recycling. int fd{-1}; butil::EndPoint remote_side; + butil::EndPoint local_side; + std::string device_name; // If `connect_on_create' is true and `fd' is less than 0, // a client connection will be established to remote_side() // regarding deadline `connect_abstime' when Socket is being created. @@ -830,6 +832,9 @@ friend void DereferenceSocket(Socket*); // Address of self. Initialized in ResetFileDescriptor(). butil::EndPoint _local_side; + // The device name of the client's network adapter. + std::string _device_name; + // Called when edge-triggered events happened on `_fd'. Read comments // of EventDispatcher::AddConsumer (event_dispatcher.h) // carefully before implementing the callback. diff --git a/src/brpc/socket_map.cpp b/src/brpc/socket_map.cpp index 14bea71db5..3984f6b866 100644 --- a/src/brpc/socket_map.cpp +++ b/src/brpc/socket_map.cpp @@ -90,11 +90,9 @@ SocketMap* get_or_new_client_side_socket_map() { } int SocketMapInsert(const SocketMapKey& key, SocketId* id, - const std::shared_ptr& ssl_ctx, - bool use_rdma, - const HealthCheckOption& hc_option) { - return get_or_new_client_side_socket_map()->Insert(key, id, ssl_ctx, use_rdma, hc_option); -} + SocketOptions& opt) { + return get_or_new_client_side_socket_map()->Insert(key, id, opt); +} int SocketMapFind(const SocketMapKey& key, SocketId* id) { SocketMap* m = get_client_side_socket_map(); @@ -227,9 +225,7 @@ void SocketMap::ShowSocketMapInBvarIfNeed() { } int SocketMap::Insert(const SocketMapKey& key, SocketId* id, - const std::shared_ptr& ssl_ctx, - bool use_rdma, - const HealthCheckOption& hc_option) { + SocketOptions& opt) { ShowSocketMapInBvarIfNeed(); std::unique_lock mu(_mutex); @@ -249,11 +245,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id, sc = NULL; } SocketId tmp_id; - SocketOptions opt; opt.remote_side = key.peer.addr; - opt.initial_ssl_ctx = ssl_ctx; - opt.use_rdma = use_rdma; - opt.hc_option = hc_option; if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) { PLOG(FATAL) << "Fail to create socket to " << key.peer; return -1; diff --git a/src/brpc/socket_map.h b/src/brpc/socket_map.h index b0d542e78e..7cf0880498 100644 --- a/src/brpc/socket_map.h +++ b/src/brpc/socket_map.h @@ -80,9 +80,19 @@ struct SocketMapKeyHasher { // successfully, SocketMapRemove() MUST be called when the Socket is not needed. // Return 0 on success, -1 otherwise. int SocketMapInsert(const SocketMapKey& key, SocketId* id, + SocketOptions& opt); + +inline int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, bool use_rdma, - const HealthCheckOption& hc_option); + const HealthCheckOption& hc_option) { + SocketOptions opt; + opt.remote_side = key.peer.addr; + opt.initial_ssl_ctx = ssl_ctx; + opt.use_rdma = use_rdma; + opt.hc_option = hc_option; + return SocketMapInsert(key, id, opt); +} inline int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx) { @@ -155,7 +165,14 @@ class SocketMap { int Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, bool use_rdma, - const HealthCheckOption& hc_option); + const HealthCheckOption& hc_option) { + SocketOptions opt; + opt.remote_side = key.peer.addr; + opt.initial_ssl_ctx = ssl_ctx; + opt.use_rdma = use_rdma; + opt.hc_option = hc_option; + return Insert(key, id, opt); +} int Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx) { @@ -167,6 +184,7 @@ class SocketMap { HealthCheckOption hc_option; return Insert(key, id, empty_ptr, false, hc_option); } + int Insert(const SocketMapKey& key, SocketId* id, SocketOptions& opt); void Remove(const SocketMapKey& key, SocketId expected_id); int Find(const SocketMapKey& key, SocketId* id); diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 4a774fab2a..8508a7986c 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -2070,4 +2070,49 @@ TEST_F(ServerTest, auth) { ASSERT_EQ(0, server.Join()); } +void TestClientHost(const butil::EndPoint& ep, + brpc::Controller& cntl, + int error_code, bool failed, + brpc::ChannelOptions& copt) { + brpc::Channel chan; + copt.max_retry = 0; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(&chan); + stub.Echo(&cntl, &req, &res, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.ErrorCode(), error_code); +} + +TEST_F(ServerTest, bind_client_host_and_network_device) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions opt; + ASSERT_EQ(0, server.Start(ep, &opt)); + + brpc::Controller cntl; + brpc::ChannelOptions copt; + copt.client_host = "localhost"; + copt.device_name = "lo"; + std::vector connection_types = { + brpc::CONNECTION_TYPE_SINGLE, + brpc::CONNECTION_TYPE_POOLED, + brpc::CONNECTION_TYPE_SHORT + }; + for (auto connect_type : connection_types) { + copt.connection_type = connect_type; + TestClientHost(ep, cntl, 0, false, copt); + cntl.Reset(); + } + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + } //namespace