diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index f735752ef9..7375eb43d3 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -437,9 +437,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe } type trigger struct { - id uint64 - cancel context.CancelFunc - subscriptions map[*Context]*sub + id uint64 + cancel context.CancelFunc + subscriptions map[*Context]*sub + subscriptionIdentifiers map[SubscriptionIdentifier]*Context // initialized is set to true when the trigger is started and initialized initialized bool updater *subscriptionUpdater @@ -817,6 +818,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) // After the startup hooks are executed, we can add the subscription to the subscriptions registry // so that it can start receive events trig.subscriptions[add.ctx] = s + trig.subscriptionIdentifiers[s.id] = add.ctx return } @@ -832,13 +834,15 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) } cloneCtx := add.ctx.clone(ctx) trig = &trigger{ - id: triggerID, - subscriptions: make(map[*Context]*sub), - cancel: cancel, - updater: updater, + id: triggerID, + subscriptions: make(map[*Context]*sub), + subscriptionIdentifiers: make(map[SubscriptionIdentifier]*Context), + cancel: cancel, + updater: updater, } r.triggers[triggerID] = trig trig.subscriptions[add.ctx] = s + trig.subscriptionIdentifiers[s.id] = add.ctx updater.subsFn = trig.subscriptionIds if r.reporter != nil { @@ -1004,13 +1008,31 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) } - for c, s := range trig.subscriptions { - if s.id != subIdentifier { - continue + // Fast path: O(1) lookup of the subscriptions resolver context via map + resolverCtx, exists := trig.subscriptionIdentifiers[subIdentifier] + + // Fallback O(N) lookup in case we couldn't find the resolver context by map: + // Loop through trig.subscriptions and find the corresponding resolver context. + if !exists { + for i := range trig.subscriptions { + if trig.subscriptions[i].id == subIdentifier { + resolverCtx = i + exists = true + break + } } - r.sendUpdateToSubscription(data, c, s) - break } + + if !exists { + return + } + + subscription, exists := trig.subscriptions[resolverCtx] + if !exists { + return + } + + r.sendUpdateToSubscription(data, resolverCtx, subscription) } func (r *Resolver) sendUpdateToSubscription(data []byte, c *Context, s *sub) { @@ -1111,6 +1133,7 @@ func (r *Resolver) completeTriggerSubscriptions(id uint64, completeMatcher func( // Important because we remove the subscription from the trigger on the same goroutine // as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point. delete(trig.subscriptions, c) + delete(trig.subscriptionIdentifiers, s.id) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID) @@ -1142,6 +1165,7 @@ func (r *Resolver) closeTriggerSubscriptions(id uint64, closeKind SubscriptionCl // Important because we remove the subscription from the trigger on the same goroutine // as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point. delete(trig.subscriptions, c) + delete(trig.subscriptionIdentifiers, s.id) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID) diff --git a/v2/pkg/engine/resolve/resolve_subscription_benchmark_test.go b/v2/pkg/engine/resolve/resolve_subscription_benchmark_test.go new file mode 100644 index 0000000000..b943a15921 --- /dev/null +++ b/v2/pkg/engine/resolve/resolve_subscription_benchmark_test.go @@ -0,0 +1,163 @@ +package resolve + +import ( + "bytes" + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func BenchmarkUpdateSubscription(b *testing.B) { + for _, n := range []int{1, 10, 100, 1000} { + b.Run(func() string { + switch n { + case 1: + return "1_subs" + case 10: + return "10_subs" + case 100: + return "100_subs" + default: + return "1000_subs" + } + }(), func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + b.Cleanup(cancel) + + streamDone := make(chan struct{}) + b.Cleanup(func() { close(streamDone) }) + + updaters := make([]func([]byte), n) + var setupWg sync.WaitGroup + setupWg.Add(n) + + // Each subscription gets its own fakeStream so that its + // subscriptionOnStartFn captures the correct slot index i. + // executeStartupHooks uses add.resolve.Trigger.Source (the + // subscription's own plan source), so the hook fires on the + // right stream regardless of goroutine scheduling order. + // All subscriptions share the same static input, so they all + // land on a single trigger whose subscriptionIdentifiers map + // grows to N entries — the map we want to exercise. + makePlan := func(i int) *GraphQLSubscription { + stream := createFakeStream( + func(counter int) (message string, done bool) { + <-streamDone + return "", true + }, + 0, + nil, + func(hookCtx StartupHookContext, _ []byte) error { + updaters[i] = hookCtx.Updater + setupWg.Done() + return nil + }, + ) + + fetches := Sequence() + fetches.Trigger = &FetchTreeNode{ + Kind: FetchTreeNodeKindTrigger, + Item: &FetchItem{ + Fetch: &SingleFetch{ + FetchDependencies: FetchDependencies{ + FetchID: 0, + }, + Info: &FetchInfo{ + DataSourceID: "0", + DataSourceName: "counter", + QueryPlan: &QueryPlan{ + Query: "subscription {\n counter\n}", + }, + }, + }, + ResponsePath: "counter", + }, + } + + return &GraphQLSubscription{ + Trigger: GraphQLSubscriptionTrigger{ + Source: stream, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`), + }, + }, + }, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("counter"), + Value: &Integer{ + Path: []string{"counter"}, + }, + Info: &FieldInfo{ + Name: "counter", + ExactParentTypeName: "Subscription", + Source: TypeFieldSource{ + IDs: []string{"0"}, + Names: []string{"counter"}, + }, + FetchID: 0, + }, + }, + }, + }, + Fetches: fetches, + }, + } + } + + resolver := newResolver(ctx) + + recorders := make([]*SubscriptionRecorder, n) + for i := 0; i < n; i++ { + recorders[i] = &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + subCtx := NewContext(context.Background()) + id := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: int64(i + 1), + } + err := resolver.AsyncResolveGraphQLSubscription(subCtx, makePlan(i), recorders[i], id) + require.NoError(b, err) + } + + // Block until all N startup hooks have fired, guaranteeing all + // entries are in subscriptionIdentifiers before timing starts. + setupWg.Wait() + + b.ResetTimer() + b.ReportAllocs() + + data := []byte(`{"data":{"counter":1}}`) + for i := 0; i < b.N; i++ { + // Update every subscription on the trigger sequentially. + // Each call does an O(1) map lookup in subscriptionIdentifiers + // then delivers to that subscription's worker. + for j := 0; j < n; j++ { + updaters[j](data) + } + } + + // Every recorder must have received exactly b.N messages. + for i := 0; i < n; i++ { + recorders[i].AwaitMessages(b, b.N, 30*time.Second) + } + }) + } +} diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 82a8e1e635..32f73e218a 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -5351,7 +5351,7 @@ type SubscriptionRecorder struct { var _ SubscriptionResponseWriter = (*SubscriptionRecorder)(nil) -func (s *SubscriptionRecorder) AwaitMessages(t *testing.T, count int, timeout time.Duration) { +func (s *SubscriptionRecorder) AwaitMessages(t testing.TB, count int, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for { @@ -5368,7 +5368,7 @@ func (s *SubscriptionRecorder) AwaitMessages(t *testing.T, count int, timeout ti } } -func (s *SubscriptionRecorder) AwaitAnyMessageCount(t *testing.T, timeout time.Duration) { +func (s *SubscriptionRecorder) AwaitAnyMessageCount(t testing.TB, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for { @@ -5385,7 +5385,7 @@ func (s *SubscriptionRecorder) AwaitAnyMessageCount(t *testing.T, timeout time.D } } -func (s *SubscriptionRecorder) AwaitComplete(t *testing.T, timeout time.Duration) { +func (s *SubscriptionRecorder) AwaitComplete(t testing.TB, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for { @@ -5399,7 +5399,7 @@ func (s *SubscriptionRecorder) AwaitComplete(t *testing.T, timeout time.Duration } } -func (s *SubscriptionRecorder) AwaitClosed(t *testing.T, timeout time.Duration) { +func (s *SubscriptionRecorder) AwaitClosed(t testing.TB, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for { @@ -6546,6 +6546,140 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { assert.Contains(t, errorMessage, "errors", "Expected error message in GraphQL format") assert.Contains(t, errorMessage, expectedErr.Error(), "Expected actual error message to be included") }) + + t.Run("subscription added to existing trigger can be targeted by UpdateSubscription", func(t *testing.T) { + // Verifies that subscriptionIdentifiers is populated for subscriptions joining an + // already-running trigger (the existing-trigger path in handleAddSubscription), so + // that handleUpdateSubscription can reach them via O(1) map lookup. + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + id2 := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 2, + } + + sub1HookDone := make(chan struct{}) + streamCanSend := make(chan struct{}) + + // The startup hook is shared by both subscriptions. + // Sub1 (new-trigger path) closes sub1HookDone and does nothing else. + // Sub2 (existing-trigger path) calls ctx.Updater, sending a targeted update only to sub2. + // Because the test waits for sub1HookDone before registering sub2, sub2's hook always + // finds sub1HookDone already closed, making the branch selection deterministic. + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + <-streamCanSend + return `{"data":{"counter":0}}`, true + }, 0, nil, func(ctx StartupHookContext, input []byte) error { + select { + case <-sub1HookDone: + // sub1HookDone is already closed: this is sub2 (existing-trigger path). + ctx.Updater([]byte(`{"data":{"counter":1000}}`)) + default: + // First call: this is sub1 (new-trigger path). + close(sub1HookDone) + } + return nil + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx1 := NewContext(context.Background()) + err := resolver.AsyncResolveGraphQLSubscription(ctx1, plan, recorder, id) + assert.NoError(t, err) + + // Wait for sub1's startup hook to complete before adding sub2, + // guaranteeing sub2 joins the existing trigger. + select { + case <-sub1HookDone: + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for sub1 startup hook") + } + + recorder2 := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + recorder2.complete.Store(false) + + ctx2 := NewContext(context.Background()) + err2 := resolver.AsyncResolveGraphQLSubscription(ctx2, plan, recorder2, id2) + assert.NoError(t, err2) + + // Wait for sub2 to receive its targeted update from the startup hook. + recorder2.AwaitAnyMessageCount(t, defaultTimeout) + + // Signal the stream to send its final message and complete both subscriptions. + close(streamCanSend) + + recorder.AwaitComplete(t, defaultTimeout) + recorder2.AwaitComplete(t, defaultTimeout) + + // sub1 receives only the stream message — it was not the target of ctx.Updater. + assert.Len(t, recorder.Messages(), 1) + assert.Equal(t, `{"data":{"counter":0}}`, recorder.Messages()[0]) + + // sub2 receives both: the targeted startup update and the stream message. + assert.Len(t, recorder2.Messages(), 2) + assert.Equal(t, `{"data":{"counter":1000}}`, recorder2.Messages()[0]) + assert.Equal(t, `{"data":{"counter":0}}`, recorder2.Messages()[1]) + }) + + t.Run("subscriptionIdentifiers entry is removed when subscription is unsubscribed", func(t *testing.T) { + // Verifies that the subscriptionIdentifiers map is cleaned up when a subscription is + // removed. Without cleanup, calling ctx.Updater after unsubscription would find a stale + // entry, attempt to send work to the already-closed work channel, and panic. + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + streamCanEnd := make(chan struct{}) + var capturedUpdater func([]byte) + hookDone := make(chan struct{}) + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + <-streamCanEnd + return "", true + }, 0, nil, func(ctx StartupHookContext, input []byte) error { + capturedUpdater = ctx.Updater + ctx.Updater([]byte(`{"data":{"counter":1000}}`)) + close(hookDone) + return nil + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx1 := NewContext(context.Background()) + err := resolver.AsyncResolveGraphQLSubscription(ctx1, plan, recorder, id) + assert.NoError(t, err) + + select { + case <-hookDone: + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for startup hook") + } + recorder.AwaitAnyMessageCount(t, defaultTimeout) + + // Unsubscribe before the stream sends any messages. + err = resolver.AsyncUnsubscribeSubscription(id) + assert.NoError(t, err) + recorder.AwaitClosed(t, defaultTimeout) + + // Unblock the stream so its goroutine can exit cleanly. + close(streamCanEnd) + + // Calling the captured updater after the subscription has been cleaned up must be + // a no-op. If subscriptionIdentifiers was not cleaned up, this would find a stale + // entry, try to send work to the closed work channel, and panic. + capturedUpdater([]byte(`{"data":{"counter":2000}}`)) + + // Give the event loop time to process the update event. + time.Sleep(50 * time.Millisecond) + + // Only the startup update should have been delivered; the post-removal call is dropped. + assert.Len(t, recorder.Messages(), 1) + assert.Equal(t, `{"data":{"counter":1000}}`, recorder.Messages()[0]) + }) } func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {