diff --git a/lntest/itest/lnd_test.go b/lntest/itest/lnd_test.go index 96f5e82a..a06bb8df 100644 --- a/lntest/itest/lnd_test.go +++ b/lntest/itest/lnd_test.go @@ -4,13 +4,21 @@ package itest import ( "bytes" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/hex" + "encoding/pem" "fmt" "io" "io/ioutil" "math" + "math/big" + "net" "os" "path/filepath" "reflect" @@ -13494,6 +13502,133 @@ func testHoldInvoicePersistence(net *lntest.NetworkHarness, t *harnessTest) { } } +// testTLSAutoRegeneration creates an expired TLS certificate, to test that a +// new TLS certificate pair is regenerated when the old pair expires. This is +// necessary because the pair expires after a little over a year. +func testTLSAutoRegeneration(lnNet *lntest.NetworkHarness, t *harnessTest) { + certPath := lnNet.Alice.TLSCertStr() + keyPath := lnNet.Alice.TLSKeyStr() + + // Create an expired certificate. + expiredCert := genExpiredCertPair( + t, lnNet, certPath, keyPath, + ) + + // Restart the node to test that the cert is automatically regenerated. + lnNet.RestartNode(lnNet.Alice, nil, nil) + + // Grab the newly generated certificate. + newCertData, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + t.Fatalf("couldn't grab new certificate") + } + + newCert, err := x509.ParseCertificate(newCertData.Certificate[0]) + if err != nil { + t.Fatalf("couldn't parse new certificate") + } + + // Check that the expired certificate was successfully deleted and + // replaced with a new one. + if !newCert.NotAfter.After(expiredCert.NotAfter) { + t.Fatalf("New certificate expiration is too old") + } +} + +// genExpiredCertPair generates an expired key/cert pair to the paths +// provided to test that expired certificates are being regenerated correctly. +func genExpiredCertPair(t *harnessTest, lnNet *lntest.NetworkHarness, certPath, + keyPath string) *x509.Certificate { + // Max serial number. + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + + // Generate a serial number that's below the serialNumberLimit. + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatalf("failed to generate serial number: %s", err) + } + + host := "lightning" + + // Create a simple ip address for the fake certificate. + ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} + + dnsNames := []string{host, "unix", "unixpacket"} + + // Construct the certificate template. + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"lnd autogenerated cert"}, + CommonName: host, + }, + NotBefore: time.Now().Add(-time.Hour * 24), + NotAfter: time.Now(), + + KeyUsage: x509.KeyUsageKeyEncipherment | + x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + IsCA: true, // so can sign self. + BasicConstraintsValid: true, + + DNSNames: dnsNames, + IPAddresses: ipAddresses, + } + + // Generate a private key for the certificate. + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate a private key") + } + + derBytes, err := x509.CreateCertificate( + rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + expiredCert, err := x509.ParseCertificate(derBytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + } + + certBuf := bytes.Buffer{} + err = pem.Encode( + &certBuf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }, + ) + if err != nil { + t.Fatalf("failed to encode certificate: %v", err) + } + + keybytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + t.Fatalf("unable to encode privkey: %v", err) + } + keyBuf := bytes.Buffer{} + err = pem.Encode( + &keyBuf, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keybytes, + }, + ) + if err != nil { + t.Fatalf("failed to encode private key: %v", err) + } + + // Write cert and key files. + if err = ioutil.WriteFile(certPath, certBuf.Bytes(), 0644); err != nil { + t.Fatalf("failed to write cert file: %v", err) + } + if err = ioutil.WriteFile(keyPath, keyBuf.Bytes(), 0600); err != nil { + os.Remove(certPath) + t.Fatalf("failed to write key file: %v", err) + } + + return expiredCert +} + type testCase struct { name string test func(net *lntest.NetworkHarness, t *harnessTest) @@ -13741,6 +13876,10 @@ var testsCases = []*testCase{ name: "cpfp", test: testCPFP, }, + { + name: "automatic certificate regeneration", + test: testTLSAutoRegeneration, + }, } // TestLightningNetworkDaemon performs a series of integration tests amongst a