fix: comments

This commit is contained in:
int7 2023-11-09 21:07:00 +08:00
parent fdd069915e
commit aaa9ed5167
7 changed files with 147 additions and 70 deletions

View File

@ -29,7 +29,6 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -106,8 +105,6 @@ func NewControl(
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW) ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
ctl.registerMsgHandlers() ctl.registerMsgHandlers()
ctl.xl.Info("get pxy cfgs: %v", util.JSONDump(ctl.pxyCfgs))
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel()) ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
@ -171,8 +168,6 @@ func (ctl *Control) handleNewProxyResp(m msg.Message) {
inMsg := m.(*msg.NewProxyResp) inMsg := m.(*msg.NewProxyResp)
xl.Info("proxy: %v, remote addr: %v, err: %v", inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)
// Server will return NewProxyResp message to each NewProxy message. // Server will return NewProxyResp message to each NewProxy message.
// Start a new proxy handler if no error got // Start a new proxy handler if no error got
err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)

View File

@ -21,10 +21,13 @@ import (
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
) )
type SSHGatewayConfig struct { type SSHTunnelGateway struct {
SSHBindPort int `json:"sshBindPort,omitempty" validate:"gte=0,lte=65535"` BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"`
SSHPrivateKeyFilePath string `json:"sshPrivateKeyFilePath,omitempty"` PrivateKeyFilePath string `json:"privateKeyFilePath,omitempty"`
SSHPublicKeyFilesPath string `json:"sshPublicKeyFilesPath,omitempty"` PublicKeyFilesPath string `json:"publicKeyFilesPath,omitempty"`
// store all public key file. load all when init
PublicKeyFilesMap map[string]string
} }
type ServerConfig struct { type ServerConfig struct {
@ -38,7 +41,7 @@ type ServerConfig struct {
// value is 7000. // value is 7000.
BindPort int `json:"bindPort,omitempty"` BindPort int `json:"bindPort,omitempty"`
SSHGatewayConfig SSHGatewayConfig `json:"sshGatewayConfig"` SSHTunnelGateway SSHTunnelGateway `json:"sshGatewayConfig,omitempty"`
// KCPBindPort specifies the KCP port that the server listens on. If this // KCPBindPort specifies the KCP port that the server listens on. If this
// value is 0, the server will not listen for KCP connections. // value is 0, the server will not listen for KCP connections.

View File

@ -1,6 +1,69 @@
package v1 package v1
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"os"
"path/filepath"
)
const ( const (
// custom define // custom define
SSHClientLoginUserPrefix = "_frpc_ssh_client_" SSHClientLoginUserPrefix = "_frpc_ssh_client_"
) )
// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format
func GeneratePrivateKey() ([]byte, error) {
privateKey, err := generatePrivateKey()
if err != nil {
return nil, errors.New("gen private key error")
}
privBlock := pem.Block{
Type: "RSA PRIVATE KEY",
Headers: nil,
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}
return pem.EncodeToMemory(&privBlock), nil
}
// generatePrivateKey creates a RSA Private Key of specified byte size
func generatePrivateKey() (*rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
err = privateKey.Validate()
if err != nil {
return nil, err
}
return privateKey, nil
}
func LoadFilesInDirectory(dirPath string) (map[string]string, error) {
fileMap := make(map[string]string)
files, err := os.ReadDir(dirPath)
if err != nil {
return nil, err
}
for _, file := range files {
filename := file.Name()
filePath := filepath.Join(dirPath, filename)
content, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
fileMap[filename] = string(content)
}
return fileMap, nil
}

View File

@ -19,7 +19,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
mathrand "math/rand" mathrand "math/rand"
"net" "net"
@ -145,8 +144,3 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati
func ConstantTimeEqString(a, b string) bool { func ConstantTimeEqString(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
} }
func JSONDump(v interface{}) string {
prettyJSON, _ := json.MarshalIndent(v, "", "\t")
return string(prettyJSON)
}

View File

@ -18,6 +18,7 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -205,26 +206,33 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
svr.listener = ln svr.listener = ln
log.Info("frps tcp listen on %s", address) log.Info("frps tcp listen on %s", address)
if cfg.SSHGatewayConfig.SSHBindPort > 0 { if cfg.SSHTunnelGateway.BindPort > 0 {
svr.sshConfig = &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
file := fmt.Sprintf("%v/%v.pub", cfg.SSHGatewayConfig.SSHPublicKeyFilesPath, ssh.FingerprintSHA256(key))
authorizedKey, err := os.ReadFile(file)
if err != nil {
return nil, err
}
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(authorizedKey) if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" {
cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadFilesInDirectory(cfg.SSHTunnelGateway.PublicKeyFilesPath)
if err != nil {
return nil, fmt.Errorf("load ssh all public key files error: %v", err)
}
log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath)
}
svr.sshConfig = &ssh.ServerConfig{
NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false),
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
keyContent, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)]
if !ok {
return nil, errors.New("cannot find public key file")
}
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyContent))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if key.Type() == parsedAuthorizedKey.Type() && bytes.Equal(key.Marshal(), parsedAuthorizedKey.Marshal()) { if key.Type() == parsedAuthorizedKey.Type() && bytes.Equal(key.Marshal(), parsedAuthorizedKey.Marshal()) {
log.Info("ssh file: %v fingerprint sha256: %v", file, ssh.FingerprintSHA256(key))
return &ssh.Permissions{ return &ssh.Permissions{
Extensions: map[string]string{ Extensions: map[string]string{
ssh.FingerprintSHA256(key): string(authorizedKey), ssh.FingerprintSHA256(key): keyContent,
}, },
}, nil }, nil
} }
@ -232,12 +240,22 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
}, },
} }
privateBytes, err := os.ReadFile(cfg.SSHGatewayConfig.SSHPrivateKeyFilePath) var privateBytes []byte
if err != nil { if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" {
log.Error("Failed to load private key") privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath)
return nil, err if err != nil {
log.Error("Failed to load private key")
return nil, err
}
log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath)
} else {
privateBytes, err = v1.GeneratePrivateKey()
if err != nil {
log.Error("Failed to load private key")
return nil, err
}
log.Info("auto gen private key file success")
} }
private, err := ssh.ParsePrivateKey(privateBytes) private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil { if err != nil {
log.Error("Failed to parse private key, error: %v", err) log.Error("Failed to parse private key, error: %v", err)
@ -246,7 +264,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
svr.sshConfig.AddHostKey(private) svr.sshConfig.AddHostKey(private)
sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHGatewayConfig.SSHBindPort)) sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort))
svr.sshListener, err = net.Listen("tcp", sshAddr) svr.sshListener, err = net.Listen("tcp", sshAddr)
if err != nil { if err != nil {
log.Error("Failed to listen on %v, error: %v", sshAddr, err) log.Error("Failed to listen on %v, error: %v", sshAddr, err)
@ -582,18 +600,9 @@ func (svr *Service) HandleSSHListener(listener net.Listener) {
ctx := context.Background() ctx := context.Background()
vs, err := NewVirtualService( vs, err := NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg,
ctx, msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()},
v1.ClientCommonConfig{}, svr.rc, pxyCfg, ss, replyCh)
*svr.cfg,
msg.Login{
User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String(),
},
svr.rc,
pxyCfg,
ss,
replyCh,
)
if err != nil { if err != nil {
log.Error("new virtual service error: %v", err) log.Error("new virtual service error: %v", err)
ss.Close() ss.Close()

View File

@ -18,7 +18,6 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1" v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
) )
const ( const (
@ -174,7 +173,7 @@ func (ss *SSHService) loopParseCmdPayload() {
if req.Type == RequestTypeHeartbeat { if req.Type == RequestTypeHeartbeat {
log.Debug("ssh heartbeat data") log.Debug("ssh heartbeat data")
} else { } else {
log.Info("default req, data: %v", util.JSONDump(req)) log.Info("default req, data: %v", req)
} }
} }
if req.WantReply { if req.WantReply {
@ -231,15 +230,18 @@ func (ss *SSHService) loopParseExtraPayload() {
log.Info("r.payload is less than 4") log.Info("r.payload is less than 4")
continue continue
} }
if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") {
dataLen := binary.BigEndian.Uint32(r.Payload[:4]) log.Info("ssh protocol exchange data")
p := string(r.Payload[4 : 4+dataLen])
if !strings.Contains(p, "frpc") {
log.Info("payload not contains frp keyword: %v", p)
continue continue
} }
// [4byte data_len|data]
end := 4 + binary.BigEndian.Uint32(r.Payload[:4])
if end > uint32(len(r.Payload)) {
end = uint32(len(r.Payload))
}
p := string(r.Payload[4:end])
msg, err := parseSSHExtraMessage(p) msg, err := parseSSHExtraMessage(p)
if err != nil { if err != nil {
log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload) log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload)
@ -331,35 +333,42 @@ func (ss *SSHService) loopGenerateProxy() {
ProxyBaseConfig: v1.ProxyBaseConfig{ ProxyBaseConfig: v1.ProxyBaseConfig{
Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()), Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()),
Type: p2.Type, Type: p2.Type,
ProxyBackend: v1.ProxyBackend{
LocalIP: p1.Address,
},
}, },
RemotePort: int(p1.Port), RemotePort: int(p1.Port),
} }
default: default:
log.Warn("invalid frp proxy type: %v", p2.Type) log.Warn("invalid frp proxy type: %v", p2.Type)
} }
} }
} }
func parseSSHExtraMessage(s string) (p SSHExtraPayload, err error) { func parseSSHExtraMessage(s string) (p SSHExtraPayload, err error) {
sn := len(s)
log.Info("parse ssh extra message: %v", s)
ss := strings.Fields(s) ss := strings.Fields(s)
if len(ss) <= 1 { if len(ss) == 0 {
return p, fmt.Errorf("invalid ssh input, args: %v", ss) if sn != 0 {
ss = append(ss, s)
} else {
return p, fmt.Errorf("invalid ssh input, args: %v", ss)
}
} }
for i, v := range ss { for i, v := range ss {
ss[i] = strings.TrimSpace(v) ss[i] = strings.TrimSpace(v)
} }
if ss[0] != "frpc" { if ss[0] != "tcp" && ss[0] != "http" {
return p, fmt.Errorf("first input should be frpc, but got: %v", ss[0])
}
if ss[1] != "tcp" && ss[1] != "http" {
return p, fmt.Errorf("only support tcp/http now") return p, fmt.Errorf("only support tcp/http now")
} }
switch ss[1] { switch ss[0] {
case "tcp": case "tcp":
tcpCmd, err := ParseTCPCommand(ss) tcpCmd, err := ParseTCPCommand(ss)
if err != nil { if err != nil {
@ -407,7 +416,7 @@ func ParseHTTPCommand(params []string) (*HTTPCommand, error) {
basicAuthPass string basicAuthPass string
) )
fs := flag.NewFlagSet("frpc http", flag.ContinueOnError) fs := flag.NewFlagSet("http", flag.ContinueOnError)
fs.StringVar(&basicAuth, "basic-auth", "", "") fs.StringVar(&basicAuth, "basic-auth", "", "")
fs.StringVar(&domainURL, "domain", "", "") fs.StringVar(&domainURL, "domain", "", "")
@ -442,21 +451,25 @@ type TCPCommand struct {
} }
func ParseTCPCommand(params []string) (*TCPCommand, error) { func ParseTCPCommand(params []string) (*TCPCommand, error) {
if len(params) < 2 || params[0] != "frpc" || params[1] != "tcp" { if len(params) == 0 || params[0] != "tcp" {
return nil, errors.New("invalid TCP command") return nil, errors.New("invalid TCP command")
} }
if len(params) == 1 {
return &TCPCommand{}, nil
}
var ( var (
address string address string
port string port string
) )
fs := flag.NewFlagSet("frpc tcp", flag.ContinueOnError) fs := flag.NewFlagSet("tcp", flag.ContinueOnError)
fs.StringVar(&address, "address", "", "The IP address to listen on") fs.StringVar(&address, "address", "", "The IP address to listen on")
fs.StringVar(&port, "port", "", "The port to listen on") fs.StringVar(&port, "port", "", "The port to listen on")
fs.SetOutput(&nullWriter{}) // Disables usage output fs.SetOutput(&nullWriter{}) // Disables usage output
args := params[2:] args := params[1:]
err := fs.Parse(args) err := fs.Parse(args)
if err != nil { if err != nil {
if !errors.Is(err, flag.ErrHelp) { if !errors.Is(err, flag.ErrHelp) {

View File

@ -87,8 +87,6 @@ func (svr *VirtualService) Run(ctx context.Context) (err error) {
svr.ctx = xlog.NewContext(ctx, xlog.New()) svr.ctx = xlog.NewContext(ctx, xlog.New())
svr.cancel = cancel svr.cancel = cancel
log.Info("get svr pxy: %v", util.JSONDump(svr.pxyCfg))
remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{ remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{
ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name, ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name,
ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type, ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type,
@ -172,10 +170,12 @@ func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr strin
func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) { func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) {
// tell ssh client open a new stream for work // tell ssh client open a new stream for work
payload := forwardedTCPPayload{ payload := forwardedTCPPayload{
Addr: svr.serverCfg.BindAddr, Addr: svr.serverCfg.BindAddr, // TODO refine
Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort), Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort),
} }
log.Info("get work conn payload: %v", payload)
channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload)) channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload))
if err != nil { if err != nil {
return nil, fmt.Errorf("open ssh channel error: %v", err) return nil, fmt.Errorf("open ssh channel error: %v", err)