diff --git a/lnd.go b/lnd.go index 95267d35..b19eab7a 100644 --- a/lnd.go +++ b/lnd.go @@ -185,9 +185,8 @@ func Main() error { defer cancel() tlsCfg, restCreds, restProxyDest, err := getTLSConfig( - cfg.TLSCertPath, - cfg.TLSKeyPath, - cfg.RPCListeners, + cfg.TLSCertPath, cfg.TLSKeyPath, cfg.TLSExtraIPs, + cfg.TLSExtraDomains, cfg.RPCListeners, ) if err != nil { err := fmt.Errorf("Unable to load TLS credentials: %v", err) @@ -555,13 +554,15 @@ func Main() error { // getTLSConfig returns a TLS configuration for the gRPC server and credentials // and a proxy destination for the REST reverse proxy. -func getTLSConfig(tlsCertPath string, tlsKeyPath string, - rpcListeners []net.Addr) (*tls.Config, +func getTLSConfig(tlsCertPath string, tlsKeyPath string, tlsExtraIPs, + tlsExtraDomains []string, rpcListeners []net.Addr) (*tls.Config, *credentials.TransportCredentials, string, error) { // Ensure we create TLS key and certificate if they don't exist if !fileExists(tlsCertPath) && !fileExists(tlsKeyPath) { - err := genCertPair(tlsCertPath, tlsKeyPath) + err := genCertPair( + tlsCertPath, tlsKeyPath, tlsExtraIPs, tlsExtraDomains, + ) if err != nil { return nil, nil, "", err } @@ -591,7 +592,9 @@ func getTLSConfig(tlsCertPath string, tlsKeyPath string, return nil, nil, "", err } - err = genCertPair(tlsCertPath, tlsKeyPath) + err = genCertPair( + tlsCertPath, tlsKeyPath, tlsExtraIPs, tlsExtraDomains, + ) if err != nil { return nil, nil, "", err } @@ -644,7 +647,9 @@ func fileExists(name string) bool { // // This function is adapted from https://github.com/btcsuite/btcd and // https://github.com/btcsuite/btcutil -func genCertPair(certFile, keyFile string) error { +func genCertPair(certFile, keyFile string, tlsExtraIPs, + tlsExtraDomains []string) error { + rpcsLog.Infof("Generating TLS certificates...") org := "lnd autogenerated cert" @@ -687,13 +692,11 @@ func genCertPair(certFile, keyFile string) error { } } - if cfg != nil { - // Add extra IPs to the slice. - for _, ip := range cfg.TLSExtraIPs { - ipAddr := net.ParseIP(ip) - if ipAddr != nil { - addIP(ipAddr) - } + // Add extra IPs to the slice. + for _, ip := range tlsExtraIPs { + ipAddr := net.ParseIP(ip) + if ipAddr != nil { + addIP(ipAddr) } } @@ -709,9 +712,7 @@ func genCertPair(certFile, keyFile string) error { if host != "localhost" { dnsNames = append(dnsNames, "localhost") } - if cfg != nil { - dnsNames = append(dnsNames, cfg.TLSExtraDomains...) - } + dnsNames = append(dnsNames, tlsExtraDomains...) // Also add fake hostnames for unix sockets, otherwise hostname // verification will fail in the client. diff --git a/server_test.go b/server_test.go index e7375117..cb9034a0 100644 --- a/server_test.go +++ b/server_test.go @@ -71,7 +71,6 @@ func TestTLSAutoRegeneration(t *testing.T) { keyPath := tempDirPath + "/tls.key" certDerBytes, keyBytes := genExpiredCertPair(t, tempDirPath) - expiredCert, err := x509.ParseCertificate(certDerBytes) if err != nil { t.Fatalf("failed to parse certificate: %v", err) @@ -106,7 +105,6 @@ func TestTLSAutoRegeneration(t *testing.T) { } err = ioutil.WriteFile(tempDirPath+"/tls.key", keyBuf.Bytes(), 0600) if err != nil { - os.Remove(tempDirPath + "tls.cert") t.Fatalf("failed to write key file: %v", err) } @@ -116,7 +114,7 @@ func TestTLSAutoRegeneration(t *testing.T) { // Now let's run getTLSConfig. If it works properly, it should delete // the cert and create a new one. - _, _, _, err = getTLSConfig(certPath, keyPath, rpcListeners) + _, _, _, err = getTLSConfig(certPath, keyPath, nil, nil, rpcListeners) if err != nil { t.Fatalf("couldn't retrieve TLS config") } @@ -185,7 +183,8 @@ func genExpiredCertPair(t *testing.T, certDirPath string) ([]byte, []byte) { } certDerBytes, err := x509.CreateCertificate( - rand.Reader, &template, &template, &priv.PublicKey, priv) + rand.Reader, &template, &template, &priv.PublicKey, priv, + ) if err != nil { t.Fatalf("failed to create certificate: %v", err) }