diff --git a/packages/client-proxy/internal/proxy/proxy.go b/packages/client-proxy/internal/proxy/proxy.go index 92f5cc9df8..b43f14655f 100644 --- a/packages/client-proxy/internal/proxy/proxy.go +++ b/packages/client-proxy/internal/proxy/proxy.go @@ -16,7 +16,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/e2b-dev/infra/packages/shared/pkg/env" "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" proxygrpc "github.com/e2b-dev/infra/packages/shared/pkg/grpc/proxy" "github.com/e2b-dev/infra/packages/shared/pkg/logger" @@ -63,6 +62,24 @@ func normalizeNodeIP(nodeIP string) (string, error) { return nodeIP, nil } +func orchestratorSandboxHost(host string, sandboxID string, port uint64) *string { + hostname := strings.Split(host, ":")[0] + if hostname == "localhost" { + orchestratorHost := fmt.Sprintf("%d-%s.localhost", port, sandboxID) + + return &orchestratorHost + } + + domain, ok := strings.CutPrefix(hostname, "envd.") + if !ok || domain == "" { + return nil + } + + orchestratorHost := fmt.Sprintf("%d-%s.%s", port, sandboxID, domain) + + return &orchestratorHost +} + func catalogResolution(ctx context.Context, sandboxId string, sandboxPort uint64, trafficAccessToken string, envdAccessToken string, c catalog.SandboxesCatalog, pausedChecker PausedSandboxResumer, featureFlags *featureflags.Client) (string, error) { s, err := c.GetSandbox(ctx, sandboxId) if err != nil { @@ -133,7 +150,7 @@ func handlePausedSandbox( } func NewClientProxy(meterProvider metric.MeterProvider, serviceName string, port uint16, catalog catalog.SandboxesCatalog, pausedSandboxResumer PausedSandboxResumer, featureFlagsClient *featureflags.Client) (*reverseproxy.Proxy, error) { - getTargetFromRequest := reverseproxy.GetTargetFromRequest(env.IsLocal()) + getTargetFromRequest := reverseproxy.GetTargetFromRequest(reverseproxy.HeaderRoutingEnabled) proxy := reverseproxy.New( port, // Retries that are needed to handle port forwarding delays in sandbox envd are handled by the orchestrator proxy @@ -198,6 +215,11 @@ func NewClientProxy(meterProvider metric.MeterProvider, serviceName string, port SandboxPort: port, ConnectionKey: pool.ClientProxyConnectionKey, Url: url, + MaskRequestHost: orchestratorSandboxHost( + r.Host, + sandboxId, + port, + ), }, nil }, nil, diff --git a/packages/orchestrator/pkg/proxy/proxy.go b/packages/orchestrator/pkg/proxy/proxy.go index ecd9141515..d27d1076fd 100644 --- a/packages/orchestrator/pkg/proxy/proxy.go +++ b/packages/orchestrator/pkg/proxy/proxy.go @@ -15,7 +15,6 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox" "github.com/e2b-dev/infra/packages/shared/pkg/connlimit" "github.com/e2b-dev/infra/packages/shared/pkg/consts" - "github.com/e2b-dev/infra/packages/shared/pkg/env" "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" "github.com/e2b-dev/infra/packages/shared/pkg/logger" reverseproxy "github.com/e2b-dev/infra/packages/shared/pkg/proxy" @@ -40,7 +39,7 @@ type SandboxProxy struct { } func NewSandboxProxy(meterProvider metric.MeterProvider, port uint16, sandboxes *sandbox.Map, featureFlags *featureflags.Client) (*SandboxProxy, error) { - getTargetFromRequest := reverseproxy.GetTargetFromRequest(env.IsLocal()) + getTargetFromRequest := reverseproxy.GetTargetFromRequest(reverseproxy.HeaderRoutingDisabled) limiter := connlimit.NewConnectionLimiter() metrics := NewMetrics(meterProvider) diff --git a/packages/shared/pkg/proxy/host.go b/packages/shared/pkg/proxy/host.go index 52321e2e3f..30aad1d6e7 100644 --- a/packages/shared/pkg/proxy/host.go +++ b/packages/shared/pkg/proxy/host.go @@ -9,9 +9,16 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/id" ) -func GetTargetFromRequest(processHeaders bool) func(r *http.Request) (sandboxId string, port uint64, err error) { +type HeaderRoutingMode uint8 + +const ( + HeaderRoutingDisabled HeaderRoutingMode = iota + HeaderRoutingEnabled +) + +func GetTargetFromRequest(headerRouting HeaderRoutingMode) func(r *http.Request) (sandboxId string, port uint64, err error) { return func(r *http.Request) (sandboxId string, port uint64, err error) { - if processHeaders { + if headerRouting == HeaderRoutingEnabled && shouldParseHeaders(r.Host) && hasRoutingHeaders(r.Header) { var ok bool sandboxId, port, ok, err = parseHeaders(r.Header) if err != nil { @@ -38,6 +45,16 @@ func GetTargetFromRequest(processHeaders bool) func(r *http.Request) (sandboxId } } +func shouldParseHeaders(host string) bool { + host = strings.Split(host, ":")[0] + + return host == "localhost" || strings.HasPrefix(host, "envd.") +} + +func hasRoutingHeaders(h http.Header) bool { + return h.Get(headerSandboxID) != "" || h.Get(headerSandboxPort) != "" +} + func parseHost(host string) (sandboxID string, port uint64, err error) { dot := strings.Index(host, ".") diff --git a/packages/shared/pkg/proxy/host_test.go b/packages/shared/pkg/proxy/host_test.go index db42d62840..f27e497908 100644 --- a/packages/shared/pkg/proxy/host_test.go +++ b/packages/shared/pkg/proxy/host_test.go @@ -7,10 +7,10 @@ import ( "github.com/stretchr/testify/require" ) -func TestGetTargetFromRequest(t *testing.T) { //nolint:tparallel // cannot call t.Setenv with t.Parallel - t.Setenv("ENVIRONMENT", "local") +func TestGetTargetFromRequest(t *testing.T) { + t.Parallel() - getTargetFromRequest := GetTargetFromRequest(true) + getTargetFromRequest := GetTargetFromRequest(HeaderRoutingEnabled) tests := []struct { name string @@ -142,6 +142,26 @@ func TestGetTargetFromRequest(t *testing.T) { //nolint:tparallel // cannot call }, wantErrIs: MissingHeaderError{Header: headerSandboxPort}, }, + { + name: "headers: envd shared host", + host: "envd.e2b.app", + headers: http.Header{ + headerSandboxID: []string{"isv6ril5xadwn1k9t2jye"}, + headerSandboxPort: []string{"49983"}, + }, + wantID: "isv6ril5xadwn1k9t2jye", + wantPort: 49983, + }, + { + name: "headers: ignored on regular sandbox host", + host: "49983-isv6ril5xadwn1k9t2jye.e2b.app", + headers: http.Header{ + headerSandboxID: []string{"iother5b5aiixd410phsjv"}, + headerSandboxPort: []string{"3000"}, + }, + wantID: "isv6ril5xadwn1k9t2jye", + wantPort: 49983, + }, } for _, tt := range tests {