]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
simplify tracking of Transports for connection migration (#5111)
authorMarten Seemann <martenseemann@gmail.com>
Sat, 3 May 2025 14:06:23 +0000 (22:06 +0800)
committerGitHub <noreply@github.com>
Sat, 3 May 2025 14:06:23 +0000 (16:06 +0200)
No functional change expected.

conn_id_generator.go
conn_id_generator_test.go
connection.go
transport.go

index 74f52414235005efc4b0b34a448ff693ec79017e..0b0b90fd8a5d6d1b9084a5ca1dd88ebfe5b15500 100644 (file)
@@ -16,7 +16,8 @@ type connRunnerCallbacks struct {
        ReplaceWithClosed  func([]protocol.ConnectionID, []byte)
 }
 
-type connRunners map[transportID]connRunnerCallbacks
+// The memory address of the Transport is used as the key.
+type connRunners map[*Transport]connRunnerCallbacks
 
 func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
        for _, c := range cr {
@@ -56,7 +57,7 @@ type connIDGenerator struct {
 }
 
 func newConnIDGenerator(
-       tID transportID,
+       tr *Transport,
        initialConnectionID protocol.ConnectionID,
        initialClientDestConnID *protocol.ConnectionID, // nil for the client
        statelessResetter *statelessResetter,
@@ -68,7 +69,7 @@ func newConnIDGenerator(
                generator:         generator,
                activeSrcConnIDs:  make(map[uint64]protocol.ConnectionID),
                statelessResetter: statelessResetter,
-               connRunners:       map[transportID]connRunnerCallbacks{tID: connRunner},
+               connRunners:       map[*Transport]connRunnerCallbacks{tr: connRunner},
                queueControlFrame: queueControlFrame,
        }
        m.activeSrcConnIDs[0] = initialConnectionID
@@ -201,7 +202,7 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
        m.connRunners.ReplaceWithClosed(connIDs, connClose)
 }
 
-func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) {
+func (m *connIDGenerator) AddConnRunner(id *Transport, r connRunnerCallbacks) {
        // The transport might have already been added earlier.
        // This happens if the application migrates back to and old path.
        if _, ok := m.connRunners[id]; ok {
index cd6d24c9a0e49f4d4bb9893e2d9ac42f53ddac97..bdd31bdb394119cd7ed82fa9ca2f97f1a0d1c008 100644 (file)
@@ -34,7 +34,7 @@ func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID
                initialClientDestConnID = &connID
        }
        g := newConnIDGenerator(
-               1,
+               &Transport{},
                protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
                initialClientDestConnID,
                sr,
@@ -107,7 +107,7 @@ func TestConnIDGeneratorRetiring(t *testing.T) {
        initialConnID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
        var added, removed []protocol.ConnectionID
        g := newConnIDGenerator(
-               1,
+               &Transport{},
                protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
                &initialConnID,
                newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -178,7 +178,7 @@ func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool)
                removed []protocol.ConnectionID
        )
        g := newConnIDGenerator(
-               0,
+               &Transport{},
                protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
                initialClientDestConnID,
                newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -228,7 +228,7 @@ func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConn
                replacedWith []byte
        )
        g := newConnIDGenerator(
-               1,
+               &Transport{},
                protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
                initialClientDestConnID,
                newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
@@ -274,7 +274,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
                added, removed, replaced []protocol.ConnectionID
        }
 
-       var tracker1, tracker2 connIDTracker
+       var tracker1, tracker2, tracker3 connIDTracker
        runner1 := connRunnerCallbacks{
                AddConnectionID:    func(c protocol.ConnectionID) { tracker1.added = append(tracker1.added, c) },
                RemoveConnectionID: func(c protocol.ConnectionID) { tracker1.removed = append(tracker1.removed, c) },
@@ -289,12 +289,20 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
                        tracker2.replaced = append(tracker2.replaced, connIDs...)
                },
        }
+       runner3 := connRunnerCallbacks{
+               AddConnectionID:    func(c protocol.ConnectionID) { tracker3.added = append(tracker3.added, c) },
+               RemoveConnectionID: func(c protocol.ConnectionID) { tracker3.removed = append(tracker3.removed, c) },
+               ReplaceWithClosed: func(connIDs []protocol.ConnectionID, _ []byte) {
+                       tracker3.replaced = append(tracker3.replaced, connIDs...)
+               },
+       }
 
        sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
        var queuedFrames []wire.Frame
 
+       tr := &Transport{}
        g := newConnIDGenerator(
-               1,
+               tr,
                initialConnID,
                &clientDestConnID,
                sr,
@@ -306,7 +314,7 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
        require.Len(t, tracker1.added, 2)
 
        // add the second runner - it should get all existing connection IDs
-       g.AddConnRunner(2, runner2)
+       g.AddConnRunner(&Transport{}, runner2)
        require.Len(t, tracker1.added, 2) // unchanged
        require.Len(t, tracker2.added, 4)
        require.Contains(t, tracker2.added, initialConnID)
@@ -314,6 +322,11 @@ func TestConnIDGeneratorAddConnRunner(t *testing.T) {
        require.Contains(t, tracker2.added, tracker1.added[0])
        require.Contains(t, tracker2.added, tracker1.added[1])
 
+       // adding the same transport again doesn't do anything
+       trCopy := tr
+       g.AddConnRunner(trCopy, runner3)
+       require.Empty(t, tracker3.added)
+
        var connIDToRetire protocol.ConnectionID
        var seqToRetire uint64
        ncid := queuedFrames[0].(*wire.NewConnectionIDFrame)
index 682e356b5f356efd6eda2044572c8ddeaf5c43d7..1165078e77872d36eb571b88b564b77ae1a29ab9 100644 (file)
@@ -269,7 +269,7 @@ var newConnection = func(
                s.queueControlFrame,
        )
        s.connIDGenerator = newConnIDGenerator(
-               tr.id(),
+               tr,
                srcConnID,
                &clientDestConnID,
                statelessResetter,
@@ -383,7 +383,7 @@ var newClientConnection = func(
                s.queueControlFrame,
        )
        s.connIDGenerator = newConnIDGenerator(
-               tr.id(),
+               tr,
                srcConnID,
                nil,
                statelessResetter,
@@ -2652,7 +2652,7 @@ func (s *connection) AddPath(t *Transport) (*Path, error) {
                func() {
                        runner := t.connRunner()
                        s.connIDGenerator.AddConnRunner(
-                               t.id(),
+                               t,
                                connRunnerCallbacks{
                                        AddConnectionID:    func(connID protocol.ConnectionID) { runner.Add(connID, s) },
                                        RemoveConnectionID: runner.Remove,
index a7775fb1476d725754387fe5ca667d204e81d9c7..915c1ce085d681585e5f5a923fe56c61ae1243d3 100644 (file)
@@ -38,10 +38,6 @@ func (e *errTransportClosed) Is(target error) bool {
        return ok
 }
 
-type transportID uint64
-
-var transportIDCounter atomic.Uint64
-
 var errListenerAlreadySet = errors.New("listener already set")
 
 // The Transport is the central point to manage incoming and outgoing QUIC connections.
@@ -136,8 +132,6 @@ type Transport struct {
        initOnce sync.Once
        initErr  error
 
-       // Set in init.
-       transportID transportID
        // If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
        connIDLen int
        // Set in init.
@@ -376,7 +370,6 @@ func (t *Transport) doDial(
 
 func (t *Transport) init(allowZeroLengthConnIDs bool) error {
        t.initOnce.Do(func() {
-               t.transportID = transportID(transportIDCounter.Add(1))
                var conn rawConn
                if c, ok := t.Conn.(rawConn); ok {
                        conn = c
@@ -430,8 +423,6 @@ func (t *Transport) connRunner() packetHandlerManager {
        return t.handlerMap
 }
 
-func (t *Transport) id() transportID { return t.transportID }
-
 // WriteTo sends a packet on the underlying connection.
 func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
        if err := t.init(false); err != nil {