feat: custom dial function

This commit is contained in:
blizard863 2021-10-16 12:48:35 +08:00
parent 0fb6aeef58
commit a37542f1bb
8 changed files with 200 additions and 91 deletions

View File

@ -234,9 +234,15 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
}
}
address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte)
opts := []frpNet.DialOption{
frpNet.WithProxy(ctl.clientCfg.HTTPProxy),
frpNet.WithProtocol(ctl.clientCfg.Protocol),
frpNet.WithRemoteAddress(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))),
frpNet.WithTLSConfig(tlsConfig),
frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte),
}
conn, err = frpNet.Dial(opts...)
if err != nil {
xl.Warn("start new connection to server error: %v", err)
return

View File

@ -790,7 +790,11 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
return
}
localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort))
opts := []frpNet.DialOption{
frpNet.WithRemoteAddress(fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort)),
}
localConn, err := frpNet.Dial(opts...)
if err != nil {
workConn.Close()
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)

View File

@ -228,8 +228,16 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
}
}
address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte)
opts := []frpNet.DialOption{
frpNet.WithProxy(svr.cfg.HTTPProxy),
frpNet.WithProtocol(svr.cfg.Protocol),
frpNet.WithBindAddress(net.JoinHostPort(svr.cfg.BindAddr, strconv.Itoa(svr.cfg.BindPort))),
frpNet.WithRemoteAddress(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))),
frpNet.WithTLSConfig(tlsConfig),
frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte),
}
conn, err = frpNet.Dial(opts...)
if err != nil {
return
}

View File

@ -38,6 +38,10 @@ type ClientCommonConf struct {
// ServerPort specifies the port to connect to the server on. By default,
// this value is 7000.
ServerPort int `ini:"server_port" json:"server_port"`
BindAddr string `ini:"bind_addr" json:"bind_addr"`
BindPort int `ini:"bind_port" json:"bind_port"`
// HTTPProxy specifies a proxy address to connect to the server through. If
// this value is "", the server will be connected to directly. By default,
// this value is read from the "http_proxy" environment variable.

View File

@ -16,18 +16,13 @@ package net
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
"time"
"github.com/fatedier/frp/pkg/util/xlog"
gnet "github.com/fatedier/golib/net"
kcp "github.com/fatedier/kcp-go"
)
type ContextGetter interface {
@ -188,56 +183,3 @@ func (statsConn *StatsConn) Close() (err error) {
}
return
}
func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
switch protocol {
case "tcp":
return net.Dial("tcp", addr)
case "kcp":
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
if errRet != nil {
err = errRet
return
}
kcpConn.SetStreamMode(true)
kcpConn.SetWriteDelay(true)
kcpConn.SetNoDelay(1, 20, 2, 1)
kcpConn.SetWindowSize(128, 512)
kcpConn.SetMtu(1350)
kcpConn.SetACKNoDelay(false)
kcpConn.SetReadBuffer(4194304)
kcpConn.SetWriteBuffer(4194304)
c = kcpConn
return
default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
}
}
func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
switch protocol {
case "tcp":
return gnet.DialTcpByProxy(proxyURL, addr)
case "kcp":
// http proxy is not supported for kcp
return ConnectServer(protocol, addr)
case "websocket":
return ConnectWebsocketServer(addr)
default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
}
}
func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) {
c, err = ConnectServerByProxy(proxyURL, protocol, addr)
if err != nil {
return
}
if tlsConfig == nil {
return
}
c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte)
return
}

164
pkg/util/net/dial.go Normal file
View File

@ -0,0 +1,164 @@
package net
import (
"crypto/tls"
"fmt"
"net"
"net/url"
"time"
gnet "github.com/fatedier/golib/net"
kcp "github.com/fatedier/kcp-go"
"golang.org/x/net/websocket"
)
type dialOptions struct {
proxyURL string
protocol string
laddr string
addr string
tlsConfig *tls.Config
disableCustomTLSHeadByte bool
}
// DialOption configures how we set up the connection.
type DialOption interface {
apply(*dialOptions)
}
type EmptyDialOption struct{}
func (EmptyDialOption) apply(*dialOptions) {}
type funcDialOption struct {
f func(*dialOptions)
}
func (fdo *funcDialOption) apply(do *dialOptions) {
fdo.f(do)
}
func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
return &funcDialOption{
f: f,
}
}
func DefaultDialOptions() dialOptions {
return dialOptions{
protocol: "tcp",
}
}
func WithProxy(proxyURL string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.proxyURL = proxyURL
})
}
func WithBindAddress(laddr string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.laddr = laddr
})
}
func WithTLSConfig(tlsConfig *tls.Config) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.tlsConfig = tlsConfig
})
}
func WithRemoteAddress(addr string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.addr = addr
})
}
func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
})
}
func WithProtocol(protocol string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.protocol = protocol
})
}
func Dial(opts ...DialOption) (c net.Conn, err error) {
op := DefaultDialOptions()
for _, opt := range opts {
opt.apply(&op)
}
c, err = dialServer(op.proxyURL, op.protocol, op.addr)
if err != nil {
return
}
if op.tlsConfig == nil {
return
}
c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
return
}
func dialServer(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
var d = &net.Dialer{}
switch protocol {
case "tcp":
return gnet.DialTcpByProxy(d, proxyURL, addr)
case "kcp":
return DialKCPServer(addr)
case "websocket":
return DialWebsocketServer(d, addr)
default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
}
}
func DialKCPServer(addr string) (net.Conn, error) {
// http proxy is not supported for kcp
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
if errRet != nil {
return nil, errRet
}
kcpConn.SetStreamMode(true)
kcpConn.SetWriteDelay(true)
kcpConn.SetNoDelay(1, 20, 2, 1)
kcpConn.SetWindowSize(128, 512)
kcpConn.SetMtu(1350)
kcpConn.SetACKNoDelay(false)
kcpConn.SetReadBuffer(4194304)
kcpConn.SetWriteBuffer(4194304)
return kcpConn, nil
}
// addr: domain:port
func DialWebsocketServer(d *net.Dialer, addr string) (net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil {
return nil, err
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, err
}
cfg.Dialer = d
cfg.Dialer.Timeout = 10 * time.Second
conn, err := websocket.DialConfig(cfg)
if err != nil {
return nil, err
}
return conn, nil
}

View File

@ -5,8 +5,6 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"time"
"golang.org/x/net/websocket"
)
@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error {
func (p *WebsocketListener) Addr() net.Addr {
return p.ln.Addr()
}
// addr: domain:port
func ConnectWebsocketServer(addr string) (net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil {
return nil, err
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, err
}
cfg.Dialer = &net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := websocket.DialConfig(cfg)
if err != nil {
return nil, err
}
return conn, nil
}

View File

@ -13,7 +13,8 @@ import (
"time"
"github.com/fatedier/frp/test/e2e/pkg/rpc"
libnet "github.com/fatedier/golib/net"
frpNet "github.com/fatedier/frp/pkg/util/net"
)
type Request struct {
@ -141,7 +142,13 @@ func (r *Request) Do() (*Response, error) {
if r.protocol != "tcp" {
return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
}
conn, err = libnet.DialTcpByProxy(r.proxyURL, addr)
opts := []frpNet.DialOption{
frpNet.WithProxy(r.proxyURL),
frpNet.WithRemoteAddress(addr),
}
conn, err = frpNet.Dial(opts...)
if err != nil {
return nil, err
}