Skip to content
48 changes: 36 additions & 12 deletions v2/pkg/engine/resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe
}

type trigger struct {
id uint64
cancel context.CancelFunc
subscriptions map[*Context]*sub
id uint64
cancel context.CancelFunc
subscriptions map[*Context]*sub
subscriptionIdentifiers map[SubscriptionIdentifier]*Context
// initialized is set to true when the trigger is started and initialized
initialized bool
updater *subscriptionUpdater
Expand Down Expand Up @@ -817,6 +818,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
// After the startup hooks are executed, we can add the subscription to the subscriptions registry
// so that it can start receive events
trig.subscriptions[add.ctx] = s
trig.subscriptionIdentifiers[s.id] = add.ctx
return
}

Expand All @@ -832,13 +834,15 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
}
cloneCtx := add.ctx.clone(ctx)
trig = &trigger{
id: triggerID,
subscriptions: make(map[*Context]*sub),
cancel: cancel,
updater: updater,
id: triggerID,
subscriptions: make(map[*Context]*sub),
subscriptionIdentifiers: make(map[SubscriptionIdentifier]*Context),
cancel: cancel,
updater: updater,
}
r.triggers[triggerID] = trig
trig.subscriptions[add.ctx] = s
trig.subscriptionIdentifiers[s.id] = add.ctx
updater.subsFn = trig.subscriptionIds

if r.reporter != nil {
Expand Down Expand Up @@ -1004,13 +1008,31 @@ func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifie
fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID)
}

for c, s := range trig.subscriptions {
if s.id != subIdentifier {
continue
// Fast path: O(1) lookup of the subscriptions resolver context via map
resolverCtx, exists := trig.subscriptionIdentifiers[subIdentifier]

// Fallback O(N) lookup in case we couldn't find the resolver context by map:
// Loop through trig.subscriptions and find the corresponding resolver context.
if !exists {
for i := range trig.subscriptions {
if trig.subscriptions[i].id == subIdentifier {
resolverCtx = i
exists = true
break
}
}
r.sendUpdateToSubscription(data, c, s)
break
}

if !exists {
return
}

subscription, exists := trig.subscriptions[resolverCtx]
if !exists {
return
}

r.sendUpdateToSubscription(data, resolverCtx, subscription)
}

func (r *Resolver) sendUpdateToSubscription(data []byte, c *Context, s *sub) {
Expand Down Expand Up @@ -1111,6 +1133,7 @@ func (r *Resolver) completeTriggerSubscriptions(id uint64, completeMatcher func(
// Important because we remove the subscription from the trigger on the same goroutine
// as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point.
delete(trig.subscriptions, c)
delete(trig.subscriptionIdentifiers, s.id)

if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID)
Expand Down Expand Up @@ -1142,6 +1165,7 @@ func (r *Resolver) closeTriggerSubscriptions(id uint64, closeKind SubscriptionCl
// Important because we remove the subscription from the trigger on the same goroutine
// as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point.
delete(trig.subscriptions, c)
delete(trig.subscriptionIdentifiers, s.id)

if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID)
Expand Down
163 changes: 163 additions & 0 deletions v2/pkg/engine/resolve/resolve_subscription_benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package resolve

import (
"bytes"
"context"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func BenchmarkUpdateSubscription(b *testing.B) {
for _, n := range []int{1, 10, 100, 1000} {
b.Run(func() string {
switch n {
case 1:
return "1_subs"
case 10:
return "10_subs"
case 100:
return "100_subs"
default:
return "1000_subs"
}
}(), func(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
b.Cleanup(cancel)

streamDone := make(chan struct{})
b.Cleanup(func() { close(streamDone) })

updaters := make([]func([]byte), n)
var setupWg sync.WaitGroup
setupWg.Add(n)

// Each subscription gets its own fakeStream so that its
// subscriptionOnStartFn captures the correct slot index i.
// executeStartupHooks uses add.resolve.Trigger.Source (the
// subscription's own plan source), so the hook fires on the
// right stream regardless of goroutine scheduling order.
// All subscriptions share the same static input, so they all
// land on a single trigger whose subscriptionIdentifiers map
// grows to N entries — the map we want to exercise.
makePlan := func(i int) *GraphQLSubscription {
stream := createFakeStream(
func(counter int) (message string, done bool) {
<-streamDone
return "", true
},
0,
nil,
func(hookCtx StartupHookContext, _ []byte) error {
updaters[i] = hookCtx.Updater
setupWg.Done()
return nil
},
)

fetches := Sequence()
fetches.Trigger = &FetchTreeNode{
Kind: FetchTreeNodeKindTrigger,
Item: &FetchItem{
Fetch: &SingleFetch{
FetchDependencies: FetchDependencies{
FetchID: 0,
},
Info: &FetchInfo{
DataSourceID: "0",
DataSourceName: "counter",
QueryPlan: &QueryPlan{
Query: "subscription {\n counter\n}",
},
},
},
ResponsePath: "counter",
},
}

return &GraphQLSubscription{
Trigger: GraphQLSubscriptionTrigger{
Source: stream,
InputTemplate: InputTemplate{
Segments: []TemplateSegment{
{
SegmentType: StaticSegmentType,
Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`),
},
},
},
PostProcessing: PostProcessingConfiguration{
SelectResponseDataPath: []string{"data"},
SelectResponseErrorsPath: []string{"errors"},
},
},
Response: &GraphQLResponse{
Data: &Object{
Fields: []*Field{
{
Name: []byte("counter"),
Value: &Integer{
Path: []string{"counter"},
},
Info: &FieldInfo{
Name: "counter",
ExactParentTypeName: "Subscription",
Source: TypeFieldSource{
IDs: []string{"0"},
Names: []string{"counter"},
},
FetchID: 0,
},
},
},
},
Fetches: fetches,
},
}
}

resolver := newResolver(ctx)

recorders := make([]*SubscriptionRecorder, n)
for i := 0; i < n; i++ {
recorders[i] = &SubscriptionRecorder{
buf: &bytes.Buffer{},
messages: []string{},
complete: atomic.Bool{},
}
subCtx := &Context{ctx: context.Background()}
id := SubscriptionIdentifier{
ConnectionID: 1,
SubscriptionID: int64(i + 1),
}
err := resolver.AsyncResolveGraphQLSubscription(subCtx, makePlan(i), recorders[i], id)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
require.NoError(b, err)
}

// Block until all N startup hooks have fired, guaranteeing all
// entries are in subscriptionIdentifiers before timing starts.
setupWg.Wait()

b.ResetTimer()
b.ReportAllocs()

data := []byte(`{"data":{"counter":1}}`)
for i := 0; i < b.N; i++ {
// Update every subscription on the trigger sequentially.
// Each call does an O(1) map lookup in subscriptionIdentifiers
// then delivers to that subscription's worker.
for j := 0; j < n; j++ {
updaters[j](data)
}
}

// Every recorder must have received exactly b.N messages.
for i := 0; i < n; i++ {
recorders[i].AwaitMessages(b, b.N, 30*time.Second)
}
})
}
}
Loading
Loading