diff --git a/channeldb/db.go b/channeldb/db.go index eace4997..aecb75e4 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -992,7 +992,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection } - err = updateEdgePolicy(tx, &chanEdge) + _, err = updateEdgePolicy(tx, &chanEdge) if err != nil { return err } diff --git a/channeldb/graph.go b/channeldb/graph.go index 4815a647..25c59f5f 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -1815,34 +1815,62 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error { c.cacheMu.Lock() defer c.cacheMu.Unlock() + var isUpdate1 bool err := c.db.Update(func(tx *bbolt.Tx) error { - return updateEdgePolicy(tx, edge) + var err error + isUpdate1, err = updateEdgePolicy(tx, edge) + return err }) if err != nil { return err } - c.rejectCache.remove(edge.ChannelID) - c.chanCache.remove(edge.ChannelID) + // If an entry for this channel is found in reject cache, we'll modify + // the entry with the updated timestamp for the direction that was just + // written. If the edge doesn't exist, we'll load the cache entry lazily + // during the next query for this edge. + if entry, ok := c.rejectCache.get(edge.ChannelID); ok { + if isUpdate1 { + entry.upd1Time = edge.LastUpdate.Unix() + } else { + entry.upd2Time = edge.LastUpdate.Unix() + } + c.rejectCache.insert(edge.ChannelID, entry) + } + + // If an entry for this channel is found in channel cache, we'll modify + // the entry with the updated policy for the direction that was just + // written. If the edge doesn't exist, we'll defer loading the info and + // policies and lazily read from disk during the next query. + if channel, ok := c.chanCache.get(edge.ChannelID); ok { + if isUpdate1 { + channel.Policy1 = edge + } else { + channel.Policy2 = edge + } + c.chanCache.insert(edge.ChannelID, channel) + } return nil } // updateEdgePolicy attempts to update an edge's policy within the relevant -// buckets using an existing database transaction. -func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) error { +// buckets using an existing database transaction. The returned boolean will be +// true if the updated policy belongs to node1, and false if the policy belonged +// to node2. +func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) (bool, error) { edges := tx.Bucket(edgeBucket) if edges == nil { - return ErrEdgeNotFound + return false, ErrEdgeNotFound } edgeIndex := edges.Bucket(edgeIndexBucket) if edgeIndex == nil { - return ErrEdgeNotFound + return false, ErrEdgeNotFound } nodes, err := tx.CreateBucketIfNotExists(nodeBucket) if err != nil { - return err + return false, err } // Create the channelID key be converting the channel ID @@ -1854,23 +1882,31 @@ func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) error { // nodes which connect this channel edge. nodeInfo := edgeIndex.Get(chanID[:]) if nodeInfo == nil { - return ErrEdgeNotFound + return false, ErrEdgeNotFound } // Depending on the flags value passed above, either the first // or second edge policy is being updated. var fromNode, toNode []byte + var isUpdate1 bool if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { fromNode = nodeInfo[:33] toNode = nodeInfo[33:66] + isUpdate1 = true } else { fromNode = nodeInfo[33:66] toNode = nodeInfo[:33] + isUpdate1 = false } // Finally, with the direction of the edge being updated // identified, we update the on-disk edge representation. - return putChanEdgePolicy(edges, nodes, edge, fromNode, toNode) + err = putChanEdgePolicy(edges, nodes, edge, fromNode, toNode) + if err != nil { + return false, err + } + + return isUpdate1, nil } // LightningNode represents an individual vertex/node within the channel graph. diff --git a/channeldb/migrations.go b/channeldb/migrations.go index f86e416b..72ba7882 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -564,7 +564,7 @@ func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { return err } - err = updateEdgePolicy(tx, edgePolicy) + _, err = updateEdgePolicy(tx, edgePolicy) if err != nil { return err }