diff --git a/server/src/__tests__/costs-service.test.ts b/server/src/__tests__/costs-service.test.ts index f642e566..517ada52 100644 --- a/server/src/__tests__/costs-service.test.ts +++ b/server/src/__tests__/costs-service.test.ts @@ -32,6 +32,7 @@ function makeDb(overrides: Record = {}) { const mockCompanyService = vi.hoisted(() => ({ getById: vi.fn(), + update: vi.fn(), })); const mockAgentService = vi.hoisted(() => ({ getById: vi.fn(), @@ -98,8 +99,34 @@ function createApp() { return app; } +function createAppWithActor(actor: any) { + const app = express(); + app.use(express.json()); + app.use((req, _res, next) => { + req.actor = actor; + next(); + }); + app.use("/api", costRoutes(makeDb() as any)); + app.use(errorHandler); + return app; +} + beforeEach(() => { vi.clearAllMocks(); + mockCompanyService.update.mockResolvedValue({ + id: "company-1", + name: "Paperclip", + budgetMonthlyCents: 100, + spentMonthlyCents: 0, + }); + mockAgentService.update.mockResolvedValue({ + id: "agent-1", + companyId: "company-1", + name: "Budget Agent", + budgetMonthlyCents: 100, + spentMonthlyCents: 0, + }); + mockBudgetService.upsertPolicy.mockResolvedValue(undefined); }); describe("cost routes", () => { @@ -155,4 +182,45 @@ describe("cost routes", () => { expect(res.status).toBe(200); expect(mockFinanceService.list).toHaveBeenCalledWith("company-1", undefined, 25); }); + + it("rejects company budget updates for board users outside the company", async () => { + const app = createAppWithActor({ + type: "board", + userId: "board-user", + source: "session", + isInstanceAdmin: false, + companyIds: ["company-2"], + }); + + const res = await request(app) + .patch("/api/companies/company-1/budgets") + .send({ budgetMonthlyCents: 2500 }); + + expect(res.status).toBe(403); + expect(mockCompanyService.update).not.toHaveBeenCalled(); + }); + + it("rejects agent budget updates for board users outside the agent company", async () => { + mockAgentService.getById.mockResolvedValue({ + id: "agent-1", + companyId: "company-1", + name: "Budget Agent", + budgetMonthlyCents: 100, + spentMonthlyCents: 0, + }); + const app = createAppWithActor({ + type: "board", + userId: "board-user", + source: "session", + isInstanceAdmin: false, + companyIds: ["company-2"], + }); + + const res = await request(app) + .patch("/api/agents/agent-1/budgets") + .send({ budgetMonthlyCents: 2500 }); + + expect(res.status).toBe(403); + expect(mockAgentService.update).not.toHaveBeenCalled(); + }); }); diff --git a/server/src/__tests__/monthly-spend-service.test.ts b/server/src/__tests__/monthly-spend-service.test.ts new file mode 100644 index 00000000..97b213af --- /dev/null +++ b/server/src/__tests__/monthly-spend-service.test.ts @@ -0,0 +1,90 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { companyService } from "../services/companies.ts"; +import { agentService } from "../services/agents.ts"; + +function createSelectSequenceDb(results: unknown[]) { + const pending = [...results]; + const chain = { + from: vi.fn(() => chain), + where: vi.fn(() => chain), + leftJoin: vi.fn(() => chain), + groupBy: vi.fn(() => chain), + then: vi.fn((resolve: (value: unknown[]) => unknown) => Promise.resolve(resolve(pending.shift() ?? []))), + }; + + return { + db: { + select: vi.fn(() => chain), + }, + }; +} + +describe("monthly spend hydration", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("recomputes company spentMonthlyCents from the current utc month instead of returning stale stored values", async () => { + const dbStub = createSelectSequenceDb([ + [{ + id: "company-1", + name: "Paperclip", + description: null, + status: "active", + issuePrefix: "PAP", + issueCounter: 1, + budgetMonthlyCents: 5000, + spentMonthlyCents: 999999, + requireBoardApprovalForNewAgents: false, + brandColor: null, + logoAssetId: null, + createdAt: new Date(), + updatedAt: new Date(), + }], + [{ + companyId: "company-1", + spentMonthlyCents: 420, + }], + ]); + + const companies = companyService(dbStub.db as any); + const [company] = await companies.list(); + + expect(company.spentMonthlyCents).toBe(420); + }); + + it("recomputes agent spentMonthlyCents from the current utc month instead of returning stale stored values", async () => { + const dbStub = createSelectSequenceDb([ + [{ + id: "agent-1", + companyId: "company-1", + name: "Budget Agent", + role: "general", + title: null, + reportsTo: null, + capabilities: null, + adapterType: "claude-local", + adapterConfig: {}, + runtimeConfig: {}, + budgetMonthlyCents: 5000, + spentMonthlyCents: 999999, + metadata: null, + permissions: null, + status: "idle", + pauseReason: null, + pausedAt: null, + createdAt: new Date(), + updatedAt: new Date(), + }], + [{ + agentId: "agent-1", + spentMonthlyCents: 175, + }], + ]); + + const agents = agentService(dbStub.db as any); + const agent = await agents.getById("agent-1"); + + expect(agent?.spentMonthlyCents).toBe(175); + }); +}); diff --git a/server/src/routes/costs.ts b/server/src/routes/costs.ts index 59374884..82925bd7 100644 --- a/server/src/routes/costs.ts +++ b/server/src/routes/costs.ts @@ -250,6 +250,7 @@ export function costRoutes(db: Db) { router.patch("/companies/:companyId/budgets", validate(updateBudgetSchema), async (req, res) => { assertBoard(req); const companyId = req.params.companyId as string; + assertCompanyAccess(req, companyId); const company = await companies.update(companyId, { budgetMonthlyCents: req.body.budgetMonthlyCents }); if (!company) { res.status(404).json({ error: "Company not found" }); @@ -288,6 +289,8 @@ export function costRoutes(db: Db) { return; } + assertCompanyAccess(req, agent.companyId); + if (req.actor.type === "agent") { if (req.actor.agentId !== agentId) { res.status(403).json({ error: "Agent can only change its own budget" }); diff --git a/server/src/services/agents.ts b/server/src/services/agents.ts index 4daa1dd9..17d2e46d 100644 --- a/server/src/services/agents.ts +++ b/server/src/services/agents.ts @@ -1,5 +1,5 @@ import { createHash, randomBytes } from "node:crypto"; -import { and, desc, eq, inArray, ne } from "drizzle-orm"; +import { and, desc, eq, gte, inArray, lt, ne, sql } from "drizzle-orm"; import type { Db } from "@paperclipai/db"; import { agents, @@ -8,6 +8,7 @@ import { agentRuntimeState, agentTaskSessions, agentWakeupRequests, + costEvents, heartbeatRunEvents, heartbeatRuns, } from "@paperclipai/db"; @@ -182,6 +183,15 @@ export function deduplicateAgentName( } export function agentService(db: Db) { + function currentUtcMonthWindow(now = new Date()) { + const year = now.getUTCFullYear(); + const month = now.getUTCMonth(); + return { + start: new Date(Date.UTC(year, month, 1, 0, 0, 0, 0)), + end: new Date(Date.UTC(year, month + 1, 1, 0, 0, 0, 0)), + }; + } + function withUrlKey(row: T) { return { ...row, @@ -196,13 +206,47 @@ export function agentService(db: Db) { }); } + async function getMonthlySpendByAgentIds(companyId: string, agentIds: string[]) { + if (agentIds.length === 0) return new Map(); + const { start, end } = currentUtcMonthWindow(); + const rows = await db + .select({ + agentId: costEvents.agentId, + spentMonthlyCents: sql`coalesce(sum(${costEvents.costCents}), 0)::int`, + }) + .from(costEvents) + .where( + and( + eq(costEvents.companyId, companyId), + inArray(costEvents.agentId, agentIds), + gte(costEvents.occurredAt, start), + lt(costEvents.occurredAt, end), + ), + ) + .groupBy(costEvents.agentId); + return new Map(rows.map((row) => [row.agentId, Number(row.spentMonthlyCents ?? 0)])); + } + + async function hydrateAgentSpend(rows: T[]) { + const agentIds = rows.map((row) => row.id); + const companyId = rows[0]?.companyId; + if (!companyId || agentIds.length === 0) return rows; + const spendByAgentId = await getMonthlySpendByAgentIds(companyId, agentIds); + return rows.map((row) => ({ + ...row, + spentMonthlyCents: spendByAgentId.get(row.id) ?? 0, + })); + } + async function getById(id: string) { const row = await db .select() .from(agents) .where(eq(agents.id, id)) .then((rows) => rows[0] ?? null); - return row ? normalizeAgentRow(row) : null; + if (!row) return null; + const [hydrated] = await hydrateAgentSpend([row]); + return normalizeAgentRow(hydrated); } async function ensureManager(companyId: string, managerId: string) { @@ -331,7 +375,8 @@ export function agentService(db: Db) { conditions.push(ne(agents.status, "terminated")); } const rows = await db.select().from(agents).where(and(...conditions)); - return rows.map(normalizeAgentRow); + const hydrated = await hydrateAgentSpend(rows); + return hydrated.map(normalizeAgentRow); }, getById, diff --git a/server/src/services/budgets.ts b/server/src/services/budgets.ts index bc09673e..577635a3 100644 --- a/server/src/services/budgets.ts +++ b/server/src/services/budgets.ts @@ -1,4 +1,4 @@ -import { and, desc, eq, gte, inArray, lt, sql } from "drizzle-orm"; +import { and, desc, eq, gte, inArray, lt, ne, sql } from "drizzle-orm"; import type { Db } from "@paperclipai/db"; import { agents, @@ -360,6 +360,7 @@ export function budgetService(db: Db, hooks: BudgetServiceHooks = {}) { eq(budgetIncidents.policyId, policy.id), eq(budgetIncidents.windowStart, start), eq(budgetIncidents.thresholdType, thresholdType), + ne(budgetIncidents.status, "dismissed"), ), ) .then((rows) => rows[0] ?? null); diff --git a/server/src/services/companies.ts b/server/src/services/companies.ts index 7fafb093..893bea9e 100644 --- a/server/src/services/companies.ts +++ b/server/src/services/companies.ts @@ -1,4 +1,4 @@ -import { eq, count } from "drizzle-orm"; +import { and, count, eq, gte, inArray, lt, sql } from "drizzle-orm"; import type { Db } from "@paperclipai/db"; import { companies, @@ -54,6 +54,49 @@ export function companyService(db: Db) { }; } + function currentUtcMonthWindow(now = new Date()) { + const year = now.getUTCFullYear(); + const month = now.getUTCMonth(); + return { + start: new Date(Date.UTC(year, month, 1, 0, 0, 0, 0)), + end: new Date(Date.UTC(year, month + 1, 1, 0, 0, 0, 0)), + }; + } + + async function getMonthlySpendByCompanyIds( + companyIds: string[], + database: Pick = db, + ) { + if (companyIds.length === 0) return new Map(); + const { start, end } = currentUtcMonthWindow(); + const rows = await database + .select({ + companyId: costEvents.companyId, + spentMonthlyCents: sql`coalesce(sum(${costEvents.costCents}), 0)::int`, + }) + .from(costEvents) + .where( + and( + inArray(costEvents.companyId, companyIds), + gte(costEvents.occurredAt, start), + lt(costEvents.occurredAt, end), + ), + ) + .groupBy(costEvents.companyId); + return new Map(rows.map((row) => [row.companyId, Number(row.spentMonthlyCents ?? 0)])); + } + + async function hydrateCompanySpend( + rows: T[], + database: Pick = db, + ) { + const spendByCompanyId = await getMonthlySpendByCompanyIds(rows.map((row) => row.id), database); + return rows.map((row) => ({ + ...row, + spentMonthlyCents: spendByCompanyId.get(row.id) ?? 0, + })); + } + function getCompanyQuery(database: Pick) { return database .select(companySelection) @@ -104,13 +147,20 @@ export function companyService(db: Db) { } return { - list: () => - getCompanyQuery(db).then((rows) => rows.map((row) => enrichCompany(row))), + list: async () => { + const rows = await getCompanyQuery(db); + const hydrated = await hydrateCompanySpend(rows); + return hydrated.map((row) => enrichCompany(row)); + }, - getById: (id: string) => - getCompanyQuery(db) + getById: async (id: string) => { + const row = await getCompanyQuery(db) .where(eq(companies.id, id)) - .then((rows) => (rows[0] ? enrichCompany(rows[0]) : null)), + .then((rows) => rows[0] ?? null); + if (!row) return null; + const [hydrated] = await hydrateCompanySpend([row], db); + return enrichCompany(hydrated); + }, create: async (data: typeof companies.$inferInsert) => { const created = await createCompanyWithUniquePrefix(data); @@ -118,7 +168,8 @@ export function companyService(db: Db) { .where(eq(companies.id, created.id)) .then((rows) => rows[0] ?? null); if (!row) throw notFound("Company not found after creation"); - return enrichCompany(row); + const [hydrated] = await hydrateCompanySpend([row], db); + return enrichCompany(hydrated); }, update: ( @@ -175,10 +226,12 @@ export function companyService(db: Db) { await tx.delete(assets).where(eq(assets.id, existing.logoAssetId)); } - return enrichCompany({ + const [hydrated] = await hydrateCompanySpend([{ ...updated, logoAssetId: logoAssetId === undefined ? existing.logoAssetId : logoAssetId, - }); + }], tx); + + return enrichCompany(hydrated); }), archive: (id: string) => @@ -193,7 +246,9 @@ export function companyService(db: Db) { const row = await getCompanyQuery(tx) .where(eq(companies.id, id)) .then((rows) => rows[0] ?? null); - return row ? enrichCompany(row) : null; + if (!row) return null; + const [hydrated] = await hydrateCompanySpend([row], tx); + return enrichCompany(hydrated); }), remove: (id: string) => diff --git a/server/src/services/costs.ts b/server/src/services/costs.ts index aa80e3a8..76a90f2d 100644 --- a/server/src/services/costs.ts +++ b/server/src/services/costs.ts @@ -1,4 +1,4 @@ -import { and, desc, eq, gte, isNotNull, lte, sql } from "drizzle-orm"; +import { and, desc, eq, gte, isNotNull, lt, lte, sql } from "drizzle-orm"; import type { Db } from "@paperclipai/db"; import { activityLog, agents, companies, costEvents, issues, projects } from "@paperclipai/db"; import { notFound, unprocessable } from "../errors.js"; @@ -12,6 +12,37 @@ export interface CostDateRange { const METERED_BILLING_TYPE = "metered_api"; const SUBSCRIPTION_BILLING_TYPES = ["subscription_included", "subscription_overage"] as const; +function currentUtcMonthWindow(now = new Date()) { + const year = now.getUTCFullYear(); + const month = now.getUTCMonth(); + return { + start: new Date(Date.UTC(year, month, 1, 0, 0, 0, 0)), + end: new Date(Date.UTC(year, month + 1, 1, 0, 0, 0, 0)), + }; +} + +async function getMonthlySpendTotal( + db: Db, + scope: { companyId: string; agentId?: string | null }, +) { + const { start, end } = currentUtcMonthWindow(); + const conditions = [ + eq(costEvents.companyId, scope.companyId), + gte(costEvents.occurredAt, start), + lt(costEvents.occurredAt, end), + ]; + if (scope.agentId) { + conditions.push(eq(costEvents.agentId, scope.agentId)); + } + const [row] = await db + .select({ + total: sql`coalesce(sum(${costEvents.costCents}), 0)::int`, + }) + .from(costEvents) + .where(and(...conditions)); + return Number(row?.total ?? 0); +} + export function costService(db: Db, budgetHooks: BudgetServiceHooks = {}) { const budgets = budgetService(db, budgetHooks); return { @@ -39,10 +70,15 @@ export function costService(db: Db, budgetHooks: BudgetServiceHooks = {}) { .returning() .then((rows) => rows[0]); + const [agentMonthSpend, companyMonthSpend] = await Promise.all([ + getMonthlySpendTotal(db, { companyId, agentId: event.agentId }), + getMonthlySpendTotal(db, { companyId }), + ]); + await db .update(agents) .set({ - spentMonthlyCents: sql`${agents.spentMonthlyCents} + ${event.costCents}`, + spentMonthlyCents: agentMonthSpend, updatedAt: new Date(), }) .where(eq(agents.id, event.agentId)); @@ -50,7 +86,7 @@ export function costService(db: Db, budgetHooks: BudgetServiceHooks = {}) { await db .update(companies) .set({ - spentMonthlyCents: sql`${companies.spentMonthlyCents} + ${event.costCents}`, + spentMonthlyCents: companyMonthSpend, updatedAt: new Date(), }) .where(eq(companies.id, companyId));