diff --git a/channeldb/db.go b/channeldb/db.go
index eace4997..aecb75e4 100644
--- a/channeldb/db.go
+++ b/channeldb/db.go
@@ -992,7 +992,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
 				chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection
 			}
 
-			err = updateEdgePolicy(tx, &chanEdge)
+			_, err = updateEdgePolicy(tx, &chanEdge)
 			if err != nil {
 				return err
 			}
diff --git a/channeldb/graph.go b/channeldb/graph.go
index 4815a647..25c59f5f 100644
--- a/channeldb/graph.go
+++ b/channeldb/graph.go
@@ -1815,34 +1815,62 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error {
 	c.cacheMu.Lock()
 	defer c.cacheMu.Unlock()
 
+	var isUpdate1 bool
 	err := c.db.Update(func(tx *bbolt.Tx) error {
-		return updateEdgePolicy(tx, edge)
+		var err error
+		isUpdate1, err = updateEdgePolicy(tx, edge)
+		return err
 	})
 	if err != nil {
 		return err
 	}
 
-	c.rejectCache.remove(edge.ChannelID)
-	c.chanCache.remove(edge.ChannelID)
+	// If an entry for this channel is found in reject cache, we'll modify
+	// the entry with the updated timestamp for the direction that was just
+	// written. If the edge doesn't exist, we'll load the cache entry lazily
+	// during the next query for this edge.
+	if entry, ok := c.rejectCache.get(edge.ChannelID); ok {
+		if isUpdate1 {
+			entry.upd1Time = edge.LastUpdate.Unix()
+		} else {
+			entry.upd2Time = edge.LastUpdate.Unix()
+		}
+		c.rejectCache.insert(edge.ChannelID, entry)
+	}
+
+	// If an entry for this channel is found in channel cache, we'll modify
+	// the entry with the updated policy for the direction that was just
+	// written. If the edge doesn't exist, we'll defer loading the info and
+	// policies and lazily read from disk during the next query.
+	if channel, ok := c.chanCache.get(edge.ChannelID); ok {
+		if isUpdate1 {
+			channel.Policy1 = edge
+		} else {
+			channel.Policy2 = edge
+		}
+		c.chanCache.insert(edge.ChannelID, channel)
+	}
 
 	return nil
 }
 
 // updateEdgePolicy attempts to update an edge's policy within the relevant
-// buckets using an existing database transaction.
-func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) error {
+// buckets using an existing database transaction. The returned boolean will be
+// true if the updated policy belongs to node1, and false if the policy belonged
+// to node2.
+func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) (bool, error) {
 	edges := tx.Bucket(edgeBucket)
 	if edges == nil {
-		return ErrEdgeNotFound
+		return false, ErrEdgeNotFound
 
 	}
 	edgeIndex := edges.Bucket(edgeIndexBucket)
 	if edgeIndex == nil {
-		return ErrEdgeNotFound
+		return false, ErrEdgeNotFound
 	}
 	nodes, err := tx.CreateBucketIfNotExists(nodeBucket)
 	if err != nil {
-		return err
+		return false, err
 	}
 
 	// Create the channelID key be converting the channel ID
@@ -1854,23 +1882,31 @@ func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) error {
 	// nodes which connect this channel edge.
 	nodeInfo := edgeIndex.Get(chanID[:])
 	if nodeInfo == nil {
-		return ErrEdgeNotFound
+		return false, ErrEdgeNotFound
 	}
 
 	// Depending on the flags value passed above, either the first
 	// or second edge policy is being updated.
 	var fromNode, toNode []byte
+	var isUpdate1 bool
 	if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
 		fromNode = nodeInfo[:33]
 		toNode = nodeInfo[33:66]
+		isUpdate1 = true
 	} else {
 		fromNode = nodeInfo[33:66]
 		toNode = nodeInfo[:33]
+		isUpdate1 = false
 	}
 
 	// Finally, with the direction of the edge being updated
 	// identified, we update the on-disk edge representation.
-	return putChanEdgePolicy(edges, nodes, edge, fromNode, toNode)
+	err = putChanEdgePolicy(edges, nodes, edge, fromNode, toNode)
+	if err != nil {
+		return false, err
+	}
+
+	return isUpdate1, nil
 }
 
 // LightningNode represents an individual vertex/node within the channel graph.
diff --git a/channeldb/migrations.go b/channeldb/migrations.go
index f86e416b..72ba7882 100644
--- a/channeldb/migrations.go
+++ b/channeldb/migrations.go
@@ -564,7 +564,7 @@ func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error {
 			return err
 		}
 
-		err = updateEdgePolicy(tx, edgePolicy)
+		_, err = updateEdgePolicy(tx, edgePolicy)
 		if err != nil {
 			return err
 		}