fix: comments

This commit is contained in:
int7 2023-11-09 21:07:00 +08:00
parent fdd069915e
commit aaa9ed5167
7 changed files with 147 additions and 70 deletions

View File

@ -29,7 +29,6 @@ import (
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/wait"
"github.com/fatedier/frp/pkg/util/xlog"
)
@ -106,8 +105,6 @@ func NewControl(
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
ctl.registerMsgHandlers()
ctl.xl.Info("get pxy cfgs: %v", util.JSONDump(ctl.pxyCfgs))
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
@ -171,8 +168,6 @@ func (ctl *Control) handleNewProxyResp(m msg.Message) {
inMsg := m.(*msg.NewProxyResp)
xl.Info("proxy: %v, remote addr: %v, err: %v", inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)
// Server will return NewProxyResp message to each NewProxy message.
// Start a new proxy handler if no error got
err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)

View File

@ -21,10 +21,13 @@ import (
"github.com/fatedier/frp/pkg/util/util"
)
type SSHGatewayConfig struct {
SSHBindPort int `json:"sshBindPort,omitempty" validate:"gte=0,lte=65535"`
SSHPrivateKeyFilePath string `json:"sshPrivateKeyFilePath,omitempty"`
SSHPublicKeyFilesPath string `json:"sshPublicKeyFilesPath,omitempty"`
type SSHTunnelGateway struct {
BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"`
PrivateKeyFilePath string `json:"privateKeyFilePath,omitempty"`
PublicKeyFilesPath string `json:"publicKeyFilesPath,omitempty"`
// store all public key file. load all when init
PublicKeyFilesMap map[string]string
}
type ServerConfig struct {
@ -38,7 +41,7 @@ type ServerConfig struct {
// value is 7000.
BindPort int `json:"bindPort,omitempty"`
SSHGatewayConfig SSHGatewayConfig `json:"sshGatewayConfig"`
SSHTunnelGateway SSHTunnelGateway `json:"sshGatewayConfig,omitempty"`
// KCPBindPort specifies the KCP port that the server listens on. If this
// value is 0, the server will not listen for KCP connections.

View File

@ -1,6 +1,69 @@
package v1
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"os"
"path/filepath"
)
const (
// custom define
SSHClientLoginUserPrefix = "_frpc_ssh_client_"
)
// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format
func GeneratePrivateKey() ([]byte, error) {
privateKey, err := generatePrivateKey()
if err != nil {
return nil, errors.New("gen private key error")
}
privBlock := pem.Block{
Type: "RSA PRIVATE KEY",
Headers: nil,
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}
return pem.EncodeToMemory(&privBlock), nil
}
// generatePrivateKey creates a RSA Private Key of specified byte size
func generatePrivateKey() (*rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
err = privateKey.Validate()
if err != nil {
return nil, err
}
return privateKey, nil
}
func LoadFilesInDirectory(dirPath string) (map[string]string, error) {
fileMap := make(map[string]string)
files, err := os.ReadDir(dirPath)
if err != nil {
return nil, err
}
for _, file := range files {
filename := file.Name()
filePath := filepath.Join(dirPath, filename)
content, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
fileMap[filename] = string(content)
}
return fileMap, nil
}

View File

@ -19,7 +19,6 @@ import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
mathrand "math/rand"
"net"
@ -145,8 +144,3 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati
func ConstantTimeEqString(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
func JSONDump(v interface{}) string {
prettyJSON, _ := json.MarshalIndent(v, "", "\t")
return string(prettyJSON)
}

View File

@ -18,6 +18,7 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
@ -205,26 +206,33 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
svr.listener = ln
log.Info("frps tcp listen on %s", address)
if cfg.SSHGatewayConfig.SSHBindPort > 0 {
svr.sshConfig = &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
file := fmt.Sprintf("%v/%v.pub", cfg.SSHGatewayConfig.SSHPublicKeyFilesPath, ssh.FingerprintSHA256(key))
authorizedKey, err := os.ReadFile(file)
if err != nil {
return nil, err
}
if cfg.SSHTunnelGateway.BindPort > 0 {
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(authorizedKey)
if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" {
cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadFilesInDirectory(cfg.SSHTunnelGateway.PublicKeyFilesPath)
if err != nil {
return nil, fmt.Errorf("load ssh all public key files error: %v", err)
}
log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath)
}
svr.sshConfig = &ssh.ServerConfig{
NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false),
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
keyContent, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)]
if !ok {
return nil, errors.New("cannot find public key file")
}
parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyContent))
if err != nil {
return nil, err
}
if key.Type() == parsedAuthorizedKey.Type() && bytes.Equal(key.Marshal(), parsedAuthorizedKey.Marshal()) {
log.Info("ssh file: %v fingerprint sha256: %v", file, ssh.FingerprintSHA256(key))
return &ssh.Permissions{
Extensions: map[string]string{
ssh.FingerprintSHA256(key): string(authorizedKey),
ssh.FingerprintSHA256(key): keyContent,
},
}, nil
}
@ -232,12 +240,22 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
},
}
privateBytes, err := os.ReadFile(cfg.SSHGatewayConfig.SSHPrivateKeyFilePath)
if err != nil {
log.Error("Failed to load private key")
return nil, err
var privateBytes []byte
if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" {
privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath)
if err != nil {
log.Error("Failed to load private key")
return nil, err
}
log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath)
} else {
privateBytes, err = v1.GeneratePrivateKey()
if err != nil {
log.Error("Failed to load private key")
return nil, err
}
log.Info("auto gen private key file success")
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Error("Failed to parse private key, error: %v", err)
@ -246,7 +264,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
svr.sshConfig.AddHostKey(private)
sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHGatewayConfig.SSHBindPort))
sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort))
svr.sshListener, err = net.Listen("tcp", sshAddr)
if err != nil {
log.Error("Failed to listen on %v, error: %v", sshAddr, err)
@ -582,18 +600,9 @@ func (svr *Service) HandleSSHListener(listener net.Listener) {
ctx := context.Background()
vs, err := NewVirtualService(
ctx,
v1.ClientCommonConfig{},
*svr.cfg,
msg.Login{
User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String(),
},
svr.rc,
pxyCfg,
ss,
replyCh,
)
vs, err := NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg,
msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()},
svr.rc, pxyCfg, ss, replyCh)
if err != nil {
log.Error("new virtual service error: %v", err)
ss.Close()

View File

@ -18,7 +18,6 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
)
const (
@ -174,7 +173,7 @@ func (ss *SSHService) loopParseCmdPayload() {
if req.Type == RequestTypeHeartbeat {
log.Debug("ssh heartbeat data")
} else {
log.Info("default req, data: %v", util.JSONDump(req))
log.Info("default req, data: %v", req)
}
}
if req.WantReply {
@ -231,15 +230,18 @@ func (ss *SSHService) loopParseExtraPayload() {
log.Info("r.payload is less than 4")
continue
}
dataLen := binary.BigEndian.Uint32(r.Payload[:4])
p := string(r.Payload[4 : 4+dataLen])
if !strings.Contains(p, "frpc") {
log.Info("payload not contains frp keyword: %v", p)
if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") {
log.Info("ssh protocol exchange data")
continue
}
// [4byte data_len|data]
end := 4 + binary.BigEndian.Uint32(r.Payload[:4])
if end > uint32(len(r.Payload)) {
end = uint32(len(r.Payload))
}
p := string(r.Payload[4:end])
msg, err := parseSSHExtraMessage(p)
if err != nil {
log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload)
@ -331,35 +333,42 @@ func (ss *SSHService) loopGenerateProxy() {
ProxyBaseConfig: v1.ProxyBaseConfig{
Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()),
Type: p2.Type,
ProxyBackend: v1.ProxyBackend{
LocalIP: p1.Address,
},
},
RemotePort: int(p1.Port),
}
default:
log.Warn("invalid frp proxy type: %v", p2.Type)
}
}
}
func parseSSHExtraMessage(s string) (p SSHExtraPayload, err error) {
sn := len(s)
log.Info("parse ssh extra message: %v", s)
ss := strings.Fields(s)
if len(ss) <= 1 {
return p, fmt.Errorf("invalid ssh input, args: %v", ss)
if len(ss) == 0 {
if sn != 0 {
ss = append(ss, s)
} else {
return p, fmt.Errorf("invalid ssh input, args: %v", ss)
}
}
for i, v := range ss {
ss[i] = strings.TrimSpace(v)
}
if ss[0] != "frpc" {
return p, fmt.Errorf("first input should be frpc, but got: %v", ss[0])
}
if ss[1] != "tcp" && ss[1] != "http" {
if ss[0] != "tcp" && ss[0] != "http" {
return p, fmt.Errorf("only support tcp/http now")
}
switch ss[1] {
switch ss[0] {
case "tcp":
tcpCmd, err := ParseTCPCommand(ss)
if err != nil {
@ -407,7 +416,7 @@ func ParseHTTPCommand(params []string) (*HTTPCommand, error) {
basicAuthPass string
)
fs := flag.NewFlagSet("frpc http", flag.ContinueOnError)
fs := flag.NewFlagSet("http", flag.ContinueOnError)
fs.StringVar(&basicAuth, "basic-auth", "", "")
fs.StringVar(&domainURL, "domain", "", "")
@ -442,21 +451,25 @@ type TCPCommand struct {
}
func ParseTCPCommand(params []string) (*TCPCommand, error) {
if len(params) < 2 || params[0] != "frpc" || params[1] != "tcp" {
if len(params) == 0 || params[0] != "tcp" {
return nil, errors.New("invalid TCP command")
}
if len(params) == 1 {
return &TCPCommand{}, nil
}
var (
address string
port string
)
fs := flag.NewFlagSet("frpc tcp", flag.ContinueOnError)
fs := flag.NewFlagSet("tcp", flag.ContinueOnError)
fs.StringVar(&address, "address", "", "The IP address to listen on")
fs.StringVar(&port, "port", "", "The port to listen on")
fs.SetOutput(&nullWriter{}) // Disables usage output
args := params[2:]
args := params[1:]
err := fs.Parse(args)
if err != nil {
if !errors.Is(err, flag.ErrHelp) {

View File

@ -87,8 +87,6 @@ func (svr *VirtualService) Run(ctx context.Context) (err error) {
svr.ctx = xlog.NewContext(ctx, xlog.New())
svr.cancel = cancel
log.Info("get svr pxy: %v", util.JSONDump(svr.pxyCfg))
remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{
ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name,
ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type,
@ -172,10 +170,12 @@ func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr strin
func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) {
// tell ssh client open a new stream for work
payload := forwardedTCPPayload{
Addr: svr.serverCfg.BindAddr,
Addr: svr.serverCfg.BindAddr, // TODO refine
Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort),
}
log.Info("get work conn payload: %v", payload)
channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload))
if err != nil {
return nil, fmt.Errorf("open ssh channel error: %v", err)