diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 9d179b7f..964e9fa3 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -77,7 +77,7 @@ func (auth *JWTAuthSetterVerifier) VerifyNewWorkConn(m *msg.NewWorkConn) error { return auth.VerifyToken("", token) } -func (auth *JWTAuthSetterVerifier) VerifyToken(user, token string) error { +func (auth *JWTAuthSetterVerifier) GetVerifyData(token string) (jwt.MapClaims, error) { methodKey := map[string]string{jwt.SigningMethodHS256.Alg(): auth.secret} parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) { @@ -90,18 +90,27 @@ func (auth *JWTAuthSetterVerifier) VerifyToken(user, token string) error { if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { - return errors.New("token is expired") + return nil, errors.New("token is expired") } - return err + return nil, err } if !parsedToken.Valid { - return fmt.Errorf("token %s is invalid", token) + return nil, fmt.Errorf("token %s is invalid", token) } claims, ok := parsedToken.Claims.(jwt.MapClaims) if !ok { - return fmt.Errorf("claims %v is invalid", parsedToken.Claims) + return nil, fmt.Errorf("claims %v is invalid", parsedToken.Claims) + } + + return claims, nil +} + +func (auth *JWTAuthSetterVerifier) VerifyToken(user, token string) error { + claims, err := auth.GetVerifyData(token) + if err != nil { + return err } sub := claims["sub"] diff --git a/server/service.go b/server/service.go index 0cfa5ed6..26cd0370 100644 --- a/server/service.go +++ b/server/service.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "strconv" + "strings" "time" "github.com/fatedier/golib/crypto" @@ -286,8 +287,11 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) { address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.VhostHTTPPort)) server := &http.Server{ - Addr: address, - Handler: rp, + Addr: address, + Handler: &authMiddleware{ + next: rp, + authVerify: svr.authVerifier.(*auth.JWTAuthSetterVerifier), + }, } var l net.Listener if httpMuxOn { @@ -655,3 +659,32 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, newMsg.UseEncryption, newMsg.UseCompression, visitorUser) } + +type authMiddleware struct { + authVerify *auth.JWTAuthSetterVerifier + next http.Handler +} + +func (m authMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if !strings.HasSuffix(request.Host, "remote.agi7.ai") { + m.next.ServeHTTP(writer, request) + return + } + + cookie, err := request.Cookie("agi7.forward.auth") + if err != nil { + writer.WriteHeader(http.StatusForbidden) + writer.Write([]byte(err.Error())) + return + } + + var token = cookie.Value + _, err = m.authVerify.GetVerifyData(token) + if err != nil { + writer.WriteHeader(http.StatusForbidden) + writer.Write([]byte(err.Error())) + return + } + + m.next.ServeHTTP(writer, request) +}