Connor McCutcheon
/ SkyCode
ws.go
go
package controllers
import (
	"encoding/json"
	"io"
	"log"
	"net/http"
	"os"
	"os/exec"
	"path/filepath"
	"strings"
	"sync"
	"time"
	"github.com/creack/pty"
	"github.com/google/uuid"
	"github.com/gorilla/websocket"
	"theskyscape.com/repo/skycode/models"
	"theskyscape.com/repo/skykit"
)
// WebSocket and terminal constants
const (
	// WebSocket buffer sizes
	wsReadBufferSize  = 1024
	wsWriteBufferSize = 1024
	// Maximum terminals per session
	maxTerminalsPerSession = 5
	// Initial PTY dimensions
	initialPTYRows = 24
	initialPTYCols = 80
	// Session cookie expiry (30 minutes) - shared with code.go
	wsSessionCookieMaxAge = 30 * 60
)
func WS() (string, skykit.Handler) {
	return "ws", &WSController{}
}
type WSController struct {
	skykit.Controller
}
func (c *WSController) Setup(app *skykit.Application) {
	c.Controller.Setup(app)
	// WebSocket endpoint for PTY terminal
	http.HandleFunc("GET /api/terminal", c.Protect(c.handleTerminal, c.requireAuth))
	// Terminal management endpoints
	http.HandleFunc("GET /api/terminals", c.Protect(c.listTerminals, c.requireAuth))
	http.HandleFunc("POST /api/terminals", c.Protect(c.createTerminal, c.requireAuth))
	http.HandleFunc("DELETE /api/terminals/{id}", c.Protect(c.deleteTerminal, c.requireAuth))
}
func (c WSController) Handle(r *http.Request) skykit.Handler {
	c.Request = r
	return &c
}
func (c *WSController) requireAuth(app *skykit.Application, w http.ResponseWriter, r *http.Request) bool {
	user, err := app.Users.Authenticate(r)
	if err != nil || user == nil {
		http.Error(w, "unauthorized", http.StatusUnauthorized)
		return false
	}
	return true
}
var upgrader = websocket.Upgrader{
	ReadBufferSize:  wsReadBufferSize,
	WriteBufferSize: wsWriteBufferSize,
	// Validate origin header to prevent WebSocket hijacking
	CheckOrigin: func(r *http.Request) bool {
		origin := r.Header.Get("Origin")
		if origin == "" {
			// Allow requests without Origin (same-origin requests from some browsers)
			return true
		}
		// Check if origin matches the request host
		host := r.Host
		// Allow if origin matches the host (handles http/https and port variations)
		return origin == "http://"+host || origin == "https://"+host
	},
}
// TerminalSession represents a PTY terminal session
type TerminalSession struct {
	ID     string
	Name   string
	PTY    *os.File
	Cmd    *exec.Cmd
	mu     sync.Mutex
	closed bool
}
// Close terminates the terminal session
func (ts *TerminalSession) Close() {
	ts.mu.Lock()
	defer ts.mu.Unlock()
	if ts.closed {
		return
	}
	ts.closed = true
	if ts.Cmd != nil && ts.Cmd.Process != nil {
		ts.Cmd.Process.Kill()
	}
	if ts.PTY != nil {
		ts.PTY.Close()
	}
}
var (
	// userTerminals maps sessionID -> terminalID -> TerminalSession
	userTerminals = make(map[string]map[string]*TerminalSession)
	terminalsMu   sync.RWMutex
)
// getTerminals returns all terminals for a session
func getTerminals(sessionID string) map[string]*TerminalSession {
	terminalsMu.RLock()
	defer terminalsMu.RUnlock()
	if terms, ok := userTerminals[sessionID]; ok {
		return terms
	}
	return nil
}
// getTerminal returns a specific terminal
func getTerminal(sessionID, terminalID string) *TerminalSession {
	terminalsMu.RLock()
	defer terminalsMu.RUnlock()
	if terms, ok := userTerminals[sessionID]; ok {
		return terms[terminalID]
	}
	return nil
}
// createTerminalSession creates a new terminal session
func createTerminalSession(sessionID, userID, workDir string, name string) (*TerminalSession, error) {
	terminalsMu.Lock()
	defer terminalsMu.Unlock()
	// Ensure session map exists
	if userTerminals[sessionID] == nil {
		userTerminals[sessionID] = make(map[string]*TerminalSession)
	}
	// Limit terminals per session
	if len(userTerminals[sessionID]) >= maxTerminalsPerSession {
		return nil, &tooManyTerminalsError{}
	}
	terminalID := uuid.NewString()
	if name == "" {
		name = "Terminal " + string('1'+rune(len(userTerminals[sessionID])))
	}
	// Ensure user tool directories exist (persists across sessions)
	userToolDir, err := models.EnsureUserToolDirs(userID)
	if err != nil {
		log.Printf("Warning: failed to create user tool dirs: %v", err)
		userToolDir = workDir // Fallback to workspace
	}
	// Build PATH with user tool directories first
	userPath := strings.Join([]string{
		filepath.Join(userToolDir, ".local/bin"),
		filepath.Join(userToolDir, ".npm-global/bin"),
		filepath.Join(userToolDir, "go/bin"),
		filepath.Join(userToolDir, ".cargo/bin"),
		os.Getenv("PATH"),
	}, ":")
	// Create PTY
	cmd := exec.Command("bash")
	cmd.Dir = workDir
	cmd.Env = append(os.Environ(),
		"TERM=xterm-256color",
		"HOME="+workDir, // Set HOME to workspace so ~ and cd work correctly
		"PATH="+userPath,
		"NPM_CONFIG_PREFIX="+filepath.Join(userToolDir, ".npm-global"),
		"NPM_CONFIG_CACHE="+filepath.Join(models.CacheDir, "npm"),
		"GOPATH="+filepath.Join(userToolDir, "go"),
		"GOMODCACHE="+filepath.Join(models.CacheDir, "go/mod"),
		"CARGO_HOME="+filepath.Join(userToolDir, ".cargo"),
		"PS1=\\[\\033[32m\\]\\u@skycode\\[\\033[0m\\]:\\[\\033[34m\\]\\w\\[\\033[0m\\]$ ",
	)
	ptmx, err := pty.Start(cmd)
	if err != nil {
		return nil, err
	}
	// Set initial size
	pty.Setsize(ptmx, &pty.Winsize{Rows: initialPTYRows, Cols: initialPTYCols})
	ts := &TerminalSession{
		ID:   terminalID,
		Name: name,
		PTY:  ptmx,
		Cmd:  cmd,
	}
	userTerminals[sessionID][terminalID] = ts
	return ts, nil
}
// deleteTerminalSession removes a terminal session
func deleteTerminalSession(sessionID, terminalID string) {
	terminalsMu.Lock()
	defer terminalsMu.Unlock()
	if terms, ok := userTerminals[sessionID]; ok {
		if ts, ok := terms[terminalID]; ok {
			ts.Close()
			delete(terms, terminalID)
		}
		// Clean up session if no terminals left
		if len(terms) == 0 {
			delete(userTerminals, sessionID)
		}
	}
}
type tooManyTerminalsError struct{}
func (e *tooManyTerminalsError) Error() string {
	return "maximum number of terminals reached"
}
// Terminal resize message
type resizeMsg struct {
	Type string `json:"type"`
	Cols int    `json:"cols"`
	Rows int    `json:"rows"`
}
// GET /api/terminals - list terminals
func (c *WSController) listTerminals(w http.ResponseWriter, r *http.Request) {
	sessionID := c.getOrCreateSessionID(w, r)
	terminals := getTerminals(sessionID)
	list := []map[string]string{}
	for id, ts := range terminals {
		list = append(list, map[string]string{
			"id":   id,
			"name": ts.Name,
		})
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(list)
}
// POST /api/terminals - create new terminal
func (c *WSController) createTerminal(w http.ResponseWriter, r *http.Request) {
	user, _ := c.Users.Authenticate(r)
	sessionID := c.getOrCreateSessionID(w, r)
	sess, err := models.GetOrCreateSession(sessionID, user.ID)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	// Initialize workspace if needed (thread-safe, runs only once)
	if err := sess.InitializeOnce(); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	var body struct {
		Name string `json:"name"`
	}
	json.NewDecoder(r.Body).Decode(&body)
	ts, err := createTerminalSession(sessionID, user.ID, sess.WorkDir, body.Name)
	if err != nil {
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusBadRequest)
		json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
		return
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(map[string]string{
		"id":   ts.ID,
		"name": ts.Name,
	})
}
// DELETE /api/terminals/{id} - delete terminal
func (c *WSController) deleteTerminal(w http.ResponseWriter, r *http.Request) {
	sessionID := c.getOrCreateSessionID(w, r)
	terminalID := r.PathValue("id")
	deleteTerminalSession(sessionID, terminalID)
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(map[string]bool{"ok": true})
}
// GET /api/terminal?id=X - WebSocket terminal connection
func (c *WSController) handleTerminal(w http.ResponseWriter, r *http.Request) {
	user, _ := c.Users.Authenticate(r)
	// Get or create session
	sessionID := c.getOrCreateSessionID(w, r)
	sess, err := models.GetOrCreateSession(sessionID, user.ID)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	// Initialize workspace if needed (thread-safe, runs only once)
	if err := sess.InitializeOnce(); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	// Ensure workspace is still present (Docker Swarm recovery)
	if err := sess.EnsureWorkspaceReady(); err != nil {
		http.Error(w, "workspace recovery failed: "+err.Error(), http.StatusInternalServerError)
		return
	}
	// Get terminal ID from query
	terminalID := r.URL.Query().Get("id")
	var ts *TerminalSession
	if terminalID != "" {
		// Connect to existing terminal
		ts = getTerminal(sessionID, terminalID)
		if ts == nil {
			http.Error(w, "terminal not found", http.StatusNotFound)
			return
		}
	} else {
		// Create a new terminal (backward compatible)
		ts, err = createTerminalSession(sessionID, user.ID, sess.WorkDir, "")
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
			return
		}
	}
	// Upgrade to WebSocket
	conn, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Printf("WebSocket upgrade failed: %v", err)
		return
	}
	defer conn.Close()
	var wg sync.WaitGroup
	wg.Add(2)
	// PTY -> WebSocket
	go func() {
		defer wg.Done()
		buf := make([]byte, 1024)
		for {
			n, err := ts.PTY.Read(buf)
			if err != nil {
				if err != io.EOF {
					log.Printf("PTY read error: %v", err)
				}
				break
			}
			if err := conn.WriteMessage(websocket.BinaryMessage, buf[:n]); err != nil {
				log.Printf("WebSocket write error: %v", err)
				break
			}
		}
		conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
	}()
	// WebSocket -> PTY
	go func() {
		defer wg.Done()
		for {
			msgType, data, err := conn.ReadMessage()
			if err != nil {
				if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
					log.Printf("WebSocket read error: %v", err)
				}
				break
			}
			// Check if this is a resize message (JSON)
			if msgType == websocket.TextMessage {
				var msg resizeMsg
				if err := json.Unmarshal(data, &msg); err == nil && msg.Type == "resize" {
					pty.Setsize(ts.PTY, &pty.Winsize{
						Rows: uint16(msg.Rows),
						Cols: uint16(msg.Cols),
					})
					continue
				}
			}
			// Regular input - write to PTY
			if _, err := ts.PTY.Write(data); err != nil {
				log.Printf("PTY write error: %v", err)
				break
			}
		}
	}()
	// Wait for WebSocket to close (don't kill PTY - it stays alive for reconnection)
	wg.Wait()
	// Sync workspace to DB when connection closes with retry logic
	var syncErr error
	for attempt := 1; attempt <= 3; attempt++ {
		syncErr = sess.SyncWorkspaceToDB()
		if syncErr == nil {
			break
		}
		if attempt < 3 {
			time.Sleep(time.Duration(attempt) * time.Second)
		}
	}
	if syncErr != nil {
		models.LogError("terminal:sync_on_close_failed", "attempts=3 err="+syncErr.Error())
	}
}
func (c *WSController) getOrCreateSessionID(w http.ResponseWriter, r *http.Request) string {
	cookie, err := r.Cookie("skycode_session")
	if err == nil && cookie.Value != "" {
		return cookie.Value
	}
	// Create new session ID
	sessionID := models.GenerateSessionID()
	http.SetCookie(w, &http.Cookie{
		Name:     "skycode_session",
		Value:    sessionID,
		Path:     "/",
		HttpOnly: true,
		Secure:   true,
		SameSite: http.SameSiteLaxMode,
		MaxAge:   wsSessionCookieMaxAge,
	})
	return sessionID
}
No comments yet.