From 380973a2009ac09595bd5a8bbdd3b2499fc6e337 Mon Sep 17 00:00:00 2001 From: UUBulb <35923940+uubulb@users.noreply.github.com> Date: Fri, 25 Oct 2024 21:45:05 +0800 Subject: [PATCH] prevent writing response to websocket connections (#457) --- cmd/dashboard/controller/controller.go | 27 ++++++++++++++++++++++++-- cmd/dashboard/controller/fm.go | 7 ++++--- cmd/dashboard/controller/terminal.go | 5 +++-- cmd/dashboard/controller/ws.go | 2 +- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 87eddfe..0376b15 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -136,6 +136,22 @@ func (ge *gormError) Error() string { return fmt.Sprintf(ge.msg, ge.a...) } +type wsError struct { + msg string + a []interface{} +} + +func newWsError(format string, args ...interface{}) error { + return &wsError{ + msg: format, + a: args, + } +} + +func (we *wsError) Error() string { + return fmt.Sprintf(we.msg, we.a...) +} + func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) { return func(c *gin.Context) { data, err := handler(c) @@ -143,11 +159,18 @@ func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) { c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data}) return } - if _, ok := err.(*gormError); ok { + switch err.(type) { + case *gormError: log.Printf("NEZHA>> gorm error: %v", err) c.JSON(http.StatusOK, newErrorResponse(errors.New("database error"))) return - } else { + case *wsError: + // Connection is upgraded to WebSocket, so c.Writer is no longer usable + if msg := err.Error(); msg != "" { + log.Printf("NEZHA>> websocket error: %v", err) + } + return + default: c.JSON(http.StatusOK, newErrorResponse(err)) return } diff --git a/cmd/dashboard/controller/fm.go b/cmd/dashboard/controller/fm.go index 6541ffe..d593042 100644 --- a/cmd/dashboard/controller/fm.go +++ b/cmd/dashboard/controller/fm.go @@ -21,7 +21,7 @@ import ( // @Description Create an "attached" FM. It is advised to only call this within a terminal session. // @Tags auth required // @Accept json -// @Param id path uint true "Server ID" +// @Param id query uint true "Server ID" // @Produce json // @Success 200 {object} model.CreateFMResponse // @Router /file [get] @@ -66,6 +66,7 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) { // @Description Start FM stream // @Tags auth required // @Param id path string true "Stream UUID" +// @Success 200 {object} model.CommonResponse[any] // @Router /ws/file/{id} [get] func fmStream(c *gin.Context) (any, error) { streamId := c.Param("id") @@ -92,8 +93,8 @@ func fmStream(c *gin.Context) (any, error) { }() if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil { - return nil, err + return nil, newWsError("%v", err) } - return nil, rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) + return nil, newWsError("%v", rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)) } diff --git a/cmd/dashboard/controller/terminal.go b/cmd/dashboard/controller/terminal.go index 7ffbd04..1eb54c3 100644 --- a/cmd/dashboard/controller/terminal.go +++ b/cmd/dashboard/controller/terminal.go @@ -66,6 +66,7 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) { // @Description Terminal stream // @Tags auth required // @Param id path string true "Stream UUID" +// @Success 200 {object} model.CommonResponse[any] // @Router /ws/terminal/{id} [get] func terminalStream(c *gin.Context) (any, error) { streamId := c.Param("id") @@ -92,8 +93,8 @@ func terminalStream(c *gin.Context) (any, error) { }() if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil { - return nil, err + return nil, newWsError("%v", err) } - return nil, rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) + return nil, newWsError("%v", rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)) } diff --git a/cmd/dashboard/controller/ws.go b/cmd/dashboard/controller/ws.go index c4f1711..53817df 100644 --- a/cmd/dashboard/controller/ws.go +++ b/cmd/dashboard/controller/ws.go @@ -51,7 +51,7 @@ func serverStream(c *gin.Context) (any, error) { } time.Sleep(time.Second * 2) } - return nil, nil + return nil, newWsError("") } var requestGroup singleflight.Group