Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds: add xDS transport custom Dialer support #7586

Merged
merged 10 commits into from
Sep 27, 2024
18 changes: 18 additions & 0 deletions internal/xds/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ package bootstrap

import (
"bytes"
"context"
"encoding/json"
"fmt"
"maps"
"net"
"net/url"
"os"
"slices"
Expand Down Expand Up @@ -179,6 +181,7 @@ type ServerConfig struct {
// credentials and store it here for easy access.
selectedCreds ChannelCreds
credsDialOption grpc.DialOption
dialerOption grpc.DialOption

cleanups []func()
}
Expand Down Expand Up @@ -223,6 +226,12 @@ func (sc *ServerConfig) CredsDialOption() grpc.DialOption {
return sc.credsDialOption
}

// DialerOption returns the first supported Dialer function that specifies how
// to dial the xDS server from the configuration, as a dial option.
easwars marked this conversation as resolved.
Show resolved Hide resolved
func (sc *ServerConfig) DialerOption() grpc.DialOption {
return sc.dialerOption
}
easwars marked this conversation as resolved.
Show resolved Hide resolved

// Cleanups returns a collection of functions to be called when the xDS client
// for this server is closed. Allows cleaning up resources created specifically
// for this server.
Expand Down Expand Up @@ -275,6 +284,12 @@ func (sc *ServerConfig) MarshalJSON() ([]byte, error) {
return json.Marshal(server)
}

// dialer captures the Dialer method specified via the credentials bundle.
type dialer interface {
// Dialer specifies how to dial the xDS server.
Dialer(context.Context, string) (net.Conn, error)
easwars marked this conversation as resolved.
Show resolved Hide resolved
}

// UnmarshalJSON takes the json data (a server) and unmarshals it to the struct.
func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
server := serverConfigJSON{}
Expand All @@ -298,6 +313,9 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
}
sc.selectedCreds = cc
sc.credsDialOption = grpc.WithCredentialsBundle(bundle)
if d, ok := bundle.(dialer); ok {
sc.dialerOption = grpc.WithContextDialer(d.Dialer)
}
sc.cleanups = append(sc.cleanups, cancel)
break
}
Expand Down
158 changes: 158 additions & 0 deletions test/xds/xds_client_custom_dialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package xds_test

import (
"context"
"encoding/json"
"fmt"
"net"
"testing"

"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/xds/e2e"
internalbootstrap "google.golang.org/grpc/internal/xds/bootstrap"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/xds/bootstrap"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

const testDialerCredsBuilderName = "test_dialer_creds"

var (
mgmtServerAddress string
customDialerCalled bool
)

func init() {
bootstrap.RegisterCredentials(&testDialerCredsBuilder{})
}

// testDialerCredsBuilder implements the `Credentials` interface defined in
// package `xds/bootstrap` and encapsulates an insecure credential with a
// custom Dialer that specifies how to dial the xDS server.
type testDialerCredsBuilder struct{}

func (t *testDialerCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return &testDialerCredsBundle{}, func() {}, nil
}

func (t *testDialerCredsBuilder) Name() string {
return testDialerCredsBuilderName
}

// testDialerCredsBundle implements the `Bundle` interface defined in package
// `credentials` and encapsulates an insecure credential with a custom Dialer
// that specifies how to dial the xDS server.
type testDialerCredsBundle struct{}

func (t *testDialerCredsBundle) TransportCredentials() credentials.TransportCredentials {
return insecure.NewCredentials()
}

func (t *testDialerCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
return nil
}

func (t *testDialerCredsBundle) NewWithMode(string) (credentials.Bundle, error) {
return &testDialerCredsBundle{}, nil
}

// Dialer specifies how to dial the xDS management server.
func (t *testDialerCredsBundle) Dialer(context.Context, string) (net.Conn, error) {
customDialerCalled = true
// Create a pass-through connection (no-op) to the xDS management server.
return net.Dial("tcp", mgmtServerAddress)
}

func (s) TestClientCustomDialerFromCredentialsBundle(t *testing.T) {
customDialerCalled = false
easwars marked this conversation as resolved.
Show resolved Hide resolved

// Start an xDS management server.
mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{})

// Create bootstrap configuration pointing to the above management server.
nodeID := uuid.New().String()
bc, err := internalbootstrap.NewContentsForTesting(internalbootstrap.ConfigOptionsForTesting{
Servers: []byte(fmt.Sprintf(`[{
"server_uri": %q,
"channel_creds": [{"type": %q}]
}]`, mgmtServer.Address, testDialerCredsBuilderName)),
Node: []byte(fmt.Sprintf(`{"id": "%s"}`, nodeID)),
})
if err != nil {
t.Fatalf("Failed to create bootstrap configuration: %v", err)
}

// Set the management server address to be used by the custom dialer.
mgmtServerAddress = mgmtServer.Address

// Create an xDS resolver with the above bootstrap configuration.
var resolverBuilder resolver.Builder
if newResolver := internal.NewXDSResolverWithConfigForTesting; newResolver != nil {
resolverBuilder, err = newResolver.(func([]byte) (resolver.Builder, error))(bc)
if err != nil {
t.Fatalf("Failed to create xDS resolver for testing: %v", err)
}
}

// Spin up a test backend.
server := stubserver.StartTestService(t, nil)
defer server.Stop()

// Configure client side xDS resources on the management server.
const serviceName = "my-service-client-side-xds"
resources := e2e.DefaultClientResources(e2e.ResourceParams{
DialTarget: serviceName,
NodeID: nodeID,
Host: "localhost",
Port: testutils.ParsePort(t, server.Address),
SecLevel: e2e.SecurityLevelNone,
})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := mgmtServer.Update(ctx, resources); err != nil {
t.Fatal(err)
}

// Create a ClientConn and make a successful RPC. The insecure transport credentials passed into
// the gRPC.NewClient is the credentials for the data plane communication with the test backend.
cc, err := grpc.NewClient(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolverBuilder))
if err != nil {
t.Fatalf("failed to dial local test server: %v", err)
}
defer cc.Close()

client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
danielzhaotongliu marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("EmptyCall() failed: %v", err)
}

if !customDialerCalled {
t.Fatalf("xDS client transport custom dialer called = false, want true")
}
}
3 changes: 3 additions & 0 deletions xds/internal/xdsclient/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ func New(opts Options) (*Transport, error) {
Timeout: 20 * time.Second,
}),
}
if dialerOpts := opts.ServerCfg.DialerOption(); dialerOpts != nil {
dopts = append(dopts, dialerOpts)
}
grpcNewClient := transportinternal.GRPCNewClient.(func(string, ...grpc.DialOption) (*grpc.ClientConn, error))
cc, err := grpcNewClient(opts.ServerCfg.ServerURI(), dopts...)
if err != nil {
Expand Down
83 changes: 81 additions & 2 deletions xds/internal/xdsclient/transport/transport_test.go
easwars marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
package transport_test

import (
"context"
"encoding/json"
"net"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/internal/xds/bootstrap"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
internalbootstrap "google.golang.org/grpc/internal/xds/bootstrap"
"google.golang.org/grpc/xds/bootstrap"
"google.golang.org/grpc/xds/internal/xdsclient/transport"
"google.golang.org/grpc/xds/internal/xdsclient/transport/internal"

Expand All @@ -39,7 +45,7 @@ func (s) TestNewWithGRPCDial(t *testing.T) {
internal.GRPCNewClient = customDialer
defer func() { internal.GRPCNewClient = oldDial }()

serverCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: "server-address"})
serverCfg, err := internalbootstrap.ServerConfigForTesting(internalbootstrap.ServerConfigTestingOptions{URI: "server-address"})
if err != nil {
t.Fatalf("Failed to create server config for testing: %v", err)
}
Expand Down Expand Up @@ -82,3 +88,76 @@ func (s) TestNewWithGRPCDial(t *testing.T) {
t.Fatalf("transport.New(%+v) custom dialer called = true, want false", opts)
}
}

const testDialerCredsBuilderName = "test_dialer_creds"

func init() {
bootstrap.RegisterCredentials(&testDialerCredsBuilder{})
}

// testDialerCredsBuilder implements the `Credentials` interface defined in
// package `xds/bootstrap` and encapsulates an insecure credential with a
// custom Dialer that specifies how to dial the xDS server.
type testDialerCredsBuilder struct{}

func (t *testDialerCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return &testDialerCredsBundle{}, func() {}, nil
}

func (t *testDialerCredsBuilder) Name() string {
return testDialerCredsBuilderName
}

// testDialerCredsBundle implements the `Bundle` interface defined in package
// `credentials` and encapsulates an insecure credential with a custom Dialer
// that specifies how to dial the xDS server.
type testDialerCredsBundle struct{}
easwars marked this conversation as resolved.
Show resolved Hide resolved

func (t *testDialerCredsBundle) TransportCredentials() credentials.TransportCredentials {
return insecure.NewCredentials()
}

func (t *testDialerCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
return nil
}

func (t *testDialerCredsBundle) NewWithMode(string) (credentials.Bundle, error) {
return &testDialerCredsBundle{}, nil
}

func (t *testDialerCredsBundle) Dialer(context.Context, string) (net.Conn, error) {
return nil, nil
}

func (s) TestNewWithDialerFromCredentialsBundle(t *testing.T) {
easwars marked this conversation as resolved.
Show resolved Hide resolved
serverCfg, err := internalbootstrap.ServerConfigForTesting(internalbootstrap.ServerConfigTestingOptions{
URI: "trafficdirector.googleapis.com:443",
ChannelCreds: []internalbootstrap.ChannelCreds{{Type: testDialerCredsBuilderName}},
})
if err != nil {
t.Fatalf("Failed to create server config for testing: %v", err)
}
if serverCfg.DialerOption() == nil {
t.Fatalf("Dialer for xDS transport in server config for testing is nil, want non-nil")
}
// Create a new transport.
opts := transport.Options{
ServerCfg: serverCfg,
NodeProto: &v3corepb.Node{},
OnRecvHandler: func(update transport.ResourceUpdate, onDone func()) error {
onDone()
return nil
},
OnErrorHandler: func(error) {},
OnSendHandler: func(*transport.ResourceSendInfo) {},
}
c, err := transport.New(opts)
defer func() {
if c != nil {
c.Close()
}
}()
if err != nil {
t.Fatalf("transport.New(%v) failed: %v", opts, err)
}
}