diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index df5e87d8..2408beb3 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -31,6 +31,7 @@ import ( "github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/util/limit" frpNet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/golib/errors" @@ -101,6 +102,7 @@ func NewProxy(ctx context.Context, pxyConf config.ProxyConf, clientCfg config.Cl pxy = &XTCPProxy{ BaseProxy: &baseProxy, cfg: cfg, + NatType: util.GetNatType(cfg.StunServer), } case *config.SUDPProxyConf: pxy = &SUDPProxy{ @@ -274,6 +276,7 @@ type XTCPProxy struct { cfg *config.XTCPProxyConf proxyPlugin plugin.Plugin + NatType string } func (pxy *XTCPProxy) Run() (err error) { @@ -305,6 +308,7 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { natHoleClientMsg := &msg.NatHoleClient{ ProxyName: pxy.cfg.ProxyName, Sid: natHoleSidMsg.Sid, + NatType: pxy.NatType, } raddr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pxy.clientCfg.ServerAddr, pxy.serverUDPPort)) @@ -321,9 +325,9 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { return } - // Wait for client address at most 5 seconds. + // Wait for client address at most 30 seconds. var natHoleRespMsg msg.NatHoleResp - clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + clientConn.SetReadDeadline(time.Now().Add(30 * time.Second)) buf := pool.GetBuf(1024) n, err := clientConn.Read(buf) @@ -344,50 +348,77 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { return } - xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) + xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s] visitor nat type [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr, natHoleRespMsg.VisitorNatType) - // Send detect message - host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr) - if err != nil { - xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err) - } - laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) - - port, err := strconv.ParseInt(portStr, 10, 64) - if err != nil { - xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) + if pxy.NatType == "Symmetric NAT" && natHoleRespMsg.VisitorNatType == "Symmetric NAT" { + xl.Error("NAT type is Symmetric NAT!") return } - pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid)) - xl.Trace("send all detect msg done") - msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}) + // Send detect message + laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) + daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.VisitorAddr) + if err != nil { + xl.Error("resolve visitor udp address error: %v", err) + return + } // Listen for clientConn's address and wait for visitor connection lConn, err := net.ListenUDP("udp", laddr) if err != nil { - xl.Error("listen on visitorConn's local address error: %v", err) + xl.Error("listen on visitorConn's local adress error: %v", err) return } defer lConn.Close() - lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) - sidBuf := pool.GetBuf(1024) + var sid_ok bool = false var uAddr *net.UDPAddr - n, uAddr, err = lConn.ReadFromUDP(sidBuf) - if err != nil { - xl.Warn("get sid from visitor error: %v", err) - return - } - lConn.SetReadDeadline(time.Time{}) - if string(sidBuf[:n]) != natHoleRespMsg.Sid { - xl.Warn("incorrect sid from visitor") - return - } - pool.PutBuf(sidBuf) - xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + for i := 0; i < 10; i++ { - lConn.WriteToUDP(sidBuf[:n], uAddr) + var detectRange int = 0 + if natHoleRespMsg.VisitorNatType == "Symmetric NAT" { + // TODO: How to detect the port of Symmetric NAT faster? AI? + detectRange = 2000 + } + for j := -detectRange; j <= detectRange; j++ { + daddr1, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", daddr.IP.String(), daddr.Port+j)) + _, err = lConn.WriteToUDP([]byte(natHoleRespMsg.Sid), daddr1) + if err != nil { + xl.Warn("get sid from visitor error, %+v", err) + } + } + xl.Trace("send all detect msg done, detectRange=%d", detectRange) + + msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}) + + lConn.SetReadDeadline(time.Now().Add(3 * time.Second)) + + sidBuf := pool.GetBuf(1024) + for { + n, tAddr, err := lConn.ReadFromUDP(sidBuf) + if err != nil { + xl.Warn("get sid from visitor error, retry") + break + } + if string(sidBuf[:n]) == natHoleRespMsg.Sid { + uAddr = tAddr + sid_ok = true + lConn.WriteToUDP(sidBuf[:n], uAddr) + } + } + pool.PutBuf(sidBuf) + if sid_ok { + xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + break + } + } + + if !sid_ok { + xl.Warn("get sid from visitor error") + return + } + + lConn.SetReadDeadline(time.Time{}) kcpConn, err := frpNet.NewKCPConnFromUDP(lConn, false, uAddr.String()) if err != nil { diff --git a/client/visitor.go b/client/visitor.go index ea548298..703c2971 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -59,6 +59,7 @@ func NewVisitor(ctx context.Context, ctl *Control, cfg config.VisitorConf) (visi visitor = &XTCPVisitor{ BaseVisitor: &baseVisitor, cfg: cfg, + NatType: util.GetNatType(cfg.StunServer), } case *config.SUDPVisitorConf: visitor = &SUDPVisitor{ @@ -171,7 +172,8 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) { type XTCPVisitor struct { *BaseVisitor - cfg *config.XTCPVisitorConf + cfg *config.XTCPVisitorConf + NatType string } func (sv *XTCPVisitor) Run() (err error) { @@ -230,6 +232,7 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { ProxyName: sv.cfg.ServerName, SignKey: util.GetAuthKey(sv.cfg.Sk, now), Timestamp: now, + NatType: sv.NatType, } err = msg.WriteMsg(visitorConn, natHoleVisitorMsg) if err != nil { @@ -237,9 +240,9 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { return } - // Wait for client address at most 10 seconds. + // Wait for client address at most 30 seconds. var natHoleRespMsg msg.NatHoleResp - visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + visitorConn.SetReadDeadline(time.Now().Add(30 * time.Second)) buf := pool.GetBuf(1024) n, err := visitorConn.Read(buf) if err != nil { @@ -260,11 +263,16 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { return } - xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) + xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s] client nat type [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr, natHoleRespMsg.ClientNatType) // Close visitorConn, so we can use it's local address. visitorConn.Close() + if sv.NatType == "Symmetric NAT" && natHoleRespMsg.ClientNatType == "Symmetric NAT" { + xl.Error("NAT type is Symmetric NAT!") + return + } + // send sid message to client laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) @@ -272,35 +280,62 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { xl.Error("resolve client udp address error: %v", err) return } - lConn, err := net.DialUDP("udp", laddr, daddr) + lConn, err := net.ListenUDP("udp", laddr) if err != nil { xl.Error("dial client udp address error: %v", err) return } defer lConn.Close() - lConn.Write([]byte(natHoleRespMsg.Sid)) + var sid_ok bool = false + var uAddr *net.UDPAddr + for i := 0; i < 10; i++ { + var detectRange int = 0 + if natHoleRespMsg.ClientNatType == "Symmetric NAT" { + // TODO: How to detect the port of Symmetric NAT faster? AI? + detectRange = 2000 + } + for j := -detectRange; j <= detectRange; j++ { + daddr1, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", daddr.IP.String(), daddr.Port+j)) + _, err = lConn.WriteToUDP([]byte(natHoleRespMsg.Sid), daddr1) + if err != nil { + xl.Warn("get sid from client error, %+v", err) + } + } + xl.Trace("send all detect msg done, detectRange=%d", detectRange) - // read ack sid from client - sidBuf := pool.GetBuf(1024) - lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) - n, err = lConn.Read(sidBuf) - if err != nil { - xl.Warn("get sid from client error: %v", err) + // read ack sid from client + lConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + + sidBuf := pool.GetBuf(1024) + for { + n, tAddr, err := lConn.ReadFromUDP(sidBuf) + if err != nil { + xl.Warn("get sid from client error, retry") + break + } + if string(sidBuf[:n]) == natHoleRespMsg.Sid { + sid_ok = true + uAddr = tAddr + } + } + pool.PutBuf(sidBuf) + if sid_ok { + xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + break + } + } + + if !sid_ok { + xl.Error("get sid from client error") return } + lConn.SetReadDeadline(time.Time{}) - if string(sidBuf[:n]) != natHoleRespMsg.Sid { - xl.Warn("incorrect sid from client") - return - } - pool.PutBuf(sidBuf) - - xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) // wrap kcp connection var remote io.ReadWriteCloser - remote, err = frpNet.NewKCPConnFromUDP(lConn, true, natHoleRespMsg.ClientAddr) + remote, err = frpNet.NewKCPConnFromUDP(lConn, false, uAddr.String()) if err != nil { xl.Error("create kcp connection from udp connection error: %v", err) return diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index f8d7eb17..8c875e74 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -72,6 +72,7 @@ var ( serverName string bindAddr string bindPort int + stunServer string tlsEnable bool ) diff --git a/cmd/frpc/sub/xtcp.go b/cmd/frpc/sub/xtcp.go index 6c6c7d84..543a45c2 100644 --- a/cmd/frpc/sub/xtcp.go +++ b/cmd/frpc/sub/xtcp.go @@ -37,6 +37,7 @@ func init() { xtcpCmd.PersistentFlags().IntVarP(&bindPort, "bind_port", "", 0, "bind port") xtcpCmd.PersistentFlags().BoolVarP(&useEncryption, "ue", "", false, "use encryption") xtcpCmd.PersistentFlags().BoolVarP(&useCompression, "uc", "", false, "use compression") + xtcpCmd.PersistentFlags().StringVarP(&stunServer, "stun_server", "", "stun.voipbuster.com:3478", "stun server") rootCmd.AddCommand(xtcpCmd) } @@ -69,6 +70,7 @@ var xtcpCmd = &cobra.Command{ cfg.Sk = sk cfg.LocalIP = localIP cfg.LocalPort = localPort + cfg.StunServer = stunServer err = cfg.CheckForCli() if err != nil { fmt.Println(err) @@ -86,6 +88,7 @@ var xtcpCmd = &cobra.Command{ cfg.ServerName = serverName cfg.BindAddr = bindAddr cfg.BindPort = bindPort + cfg.StunServer = stunServer err = cfg.Check() if err != nil { fmt.Println(err) diff --git a/go.mod b/go.mod index 0a92dfb5..d1a8734c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.16 require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 + github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029 github.com/coreos/go-oidc v2.2.1+incompatible github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb github.com/fatedier/golib v0.1.1-0.20220321042308-c306138b83ac diff --git a/go.sum b/go.sum index 433f03e2..41e6b30f 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= +github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029 h1:POmUHfxXdeyM8Aomg4tKDcwATCFuW+cYLkj6pwsw9pc= +github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029/go.mod h1:Rpr5n9cGHYdM3S3IK8ROSUUUYjQOu+MSUCZDcJbYWi8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index c000bb30..de181cc4 100644 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -196,8 +196,9 @@ type STCPProxyConf struct { type XTCPProxyConf struct { BaseProxyConf `ini:",extends"` - Role string `ini:"role" json:"role"` - Sk string `ini:"sk" json:"sk"` + Role string `ini:"role" json:"role"` + Sk string `ini:"sk" json:"sk"` + StunServer string `ini:"stun_server" json:"stun_server"` } // UDP @@ -299,7 +300,6 @@ func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (ProxyC } conf.UnmarshalFromMsg(pMsg) - err := conf.CheckForSvr(serverCfg) if err != nil { return nil, err @@ -1027,6 +1027,10 @@ func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section * cfg.Role = "server" } + if cfg.StunServer == "" { + cfg.StunServer = "stun.voipstunt.com:3478" + } + return nil } diff --git a/pkg/config/visitor.go b/pkg/config/visitor.go index aede14d8..a726f0ed 100644 --- a/pkg/config/visitor.go +++ b/pkg/config/visitor.go @@ -61,6 +61,7 @@ type STCPVisitorConf struct { type XTCPVisitorConf struct { BaseVisitorConf `ini:",extends"` + StunServer string `ini:"stun_server" json:"stun_server"` } // DefaultVisitorConf creates a empty VisitorConf object by visitorType. @@ -269,7 +270,9 @@ func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section } // Add custom logic unmarshal, if exists - + if cfg.StunServer == "" { + cfg.StunServer = "stun.voipstunt.com:3478" + } return } diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index 6e59b7cf..5e539533 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -172,18 +172,22 @@ type NatHoleVisitor struct { ProxyName string `json:"proxy_name"` SignKey string `json:"sign_key"` Timestamp int64 `json:"timestamp"` + NatType string `json:"nat_type"` } type NatHoleClient struct { ProxyName string `json:"proxy_name"` Sid string `json:"sid"` + NatType string `json:"nat_type"` } type NatHoleResp struct { - Sid string `json:"sid"` - VisitorAddr string `json:"visitor_addr"` - ClientAddr string `json:"client_addr"` - Error string `json:"error"` + Sid string `json:"sid"` + VisitorAddr string `json:"visitor_addr"` + ClientAddr string `json:"client_addr"` + Error string `json:"error"` + VisitorNatType string `json:"visitor_nat_type"` + ClientNatType string `json:"client_nat_type"` } type NatHoleClientDetectOK struct { diff --git a/pkg/nathole/nathole.go b/pkg/nathole/nathole.go index 545ad7aa..db32145e 100644 --- a/pkg/nathole/nathole.go +++ b/pkg/nathole/nathole.go @@ -15,8 +15,19 @@ import ( "github.com/fatedier/golib/pool" ) +/* ++---------+--------------------+--------------------+--------------------+---------------------+--------------------+ +| O: SUPPORT X: UNSUPPORT | Full cone NAT | IP Restricted NAT | Port Restricted NAT | Symmetric NAT | ++---------+--------------------+--------------------+--------------------+---------------------+--------------------+ +| |Full cone NAT | O | O | O | O | +| VISITOR |IP Restricted NAT | O | O | O | O | +| |Port restricted NAT | O | O | O | O | +| |Symmetric NAT | O | O | O | X | ++---------+--------------------+--------------------+--------------------+---------------------+--------------------+ +*/ + // Timeout seconds. -var NatHoleTimeout int64 = 10 +var NatHoleTimeout int64 = 30 type SidRequest struct { Sid string @@ -105,9 +116,10 @@ func (nc *Controller) GenSid() string { func (nc *Controller) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) { sid := nc.GenSid() session := &Session{ - Sid: sid, - VisitorAddr: raddr, - NotifyCh: make(chan struct{}, 0), + Sid: sid, + VisitorAddr: raddr, + VisitorNatType: m.NatType, + NotifyCh: make(chan struct{}, 0), } nc.mu.Lock() clientCfg, ok := nc.clientCfgs[m.ProxyName] @@ -166,6 +178,7 @@ func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) { } log.Trace("handle client message, sid [%s]", session.Sid) session.ClientAddr = raddr + session.ClientNatType = m.NatType resp := nc.GenNatHoleResponse(session, "") log.Trace("send nat hole response to client") @@ -174,20 +187,26 @@ func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) { func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byte { var ( - sid string - visitorAddr string - clientAddr string + sid string + visitorAddr string + clientAddr string + visitorNatType string + clientNatType string ) if session != nil { sid = session.Sid visitorAddr = session.VisitorAddr.String() clientAddr = session.ClientAddr.String() + visitorNatType = session.VisitorNatType + clientNatType = session.ClientNatType } m := &msg.NatHoleResp{ - Sid: sid, - VisitorAddr: visitorAddr, - ClientAddr: clientAddr, - Error: errInfo, + Sid: sid, + VisitorAddr: visitorAddr, + ClientAddr: clientAddr, + Error: errInfo, + VisitorNatType: visitorNatType, + ClientNatType: clientNatType, } b := bytes.NewBuffer(nil) err := msg.WriteMsg(b, m) @@ -198,9 +217,11 @@ func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byt } type Session struct { - Sid string - VisitorAddr *net.UDPAddr - ClientAddr *net.UDPAddr + Sid string + VisitorAddr *net.UDPAddr + ClientAddr *net.UDPAddr + VisitorNatType string + ClientNatType string NotifyCh chan struct{} } diff --git a/pkg/util/util/util.go b/pkg/util/util/util.go index 032675fb..f5461cdd 100644 --- a/pkg/util/util/util.go +++ b/pkg/util/util/util.go @@ -24,6 +24,8 @@ import ( "strconv" "strings" "time" + + "github.com/ccding/go-stun/stun" ) // RandID return a rand string used in frp. @@ -125,3 +127,24 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati time.Sleep(d) return d } + +var g_nat_type string = "" + +func GetNatType(stunServer string) string { + // 检测NAT 类型 + if stunServer != "disable" { + if g_nat_type == "" { + client := stun.NewClient() + client.SetServerAddr(stunServer) + nat, _, err := client.Discover() + if err != nil { + nat = stun.NATError + } + g_nat_type = nat.String() + } + } else { + g_nat_type = stun.NATUnknown.String() + } + + return g_nat_type +}