diff --git a/client/control.go b/client/control.go index 1c11bffa..6469c4d7 100644 --- a/client/control.go +++ b/client/control.go @@ -29,7 +29,6 @@ import ( "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/transport" utilnet "github.com/fatedier/frp/pkg/util/net" - "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -106,8 +105,6 @@ func NewControl( ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) ctl.registerMsgHandlers() - ctl.xl.Info("get pxy cfgs: %v", util.JSONDump(ctl.pxyCfgs)) - ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel()) ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) @@ -171,8 +168,6 @@ func (ctl *Control) handleNewProxyResp(m msg.Message) { inMsg := m.(*msg.NewProxyResp) - xl.Info("proxy: %v, remote addr: %v, err: %v", inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) - // Server will return NewProxyResp message to each NewProxy message. // Start a new proxy handler if no error got err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go index a07c58e8..d91002cb 100644 --- a/pkg/config/v1/server.go +++ b/pkg/config/v1/server.go @@ -21,10 +21,13 @@ import ( "github.com/fatedier/frp/pkg/util/util" ) -type SSHGatewayConfig struct { - SSHBindPort int `json:"sshBindPort,omitempty" validate:"gte=0,lte=65535"` - SSHPrivateKeyFilePath string `json:"sshPrivateKeyFilePath,omitempty"` - SSHPublicKeyFilesPath string `json:"sshPublicKeyFilesPath,omitempty"` +type SSHTunnelGateway struct { + BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"` + PrivateKeyFilePath string `json:"privateKeyFilePath,omitempty"` + PublicKeyFilesPath string `json:"publicKeyFilesPath,omitempty"` + + // store all public key file. load all when init + PublicKeyFilesMap map[string]string } type ServerConfig struct { @@ -38,7 +41,7 @@ type ServerConfig struct { // value is 7000. BindPort int `json:"bindPort,omitempty"` - SSHGatewayConfig SSHGatewayConfig `json:"sshGatewayConfig"` + SSHTunnelGateway SSHTunnelGateway `json:"sshGatewayConfig,omitempty"` // KCPBindPort specifies the KCP port that the server listens on. If this // value is 0, the server will not listen for KCP connections. diff --git a/pkg/config/v1/ssh.go b/pkg/config/v1/ssh.go index d0f96c5c..d21581ce 100644 --- a/pkg/config/v1/ssh.go +++ b/pkg/config/v1/ssh.go @@ -1,6 +1,69 @@ package v1 +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "os" + "path/filepath" +) + const ( // custom define SSHClientLoginUserPrefix = "_frpc_ssh_client_" ) + +// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format +func GeneratePrivateKey() ([]byte, error) { + privateKey, err := generatePrivateKey() + if err != nil { + return nil, errors.New("gen private key error") + } + + privBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + return pem.EncodeToMemory(&privBlock), nil +} + +// generatePrivateKey creates a RSA Private Key of specified byte size +func generatePrivateKey() (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, err + } + + err = privateKey.Validate() + if err != nil { + return nil, err + } + return privateKey, nil +} + +func LoadFilesInDirectory(dirPath string) (map[string]string, error) { + fileMap := make(map[string]string) + + files, err := os.ReadDir(dirPath) + if err != nil { + return nil, err + } + + for _, file := range files { + filename := file.Name() + filePath := filepath.Join(dirPath, filename) + + content, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + + fileMap[filename] = string(content) + } + + return fileMap, nil +} diff --git a/pkg/util/util/util.go b/pkg/util/util/util.go index d80d226a..d2562b01 100644 --- a/pkg/util/util/util.go +++ b/pkg/util/util/util.go @@ -19,7 +19,6 @@ import ( "crypto/rand" "crypto/subtle" "encoding/hex" - "encoding/json" "fmt" mathrand "math/rand" "net" @@ -145,8 +144,3 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati func ConstantTimeEqString(a, b string) bool { return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 } - -func JSONDump(v interface{}) string { - prettyJSON, _ := json.MarshalIndent(v, "", "\t") - return string(prettyJSON) -} diff --git a/server/service.go b/server/service.go index 9c20d68b..abfa2d7d 100644 --- a/server/service.go +++ b/server/service.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -205,26 +206,33 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { svr.listener = ln log.Info("frps tcp listen on %s", address) - if cfg.SSHGatewayConfig.SSHBindPort > 0 { - svr.sshConfig = &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - file := fmt.Sprintf("%v/%v.pub", cfg.SSHGatewayConfig.SSHPublicKeyFilesPath, ssh.FingerprintSHA256(key)) - authorizedKey, err := os.ReadFile(file) - if err != nil { - return nil, err - } + if cfg.SSHTunnelGateway.BindPort > 0 { - parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(authorizedKey) + if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" { + cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadFilesInDirectory(cfg.SSHTunnelGateway.PublicKeyFilesPath) + if err != nil { + return nil, fmt.Errorf("load ssh all public key files error: %v", err) + } + log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath) + } + + svr.sshConfig = &ssh.ServerConfig{ + NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false), + + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + keyContent, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)] + if !ok { + return nil, errors.New("cannot find public key file") + } + parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyContent)) if err != nil { return nil, err } if key.Type() == parsedAuthorizedKey.Type() && bytes.Equal(key.Marshal(), parsedAuthorizedKey.Marshal()) { - log.Info("ssh file: %v fingerprint sha256: %v", file, ssh.FingerprintSHA256(key)) - return &ssh.Permissions{ Extensions: map[string]string{ - ssh.FingerprintSHA256(key): string(authorizedKey), + ssh.FingerprintSHA256(key): keyContent, }, }, nil } @@ -232,12 +240,22 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { }, } - privateBytes, err := os.ReadFile(cfg.SSHGatewayConfig.SSHPrivateKeyFilePath) - if err != nil { - log.Error("Failed to load private key") - return nil, err + var privateBytes []byte + if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" { + privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath) + if err != nil { + log.Error("Failed to load private key") + return nil, err + } + log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath) + } else { + privateBytes, err = v1.GeneratePrivateKey() + if err != nil { + log.Error("Failed to load private key") + return nil, err + } + log.Info("auto gen private key file success") } - private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { log.Error("Failed to parse private key, error: %v", err) @@ -246,7 +264,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { svr.sshConfig.AddHostKey(private) - sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHGatewayConfig.SSHBindPort)) + sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort)) svr.sshListener, err = net.Listen("tcp", sshAddr) if err != nil { log.Error("Failed to listen on %v, error: %v", sshAddr, err) @@ -582,18 +600,9 @@ func (svr *Service) HandleSSHListener(listener net.Listener) { ctx := context.Background() - vs, err := NewVirtualService( - ctx, - v1.ClientCommonConfig{}, - *svr.cfg, - msg.Login{ - User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String(), - }, - svr.rc, - pxyCfg, - ss, - replyCh, - ) + vs, err := NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg, + msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()}, + svr.rc, pxyCfg, ss, replyCh) if err != nil { log.Error("new virtual service error: %v", err) ss.Close() diff --git a/server/ssh_service.go b/server/ssh_service.go index c9c78c19..90c68ed6 100644 --- a/server/ssh_service.go +++ b/server/ssh_service.go @@ -18,7 +18,6 @@ import ( v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/util/log" - "github.com/fatedier/frp/pkg/util/util" ) const ( @@ -174,7 +173,7 @@ func (ss *SSHService) loopParseCmdPayload() { if req.Type == RequestTypeHeartbeat { log.Debug("ssh heartbeat data") } else { - log.Info("default req, data: %v", util.JSONDump(req)) + log.Info("default req, data: %v", req) } } if req.WantReply { @@ -231,15 +230,18 @@ func (ss *SSHService) loopParseExtraPayload() { log.Info("r.payload is less than 4") continue } - - dataLen := binary.BigEndian.Uint32(r.Payload[:4]) - p := string(r.Payload[4 : 4+dataLen]) - - if !strings.Contains(p, "frpc") { - log.Info("payload not contains frp keyword: %v", p) + if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") { + log.Info("ssh protocol exchange data") continue } + // [4byte data_len|data] + end := 4 + binary.BigEndian.Uint32(r.Payload[:4]) + if end > uint32(len(r.Payload)) { + end = uint32(len(r.Payload)) + } + p := string(r.Payload[4:end]) + msg, err := parseSSHExtraMessage(p) if err != nil { log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload) @@ -331,35 +333,42 @@ func (ss *SSHService) loopGenerateProxy() { ProxyBaseConfig: v1.ProxyBaseConfig{ Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()), Type: p2.Type, + + ProxyBackend: v1.ProxyBackend{ + LocalIP: p1.Address, + }, }, RemotePort: int(p1.Port), } default: log.Warn("invalid frp proxy type: %v", p2.Type) } - } } func parseSSHExtraMessage(s string) (p SSHExtraPayload, err error) { + sn := len(s) + + log.Info("parse ssh extra message: %v", s) + ss := strings.Fields(s) - if len(ss) <= 1 { - return p, fmt.Errorf("invalid ssh input, args: %v", ss) + if len(ss) == 0 { + if sn != 0 { + ss = append(ss, s) + } else { + return p, fmt.Errorf("invalid ssh input, args: %v", ss) + } } for i, v := range ss { ss[i] = strings.TrimSpace(v) } - if ss[0] != "frpc" { - return p, fmt.Errorf("first input should be frpc, but got: %v", ss[0]) - } - - if ss[1] != "tcp" && ss[1] != "http" { + if ss[0] != "tcp" && ss[0] != "http" { return p, fmt.Errorf("only support tcp/http now") } - switch ss[1] { + switch ss[0] { case "tcp": tcpCmd, err := ParseTCPCommand(ss) if err != nil { @@ -407,7 +416,7 @@ func ParseHTTPCommand(params []string) (*HTTPCommand, error) { basicAuthPass string ) - fs := flag.NewFlagSet("frpc http", flag.ContinueOnError) + fs := flag.NewFlagSet("http", flag.ContinueOnError) fs.StringVar(&basicAuth, "basic-auth", "", "") fs.StringVar(&domainURL, "domain", "", "") @@ -442,21 +451,25 @@ type TCPCommand struct { } func ParseTCPCommand(params []string) (*TCPCommand, error) { - if len(params) < 2 || params[0] != "frpc" || params[1] != "tcp" { + if len(params) == 0 || params[0] != "tcp" { return nil, errors.New("invalid TCP command") } + if len(params) == 1 { + return &TCPCommand{}, nil + } + var ( address string port string ) - fs := flag.NewFlagSet("frpc tcp", flag.ContinueOnError) + fs := flag.NewFlagSet("tcp", flag.ContinueOnError) fs.StringVar(&address, "address", "", "The IP address to listen on") fs.StringVar(&port, "port", "", "The port to listen on") fs.SetOutput(&nullWriter{}) // Disables usage output - args := params[2:] + args := params[1:] err := fs.Parse(args) if err != nil { if !errors.Is(err, flag.ErrHelp) { diff --git a/server/vclient_service.go b/server/vclient_service.go index 1e0167e6..5c012b90 100644 --- a/server/vclient_service.go +++ b/server/vclient_service.go @@ -87,8 +87,6 @@ func (svr *VirtualService) Run(ctx context.Context) (err error) { svr.ctx = xlog.NewContext(ctx, xlog.New()) svr.cancel = cancel - log.Info("get svr pxy: %v", util.JSONDump(svr.pxyCfg)) - remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{ ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name, ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type, @@ -172,10 +170,12 @@ func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr strin func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) { // tell ssh client open a new stream for work payload := forwardedTCPPayload{ - Addr: svr.serverCfg.BindAddr, + Addr: svr.serverCfg.BindAddr, // TODO refine Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort), } + log.Info("get work conn payload: %v", payload) + channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload)) if err != nil { return nil, fmt.Errorf("open ssh channel error: %v", err)