diff --git a/client/control.go b/client/control.go index 3da61712..54b25b2f 100644 --- a/client/control.go +++ b/client/control.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/fatedier/frp/models/transport" "io" "net" "runtime/debug" @@ -208,11 +209,19 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { conn = stream } else { var tlsConfig *tls.Config + if ctl.clientCfg.TLSEnable { - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, + tlsConfig, err = transport.NewServerTLSConfig( + ctl.clientCfg.TLSCertFile, + ctl.clientCfg.TLSKeyFile, + ctl.clientCfg.TLSTrustedCaFile) + + if err != nil { + xl.Warn("fail to build tls configuration when connecting to server, err: %v", err) + return } } + conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HttpProxy, ctl.clientCfg.Protocol, fmt.Sprintf("%s:%d", ctl.clientCfg.ServerAddr, ctl.clientCfg.ServerPort), tlsConfig) if err != nil { diff --git a/client/service.go b/client/service.go index b6ea0634..53ea93d9 100644 --- a/client/service.go +++ b/client/service.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/fatedier/frp/models/transport" "io/ioutil" "net" "runtime" @@ -205,10 +206,17 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { xl := xlog.FromContextSafe(svr.ctx) var tlsConfig *tls.Config if svr.cfg.TLSEnable { - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, + tlsConfig, err = transport.NewClientTLSConfig( + svr.cfg.TLSCertFile, + svr.cfg.TLSKeyFile, + svr.cfg.TLSTrustedCaFile, + svr.cfg.ServerAddr) + if err != nil { + xl.Warn("fail to build tls configuration when service login, err: %v", err) + return } } + conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HttpProxy, svr.cfg.Protocol, fmt.Sprintf("%s:%d", svr.cfg.ServerAddr, svr.cfg.ServerPort), tlsConfig) if err != nil { diff --git a/models/config/client_common.go b/models/config/client_common.go index 3f8c485d..2216bf86 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -104,8 +104,21 @@ type ClientCommonConf struct { // is "tcp". Protocol string `json:"protocol"` // TLSEnable specifies whether or not TLS should be used when communicating - // with the server. + // with the server. If "tls_cert_file" and "tls_key_file" are valid, + // client will load the supplied tls configuration. Otherwise, it will + // load the tls configuration generated by itself. TLSEnable bool `json:"tls_enable"` + // ClientTLSCertPath specifies the path of the cert file that client will + // load. It only works when "tls_enable" is true and "tls_key_file" is valid. + TLSCertFile string `json:"tls_cert_file"` + // ClientTLSKeyPath specifies the path of the secret key file that client + // will load. It only works when "tls_enable" is true and "tls_cert_file" + // is valid. + TLSKeyFile string `json:"tls_key_file"` + // TrustedCaFile specifies the path of the trusted ca file that will load. + // It only works when "tls_enable", "tls_cert_file" and "tls_key_file" are + // valid. + TLSTrustedCaFile string `json:"tls_trusted_ca_file"` // HeartBeatInterval specifies at what interval heartbeats are sent to the // server, in seconds. It is not recommended to change this value. By // default, this value is 30. @@ -142,6 +155,9 @@ func GetDefaultClientConf() ClientCommonConf { Start: make(map[string]struct{}), Protocol: "tcp", TLSEnable: false, + TLSCertFile:"", + TLSKeyFile:"", + TLSTrustedCaFile:"", HeartBeatInterval: 30, HeartBeatTimeout: 90, Metas: make(map[string]string), @@ -276,6 +292,18 @@ func UnmarshalClientConfFromIni(content string) (cfg ClientCommonConf, err error cfg.TLSEnable = false } + if tmpStr, ok = conf.Get("common", "tls_cert_file"); ok { + cfg.TLSCertFile = tmpStr + } + + if tmpStr, ok := conf.Get("common", "tls_key_file"); ok { + cfg.TLSKeyFile = tmpStr + } + + if tmpStr, ok := conf.Get("common", "tls_trusted_ca_file"); ok { + cfg.TLSTrustedCaFile = tmpStr + } + if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok { if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout") diff --git a/models/config/server_common.go b/models/config/server_common.go index a6aa969e..aae5a9ee 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -133,9 +133,23 @@ type ServerCommonConf struct { // may proxy to. If this value is 0, no limit will be applied. By default, // this value is 0. MaxPortsPerClient int64 `json:"max_ports_per_client"` - // TlsOnly specifies whether to only accept TLS-encrypted connections. By - // default, the value is false. + // TlsOnly specifies whether to only accept TLS-encrypted connections. If + // "tls_cert_file", "tls_key_file" are valid, the server will use this + // supplied tls configuration. Otherwise, the server will use the tls + // configuration generated by itself. By default, the value is false. TlsOnly bool `json:"tls_only"` + // TLSCertFile specifies the path of the cert file that the server will + // load. When "tls_only" is true, it will works with "tls_key_file" and + // "tls_trusted_ca_file". + TLSCertFile string `json:"tls_cert_file"` + // TLSKeyFile specifies the path of the secret key that the server will + // load. When "tls_only" is true, it will works with "server_tls_cert_path" and + // "tls_trusted_ca_file". + TLSKeyFile string `json:"tls_key_file"` + // TLSTrustedCaFile specifies the paths of the client cert files that the + // server will load. If "tls_cert_file" is not empty, it only works + // when "tls_only" is true and "tls_cert_file" and "tls_key_file" are valid. + TLSTrustedCaFile string `json:"tls_trusted_ca_file"` // HeartBeatTimeout specifies the maximum time to wait for a heartbeat // before terminating the connection. It is not recommended to change this // value. By default, this value is 90. @@ -178,6 +192,9 @@ func GetDefaultServerConf() ServerCommonConf { MaxPoolCount: 5, MaxPortsPerClient: 0, TlsOnly: false, + TLSCertFile:"", + TLSKeyFile:"", + TLSTrustedCaFile:"", HeartBeatTimeout: 90, UserConnTimeout: 10, Custom404Page: "", @@ -416,6 +433,19 @@ func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error } else { cfg.TlsOnly = false } + + if tmpStr, ok := conf.Get("common", "tls_cert_file"); ok { + cfg.TLSCertFile = tmpStr + } + + if tmpStr, ok := conf.Get("common", "tls_key_file"); ok { + cfg.TLSKeyFile = tmpStr + } + + if tmpStr, ok := conf.Get("common", "tls_trusted_ca_file"); ok { + cfg.TLSTrustedCaFile = tmpStr + } + return } diff --git a/models/transport/tls.go b/models/transport/tls.go new file mode 100644 index 00000000..74519f0b --- /dev/null +++ b/models/transport/tls.go @@ -0,0 +1,138 @@ +package transport + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" +) + +/* + +Self CA: +openssl genrsa -out ca.key 2048 +openssl req -x509 -new -nodes -key ca.key -subj "/CN=example.ca.com" -days 5000 -out ca.crt + +Server: +openssl genrsa -out server.key 2048 +openssl req -new -key server.key -subj "/CN=example.server.com" -out server.csr +openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out server.crt -days 5000 + +Client: +openssl genrsa -out client.key 2048 +openssl req -new -key client.key -subj "/CN=example.client.com" -out client.csr +openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out client.crt -days 5000 + +*/ + +func newCustomTLSKeyPair(certfile, keyfile string) (*tls.Certificate, error) { + tlsCert, err := tls.LoadX509KeyPair(certfile, keyfile) + if err != nil { + return nil, err + } + return &tlsCert, nil +} + +func newRandomTLSKeyPair() *tls.Certificate { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate( + rand.Reader, + &template, + &template, + &key.PublicKey, + key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tlsCert +} + +// Only supprt one ca file to add +func newCertPool(caPath string) (*x509.CertPool, error) { + pool := x509.NewCertPool() + + caCrt, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, err + } + + pool.AppendCertsFromPEM(caCrt) + + return pool, nil +} + +func NewServerTLSConfig(certPath, keyPath, caPath string ) (*tls.Config, error) { + var base = &tls.Config{} + + if certPath == "" || keyPath == "" { + // server will generate tls conf by itself + cert := newRandomTLSKeyPair() + base.Certificates = []tls.Certificate{*cert} + } else { + cert, err := newCustomTLSKeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + + base.Certificates = []tls.Certificate{*cert} + } + + if caPath != "" { + pool, err := newCertPool(caPath) + if err != nil { + return nil, err + } + + base.ClientAuth = tls.RequireAndVerifyClientCert + base.ClientCAs = pool + } + + return base, nil +} + + +func NewClientTLSConfig(certPath, keyPath, caPath, servearName string) (*tls.Config, error) { + var base = &tls.Config{} + fmt.Printf("yyl-test client tls servername is %s\n", servearName) + + if certPath == "" || keyPath == "" { + // client will not generate tls conf by itself + } else { + cert, err := newCustomTLSKeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + + base.Certificates = []tls.Certificate{*cert} + } + + if caPath != "" { + pool, err := newCertPool(caPath) + if err != nil { + return nil, err + } + + base.RootCAs = pool + base.ServerName = servearName + base.InsecureSkipVerify = false + } else { + base.InsecureSkipVerify = true + } + + return base, nil +} diff --git a/server/service.go b/server/service.go index 4f1c702f..039a419b 100644 --- a/server/service.go +++ b/server/service.go @@ -17,14 +17,10 @@ package server import ( "bytes" "context" - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "encoding/pem" "fmt" + "github.com/fatedier/frp/models/transport" "io/ioutil" - "math/big" "net" "net/http" "time" @@ -99,6 +95,19 @@ type Service struct { } func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { + var tlsConfig *tls.Config + if cfg.TlsOnly { + tlsConfig, err = transport.NewServerTLSConfig( + cfg.TLSCertFile, + cfg.TLSKeyFile, + cfg.TLSTrustedCaFile) + if err != nil { + return + } + } + + fmt.Printf("yyl-test tlsonly is %t, tlsconfig is %+v\n", cfg.TlsOnly, tlsConfig) + svr = &Service{ ctlManager: NewControlManager(), pxyManager: proxy.NewProxyManager(), @@ -110,8 +119,8 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { }, httpVhostRouter: vhost.NewVhostRouters(), authVerifier: auth.NewAuthVerifier(cfg.AuthServerConfig), - tlsConfig: generateTLSConfig(), - cfg: cfg, + tlsConfig: tlsConfig, + cfg: cfg, } // Create tcpmux httpconnect multiplexer. @@ -497,25 +506,4 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error { return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, newMsg.UseEncryption, newMsg.UseCompression) -} - -// Setup a bare-bones TLS config for the server -func generateTLSConfig() *tls.Config { - key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - panic(err) - } - template := x509.Certificate{SerialNumber: big.NewInt(1)} - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - panic(err) - } - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - panic(err) - } - return &tls.Config{Certificates: []tls.Certificate{tlsCert}} -} +} \ No newline at end of file