135 lines
3.1 KiB
Go
135 lines
3.1 KiB
Go
package controller
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"golang.org/x/sync/singleflight"
|
|
|
|
"github.com/nezhahq/nezha/model"
|
|
"github.com/nezhahq/nezha/pkg/utils"
|
|
"github.com/nezhahq/nezha/service/singleton"
|
|
)
|
|
|
|
var upgrader *websocket.Upgrader
|
|
|
|
func InitUpgrader() {
|
|
var checkOrigin func(r *http.Request) bool
|
|
|
|
// Allow CORS from loopback addresses in debug mode
|
|
if singleton.Conf.Debug {
|
|
checkOrigin = func(r *http.Request) bool {
|
|
hostAddr := r.Host
|
|
host, _, err := net.SplitHostPort(hostAddr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
if ip := net.ParseIP(host); ip != nil {
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
} else {
|
|
// Handle domains like "localhost"
|
|
ip, err := net.LookupHost(host)
|
|
if err != nil || len(ip) == 0 {
|
|
return false
|
|
}
|
|
if netIP := net.ParseIP(ip[0]); netIP != nil && netIP.IsLoopback() {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
}
|
|
|
|
upgrader = &websocket.Upgrader{
|
|
ReadBufferSize: 32768,
|
|
WriteBufferSize: 32768,
|
|
CheckOrigin: checkOrigin,
|
|
}
|
|
}
|
|
|
|
// Websocket server stream
|
|
// @Summary Websocket server stream
|
|
// @tags common
|
|
// @Schemes
|
|
// @Description Websocket server stream
|
|
// @security BearerAuth
|
|
// @Produce json
|
|
// @Success 200 {object} model.StreamServerData
|
|
// @Router /ws/server [get]
|
|
func serverStream(c *gin.Context) (any, error) {
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
return nil, newWsError("%v", err)
|
|
}
|
|
defer conn.Close()
|
|
count := 0
|
|
for {
|
|
stat, err := getServerStat(c, count == 0)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil {
|
|
break
|
|
}
|
|
count += 1
|
|
if count%4 == 0 {
|
|
err = conn.WriteMessage(websocket.PingMessage, []byte{})
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
time.Sleep(time.Second * 2)
|
|
}
|
|
return nil, newWsError("")
|
|
}
|
|
|
|
var requestGroup singleflight.Group
|
|
|
|
func getServerStat(c *gin.Context, withPublicNote bool) ([]byte, error) {
|
|
_, isMember := c.Get(model.CtxKeyAuthorizedUser)
|
|
authorized := isMember // TODO || isViewPasswordVerfied
|
|
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
|
|
singleton.SortedServerLock.RLock()
|
|
defer singleton.SortedServerLock.RUnlock()
|
|
|
|
var serverList []*model.Server
|
|
if authorized {
|
|
serverList = singleton.SortedServerList
|
|
} else {
|
|
serverList = singleton.SortedServerListForGuest
|
|
}
|
|
|
|
servers := make([]model.StreamServer, 0, len(serverList))
|
|
for _, server := range serverList {
|
|
var countryCode string
|
|
if server.GeoIP != nil {
|
|
countryCode = server.GeoIP.CountryCode
|
|
}
|
|
servers = append(servers, model.StreamServer{
|
|
ID: server.ID,
|
|
Name: server.Name,
|
|
PublicNote: utils.IfOr(withPublicNote, server.PublicNote, ""),
|
|
DisplayIndex: server.DisplayIndex,
|
|
Host: server.Host,
|
|
State: server.State,
|
|
CountryCode: countryCode,
|
|
LastActive: server.LastActive,
|
|
})
|
|
}
|
|
|
|
return utils.Json.Marshal(model.StreamServerData{
|
|
Now: time.Now().Unix() * 1000,
|
|
Servers: servers,
|
|
})
|
|
})
|
|
|
|
return v.([]byte), err
|
|
}
|