diff --git a/server/src/auth/better-auth.ts b/server/src/auth/better-auth.ts index 6d8f1b0f..351d5926 100644 --- a/server/src/auth/better-auth.ts +++ b/server/src/auth/better-auth.ts @@ -1,4 +1,5 @@ import type { Request, RequestHandler } from "express"; +import type { IncomingHttpHeaders } from "node:http"; import { betterAuth } from "better-auth"; import { drizzleAdapter } from "better-auth/adapters/drizzle"; import { toNodeHandler } from "better-auth/node"; @@ -24,9 +25,9 @@ export type BetterAuthSessionResult = { type BetterAuthInstance = ReturnType; -function headersFromExpressRequest(req: Request): Headers { +function headersFromNodeHeaders(rawHeaders: IncomingHttpHeaders): Headers { const headers = new Headers(); - for (const [key, raw] of Object.entries(req.headers)) { + for (const [key, raw] of Object.entries(rawHeaders)) { if (!raw) continue; if (Array.isArray(raw)) { for (const value of raw) headers.append(key, value); @@ -37,6 +38,10 @@ function headersFromExpressRequest(req: Request): Headers { return headers; } +function headersFromExpressRequest(req: Request): Headers { + return headersFromNodeHeaders(req.headers); +} + export function createBetterAuthInstance(db: Db, config: Config): BetterAuthInstance { const baseUrl = config.authBaseUrlMode === "explicit" ? config.authPublicBaseUrl : undefined; const secret = process.env.BETTER_AUTH_SECRET ?? process.env.PAPERCLIP_AGENT_JWT_SECRET ?? "paperclip-dev-secret"; @@ -73,15 +78,15 @@ export function createBetterAuthHandler(auth: BetterAuthInstance): RequestHandle }; } -export async function resolveBetterAuthSession( +export async function resolveBetterAuthSessionFromHeaders( auth: BetterAuthInstance, - req: Request, + headers: Headers, ): Promise { const api = (auth as unknown as { api?: { getSession?: (input: unknown) => Promise } }).api; if (!api?.getSession) return null; const sessionValue = await api.getSession({ - headers: headersFromExpressRequest(req), + headers, }); if (!sessionValue || typeof sessionValue !== "object") return null; @@ -103,3 +108,10 @@ export async function resolveBetterAuthSession( if (!session || !user) return null; return { session, user }; } + +export async function resolveBetterAuthSession( + auth: BetterAuthInstance, + req: Request, +): Promise { + return resolveBetterAuthSessionFromHeaders(auth, headersFromExpressRequest(req)); +} diff --git a/server/src/index.ts b/server/src/index.ts index ca222d35..16a54e95 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -29,6 +29,7 @@ import { createBetterAuthHandler, createBetterAuthInstance, resolveBetterAuthSession, + resolveBetterAuthSessionFromHeaders, } from "./auth/better-auth.js"; type EmbeddedPostgresInstance = { @@ -324,6 +325,9 @@ let betterAuthHandler: ReturnType | undefined; let resolveSession: | ((req: ExpressRequest) => Promise>>) | undefined; +let resolveSessionFromHeaders: + | ((headers: Headers) => Promise>>) + | undefined; if (config.deploymentMode === "local_trusted") { await ensureLocalTrustedBoardPrincipal(db as any); } @@ -338,6 +342,7 @@ if (config.deploymentMode === "authenticated") { const auth = createBetterAuthInstance(db as any, config); betterAuthHandler = createBetterAuthHandler(auth); resolveSession = (req) => resolveBetterAuthSession(auth, req); + resolveSessionFromHeaders = (headers) => resolveBetterAuthSessionFromHeaders(auth, headers); await initializeBoardClaimChallenge(db as any, { deploymentMode: config.deploymentMode }); authReady = true; } @@ -362,7 +367,10 @@ if (listenPort !== config.port) { logger.warn({ requestedPort: config.port, selectedPort: listenPort }, "Requested port is busy; using next free port"); } -setupLiveEventsWebSocketServer(server, db as any, { deploymentMode: config.deploymentMode }); +setupLiveEventsWebSocketServer(server, db as any, { + deploymentMode: config.deploymentMode, + resolveSessionFromHeaders, +}); if (config.heartbeatSchedulerEnabled) { const heartbeat = heartbeatService(db as any); diff --git a/server/src/realtime/live-events-ws.ts b/server/src/realtime/live-events-ws.ts index dcc1d46e..442ac05a 100644 --- a/server/src/realtime/live-events-ws.ts +++ b/server/src/realtime/live-events-ws.ts @@ -3,9 +3,10 @@ import type { IncomingMessage, Server as HttpServer } from "node:http"; import type { Duplex } from "node:stream"; import { and, eq, isNull } from "drizzle-orm"; import type { Db } from "@paperclip/db"; -import { agentApiKeys } from "@paperclip/db"; +import { agentApiKeys, companyMemberships, instanceUserRoles } from "@paperclip/db"; import type { DeploymentMode } from "@paperclip/shared"; import { WebSocket, WebSocketServer } from "ws"; +import type { BetterAuthSessionResult } from "../auth/better-auth.js"; import { logger } from "../middleware/logger.js"; import { subscribeCompanyLiveEvents } from "../services/live-events.js"; @@ -48,26 +49,76 @@ function parseBearerToken(rawAuth: string | string[] | undefined) { return token.length > 0 ? token : null; } +function headersFromIncomingMessage(req: IncomingMessage): Headers { + const headers = new Headers(); + for (const [key, raw] of Object.entries(req.headers)) { + if (!raw) continue; + if (Array.isArray(raw)) { + for (const value of raw) headers.append(key, value); + continue; + } + headers.set(key, raw); + } + return headers; +} + async function authorizeUpgrade( db: Db, req: IncomingMessage, companyId: string, url: URL, - deploymentMode: DeploymentMode, + opts: { + deploymentMode: DeploymentMode; + resolveSessionFromHeaders?: (headers: Headers) => Promise; + }, ): Promise { const queryToken = url.searchParams.get("token")?.trim() ?? ""; const authToken = parseBearerToken(req.headers.authorization); const token = authToken ?? (queryToken.length > 0 ? queryToken : null); - // Local trusted browser board context has no bearer token in V1. + // Browser board context has no bearer token in local_trusted and authenticated modes. if (!token) { - if (deploymentMode !== "local_trusted") { + if (opts.deploymentMode === "local_trusted") { + return { + companyId, + actorType: "board", + actorId: "board", + }; + } + + if (opts.deploymentMode !== "authenticated" || !opts.resolveSessionFromHeaders) { return null; } + + const session = await opts.resolveSessionFromHeaders(headersFromIncomingMessage(req)); + const userId = session?.user?.id; + if (!userId) return null; + + const [roleRow, memberships] = await Promise.all([ + db + .select({ id: instanceUserRoles.id }) + .from(instanceUserRoles) + .where(and(eq(instanceUserRoles.userId, userId), eq(instanceUserRoles.role, "instance_admin"))) + .then((rows) => rows[0] ?? null), + db + .select({ companyId: companyMemberships.companyId }) + .from(companyMemberships) + .where( + and( + eq(companyMemberships.principalType, "user"), + eq(companyMemberships.principalId, userId), + eq(companyMemberships.status, "active"), + ), + ), + ]); + + const hasCompanyMembership = memberships.some((row) => row.companyId === companyId); + if (!roleRow && !hasCompanyMembership) return null; + return { companyId, actorType: "board", - actorId: "board", + actorId: userId, }; } @@ -97,7 +148,10 @@ async function authorizeUpgrade( export function setupLiveEventsWebSocketServer( server: HttpServer, db: Db, - opts: { deploymentMode: DeploymentMode }, + opts: { + deploymentMode: DeploymentMode; + resolveSessionFromHeaders?: (headers: Headers) => Promise; + }, ) { const wss = new WebSocketServer({ noServer: true }); const cleanupByClient = new Map void>(); @@ -162,7 +216,10 @@ export function setupLiveEventsWebSocketServer( return; } - void authorizeUpgrade(db, req, companyId, url, opts.deploymentMode) + void authorizeUpgrade(db, req, companyId, url, { + deploymentMode: opts.deploymentMode, + resolveSessionFromHeaders: opts.resolveSessionFromHeaders, + }) .then((context) => { if (!context) { rejectUpgrade(socket, "403 Forbidden", "forbidden");