diff --git a/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/barriers.go b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/barriers.go new file mode 100644 index 0000000000..6094f4b74f --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/barriers.go @@ -0,0 +1,149 @@ +package testharness + +import ( + "context" + "fmt" + "sync" +) + +// Point identifies which worker hook to park on. Values must match the +// parent package's faultPhase iota so the hook can cast across. +type Point uint8 + +const ( + // BeforeRLock parks before settleRequests.RLock. + BeforeRLock Point = iota + // BeforeFaultPage parks after settleRequests.RLock, before UFFDIO_COPY. + BeforeFaultPage +) + +// Registry is the child-side barrier store consulted by the per-fault hook. +type Registry struct { + mu sync.Mutex + next uint64 + tokens map[uint64]*slot + byKey map[key]uint64 +} + +type key struct { + addr uintptr + point Point +} + +type slot struct { + addr uintptr + point Point + arrived chan struct{} + release chan struct{} + arrivedOnce sync.Once +} + +func NewRegistry() *Registry { + return &Registry{ + tokens: make(map[uint64]*slot), + byKey: make(map[key]uint64), + } +} + +func (r *Registry) Install(addr uintptr, point Point) uint64 { + r.mu.Lock() + defer r.mu.Unlock() + + r.next++ + token := r.next + s := &slot{ + addr: addr, + point: point, + arrived: make(chan struct{}), + release: make(chan struct{}), + } + r.tokens[token] = s + r.byKey[key{addr, point}] = token + + return token +} + +func (r *Registry) lookupByAddr(addr uintptr, point Point) *slot { + r.mu.Lock() + defer r.mu.Unlock() + + token, ok := r.byKey[key{addr, point}] + if !ok { + return nil + } + + return r.tokens[token] +} + +func (r *Registry) WaitArrived(ctx context.Context, token uint64) error { + r.mu.Lock() + s, ok := r.tokens[token] + r.mu.Unlock() + if !ok { + return fmt.Errorf("unknown barrier token %d", token) + } + + select { + case <-s.arrived: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// Release frees the barrier; unknown token is a no-op. +func (r *Registry) Release(token uint64) { + r.mu.Lock() + s, ok := r.tokens[token] + delete(r.tokens, token) + if ok { + // A later Install at this key overwrites byKey; only delete if + // it still maps to this token. + k := key{s.addr, s.point} + if r.byKey[k] == token { + delete(r.byKey, k) + } + } + r.mu.Unlock() + + if !ok { + return + } + + select { + case <-s.release: + default: + close(s.release) + } +} + +// ReleaseAll releases every still-installed barrier. +func (r *Registry) ReleaseAll() { + r.mu.Lock() + tokens := make([]uint64, 0, len(r.tokens)) + for t := range r.tokens { + tokens = append(tokens, t) + } + r.mu.Unlock() + + for _, t := range tokens { + r.Release(t) + } +} + +// Hook returns the per-fault hook to install on *Userfaultfd; faults +// without an installed slot are no-ops. +func (r *Registry) Hook() func(addr uintptr, point Point) { + return func(addr uintptr, point Point) { + s := r.lookupByAddr(addr, point) + if s == nil { + return + } + + s.arrivedOnce.Do(func() { + close(s.arrived) + }) + + <-s.release + } +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/client.go b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/client.go new file mode 100644 index 0000000000..ad1f0a6a05 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/client.go @@ -0,0 +1,73 @@ +package testharness + +import ( + "io" + "net/rpc" + "net/rpc/jsonrpc" +) + +// Client is the typed parent-side wrapper around the JSON-RPC channel +// to the child helper process. +type Client struct { + rpc *rpc.Client + conn io.Closer +} + +// NewClient wraps an already-connected duplex stream. Closing the +// returned Client closes the underlying conn. +func NewClient(conn io.ReadWriteCloser) *Client { + return &Client{ + rpc: jsonrpc.NewClient(conn), + conn: conn, + } +} + +func (c *Client) Bootstrap(args BootstrapArgs) error { + return c.rpc.Call("Lifecycle.Bootstrap", &args, &BootstrapReply{}) +} + +func (c *Client) WaitReady() error { + return c.rpc.Call("Lifecycle.WaitReady", &Empty{}, &Empty{}) +} + +func (c *Client) Shutdown() error { + return c.rpc.Call("Lifecycle.Shutdown", &Empty{}, &Empty{}) +} + +func (c *Client) Pause() error { + return c.rpc.Call("Paging.Pause", &Empty{}, &Empty{}) +} + +func (c *Client) Resume() error { + return c.rpc.Call("Paging.Resume", &Empty{}, &Empty{}) +} + +func (c *Client) PageStates() ([]PageStateEntry, error) { + var reply PageStatesReply + if err := c.rpc.Call("Paging.States", &Empty{}, &reply); err != nil { + return nil, err + } + + return reply.Entries, nil +} + +func (c *Client) InstallBarrier(addr uintptr, point Point) (uint64, error) { + var reply FaultBarrierReply + if err := c.rpc.Call("Barriers.Install", &FaultBarrierArgs{Addr: uint64(addr), Point: uint8(point)}, &reply); err != nil { + return 0, err + } + + return reply.Token, nil +} + +func (c *Client) WaitFaultHeld(token uint64) error { + return c.rpc.Call("Barriers.WaitHeld", &TokenArgs{Token: token}, &Empty{}) +} + +func (c *Client) ReleaseFault(token uint64) error { + return c.rpc.Call("Barriers.Release", &TokenArgs{Token: token}, &Empty{}) +} + +func (c *Client) Close() error { + return c.conn.Close() +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go new file mode 100644 index 0000000000..7a475b775d --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go @@ -0,0 +1,43 @@ +// Package testharness provides the wire types, typed RPC client, and +// barrier registry shared between the parent and child halves of the +// userfaultfd test harness. +package testharness + +// Empty is the placeholder for net/rpc methods that take or return +// nothing; net/rpc still requires both args and reply pointers. +type Empty struct{} + +type BootstrapArgs struct { + MmapStart uint64 + Pagesize int64 + TotalSize int64 + AlwaysWP bool + // Barriers gates the test-only worker hooks (off by default). + Barriers bool + Content []byte +} + +type BootstrapReply struct{} + +// PageStateEntry is the wire form of the parent package's pageState enum. +type PageStateEntry struct { + State uint8 + Offset uint64 +} + +type PageStatesReply struct { + Entries []PageStateEntry +} + +type FaultBarrierArgs struct { + Addr uint64 + Point uint8 +} + +type FaultBarrierReply struct { + Token uint64 +} + +type TokenArgs struct { + Token uint64 +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/async_wp_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/async_wp_test.go index ab8941abea..f779f963c8 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/async_wp_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/async_wp_test.go @@ -289,7 +289,7 @@ func TestAsyncWriteProtection(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - h, err := configureCrossProcessTest(t, testConfig{ + h, err := configureCrossProcessTest(t.Context(), t, testConfig{ pagesize: tt.pagesize, numberOfPages: tt.numberOfPages, alwaysWP: tt.alwaysWP, diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go deleted file mode 100644 index ece3c6eef6..0000000000 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ /dev/null @@ -1,586 +0,0 @@ -package userfaultfd - -// This tests is creating uffd in the main process and handling the page faults in another process. -// It prevents problems with Go mmap during testing (https://pojntfx.github.io/networked-linux-memsync/main.html#limitations) and also more accurately simulates what we do with Firecracker. -// These problems are not affecting Firecracker, because: -// 1. It is a different process that handles the page faults -// 2. Does not use garbage collection - -import ( - "context" - "crypto/rand" - "encoding/binary" - "errors" - "fmt" - "io" - "os" - "os/exec" - "os/signal" - "slices" - "strconv" - "strings" - "sync" - "syscall" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sys/unix" - - "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" - "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory" - "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" - "github.com/e2b-dev/infra/packages/shared/pkg/logger" -) - -// MemorySlicer exposes byte slice via the Slicer interface. -// This is used for testing purposes. -type MemorySlicer struct { - content []byte - pagesize int64 -} - -var _ block.Slicer = (*MemorySlicer)(nil) - -func NewMemorySlicer(content []byte, pagesize int64) *MemorySlicer { - return &MemorySlicer{ - content: content, - pagesize: pagesize, - } -} - -func (s *MemorySlicer) Slice(_ context.Context, offset, size int64) ([]byte, error) { - return s.content[offset : offset+size], nil -} - -func (s *MemorySlicer) Size() (int64, error) { - return int64(len(s.content)), nil -} - -func (s *MemorySlicer) Content() []byte { - return s.content -} - -func (s *MemorySlicer) BlockSize() int64 { - return s.pagesize -} - -func RandomPages(pagesize, numberOfPages uint64) *MemorySlicer { - size := pagesize * numberOfPages - - n := int(size) - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - panic(err) - } - - return NewMemorySlicer(buf, int64(pagesize)) -} - -// Main process, FC in our case -func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error) { - t.Helper() - - data := RandomPages(tt.pagesize, tt.numberOfPages) - - size, err := data.Size() - require.NoError(t, err) - - memoryArea, memoryStart, err := testutils.NewPageMmap(t, uint64(size), tt.pagesize) - require.NoError(t, err) - - // We can pass mapping nil as the serve is used only in the helper process. - uffdFd, err := newFd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) - require.NoError(t, err) - - t.Cleanup(func() { - uffdFd.close() - }) - - err = configureApi(uffdFd, tt.pagesize) - require.NoError(t, err) - - err = register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP) - require.NoError(t, err) - - cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess", "-test.timeout=0") - cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") - cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) - cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) - if tt.alwaysWP { - cmd.Env = append(cmd.Env, "GO_ALWAYS_WP=1") - } - if tt.gated { - cmd.Env = append(cmd.Env, "GO_GATED=1") - } - - dup, err := syscall.Dup(int(uffdFd)) - require.NoError(t, err) - - // clear FD_CLOEXEC on the dup we pass across exec - _, err = unix.FcntlInt(uintptr(dup), unix.F_SETFD, 0) - require.NoError(t, err) - - uffdFile := os.NewFile(uintptr(dup), "uffd") - - contentReader, contentWriter, err := os.Pipe() - require.NoError(t, err) - - go func() { - _, writeErr := contentWriter.Write(data.Content()) - assert.NoError(t, writeErr) - - closeErr := contentWriter.Close() - assert.NoError(t, closeErr) - }() - - offsetsReader, offsetsWriter, err := os.Pipe() - require.NoError(t, err) - - t.Cleanup(func() { - offsetsReader.Close() - }) - - readyReader, readyWriter, err := os.Pipe() - require.NoError(t, err) - - t.Cleanup(func() { - readyReader.Close() - }) - - readySignal := make(chan struct{}, 1) - go func() { - _, err := io.ReadAll(readyReader) - assert.NoError(t, err) - - readySignal <- struct{}{} - }() - - extraFiles := []*os.File{ - uffdFile, - contentReader, - offsetsWriter, - readyWriter, - } - - var gateCmdWriter *os.File - var gateSyncReader *os.File - if tt.gated { - var gateCmdReader *os.File - gateCmdReader, gateCmdWriter, err = os.Pipe() - require.NoError(t, err) - - var gateSyncWriter *os.File - gateSyncReader, gateSyncWriter, err = os.Pipe() - require.NoError(t, err) - - t.Cleanup(func() { - gateCmdWriter.Close() - gateSyncReader.Close() - }) - - extraFiles = append(extraFiles, gateCmdReader) // fd 7 - extraFiles = append(extraFiles, gateSyncWriter) // fd 8 - } - - cmd.ExtraFiles = extraFiles - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - err = cmd.Start() - require.NoError(t, err) - - contentReader.Close() - offsetsWriter.Close() - readyWriter.Close() - uffdFile.Close() - if tt.gated { - extraFiles[4].Close() // gateCmdReader - extraFiles[5].Close() // gateSyncWriter - } - - t.Cleanup(func() { - signalErr := cmd.Process.Signal(syscall.SIGUSR1) - assert.NoError(t, signalErr) - - waitErr := cmd.Wait() - // It can be either nil, an ExitError, a context.Canceled error, or "signal: killed" - assert.True(t, - (waitErr != nil && func(err error) bool { - var exitErr *exec.ExitError - - return errors.As(err, &exitErr) - }(waitErr)) || - errors.Is(waitErr, context.Canceled) || - (waitErr != nil && strings.Contains(waitErr.Error(), "signal: killed")) || - waitErr == nil, - "unexpected error: %v", waitErr, - ) - - // Tear down the UFFD registration before the early uffdFd.close() - // cleanup runs. Today this is a no-op (no test enables - // UFFD_FEATURE_EVENT_REMOVE) but a follow-up that does will - // otherwise see munmap block on un-acked REMOVE events queued - // against the still-registered range. Cleanups run LIFO, so - // this fires before the close registered earlier. - assert.NoError(t, unregister(uffdFd, memoryStart, uint64(size))) - }) - - // pageStatesOnce asks the serving process for a snapshot of its pageTracker - // and decodes it into a per-state view. It can only be called once. - pageStatesOnce := func() (handlerPageStates, error) { - err := cmd.Process.Signal(syscall.SIGUSR2) - if err != nil { - return handlerPageStates{}, err - } - - var result handlerPageStates - - for { - var entry pageStateEntry - - // binary.Read uses the same field layout as binary.Write on - // the producer side (sum of fixed-size fields, no struct - // padding), so we never have to hard-code the wire size. - err := binary.Read(offsetsReader, binary.LittleEndian, &entry) - if errors.Is(err, io.EOF) { - break - } - - if err != nil { - return handlerPageStates{}, fmt.Errorf("decoding page state entry: %w", err) - } - - if pageState(entry.State) == faulted { - result.faulted = append(result.faulted, uint(entry.Offset)) - } - } - - slices.Sort(result.faulted) - - return result, nil - } - - select { - case <-t.Context().Done(): - return nil, t.Context().Err() - case <-readySignal: - } - - h := &testHandler{ - memoryArea: &memoryArea, - pagesize: tt.pagesize, - data: data, - pageStatesOnce: pageStatesOnce, - } - - if tt.gated { - h.servePause = func() error { - if _, err := gateCmdWriter.Write([]byte{'P'}); err != nil { - return err - } - var buf [1]byte - _, err := gateSyncReader.Read(buf[:]) - - return err - } - h.serveResume = func() error { - _, err := gateCmdWriter.Write([]byte{'R'}) - - return err - } - } - - return h, nil -} - -// Secondary process, orchestrator in our case -func TestHelperServingProcess(t *testing.T) { - t.Parallel() - - if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { - t.Skip("this is a helper process, skipping direct execution") - } - - err := crossProcessServe() - if err != nil { - fmt.Println("exit serving process", err) - os.Exit(1) - } - - os.Exit(0) -} - -func crossProcessServe() error { - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - - startRaw, err := strconv.Atoi(os.Getenv("GO_MMAP_START")) - if err != nil { - return fmt.Errorf("exit parsing mmap start: %w", err) - } - - memoryStart := uintptr(startRaw) - - uffdFile := os.NewFile(uintptr(3), os.Getenv("GO_UFFD_FILE")) - defer uffdFile.Close() - - uffdFd := uffdFile.Fd() - - contentFile := os.NewFile(uintptr(4), "content") - defer contentFile.Close() - - content, err := io.ReadAll(contentFile) - if err != nil { - return fmt.Errorf("exit reading content: %w", err) - } - - pageSize, err := strconv.ParseInt(os.Getenv("GO_MMAP_PAGE_SIZE"), 10, 64) - if err != nil { - return fmt.Errorf("exit parsing page size: %w", err) - } - - data := NewMemorySlicer(content, pageSize) - - m := memory.NewMapping([]memory.Region{ - { - BaseHostVirtAddr: memoryStart, - Size: uintptr(len(content)), - Offset: 0, - PageSize: uintptr(pageSize), - }, - }) - - exitUffd := make(chan struct{}, 1) - defer close(exitUffd) - - l, err := logger.NewDevelopmentLogger() - if err != nil { - return fmt.Errorf("exit creating logger: %w", err) - } - - uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, l) - if err != nil { - return fmt.Errorf("exit creating uffd: %w", err) - } - - if os.Getenv("GO_ALWAYS_WP") == "1" { - uffd.defaultCopyMode = UFFDIO_COPY_MODE_WP - } - - offsetsFile := os.NewFile(uintptr(5), "offsets") - - offsetsSignal := make(chan os.Signal, 1) - signal.Notify(offsetsSignal, syscall.SIGUSR2) - defer signal.Stop(offsetsSignal) - - go func() { - defer offsetsFile.Close() - - for { - select { - case <-ctx.Done(): - return - case <-offsetsSignal: - entries, entriesErr := uffd.pageStateEntries() - if entriesErr != nil { - cancel(fmt.Errorf("error getting page state entries: %w", entriesErr)) - - return - } - - for _, entry := range entries { - writeErr := binary.Write(offsetsFile, binary.LittleEndian, entry) - if writeErr != nil { - cancel(fmt.Errorf("error writing page state entry: %w", writeErr)) - - return - } - } - - return - } - } - }() - - fdExit, err := fdexit.New() - if err != nil { - return fmt.Errorf("exit creating fd exit: %w", err) - } - defer fdExit.Close() - - go func() { - defer func() { - exitUffd <- struct{}{} - }() - - serverErr := uffd.Serve(ctx, fdExit) - if serverErr != nil { - msg := fmt.Errorf("error serving: %w", serverErr) - - fmt.Fprint(os.Stderr, msg.Error()) - - cancel(msg) - - return - } - }() - - // stopFn drains whichever Serve goroutine is currently running. The - // running flag plus stopMu makes both pause-then-exit (no resume in - // between) and pause-resume-pause-exit safe, and rejects nonsensical - // command sequences from the gated channel: 'P' when already paused - // is a no-op, 'R' when already running is a no-op so a stray or - // duplicate resume can't leak an untracked Serve goroutine and break - // later pauses. - var ( - stopMu sync.Mutex - running = true - stopFn = func() { - err := fdExit.SignalExit() - if err != nil { - msg := fmt.Errorf("error signaling exit: %w", err) - - fmt.Fprint(os.Stderr, msg.Error()) - - cancel(msg) - - return - } - - <-exitUffd - } - ) - - stopServe := func() { - stopMu.Lock() - if !running { - stopMu.Unlock() - - return - } - fn := stopFn - stopFn = func() {} - running = false - stopMu.Unlock() - - fn() - } - - defer stopServe() - - if os.Getenv("GO_GATED") == "1" { - gateCmdFile := os.NewFile(uintptr(7), "gate-cmd") - defer gateCmdFile.Close() - - gateSyncFile := os.NewFile(uintptr(8), "gate-sync") - defer gateSyncFile.Close() - - startServe := func() { - stopMu.Lock() - if running { - stopMu.Unlock() - - return - } - stopMu.Unlock() - - newExit, fdErr := fdexit.New() - if fdErr != nil { - cancel(fmt.Errorf("error creating fd exit: %w", fdErr)) - - return - } - - done := make(chan struct{}) - go func() { - defer close(done) - if err := uffd.Serve(ctx, newExit); err != nil { - cancel(fmt.Errorf("error serving: %w", err)) - } - }() - - stopMu.Lock() - stopFn = func() { - newExit.SignalExit() - <-done - newExit.Close() - } - running = true - stopMu.Unlock() - } - - go func() { - var buf [1]byte - for { - if _, err := gateCmdFile.Read(buf[:]); err != nil { - return - } - - switch buf[0] { - case 'P': - stopServe() - if _, err := gateSyncFile.Write([]byte{1}); err != nil { - cancel(fmt.Errorf("writing gate sync: %w", err)) - - return - } - case 'R': - startServe() - } - } - }() - } - - exitSignal := make(chan os.Signal, 1) - signal.Notify(exitSignal, syscall.SIGUSR1) - defer signal.Stop(exitSignal) - - readyFile := os.NewFile(uintptr(6), "ready") - - closeErr := readyFile.Close() - if closeErr != nil { - return fmt.Errorf("error closing ready file: %w", closeErr) - } - - select { - case <-ctx.Done(): - return fmt.Errorf("context done: %w: %w", ctx.Err(), context.Cause(ctx)) - case <-exitSignal: - return nil - } -} - -// pageStateEntry is the wire format used between the main test process -// and the serving helper process. State is emitted as a single byte so it -// can be written directly with binary.Write and decoded on the other side. -type pageStateEntry struct { - State uint8 - Offset uint64 -} - -// pageStateEntries returns a snapshot of every tracked page and its state. -// It first drains in-flight faultPage workers via settleRequests.Lock(), then -// holds the pageTracker RLock while iterating so any future writer that -// doesn't go through settleRequests (e.g. the upcoming REMOVE handler) -// still can't mutate the map under us. -func (u *Userfaultfd) pageStateEntries() ([]pageStateEntry, error) { - u.settleRequests.Lock() - defer u.settleRequests.Unlock() - - u.pageTracker.mu.RLock() - defer u.pageTracker.mu.RUnlock() - - entries := make([]pageStateEntry, 0, len(u.pageTracker.m)) - for addr, state := range u.pageTracker.m { - offset, err := u.ma.GetOffset(addr) - if err != nil { - return nil, fmt.Errorf("address %#x not in mapping: %w", addr, err) - } - - entries = append(entries, pageStateEntry{uint8(state), uint64(offset)}) - } - - return entries, nil -} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/harness_child_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/harness_child_test.go new file mode 100644 index 0000000000..899a2f2b95 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/harness_child_test.go @@ -0,0 +1,82 @@ +package userfaultfd + +import ( + "fmt" + "net" + "net/rpc" + "net/rpc/jsonrpc" + "os" + "sync" + "testing" +) + +func TestHelperServingProcess(t *testing.T) { + t.Parallel() + + if os.Getenv(envHelperFlag) != "1" { + t.Skip("this is a helper process, skipping direct execution") + } + + if err := crossProcessServe(); err != nil { + fmt.Fprintln(os.Stderr, "exit serving process:", err) + os.Exit(1) + } + + os.Exit(0) +} + +func crossProcessServe() error { + // fork+exec dup3's the parent's ExtraFiles to fd 3 (uffd) and fd 4 (rpc). + uffdFile := os.NewFile(uintptr(3), "uffd") + defer uffdFile.Close() + + rpcFile := os.NewFile(uintptr(4), "rpc") + conn, err := net.FileConn(rpcFile) + rpcFile.Close() + if err != nil { + return fmt.Errorf("net.FileConn rpc: %w", err) + } + // Explicit close before <-codecDone unblocks ServeCodec on the success + // path; the deferred close is the safety net for early returns. + var closeConnOnce sync.Once + closeConn := func() { closeConnOnce.Do(func() { _ = conn.Close() }) } + defer closeConn() + + state := newHarnessState(uffdFile.Fd()) + + server := rpc.NewServer() + if err := server.Register(&Lifecycle{state: state}); err != nil { + return fmt.Errorf("rpc Register Lifecycle: %w", err) + } + if err := server.Register(&Paging{state: state}); err != nil { + return fmt.Errorf("rpc Register Paging: %w", err) + } + if err := server.Register(&Barriers{state: state}); err != nil { + return fmt.Errorf("rpc Register Barriers: %w", err) + } + + // Run the codec in a goroutine so Shutdown can unblock us via ctx. + codecDone := make(chan struct{}) + go func() { + defer close(codecDone) + server.ServeCodec(jsonrpc.NewServerCodec(conn)) + }() + + select { + case <-state.ctx.Done(): + case <-codecDone: + } + + state.mu.Lock() + br := state.br + state.mu.Unlock() + if br != nil { + br.ReleaseAll() + } + state.stopServe() + + closeConn() + <-codecDone + + return nil +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/harness_parent_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/harness_parent_test.go new file mode 100644 index 0000000000..01baa10f4f --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/harness_parent_test.go @@ -0,0 +1,160 @@ +package userfaultfd + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "os" + "os/exec" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness" +) + +type MemorySlicer struct { + content []byte + pagesize int64 +} + +var _ block.Slicer = (*MemorySlicer)(nil) + +func NewMemorySlicer(content []byte, pagesize int64) *MemorySlicer { + return &MemorySlicer{content: content, pagesize: pagesize} +} + +func (s *MemorySlicer) Slice(_ context.Context, offset, size int64) ([]byte, error) { + return s.content[offset : offset+size], nil +} + +func (s *MemorySlicer) Size() (int64, error) { + return int64(len(s.content)), nil +} + +func (s *MemorySlicer) Content() []byte { + return s.content +} + +func (s *MemorySlicer) BlockSize() int64 { + return s.pagesize +} + +func RandomPages(pagesize, numberOfPages uint64) *MemorySlicer { + size := pagesize * numberOfPages + buf := make([]byte, int(size)) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + + return NewMemorySlicer(buf, int64(pagesize)) +} + +const envHelperFlag = "GO_TEST_HELPER_PROCESS" + +func configureCrossProcessTest(ctx context.Context, t *testing.T, tt testConfig) (*testHandler, error) { + t.Helper() + + data := RandomPages(tt.pagesize, tt.numberOfPages) + + size, err := data.Size() + require.NoError(t, err) + + memoryArea, memoryStart, err := testutils.NewPageMmap(t, uint64(size), tt.pagesize) + require.NoError(t, err) + + uffdFd, err := newFd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) + require.NoError(t, err) + t.Cleanup(func() { uffdFd.close() }) + + require.NoError(t, configureApi(uffdFd, tt.pagesize)) + require.NoError(t, register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP)) + t.Cleanup(func() { + // Unregister before close (LIFO): a future test enabling + // UFFD_FEATURE_EVENT_REMOVE would otherwise see munmap block on + // un-acked REMOVE events against a still-registered range. + assert.NoError(t, unregister(uffdFd, memoryStart, uint64(size))) + }) + + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestHelperServingProcess", "-test.timeout=0") + cmd.Env = append(os.Environ(), envHelperFlag+"=1") + + // F_DUPFD_CLOEXEC dup's atomically with CLOEXEC set; a concurrent + // fork cannot inherit the fd before we hand it off via ExtraFiles. + dup, err := unix.FcntlInt(uintptr(uffdFd), unix.F_DUPFD_CLOEXEC, 0) + require.NoError(t, err) + uffdFile := os.NewFile(uintptr(dup), "uffd") + + fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) + require.NoError(t, err) + parentEnd := os.NewFile(uintptr(fds[0]), "rpc-parent") + childEnd := os.NewFile(uintptr(fds[1]), "rpc-child") + + cmd.ExtraFiles = []*os.File{uffdFile, childEnd} + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + startErr := cmd.Start() + uffdFile.Close() + childEnd.Close() + if startErr != nil { + parentEnd.Close() + require.NoError(t, startErr) + } + + var client *testharness.Client + t.Cleanup(func() { + if client != nil { + _ = client.Shutdown() + _ = client.Close() + } else { + _ = cmd.Process.Kill() + } + _ = parentEnd.Close() + + waitErr := cmd.Wait() + if waitErr != nil { + var exitErr *exec.ExitError + if !errors.As(waitErr, &exitErr) { + t.Logf("helper process Wait: %v", waitErr) + } + } + }) + + parentConn, err := net.FileConn(parentEnd) + parentEnd.Close() + require.NoError(t, err) + + client = testharness.NewClient(parentConn) + + h := &testHandler{ + memoryArea: &memoryArea, + pagesize: tt.pagesize, + data: data, + client: client, + } + + if err := client.Bootstrap(testharness.BootstrapArgs{ + MmapStart: uint64(memoryStart), + Pagesize: int64(tt.pagesize), + TotalSize: size, + AlwaysWP: tt.alwaysWP, + Barriers: tt.barriers, + Content: data.Content(), + }); err != nil { + return nil, fmt.Errorf("Lifecycle.Bootstrap: %w", err) + } + + if err := client.WaitReady(); err != nil { + return nil, fmt.Errorf("Lifecycle.WaitReady: %w", err) + } + + return h, nil +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go index b7e28ef071..443e5cef83 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go @@ -3,7 +3,6 @@ package userfaultfd import ( "bytes" "context" - "errors" "fmt" "slices" "sync" @@ -16,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness" ) type testConfig struct { @@ -28,8 +28,8 @@ type testConfig struct { operations []operation // alwaysWP makes the handler copy with UFFDIO_COPY_MODE_WP for all faults. alwaysWP bool - // gated enables pause/resume control over the handler's serve loop. - gated bool + // barriers enables the per-worker fault hook (race tests only). + barriers bool } type operationMode uint32 @@ -52,25 +52,10 @@ type operation struct { async bool } -// handlerPageStates is a snapshot of the pageTracker grouped by state. It -// lets tests assert on the set of pages that the handler observed in each -// state, rather than a flat list of "accessed" offsets. Follow-up PRs can -// add more state-specific fields (e.g. removed) without touching the -// existing call sites. type handlerPageStates struct { faulted []uint } -// allAccessed returns the sorted union of offsets that the handler touched -// in any non-missing state. Tests that only care about "which pages did the -// handler see" can compare directly against this. -// -// pageStatesOnce already returns each per-state slice sorted, and a page -// has exactly one state at a time in pageTracker, so the per-state slices -// are disjoint. Follow-up PRs that add more state-specific fields should -// sorted-merge them here instead of reaching for a bitset — byte offsets -// make poor bit indices (a single hugepage offset would force ~1.8 MB of -// backing storage). func (s handlerPageStates) allAccessed() []uint { return slices.Clone(s.faulted) } @@ -79,15 +64,25 @@ type testHandler struct { memoryArea *[]byte pagesize uint64 data *MemorySlicer - // pageStatesOnce returns a per-state snapshot of the handler's pageTracker. - // It can only be called once. - pageStatesOnce func() (handlerPageStates, error) - // servePause and serveResume gate the UFFD event loop in the child process. - // Tests use them to deterministically batch a sequence of UFFD events - // before more faults are processed. - servePause func() error - serveResume func() error - mutex sync.RWMutex + client *testharness.Client + mutex sync.RWMutex +} + +func (h *testHandler) pageStates() (handlerPageStates, error) { + entries, err := h.client.PageStates() + if err != nil { + return handlerPageStates{}, err + } + + var states handlerPageStates + for _, e := range entries { + if pageState(e.State) == faulted { + states.faulted = append(states.faulted, uint(e.Offset)) + } + } + slices.Sort(states.faulted) + + return states, nil } func (h *testHandler) executeAll(t *testing.T, operations []operation) { @@ -174,17 +169,9 @@ func (h *testHandler) executeOperation(ctx context.Context, op operation) error case operationModeWrite: return h.executeWrite(ctx, op) case operationModeServePause: - if h.servePause == nil { - return errors.New("operationModeServePause requires testConfig.gated = true") - } - - return h.servePause() + return h.client.Pause() case operationModeServeResume: - if h.serveResume == nil { - return errors.New("operationModeServeResume requires testConfig.gated = true") - } - - return h.serveResume() + return h.client.Resume() case operationModeSleep: time.Sleep(50 * time.Millisecond) diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/hooks_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/hooks_test.go new file mode 100644 index 0000000000..5a6e81c7ca --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/hooks_test.go @@ -0,0 +1,12 @@ +package userfaultfd + +// SetTestFaultHook installs the per-fault hook atomically; pass nil to +// clear. Lives in a _test.go file so production binaries cannot link it. +func (u *Userfaultfd) SetTestFaultHook(h func(uintptr, faultPhase)) { + if h == nil { + u.testFaultHook.Store(nil) + + return + } + u.testFaultHook.Store(&h) +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_test.go index 6e99c865f9..6aeb8d1244 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_test.go @@ -121,14 +121,14 @@ func TestMissing(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) h.executeAll(t, tt.operations) expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") @@ -141,14 +141,14 @@ func TestMissing(t *testing.T) { func TestParallelMissing(t *testing.T) { t.Parallel() - parallelOperations := 10_000 + parallelOperations := 1_000_000 tt := testConfig{ pagesize: header.PageSize, numberOfPages: 2, } - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) readOp := operation{ @@ -169,7 +169,7 @@ func TestParallelMissing(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") @@ -185,7 +185,7 @@ func TestParallelMissingWithPrefault(t *testing.T) { numberOfPages: 2, } - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) readOp := operation{ @@ -209,7 +209,7 @@ func TestParallelMissingWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") @@ -218,14 +218,14 @@ func TestParallelMissingWithPrefault(t *testing.T) { func TestSerialMissing(t *testing.T) { t.Parallel() - serialOperations := 10_000 + serialOperations := 1_000_000 tt := testConfig{ pagesize: header.PageSize, numberOfPages: 2, } - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) readOp := operation{ @@ -240,7 +240,7 @@ func TestSerialMissing(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_write_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_write_test.go index 0a20b62f59..ff9ca621d5 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_write_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/missing_write_test.go @@ -116,14 +116,14 @@ func TestMissingWrite(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) h.executeAll(t, tt.operations) expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") @@ -136,14 +136,14 @@ func TestMissingWrite(t *testing.T) { func TestParallelMissingWrite(t *testing.T) { t.Parallel() - parallelOperations := 10_000 + parallelOperations := 1_000_000 tt := testConfig{ pagesize: header.PageSize, numberOfPages: 2, } - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) writeOp := operation{ @@ -164,7 +164,7 @@ func TestParallelMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") @@ -173,14 +173,14 @@ func TestParallelMissingWrite(t *testing.T) { func TestParallelMissingWriteWithPrefault(t *testing.T) { t.Parallel() - parallelOperations := 10_000 + parallelOperations := 1_000_000 tt := testConfig{ pagesize: header.PageSize, numberOfPages: 2, } - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) writeOp := operation{ @@ -204,7 +204,7 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") @@ -213,14 +213,14 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { func TestSerialMissingWrite(t *testing.T) { t.Parallel() - serialOperations := 10_000 + serialOperations := 1_000_000 tt := testConfig{ pagesize: header.PageSize, numberOfPages: 2, } - h, err := configureCrossProcessTest(t, tt) + h, err := configureCrossProcessTest(t.Context(), t, tt) require.NoError(t, err) writeOp := operation{ @@ -235,7 +235,7 @@ func TestSerialMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - states, err := h.pageStatesOnce() + states, err := h.pageStates() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, states.allAccessed(), "checking which pages were faulted") diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go new file mode 100644 index 0000000000..380e9c1d2f --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go @@ -0,0 +1,253 @@ +package userfaultfd + +// RPC service implementations for the cross-process UFFD test harness; +// in _test.go because they need *Userfaultfd internals. + +import ( + "context" + "errors" + "fmt" + "os" + "sync" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" +) + +//nolint:containedctx // shutdown-aware ctx shared with RPC handlers; lifetime is the child process. +type harnessState struct { + uffdFd uintptr + + mu sync.Mutex + uffd *Userfaultfd + br *testharness.Registry + stop func() // serve-stop fn; nil when paused + ctx context.Context + cancel context.CancelFunc + closed bool +} + +func newHarnessState(uffdFd uintptr) *harnessState { + ctx, cancel := context.WithCancel(context.Background()) + + return &harnessState{ + uffdFd: uffdFd, + ctx: ctx, + cancel: cancel, + } +} + +// startServeLocked is idempotent so a stray duplicate Resume cannot +// leak an untracked Serve goroutine. Caller must hold s.mu. +func (s *harnessState) startServeLocked() error { + if s.stop != nil { + return nil + } + + exit, err := fdexit.New() + if err != nil { + return fmt.Errorf("fdexit.New: %w", err) + } + + uffd := s.uffd + done := make(chan struct{}) + go func() { + defer close(done) + if err := uffd.Serve(context.Background(), exit); err != nil { + fmt.Fprintln(os.Stderr, "uffd.Serve:", err) + } + }() + + s.stop = func() { + _ = exit.SignalExit() + <-done + exit.Close() + } + + return nil +} + +func (s *harnessState) stopServe() { + // Drop s.mu before stop() — stop() blocks on the Serve drain, and any + // concurrent RPC handler needing s.mu (e.g. WaitFaultHeld during a + // parked barrier) would otherwise stall until the drain completes. + s.mu.Lock() + stop := s.stop + s.stop = nil + s.mu.Unlock() + + if stop != nil { + stop() + } +} + +type Lifecycle struct { + state *harnessState +} + +func (l *Lifecycle) Bootstrap(args *testharness.BootstrapArgs, _ *testharness.BootstrapReply) error { + if int64(len(args.Content)) != args.TotalSize { + return fmt.Errorf("content size %d != expected %d", len(args.Content), args.TotalSize) + } + + data := NewMemorySlicer(args.Content, args.Pagesize) + + mapping := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: uintptr(args.MmapStart), + Size: uintptr(args.TotalSize), + Offset: 0, + PageSize: uintptr(args.Pagesize), + }, + }) + + log, err := logger.NewDevelopmentLogger() + if err != nil { + return fmt.Errorf("logger: %w", err) + } + + uffd, err := NewUserfaultfdFromFd(l.state.uffdFd, data, mapping, log) + if err != nil { + return fmt.Errorf("NewUserfaultfdFromFd: %w", err) + } + + if args.AlwaysWP { + uffd.defaultCopyMode = UFFDIO_COPY_MODE_WP + } + + var br *testharness.Registry + if args.Barriers { + br = testharness.NewRegistry() + hook := br.Hook() + uffd.SetTestFaultHook(func(addr uintptr, p faultPhase) { + hook(addr, testharness.Point(p)) + }) + } + + l.state.mu.Lock() + defer l.state.mu.Unlock() + l.state.uffd = uffd + l.state.br = br + + return l.state.startServeLocked() +} + +// WaitReady is a no-op today (Bootstrap is synchronous); kept as a separate +// RPC so an async-Bootstrap variant can hold the parent here unchanged. +func (l *Lifecycle) WaitReady(_ *testharness.Empty, _ *testharness.Empty) error { + return nil +} + +func (l *Lifecycle) Shutdown(_ *testharness.Empty, _ *testharness.Empty) error { + l.state.mu.Lock() + defer l.state.mu.Unlock() + if !l.state.closed { + l.state.closed = true + l.state.cancel() + } + + return nil +} + +type Paging struct { + state *harnessState +} + +func (p *Paging) States(_ *testharness.Empty, reply *testharness.PageStatesReply) error { + p.state.mu.Lock() + uffd := p.state.uffd + p.state.mu.Unlock() + if uffd == nil { + return errors.New("Paging.States called before Lifecycle.Bootstrap") + } + + entries, err := uffd.pageStateEntries() + if err != nil { + return err + } + reply.Entries = entries + + return nil +} + +func (p *Paging) Pause(_ *testharness.Empty, _ *testharness.Empty) error { + p.state.stopServe() + + return nil +} + +func (p *Paging) Resume(_ *testharness.Empty, _ *testharness.Empty) error { + p.state.mu.Lock() + defer p.state.mu.Unlock() + + return p.state.startServeLocked() +} + +// pageStateEntries returns a wire-format snapshot of pageTracker. +// settleRequests.Lock drains fault workers (mirrors PrefetchData); +// pageTracker.mu.RLock is defensive against a future REMOVE writer +// that mutates pageTracker.m outside settleRequests. +func (u *Userfaultfd) pageStateEntries() ([]testharness.PageStateEntry, error) { + u.settleRequests.Lock() + defer u.settleRequests.Unlock() + + u.pageTracker.mu.RLock() + defer u.pageTracker.mu.RUnlock() + + entries := make([]testharness.PageStateEntry, 0, len(u.pageTracker.m)) + for addr, state := range u.pageTracker.m { + offset, err := u.ma.GetOffset(addr) + if err != nil { + return nil, fmt.Errorf("address %#x not in mapping: %w", addr, err) + } + entries = append(entries, testharness.PageStateEntry{State: uint8(state), Offset: uint64(offset)}) + } + + return entries, nil +} + +type Barriers struct { + state *harnessState +} + +func (b *Barriers) Install(args *testharness.FaultBarrierArgs, reply *testharness.FaultBarrierReply) error { + br, err := b.registry() + if err != nil { + return err + } + reply.Token = br.Install(uintptr(args.Addr), testharness.Point(args.Point)) + + return nil +} + +func (b *Barriers) WaitHeld(args *testharness.TokenArgs, _ *testharness.Empty) error { + br, err := b.registry() + if err != nil { + return err + } + + return br.WaitArrived(b.state.ctx, args.Token) +} + +func (b *Barriers) Release(args *testharness.TokenArgs, _ *testharness.Empty) error { + br, err := b.registry() + if err != nil { + return err + } + br.Release(args.Token) + + return nil +} + +func (b *Barriers) registry() (*testharness.Registry, error) { + b.state.mu.Lock() + br := b.state.br + b.state.mu.Unlock() + if br == nil { + return nil, errors.New("Barriers RPC requires args.Barriers=true at Bootstrap") + } + + return br, nil +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go index 133af0f547..43b9fee492 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "sync" + "sync/atomic" "syscall" "time" "unsafe" @@ -63,9 +64,20 @@ type Userfaultfd struct { // defaultCopyMode overrides the UFFDIO_COPY mode for all faults when non-zero. defaultCopyMode CULong + // testFaultHook is set only by SetTestFaultHook in test builds. + testFaultHook atomic.Pointer[func(uintptr, faultPhase)] + logger logger.Logger } +// faultPhase identifies the worker fault hook call site (test-only). +type faultPhase uint8 + +const ( + faultPhaseBeforeRLock faultPhase = iota + faultPhaseBeforeFaultPage +) + // NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { blockSize := src.BlockSize() @@ -250,6 +262,10 @@ func (u *Userfaultfd) Serve( // For the write to be executed, we first need to copy the page from the source to the guest memory. if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { u.wg.Go(func() error { + if h := u.testFaultHook.Load(); h != nil { + (*h)(addr, faultPhaseBeforeRLock) + } + return u.faultPage(ctx, addr, offset, u.src, fdExit.SignalExit, block.Write) }) @@ -260,6 +276,10 @@ func (u *Userfaultfd) Serve( // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. if flags == 0 { u.wg.Go(func() error { + if h := u.testFaultHook.Load(); h != nil { + (*h)(addr, faultPhaseBeforeRLock) + } + return u.faultPage(ctx, addr, offset, u.src, fdExit.SignalExit, block.Read) }) @@ -299,6 +319,10 @@ func (u *Userfaultfd) faultPage( u.settleRequests.RLock() defer u.settleRequests.RUnlock() + if h := u.testFaultHook.Load(); h != nil { + (*h)(addr, faultPhaseBeforeFaultPage) + } + defer func() { if r := recover(); r != nil { u.logger.Error(ctx, "UFFD serve panic", zap.Any("pagesize", u.pageSize), zap.Any("panic", r))