Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c21d55b
Benchmark the mcq medeical dataset
mina5rovic Apr 16, 2026
3014cf5
lint error correction
mina5rovic Apr 17, 2026
e785cf6
add val dataset path param
mina5rovic Apr 19, 2026
c93631c
add working local train
mina5rovic Apr 20, 2026
2e34a37
add model saving to disk arg and more debug lines
mina5rovic Apr 22, 2026
4c9c84c
add debug commands
mina5rovic Apr 24, 2026
23c733d
add cnahges to federated approach
mina5rovic Apr 24, 2026
096393c
change round 0 payload null handling
mina5rovic Apr 24, 2026
1ef7d85
chnage server max payload limit to higher number
mina5rovic Apr 26, 2026
417bfa5
fix memory reads and wait for all clients to begin
mina5rovic Apr 26, 2026
77ce9a9
add server debug logs to see why ws close session
mina5rovic Apr 26, 2026
7aef628
cover whole dataset and split data to clients
mina5rovic Apr 27, 2026
d65fb4a
add validation dataset loading changes
May 4, 2026
fe7c51a
change gpt config to use whole dadataset
mina5rovic May 4, 2026
0dc32aa
add arg for model saving location, cnahge save to saveLog, change lin…
mina5rovic May 5, 2026
5c03f40
add training optimizations
mina5rovic May 5, 2026
4ee6096
change onnx converter to be able to convert different context len
mina5rovic May 6, 2026
5b5f88c
aggregate inside of an epoch for llms
mina5rovic May 11, 2026
8d409b9
fix mem leak
mina5rovic May 11, 2026
144b751
change ligs
mina5rovic May 11, 2026
c046e9e
change model to 256
mina5rovic May 12, 2026
e89b9e0
back to 512
mina5rovic May 12, 2026
fbcca59
fix end of training
mina5rovic May 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
"train_gpt": "npm run build && node dist/train_gpt.js",
"hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js",
"eval_finetuned_gpt2": "npm run build && node dist/evaluate_finetuned_gpt2.js",
"measure_memorization_gpt2": "npm run build && node dist/measure_memorization_gpt2.js",
"finetune_gpt": "npm run build && node dist/finetune_gpt.js",
"build": "tsc --build",
"test": ": nothing"
},
Expand Down
22 changes: 18 additions & 4 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ export interface BenchmarkArguments {
numberOfUsers: number
epochs: number
roundDuration: number
roundIterations?: number
batchSize: number
validationSplit: number
datasetPath?: string
validationDatasetPath?: string
outputPath?: string

// DP
epsilon?: number
Expand All @@ -35,12 +39,15 @@ export interface BenchmarkArguments {
// Secure aggregator
maxShareValue?: number

save: boolean
saveLogs: boolean
saveModel: boolean
host: URL
}

type BenchmarkUnsafeArguments = Omit<BenchmarkArguments, 'provider'> & {
task: string
datasetPath?: string
validationDatasetPath?: string
help?: boolean
}

Expand All @@ -53,9 +60,14 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
roundIterations: { type: Number, description: 'For GPT text tasks, aggregate every N training batches without rewinding the dataset', optional: true },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true },
validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true },
outputPath: { type: String, alias: 'o', description: 'Path to save logs and models. Defaults to ./<testID>', optional: true },
saveLogs: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
saveModel: { type: Boolean, alias: 'm', description: 'Save trained model to disk', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
typeLabel: "URL",
Expand Down Expand Up @@ -89,18 +101,19 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(

const supportedTasks = Map(
await Promise.all(
Set.of<TaskProvider<"image" | "tabular", Network>>(
Set.of<TaskProvider<"image" | "tabular" | "text", Network>>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
defaultTasks.mnist,
defaultTasks.privacyrun,
).map(
async (t) =>
[(await t.getTask()).id, t] as [
string,
TaskProvider<"image" | "tabular", Network>,
TaskProvider<"image" | "tabular" | "text", Network>,
],
),
),
Expand All @@ -122,6 +135,7 @@ export const args: BenchmarkArguments = {
task.trainingInformation.roundDuration = unsafeArgs.roundDuration;
task.trainingInformation.epochs = unsafeArgs.epochs;
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;
(task.trainingInformation as typeof task.trainingInformation & { roundIterations?: number }).roundIterations = unsafeArgs.roundIterations;

const {aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs;

Expand Down
64 changes: 51 additions & 13 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import fs from 'node:fs/promises'
import { createWriteStream } from "node:fs";
import path from "node:path";

import createDebug from "debug";
import type {
Dataset,
DataFormat,
Expand All @@ -17,50 +17,80 @@
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'

import { saveModelToDisk } from "@epfml/discojs-node";
import { getTaskData } from './data.js'
import { args } from './args.js'
import { makeUserLogFile } from "./user_log.js";
import type { UserLogFile } from "./user_log.js";

const debug = createDebug("cli:main");

function getOutputDir(): string {
return args.outputPath ?? path.join(".", `${args.testID}`);
}

async function runUser<D extends DataType, N extends Network>(
task: Task<D, N>,
provider: TaskProvider<D, N>,

Check failure on line 34 in cli/src/cli.ts

View workflow job for this annotation

GitHub Actions / lint-most

'provider' is defined but never used. Allowed unused args must match /^_/u
url: URL,
data: Dataset<DataFormat.Raw[D]>,
validationData: Dataset<DataFormat.Raw[D]> | undefined,
userIndex: number,
numberOfUsers: number,
): Promise<List<SummaryLogs>> {
// cast as typescript isn't good with generics
debug(`Starting runUser for client ${userIndex}`);
const userStart = Date.now();

Check failure on line 42 in cli/src/cli.ts

View workflow job for this annotation

GitHub Actions / lint-most

'userStart' is assigned a value but never used. Allowed unused vars must match /^_/u
const trainingScheme = task.trainingInformation.scheme as N
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, client, { scheme: trainingScheme });

const dir = path.join(".", `${args.testID}`);
const disco = new Disco(task, client, { scheme: trainingScheme, preprocessOnce: false });

// For local training, load model from provider before training starts
// if (trainingScheme === "local") {
// debug(`Loading model for training client ${userIndex}...`);
// const modelStart = Date.now();
// console.log("Loading model for local training...");
// disco.trainer.model = await provider.getModel();
// console.log("Model loaded successfully");
// debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`);
// }



const dir = getOutputDir();
await fs.mkdir(dir, { recursive: true });
const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`);

const finalLog: SummaryLogs[] = [];
// create a write stream that saves learning logs during the train
let jsonStream: ReturnType<typeof createWriteStream> | null = null;

if (args.save){
if (args.saveLogs){
jsonStream = createWriteStream(streamPath, {flags: "w"});
}

try{
for await (const log of disco.trainSummary(data)){
debug(`Starting training for client ${userIndex}`);
const trainStart = Date.now();
for await (const log of disco.trainSummary(data, validationData)){
finalLog.push(log);

if (jsonStream){
jsonStream.write(JSON.stringify(log) + "\n");
}
}
debug(`Training took ${Date.now() - trainStart}ms for client ${userIndex}`);

await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish

// Save the trained model if requested
if (args.saveModel) {
const modelDir = path.join(getOutputDir(), "models");
const modelFileName = `client${userIndex}_model.json`;
await saveModelToDisk(disco.trainer.model, modelDir, modelFileName);
console.log(`Model saved for client ${userIndex} at ${modelDir}/${modelFileName}`);
}
// saving the entire per-user logs
if (args.save) {
if (args.saveLogs) {
const finalPath = path.join(dir, `client${userIndex}_local_log.json`);

const userLog: UserLogFile = makeUserLogFile(task, numberOfUsers, userIndex, client.ownId, finalLog);
Expand Down Expand Up @@ -104,14 +134,22 @@
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers, args.datasetPath))
)

let validationData: Dataset<DataFormat.Raw[D]> | undefined = undefined;
if (args.validationDatasetPath) {
validationData = (
await getTaskData(task.id, 0, 1, args.validationDatasetPath, true, args.validationDatasetPath)
).cached() as Dataset<DataFormat.Raw[D]>;
}

const logs = await Promise.all(
dataSplits.map((data, i) => runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>, i, numberOfUsers))
dataSplits.map((data, i) => runUser(task, provider, args.host, data as Dataset<DataFormat.Raw[D]>, validationData, i, numberOfUsers))
)

if (args.save) {
const dir = path.join(".", `${args.testID}`, `${task.id}`);
if (args.saveLogs) {
const dir = path.join(getOutputDir(), `${task.id}`);
await fs.mkdir(dir, { recursive: true });

const filePath = path.join(dir, `${task.id}_${numberOfUsers}users.json`);
Expand Down
72 changes: 71 additions & 1 deletion cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,64 @@
import path from "node:path";
import { createReadStream } from "node:fs";
import { Dataset, processing } from "@epfml/discojs";
import {
DataFormat,
DataType,
Image,
Task,
Text,
} from "@epfml/discojs";
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

function loadTextSamples(
filePath: string,
userIdx?: number,
totalClient?: number,
): Dataset<Text> {
return new Dataset(async function* () {
const stream = createReadStream(filePath, { encoding: "utf8" });
const sampleDelimiter = "<|endoftext|>";
let buffer = "";
let sampleIndex = 0;

for await (const chunk of stream) {
if (typeof chunk !== "string") {
throw new Error("Expected file stream to yield string");
}

buffer += chunk;

let delimiterIndex = buffer.indexOf(sampleDelimiter);
while (delimiterIndex !== -1) {
const sample = buffer.slice(0, delimiterIndex + sampleDelimiter.length).trim();
const shouldYield =
userIdx === undefined ||
totalClient === undefined ||
sampleIndex % totalClient === userIdx;

if (sample !== "" && shouldYield) {
yield sample;
}

sampleIndex++;
buffer = buffer.slice(delimiterIndex + sampleDelimiter.length);
delimiterIndex = buffer.indexOf(sampleDelimiter);
}
}

const trailingSample = buffer.trim();
const shouldYieldTrailing =
userIdx === undefined ||
totalClient === undefined ||
sampleIndex % totalClient === userIdx;

if (trailingSample !== "" && shouldYieldTrailing) {
yield trailingSample;
}
});
}

async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "simple_face");

Expand Down Expand Up @@ -94,7 +144,10 @@ function loadData(dataName: string, split: number): Dataset<DataFormat.Raw["imag
export async function getTaskData<D extends DataType>(
taskID: Task.ID,
userIdx: number,
totalClient: number
totalClient: number,
datasetPath?: string,
isValidation?: boolean,
validationDatasetPath?: string
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (taskID) {
case "simple_face": // remove
Expand All @@ -118,6 +171,23 @@ export async function getTaskData<D extends DataType>(
case "mnist_federated":
case "mnist":
return loadData("mnist", userIdx) as Dataset<DataFormat.Raw[D]>;
case "privacyrun": {
const filePath =
isValidation && validationDatasetPath
? validationDatasetPath
: datasetPath ?? "../datasets/med_mcq/train.txt";

// Keep validation shared, but shard training data across clients by MCQ sample.
if (isValidation) {
return loadTextSamples(filePath) as Dataset<DataFormat.Raw[D]>;
}

return loadTextSamples(
filePath,
userIdx,
totalClient,
) as Dataset<DataFormat.Raw[D]>;
}
default:
throw new Error(`Data loader for ${taskID} not implemented.`);
}
Expand Down
Loading
Loading