diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index 6c41647e..49a4b4df 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -26,6 +26,11 @@ type HTLCAttemptInfo struct { // EC operations used by btcec.PrivKeyFromBytes. sessionKey [btcec.PrivKeyBytesLen]byte + // cachedSessionKey is our fully deserialized sesionKey. This value + // may be nil if the attempt has just been read from disk and its + // session key has not been used yet. + cachedSessionKey *btcec.PrivateKey + // Route is the route attempted to send the HTLC. Route route.Route @@ -49,19 +54,25 @@ func NewHtlcAttemptInfo(attemptID uint64, sessionKey *btcec.PrivateKey, copy(scratch[:], sessionKey.Serialize()) return &HTLCAttemptInfo{ - AttemptID: attemptID, - sessionKey: scratch, - Route: route, - AttemptTime: attemptTime, - Hash: hash, + AttemptID: attemptID, + sessionKey: scratch, + cachedSessionKey: sessionKey, + Route: route, + AttemptTime: attemptTime, + Hash: hash, } } // SessionKey returns the ephemeral key used for a htlc attempt. This function -// performs expensive ec-ops to obtain the session key. +// performs expensive ec-ops to obtain the session key if it is not cached. func (h *HTLCAttemptInfo) SessionKey() *btcec.PrivateKey { - priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), h.sessionKey[:]) - return priv + if h.cachedSessionKey == nil { + h.cachedSessionKey, _ = btcec.PrivKeyFromBytes( + btcec.S256(), h.sessionKey[:], + ) + } + + return h.cachedSessionKey } // HTLCAttempt contains information about a specific HTLC attempt for a given diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 37b20b51..370a1d9c 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -120,6 +120,10 @@ func TestSentPaymentSerialization(t *testing.T) { newWireInfo.Route = route.Route{} s.Route = route.Route{} + // Call session key method to set our cached session key so we can use + // DeepEqual, and assert that our key equals the original key. + require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey()) + if !reflect.DeepEqual(s, newWireInfo) { t.Fatalf("Payments do not match after "+ "serialization/deserialization %v vs %v",