diff --git a/cmd/gcs/main.go b/cmd/gcs/main.go index c21ea167fa..54dfae68a7 100644 --- a/cmd/gcs/main.go +++ b/cmd/gcs/main.go @@ -466,6 +466,12 @@ func main() { break } + // Drop every container stdio ConnSlot. Relay goroutines park inside + // ConnSlot.Write until the host re-attaches stdio with a fresh + // connection; producing processes pause naturally when their kernel + // pipe buffers fill, preserving in-flight bytes. + h.DisconnectAllStdio() + logrus.WithError(serveErr).Warn("bridge connection lost, will reconnect") time.Sleep(reconnectInterval) } diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index 0ce930dd65..d22331cce8 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -46,9 +46,10 @@ const ( type Container struct { id string - vsock transport.Transport - logPath string // path to [logFile]. - logFile *os.File // file to redirect container's stdio to. + vsock transport.Transport + slotRegistry slotRegistry + logPath string // path to [logFile]. + logFile *os.File // file to redirect container's stdio to. spec *oci.Spec ociBundlePath string @@ -81,6 +82,14 @@ type Container struct { sandboxRoot string } +// slotRegistry is the narrow seam Container uses to register the stdio +// ConnSlots produced by stdio.Connect with the parent Host so the bridge +// reconnect loop can disconnect them after live migration. Defined here so +// container.go does not depend on the concrete *Host type. +type slotRegistry interface { + RegisterStdioSlots(*stdio.ConnectionSet) +} + func (c *Container) Start(ctx context.Context, conSettings stdio.ConnectionSettings) (_ int, err error) { entity := log.G(ctx).WithField(logfields.ContainerID, c.id) entity.Info("opengcs::Container::Start") @@ -116,6 +125,9 @@ func (c *Container) Start(ctx context.Context, conSettings stdio.ConnectionSetti if err != nil { return -1, err } + if c.slotRegistry != nil { + c.slotRegistry.RegisterStdioSlots(stdioSet) + } if c.initProcess.spec.Terminal { ttyr := c.container.Tty() @@ -140,6 +152,9 @@ func (c *Container) ExecProcess(ctx context.Context, process *oci.Process, conSe if err != nil { return -1, err } + if c.slotRegistry != nil { + c.slotRegistry.RegisterStdioSlots(stdioSet) + } // Add in the core rlimit specified on the container in case there was one set. This makes it so that execed processes can also generate // core dumps. diff --git a/internal/guest/runtime/hcsv2/stdio_slots_test.go b/internal/guest/runtime/hcsv2/stdio_slots_test.go new file mode 100644 index 0000000000..edb9b6f976 --- /dev/null +++ b/internal/guest/runtime/hcsv2/stdio_slots_test.go @@ -0,0 +1,135 @@ +//go:build linux + +package hcsv2 + +import ( + "errors" + "io" + "os" + "sync" + "testing" + + "github.com/Microsoft/hcsshim/internal/guest/stdio" + "github.com/Microsoft/hcsshim/internal/guest/transport" +) + +// stubConn is a minimal transport.Connection used to exercise +// Host.RegisterStdioSlots / DisconnectAllStdio without real sockets. +type stubConn struct { + mu sync.Mutex + closed bool +} + +func (c *stubConn) Read(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, io.EOF + } + return 0, nil +} + +func (c *stubConn) Write(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, io.ErrClosedPipe + } + return len(p), nil +} + +func (c *stubConn) Close() error { + c.mu.Lock() + c.closed = true + c.mu.Unlock() + return nil +} + +func (c *stubConn) CloseRead() error { return c.Close() } +func (c *stubConn) CloseWrite() error { return nil } +func (c *stubConn) File() (*os.File, error) { return nil, errors.New("no file") } +func (c *stubConn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +var _ transport.Connection = (*stubConn)(nil) + +// Host satisfies the slotRegistry contract used by Container. +var _ slotRegistry = (*Host)(nil) + +func TestHost_RegisterStdioSlots_TracksSlots(t *testing.T) { + h := &Host{} + c1, c2, c3 := &stubConn{}, &stubConn{}, &stubConn{} + set := &stdio.ConnectionSet{ + In: stdio.NewConnSlot(c1, nil), + Out: stdio.NewConnSlot(c2, nil), + Err: stdio.NewConnSlot(c3, nil), + } + h.RegisterStdioSlots(set) + + if got, want := len(h.stdioSlots), 3; got != want { + t.Fatalf("stdioSlots len = %d, want %d", got, want) + } +} + +func TestHost_RegisterStdioSlots_IgnoresNilAndNonSlot(t *testing.T) { + h := &Host{} + // Mixed set: only Out is a ConnSlot; In is nil; Err is a non-slot. + set := &stdio.ConnectionSet{ + Out: stdio.NewConnSlot(&stubConn{}, nil), + Err: &stubConn{}, + } + h.RegisterStdioSlots(set) + + if got, want := len(h.stdioSlots), 1; got != want { + t.Fatalf("stdioSlots len = %d, want %d", got, want) + } +} + +func TestHost_RegisterStdioSlots_NilSet_NoOp(t *testing.T) { + h := &Host{} + h.RegisterStdioSlots(nil) // must not panic + if len(h.stdioSlots) != 0 { + t.Fatalf("nil set must register nothing, got %d", len(h.stdioSlots)) + } +} + +func TestHost_DisconnectAllStdio_ClosesEveryUnderlyingConn(t *testing.T) { + h := &Host{} + conns := []*stubConn{{}, {}, {}} + for _, c := range conns { + h.RegisterStdioSlots(&stdio.ConnectionSet{Out: stdio.NewConnSlot(c, nil)}) + } + + h.DisconnectAllStdio() + + for i, c := range conns { + if !c.isClosed() { + t.Fatalf("conns[%d] not closed by DisconnectAllStdio", i) + } + } +} + +func TestHost_RegisterStdioSlots_CompactsClosedSlots(t *testing.T) { + h := &Host{} + live := stdio.NewConnSlot(&stubConn{}, nil) + dead := stdio.NewConnSlot(&stubConn{}, nil) + _ = dead.Close() + + h.RegisterStdioSlots(&stdio.ConnectionSet{Out: live}) + h.RegisterStdioSlots(&stdio.ConnectionSet{Out: dead}) + + // dead is registered but should compact away on the next register call. + h.RegisterStdioSlots(&stdio.ConnectionSet{Out: stdio.NewConnSlot(&stubConn{}, nil)}) + + for _, s := range h.stdioSlots { + if !s.IsAlive() { + t.Fatal("compaction did not drop closed slot") + } + } + if got := len(h.stdioSlots); got != 2 { + t.Fatalf("after compact want 2 live slots, got %d", got) + } +} diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index bbec0b7564..eada4ac0f5 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -89,6 +89,7 @@ type VirtualPod struct { type Host struct { containersMutex sync.Mutex containers map[string]*Container + stdioSlots []*stdio.ConnSlot externalProcessesMutex sync.Mutex externalProcesses map[int]*externalProcess @@ -205,6 +206,72 @@ func (h *Host) Transport() transport.Transport { return h.vsock } +// RegisterStdioSlots tracks per-process stdio so the bridge reconnect loop +// can disconnect them after live migration. Called from container Start, +// ExecProcess, and runExternalProcess after stdio.Connect. Any +// *stdio.ConnSlot in the set is added to the registry; nil entries and +// other transport.Connection types are ignored. Already-closed slots are +// compacted out on each call to bound the slice's growth. +func (h *Host) RegisterStdioSlots(set *stdio.ConnectionSet) { + if set == nil { + return + } + incoming := make([]*stdio.ConnSlot, 0, 3) + for _, c := range []transport.Connection{set.In, set.Out, set.Err} { + if slot, ok := c.(*stdio.ConnSlot); ok && slot != nil { + incoming = append(incoming, slot) + } + } + if len(incoming) == 0 { + return + } + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + h.stdioSlots = compactStdioSlots(h.stdioSlots) + h.stdioSlots = append(h.stdioSlots, incoming...) +} + +// DisconnectAllStdio drops the current connection on every tracked stdio +// slot. Called from the GCS reconnect loop after the bridge connection is +// lost. Relays park inside slot.Write until the host re-attaches stdio with +// a fresh connection; the producing process pauses naturally when its +// kernel pipe buffer fills. +// +// Each slot's Disconnect is wrapped in a recover so a single bad slot +// cannot break the loop and leave the rest of the container stdio without +// back pressure. +func (h *Host) DisconnectAllStdio() { + h.containersMutex.Lock() + h.stdioSlots = compactStdioSlots(h.stdioSlots) + slots := append([]*stdio.ConnSlot(nil), h.stdioSlots...) + h.containersMutex.Unlock() + for _, s := range slots { + func() { + defer func() { + if r := recover(); r != nil { + logrus.WithField("panic", r).Error("ConnSlot: Disconnect panicked") + } + }() + s.Disconnect() + }() + } +} + +// compactStdioSlots returns a new slice with closed slots filtered out so +// the registry does not grow unbounded over the UVM lifetime. +func compactStdioSlots(slots []*stdio.ConnSlot) []*stdio.ConnSlot { + if len(slots) == 0 { + return slots[:0] + } + out := slots[:0] + for _, s := range slots { + if s.IsAlive() { + out = append(out, s) + } + } + return out +} + func (h *Host) RemoveContainer(id string) { h.containersMutex.Lock() defer h.containersMutex.Unlock() @@ -406,6 +473,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM c := &Container{ id: id, vsock: h.vsock, + slotRegistry: h, spec: settings.OCISpecification, ociBundlePath: settings.OCIBundlePath, isSandbox: criType == "sandbox", @@ -1077,6 +1145,7 @@ func (h *Host) runExternalProcess( if err != nil { return -1, err } + h.RegisterStdioSlots(stdioSet) defer func() { if err != nil { stdioSet.Close() diff --git a/internal/guest/stdio/connection.go b/internal/guest/stdio/connection.go index d4c7cbdf18..05bb3c3af7 100644 --- a/internal/guest/stdio/connection.go +++ b/internal/guest/stdio/connection.go @@ -19,7 +19,8 @@ type ConnectionSettings struct { // Connect returns new transport.Connection instances, one for each stdio pipe // to be used. If CreateStd*Pipe for a given pipe is false, the given Connection -// is set to nil. +// is set to nil. Each connection is wrapped in a ConnSlot so the underlying +// vsock can be replaced when the bridge reconnects after live migration. func Connect(tport transport.Transport, settings ConnectionSettings) (_ *ConnectionSet, err error) { connSet := &ConnectionSet{} defer func() { @@ -28,25 +29,42 @@ func Connect(tport transport.Transport, settings ConnectionSettings) (_ *Connect } }() if settings.StdIn != nil { - c, err := tport.Dial(*settings.StdIn) + port := *settings.StdIn + c, err := tport.Dial(port) if err != nil { return nil, errors.Wrap(err, "failed creating stdin Connection") } - connSet.In = transport.NewLogConnection(c, *settings.StdIn) + connSet.In = NewConnSlot(transport.NewLogConnection(c, port), redialer(tport, port)) } if settings.StdOut != nil { - c, err := tport.Dial(*settings.StdOut) + port := *settings.StdOut + c, err := tport.Dial(port) if err != nil { return nil, errors.Wrap(err, "failed creating stdout Connection") } - connSet.Out = transport.NewLogConnection(c, *settings.StdOut) + connSet.Out = NewConnSlot(transport.NewLogConnection(c, port), redialer(tport, port)) } if settings.StdErr != nil { - c, err := tport.Dial(*settings.StdErr) + port := *settings.StdErr + c, err := tport.Dial(port) if err != nil { return nil, errors.Wrap(err, "failed creating stderr Connection") } - connSet.Err = transport.NewLogConnection(c, *settings.StdErr) + connSet.Err = NewConnSlot(transport.NewLogConnection(c, port), redialer(tport, port)) } return connSet, nil } + +// redialer returns a callback that re-dials the given vsock port via the +// provided transport. Used by ConnSlot to recover from a bridge disconnect: +// after live migration the source-host listener is gone but the destination +// host has a fresh listener on the same port number. +func redialer(tport transport.Transport, port uint32) func() (transport.Connection, error) { + return func() (transport.Connection, error) { + nc, err := tport.Dial(port) + if err != nil { + return nil, err + } + return transport.NewLogConnection(nc, port), nil + } +} diff --git a/internal/guest/stdio/connslot.go b/internal/guest/stdio/connslot.go new file mode 100644 index 0000000000..7fddeed570 --- /dev/null +++ b/internal/guest/stdio/connslot.go @@ -0,0 +1,261 @@ +//go:build linux +// +build linux + +package stdio + +import ( + "errors" + "io" + "os" + "sync" + "time" + + "github.com/Microsoft/hcsshim/internal/guest/transport" + "github.com/sirupsen/logrus" +) + +// runRedial pacing. Tight fixed interval matches the bridge reconnect loop +// in cmd/gcs/main.go; the destination listener should be live as soon as +// the bridge re-accepts. maxRedialAttempts bounds the goroutine's lifetime +// against a permanently broken peer. +const ( + redialInterval = 100 * time.Millisecond + maxRedialAttempts = 60 +) + +// ConnSlot wraps a transport.Connection so the underlying connection can be +// replaced at runtime. While disconnected, Read and Write block (parking +// the relay in acquire) so the producing process back-pressures on its own +// kernel pipe instead of losing bytes. Set installs a fresh connection and +// wakes blocked relays. +type ConnSlot struct { + mu sync.Mutex + cond *sync.Cond + conn transport.Connection + closed bool + + // redial, if non-nil, is invoked from a background goroutine after + // Disconnect to obtain a fresh connection. The slot calls Set with the + // returned connection so blocked Read/Write calls resume automatically. + redial func() (transport.Connection, error) + redialing bool +} + +var _ transport.Connection = (*ConnSlot)(nil) + +// NewConnSlot wraps an initial connection. If redial is non-nil it is +// invoked from a background goroutine after a Disconnect or read/write +// error to obtain a fresh connection. +func NewConnSlot(conn transport.Connection, redial func() (transport.Connection, error)) *ConnSlot { + s := &ConnSlot{conn: conn, redial: redial} + s.cond = sync.NewCond(&s.mu) + return s +} + +// IsAlive reports whether the slot has not been permanently closed. +func (s *ConnSlot) IsAlive() bool { + s.mu.Lock() + defer s.mu.Unlock() + return !s.closed +} + +// Set installs a new connection, closing any previous one, and wakes +// goroutines blocked in Read or Write. If the slot is already closed, c is +// closed and the slot remains empty. +func (s *ConnSlot) Set(c transport.Connection) { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + if c != nil { + _ = c.Close() + } + return + } + prev := s.conn + s.conn = c + s.cond.Broadcast() + s.mu.Unlock() + if prev != nil { + _ = prev.Close() + } +} + +// Disconnect closes the current connection but keeps the slot open. +// Subsequent Read and Write calls block until Set is called or Close is +// called. If a redialer was installed, a background goroutine is kicked off +// to obtain a fresh connection. Safe to call repeatedly and on closed slots. +func (s *ConnSlot) Disconnect() { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return + } + prev := s.conn + s.conn = nil + needReconnect := s.redial != nil && !s.redialing + if needReconnect { + s.redialing = true + } + s.mu.Unlock() + if prev != nil { + _ = prev.Close() + } + if needReconnect { + go s.runRedial() + } +} + +// runRedial loops the redialer up to maxRedialAttempts times, installing +// the first successful connection via Set. On exhaustion the goroutine +// exits leaving the slot empty; producers stay back-pressured until Close. +func (s *ConnSlot) runRedial() { + defer func() { + s.mu.Lock() + s.redialing = false + s.mu.Unlock() + }() + + for attempt := 1; attempt <= maxRedialAttempts; attempt++ { + s.mu.Lock() + if s.closed || s.conn != nil { + s.mu.Unlock() + return + } + redial := s.redial + s.mu.Unlock() + if redial == nil { + return + } + + c, err := redial() + if err == nil { + logrus.WithField("attempt", attempt).Info("ConnSlot: redial succeeded") + s.Set(c) + return + } + logrus.WithError(err).WithField("attempt", attempt).Debug("ConnSlot: redial failed") + time.Sleep(redialInterval) + } + logrus.WithField("attempts", maxRedialAttempts). + Warn("ConnSlot: redial attempts exhausted; slot left empty") +} + +// Close permanently closes the slot. Any blocked Read or Write returns io.EOF. +// Calling Close on an already-closed slot is a no-op. +func (s *ConnSlot) Close() error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + prev := s.conn + s.conn = nil + s.cond.Broadcast() + s.mu.Unlock() + if prev != nil { + _ = prev.Close() + } + return nil +} + +// Read implements transport.Connection. Blocks until a connection is +// available or the slot is closed. On read error other than EOF, drops the +// current connection so the next Read waits for a replacement. +func (s *ConnSlot) Read(p []byte) (int, error) { + c, err := s.acquire() + if err != nil { + return 0, err + } + n, rerr := c.Read(p) + if rerr != nil && !errors.Is(rerr, io.EOF) { + s.dropIfCurrent(c) + } + return n, rerr +} + +// Write implements transport.Connection. Loops until all bytes are written +// or the slot is closed; on connection failure, drops the conn and parks +// in acquire for a replacement before retrying the remaining bytes. This +// is the back-pressure path: while disconnected, the loop parks and the +// caller's upstream pipe fills, eventually blocking the producing process. +func (s *ConnSlot) Write(p []byte) (int, error) { + written := 0 + for written < len(p) { + c, err := s.acquire() + if err != nil { + return written, err + } + n, werr := c.Write(p[written:]) + written += n + if werr != nil { + s.dropIfCurrent(c) + } + } + return written, nil +} + +// CloseRead delegates to the underlying connection if one is set. +func (s *ConnSlot) CloseRead() error { + s.mu.Lock() + c := s.conn + s.mu.Unlock() + if c == nil { + return nil + } + return c.CloseRead() +} + +// CloseWrite delegates to the underlying connection if one is set. +func (s *ConnSlot) CloseWrite() error { + s.mu.Lock() + c := s.conn + s.mu.Unlock() + if c == nil { + return nil + } + return c.CloseWrite() +} + +// File returns the current connection's file descriptor. Returns an error +// if the slot is disconnected or closed. +func (s *ConnSlot) File() (*os.File, error) { + c, err := s.acquire() + if err != nil { + return nil, err + } + return c.File() +} + +func (s *ConnSlot) acquire() (transport.Connection, error) { + s.mu.Lock() + defer s.mu.Unlock() + for s.conn == nil && !s.closed { + s.cond.Wait() + } + if s.closed { + return nil, io.EOF + } + return s.conn, nil +} + +// dropIfCurrent clears s.conn only when it still equals c, then closes c +// and starts a redial if needed. The conn-equality check avoids racing +// with a concurrent Set that may have already installed a fresh conn. +func (s *ConnSlot) dropIfCurrent(c transport.Connection) { + s.mu.Lock() + if s.closed || s.conn != c { + s.mu.Unlock() + return + } + s.conn = nil + needReconnect := s.redial != nil && !s.redialing + if needReconnect { + s.redialing = true + } + s.mu.Unlock() + _ = c.Close() + if needReconnect { + go s.runRedial() + } +} diff --git a/internal/guest/stdio/connslot_test.go b/internal/guest/stdio/connslot_test.go new file mode 100644 index 0000000000..f252c4c84e --- /dev/null +++ b/internal/guest/stdio/connslot_test.go @@ -0,0 +1,525 @@ +//go:build linux + +package stdio + +import ( + "bytes" + "errors" + "io" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Microsoft/hcsshim/internal/guest/transport" +) + +// fakeConn is a controllable transport.Connection backed by an in-memory +// buffer pair, used to exercise ConnSlot without touching real sockets. +type fakeConn struct { + mu sync.Mutex + rd *bytes.Buffer + wr *bytes.Buffer + closed bool + failNextRW error + closeReadCh chan struct{} +} + +func newFakeConn() *fakeConn { + return &fakeConn{ + rd: new(bytes.Buffer), + wr: new(bytes.Buffer), + closeReadCh: make(chan struct{}), + } +} + +func (c *fakeConn) feedRead(b []byte) { + c.mu.Lock() + c.rd.Write(b) + c.mu.Unlock() +} + +func (c *fakeConn) failNext(err error) { + c.mu.Lock() + c.failNextRW = err + c.mu.Unlock() +} + +func (c *fakeConn) writtenBytes() []byte { + c.mu.Lock() + defer c.mu.Unlock() + return append([]byte(nil), c.wr.Bytes()...) +} + +func (c *fakeConn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *fakeConn) Read(p []byte) (int, error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.EOF + } + if c.failNextRW != nil { + err := c.failNextRW + c.failNextRW = nil + c.mu.Unlock() + return 0, err + } + if c.rd.Len() == 0 { + c.mu.Unlock() + // Block until close or another write feeds data; simulate a live socket. + <-c.closeReadCh + return 0, io.EOF + } + n, err := c.rd.Read(p) + c.mu.Unlock() + return n, err +} + +func (c *fakeConn) Write(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, io.ErrClosedPipe + } + if c.failNextRW != nil { + err := c.failNextRW + c.failNextRW = nil + return 0, err + } + return c.wr.Write(p) +} + +func (c *fakeConn) Close() error { + c.mu.Lock() + if !c.closed { + c.closed = true + close(c.closeReadCh) + } + c.mu.Unlock() + return nil +} + +func (c *fakeConn) CloseRead() error { return c.Close() } +func (c *fakeConn) CloseWrite() error { return nil } +func (c *fakeConn) File() (*os.File, error) { return nil, errors.New("no file") } + +var _ transport.Connection = (*fakeConn)(nil) + +// ----------------------------------------------------------------------------- +// Basic happy-path +// ----------------------------------------------------------------------------- + +func TestConnSlot_Write_PassThroughWhenConnected(t *testing.T) { + c := newFakeConn() + s := NewConnSlot(c, nil) + + n, err := s.Write([]byte("hello")) + if err != nil || n != 5 { + t.Fatalf("Write got n=%d err=%v, want n=5 err=nil", n, err) + } + if got := string(c.writtenBytes()); got != "hello" { + t.Fatalf("underlying conn got %q, want %q", got, "hello") + } +} + +func TestConnSlot_Write_BlocksWhileDisconnected_ResumesAfterSet(t *testing.T) { + c1 := newFakeConn() + s := NewConnSlot(c1, nil) + s.Disconnect() + + done := make(chan error, 1) + go func() { + _, err := s.Write([]byte("queued")) + done <- err + }() + + select { + case <-done: + t.Fatal("Write returned before Set was called") + case <-time.After(50 * time.Millisecond): + } + + c2 := newFakeConn() + s.Set(c2) + + select { + case err := <-done: + if err != nil { + t.Fatalf("Write after reconnect err=%v", err) + } + case <-time.After(time.Second): + t.Fatal("Write did not complete after Set") + } + if got := string(c2.writtenBytes()); got != "queued" { + t.Fatalf("c2 got %q, want %q", got, "queued") + } +} + +func TestConnSlot_Write_DropsConnOnError_RetriesRemainingOnNewConn(t *testing.T) { + c1 := newFakeConn() + c1.failNext(io.ErrShortWrite) + s := NewConnSlot(c1, nil) + + done := make(chan error, 1) + go func() { + _, err := s.Write([]byte("payload")) + done <- err + }() + + select { + case <-done: + t.Fatal("Write returned before reconnect") + case <-time.After(50 * time.Millisecond): + } + + c2 := newFakeConn() + s.Set(c2) + + select { + case err := <-done: + if err != nil { + t.Fatalf("Write after recovery err=%v", err) + } + case <-time.After(time.Second): + t.Fatal("Write did not complete after Set") + } + if got := string(c2.writtenBytes()); got != "payload" { + t.Fatalf("c2 got %q, want full payload", got) + } +} + +func TestConnSlot_Read_BlocksWhileDisconnected_ResumesAfterSet(t *testing.T) { + s := NewConnSlot(newFakeConn(), nil) + s.Disconnect() + + type readResult struct { + buf []byte + err error + } + done := make(chan readResult, 1) + go func() { + buf := make([]byte, 16) + n, err := s.Read(buf) + done <- readResult{buf: buf[:n], err: err} + }() + + select { + case <-done: + t.Fatal("Read returned before Set") + case <-time.After(50 * time.Millisecond): + } + + c2 := newFakeConn() + c2.feedRead([]byte("greetings")) + s.Set(c2) + + select { + case r := <-done: + if r.err != nil { + t.Fatalf("Read err=%v", r.err) + } + if string(r.buf) != "greetings" { + t.Fatalf("Read got %q, want %q", string(r.buf), "greetings") + } + case <-time.After(time.Second): + t.Fatal("Read did not return after Set") + } +} + +// ----------------------------------------------------------------------------- +// Lifecycle / idempotency +// ----------------------------------------------------------------------------- + +func TestConnSlot_Close_UnblocksWriteWithEOF(t *testing.T) { + s := NewConnSlot(newFakeConn(), nil) + s.Disconnect() + + done := make(chan error, 1) + go func() { + _, err := s.Write([]byte("never sent")) + done <- err + }() + + time.Sleep(20 * time.Millisecond) + _ = s.Close() + + select { + case err := <-done: + if !errors.Is(err, io.EOF) { + t.Fatalf("Write after Close got err=%v, want io.EOF", err) + } + case <-time.After(time.Second): + t.Fatal("Write did not return after Close") + } +} + +func TestConnSlot_Set_ClosesPreviousConnection(t *testing.T) { + c1 := newFakeConn() + c2 := newFakeConn() + s := NewConnSlot(c1, nil) + + s.Set(c2) + + if !c1.isClosed() { + t.Fatal("Set did not close previous connection") + } + if c2.isClosed() { + t.Fatal("Set must not close the new connection") + } +} + +func TestConnSlot_SetAfterClose_ClosesNewConn(t *testing.T) { + s := NewConnSlot(newFakeConn(), nil) + _ = s.Close() + + c := newFakeConn() + s.Set(c) + + if !c.isClosed() { + t.Fatal("Set on closed slot must close the new connection (otherwise we leak it)") + } +} + +func TestConnSlot_Disconnect_Idempotent(t *testing.T) { + c := newFakeConn() + s := NewConnSlot(c, nil) + + s.Disconnect() + s.Disconnect() // must not panic / double-close + s.Disconnect() + + if !c.isClosed() { + t.Fatal("Disconnect did not close underlying conn") + } +} + +func TestConnSlot_Close_Idempotent(t *testing.T) { + c := newFakeConn() + s := NewConnSlot(c, nil) + + if err := s.Close(); err != nil { + t.Fatalf("Close 1 err=%v", err) + } + if err := s.Close(); err != nil { + t.Fatalf("Close 2 err=%v", err) + } +} + +func TestConnSlot_Disconnect_AfterClose_NoOp(t *testing.T) { + s := NewConnSlot(newFakeConn(), nil) + _ = s.Close() + s.Disconnect() // must not start a runRedial goroutine +} + +func TestConnSlot_Disconnect_NoRedialer_NoGoroutine(t *testing.T) { + c := newFakeConn() + s := NewConnSlot(c, nil) + // nil redialer. + + s.Disconnect() + // Slot stays empty until explicit Set; nothing else should happen. + // We assert by giving any rogue redial goroutine a chance to run and + // confirming the conn stayed nil. + time.Sleep(20 * time.Millisecond) + + if !s.IsAlive() { + t.Fatal("Disconnect closed the slot") + } +} + +// ----------------------------------------------------------------------------- +// Redial behavior +// ----------------------------------------------------------------------------- + +func TestConnSlot_Redial_SuccessOnFirstAttempt(t *testing.T) { + c2 := newFakeConn() + var calls atomic.Int32 + s := NewConnSlot(newFakeConn(), func() (transport.Connection, error) { + calls.Add(1) + return c2, nil + }) + s.Disconnect() + + // Wait for runRedial to land c2 via Set. + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + n, err := s.Write([]byte("x")) + if err == nil && n == 1 { + break + } + } + if got := string(c2.writtenBytes()); got != "x" { + t.Fatalf("c2 got %q, want %q", got, "x") + } + if got := calls.Load(); got != 1 { + t.Fatalf("redial calls=%d, want 1", got) + } +} + +func TestConnSlot_Redial_AlwaysFailing_BoundedAttempts_LeavesSlotEmpty(t *testing.T) { + // Redialer always fails. runRedial must stop after maxRedialAttempts + // attempts (no infinite goroutine), but must NOT close the slot — a + // later Disconnect() from the bridge reconnect loop should be able to + // kick off a fresh runRedial. This is the recovery path for a + // destination listener that comes back online after the first run gave + // up. + if testing.Short() { + t.Skip("slow: takes ~maxRedialAttempts * redialInterval to exhaust") + } + + var calls atomic.Int32 + s := NewConnSlot(newFakeConn(), func() (transport.Connection, error) { + calls.Add(1) + return nil, errors.New("nope") + }) + defer s.Close() + s.Disconnect() + + // Wait for runRedial to exhaust its bounded attempts. + deadline := time.Now().Add(time.Duration(maxRedialAttempts+5) * redialInterval) + for time.Now().Before(deadline) && calls.Load() < int32(maxRedialAttempts) { + time.Sleep(redialInterval) + } + if got := calls.Load(); got != int32(maxRedialAttempts) { + t.Fatalf("redial calls=%d, want %d", got, maxRedialAttempts) + } + + // Slot must remain open so the next bridge cycle can revive it. + if !s.IsAlive() { + t.Fatal("slot self-closed; want still-open so next Disconnect can re-trigger redial") + } + + // Verify next Disconnect restarts redial. The previous goroutine's + // final per-attempt sleep can run for redialInterval after we observed + // the last increment, so wait two intervals to ensure the deferred + // redialing=false has executed. + time.Sleep(2 * redialInterval) + prev := calls.Load() + s.Disconnect() + deadline = time.Now().Add(2 * redialInterval) + for time.Now().Before(deadline) && calls.Load() == prev { + time.Sleep(10 * time.Millisecond) + } + if calls.Load() <= prev { + t.Fatalf("Disconnect after exhaustion did not kick a fresh redial (calls stuck at %d)", prev) + } +} + +// ----------------------------------------------------------------------------- +// Concurrency / race detector +// ----------------------------------------------------------------------------- + +func TestConnSlot_Concurrent_DisconnectWithWrites(t *testing.T) { + // Run Disconnect and many writers concurrently to exercise the lock + // discipline. With -race enabled this catches lock ordering bugs. + var redialCalls atomic.Int32 + c := newFakeConn() + s := NewConnSlot(c, func() (transport.Connection, error) { + redialCalls.Add(1) + return newFakeConn(), nil + }) + defer s.Close() + + var wg sync.WaitGroup + stop := make(chan struct{}) + + // 4 writers spinning small writes. + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := []byte("x") + for { + select { + case <-stop: + return + default: + } + _, _ = s.Write(buf) + } + }() + } + + // Disconnect 50 times with small pauses. + for i := 0; i < 50; i++ { + s.Disconnect() + time.Sleep(time.Millisecond) + } + + close(stop) + wg.Wait() +} + +// ----------------------------------------------------------------------------- +// Pipe relay integration (matches real PipeRelay usage) +// ----------------------------------------------------------------------------- + +func TestConnSlot_PipeRelayIntegration(t *testing.T) { + pipeR, pipeW, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe: %v", err) + } + defer pipeR.Close() + defer pipeW.Close() + + c1 := newFakeConn() + s := NewConnSlot(c1, nil) + + relayDone := make(chan error, 1) + go func() { + _, err := io.Copy(s, pipeR) + relayDone <- err + }() + + if _, err := pipeW.Write([]byte("first")); err != nil { + t.Fatalf("write first: %v", err) + } + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if string(c1.writtenBytes()) == "first" { + break + } + time.Sleep(10 * time.Millisecond) + } + if got := string(c1.writtenBytes()); got != "first" { + t.Fatalf("c1 got %q, want %q", got, "first") + } + + s.Disconnect() + + if _, err := pipeW.Write([]byte("second")); err != nil { + t.Fatalf("write second: %v", err) + } + time.Sleep(50 * time.Millisecond) + if got := string(c1.writtenBytes()); got != "first" { + t.Fatalf("c1 received bytes during disconnect: %q", got) + } + + c2 := newFakeConn() + s.Set(c2) + + deadline = time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if string(c2.writtenBytes()) == "second" { + break + } + time.Sleep(10 * time.Millisecond) + } + if got := string(c2.writtenBytes()); got != "second" { + t.Fatalf("c2 got %q after reconnect, want %q", got, "second") + } + + pipeW.Close() + select { + case <-relayDone: + case <-time.After(time.Second): + t.Fatal("relay did not exit after pipe close") + } + _ = s.Close() +}