From c21d55bd1b202d47df5713532642323ccb946fcb Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Thu, 16 Apr 2026 13:49:48 +0200 Subject: [PATCH 01/23] Benchmark the mcq medeical dataset --- cli/package.json | 2 + cli/src/evaluate_finetuned_gpt2.ts | 243 ++++++++++++++++++++++++ discojs/src/default_tasks/index.ts | 3 +- discojs/src/default_tasks/privacyrun.ts | 66 +++++++ 4 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 cli/src/evaluate_finetuned_gpt2.ts create mode 100644 discojs/src/default_tasks/privacyrun.ts diff --git a/cli/package.json b/cli/package.json index cc0f741e2..8ed779b3f 100644 --- a/cli/package.json +++ b/cli/package.json @@ -9,6 +9,8 @@ "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", "train_gpt": "npm run build && node dist/train_gpt.js", "hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js", + "eval_finetuned_gpt2": "npm run build && node dist/evaluate_finetuned_gpt2.js", + "finetune_gpt": "npm run build && node dist/finetune_gpt.js", "build": "tsc --build", "test": ": nothing" }, diff --git a/cli/src/evaluate_finetuned_gpt2.ts b/cli/src/evaluate_finetuned_gpt2.ts new file mode 100644 index 000000000..6f031cd04 --- /dev/null +++ b/cli/src/evaluate_finetuned_gpt2.ts @@ -0,0 +1,243 @@ +import "@tensorflow/tfjs-node"; +import * as tf from "@tensorflow/tfjs"; +import fs from "node:fs/promises"; +import { parse } from "ts-command-line-args"; +import { models, Tokenizer } from "@epfml/discojs"; +import { loadModelFromDisk } from "@epfml/discojs-node"; + +interface Args { + modelPath: string; + testPath: string; + maxSamples?: number; + savePath?: string; + help?: boolean; +} + +// ========================= +// HOW TO RUN +// ========================= +// npm -w cli run eval_finetuned_gpt2 -- --modelPath absolute_path_to_model/model.json --testPath absolute_path_to_test_data/train_no_exp.txt --maxSamples 100 + +// ========================= +// LOAD DATASET +// ========================= +async function loadDataset(filePath: string, limit = -1): Promise { + const text = await fs.readFile(filePath, "utf-8"); + const lines = text.split("\n"); + + const samples: string[] = []; + let current = ""; + + for (const line of lines) { + const l = line.trim(); + + if (l.includes("<|startoftext|>")) { + current = ""; + } else if (l.includes("<|endoftext|>")) { + samples.push(current.trim()); + if (limit !== -1 && samples.length >= limit) break; + } else { + current += l + "\n"; + } + } + + return samples; +} + +// ========================= +// PARSE SAMPLE +// ========================= +function parseSample(sample: string) { + const lines = sample.split("\n"); + + let answer = ""; + const promptLines: string[] = []; + + for (const line of lines) { + if (line.startsWith("Answer:")) { + answer = line.replace("Answer:", "").trim(); + } else { + promptLines.push(line); + } + } + + const basePrompt = promptLines.join("\n"); + return { basePrompt, answer }; +} + +// ========================= +// SOFTMAX (for safety) +// ========================= +async function scoreText( + tfModel: tf.LayersModel, + tokenizer: Tokenizer, + text: string +): Promise { + const tokens = tokenizer.tokenize(text); + + if (tokens.size < 2) return -Infinity; + + const inputTokens = tokens.slice(0, tokens.size - 1).toArray(); + const targets = tokens.slice(1).toArray(); + + const inputTensor = tf.tensor([inputTokens], [1, inputTokens.length], "int32"); + + const logits = tfModel.predict(inputTensor) as tf.Tensor; + const logitsArray = await logits.array() as number[][][]; + + let score = 0; + + for (let i = 0; i < targets.length; i++) { + const stepLogits = logitsArray[0][i]; + + const logit = stepLogits[targets[i]] ?? -100; + + score += logit; + } + + inputTensor.dispose(); + logits.dispose(); + + return score; +} + +// ========================= +// SCORE OPTIONS +// ========================= +async function scoreOptions( + tfModel: tf.LayersModel, + tokenizer: Tokenizer, + texts: string[] +): Promise { + const scores: number[] = []; + + for (const t of texts) { + const s = await scoreText(tfModel, tokenizer, t); + scores.push(s); + } + + return scores; +} + +// ========================= +// BENCHMARK +// ========================= +async function benchmarkQA( + model: models.GPT, + tokenizer: Tokenizer, + dataset: string[], + savePath?: string +) { + console.log("=== QA LOGPROB BENCHMARK ==="); + + const tfModel = model.extract(); + + let correct = 0; + let total = 0; + + const options = ["A", "B", "C", "D"]; + + const confusion: Record> = { + A: { A: 0, B: 0, C: 0, D: 0 }, + B: { A: 0, B: 0, C: 0, D: 0 }, + C: { A: 0, B: 0, C: 0, D: 0 }, + D: { A: 0, B: 0, C: 0, D: 0 } + }; + + const logs: any[] = []; + + const start = Date.now(); + + for (const sample of dataset) { + const { basePrompt, answer } = parseSample(sample); + + const texts = options.map( + (opt) => `${basePrompt}\nAnswer: ${opt}` + ); + + const scores = await scoreOptions(tfModel, tokenizer, texts); + + let bestIdx = 0; + for (let i = 1; i < scores.length; i++) { + if (scores[i] > scores[bestIdx]) bestIdx = i; + } + + const predicted = options[bestIdx]; + + if (predicted === answer) correct++; + total++; + + if (confusion[answer]) { + confusion[answer][predicted]++; + } + + logs.push({ + predicted, + answer, + correct: predicted === answer + }); + + if (total % 50 === 0) { + console.log(`Processed ${total} samples...`); + } + } + + const accuracy = correct / total; + const duration = ((Date.now() - start) / 1000).toFixed(2); + + console.log("\n========================="); + console.log(`Accuracy: ${(accuracy * 100).toFixed(2)}%`); + console.log(`Time: ${duration}s`); + console.log("=========================\n"); + + console.log("Confusion Matrix:"); + console.table(confusion); + + console.log("\nPer-class accuracy:"); + for (const cls of options) { + const totalCls = Object.values(confusion[cls]).reduce((a, b) => a + b, 0); + const correctCls = confusion[cls][cls]; + const acc = totalCls ? (correctCls / totalCls) * 100 : 0; + + console.log(`${cls}: ${acc.toFixed(2)}%`); + } + + if (savePath) { + await fs.writeFile(savePath, JSON.stringify(logs, null, 2)); + console.log(`Saved results to ${savePath}`); + } +} + +// ========================= +// MAIN +// ========================= +async function main() { + const args = parse({ + modelPath: { type: String }, + testPath: { type: String }, + maxSamples: { type: Number, optional: true, defaultValue: 100 }, + savePath: { type: String, optional: true }, + help: { type: Boolean, optional: true } + }); + + console.log("Loading tokenizer..."); + const tokenizer = await Tokenizer.from_pretrained("Xenova/gpt2"); + + console.log("Loading model..."); + const model = await loadModelFromDisk(args.modelPath); + + if (!(model instanceof models.GPT)) { + throw new Error("Model must be GPT"); + } + + console.log("Loading dataset..."); + const dataset = await loadDataset(args.testPath, args.maxSamples); + + console.log(`Loaded ${dataset.length} samples`); + + await benchmarkQA(model, tokenizer, dataset, args.savePath); + + console.log("Done."); +} + +main().catch(console.error); \ No newline at end of file diff --git a/discojs/src/default_tasks/index.ts b/discojs/src/default_tasks/index.ts index 43adf0d3c..f46f27026 100644 --- a/discojs/src/default_tasks/index.ts +++ b/discojs/src/default_tasks/index.ts @@ -4,4 +4,5 @@ export { mnist } from './mnist.js' export { simpleFace } from './simple_face.js' export { titanic } from './titanic.js' export { wikitext } from './wikitext.js' -export { tinderDog } from './tinder_dog.js' \ No newline at end of file +export { tinderDog } from './tinder_dog.js' +export { privacyrun } from './privacyrun.js' \ No newline at end of file diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts new file mode 100644 index 000000000..98dc52aad --- /dev/null +++ b/discojs/src/default_tasks/privacyrun.ts @@ -0,0 +1,66 @@ +import type { TaskProvider } from "../index.js"; +import { Tokenizer, models, serialization } from "../index.js"; + +export const privacyrun: TaskProvider<"text", "local"> = { + async getTask() { + return { + id: 'privacyrun_task', + dataType: "text", + displayInformation: { + title: "GPT Privacy-Preserving Fine-tuning", + summary: { + preview: 'Fine-tune a pre-trained GPT model collaboratively and privately.', + overview: "Fine-tune a pre-trained GPT-2 model created by the ONNX converter in your browser collaboratively without sharing your raw data. The model is loaded from Google Cloud Storage and fine-tuned using federated learning." + }, + model: [ + "The model is a pre-trained GPT-2 architecture converted from ONNX and loaded from Google Cloud Storage.", + "The tokenizer used for preprocessing is the GPT-2 Byte-Pair encoding tokenizer.", + "The model is trained via an Adam optimizer with unit gradient clipping and softmax cross-entropy loss.", + "Context length is kept at 1024 to match the pre-trained model, with batch size at 1.", + ].join(" "), + dataFormatInformation: 'You can use any natural language (text) dataset. The dataset should be formatted as a plain text file with each line representing a segment of text.', + dataExample: + "For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work.", + }, + trainingInformation: { + scheme: 'local', + aggregationStrategy: 'mean', + minNbOfParticipants: 2, + epochs: 6, + validationSplit: 0.1, + roundDuration: 2, + batchSize: 8, + tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"), + contextLength: 1024, + tensorBackend: 'gpt' + } + } + }, + + async getModel() { + // Load the pre-trained ONNX-converted model from Google Cloud Storage + // The model should be in DiscoJS serialization format (created by onnx-converter) + const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model.json"; + + try { + const response = await fetch(modelUrl); + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const arrayBuffer = await response.arrayBuffer(); + const encodedData = new Uint8Array(arrayBuffer); + + const model = await serialization.model.decode(encodedData); + + if (!(model instanceof models.GPT)) { + throw new Error("Loaded model is not a GPT model"); + } + + return model; + } catch (error) { + console.error("Failed to load model from Google Cloud Storage:", error); + throw new Error(`Could not load model from ${modelUrl}. Make sure the URL is correct and the model exists in DiscoJS serialization format.`); + } + }, +} From 3014cf53fe128b8d5490406b584736fa39b0ce66 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 17 Apr 2026 19:50:21 +0200 Subject: [PATCH 02/23] lint error correction --- cli/src/evaluate_finetuned_gpt2.ts | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cli/src/evaluate_finetuned_gpt2.ts b/cli/src/evaluate_finetuned_gpt2.ts index 6f031cd04..ecccc6bc1 100644 --- a/cli/src/evaluate_finetuned_gpt2.ts +++ b/cli/src/evaluate_finetuned_gpt2.ts @@ -144,7 +144,13 @@ async function benchmarkQA( D: { A: 0, B: 0, C: 0, D: 0 } }; - const logs: any[] = []; + type PredictionLog = { + predicted: string; + answer: string; + correct: boolean; + }; + + const logs: PredictionLog[] = []; const start = Date.now(); From e785cf64162209fd7f4e52c8c1488da34b8c8d6a Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Sun, 19 Apr 2026 13:57:13 +0200 Subject: [PATCH 03/23] add val dataset path param --- cli/src/args.ts | 11 +++++-- cli/src/cli.ts | 15 +++++++-- cli/src/data.ts | 10 ++++-- discojs/src/default_tasks/privacyrun.ts | 6 ++-- discojs/src/training/disco.ts | 44 +++++++++++++++++++++---- 5 files changed, 70 insertions(+), 16 deletions(-) diff --git a/cli/src/args.ts b/cli/src/args.ts index ced893a72..955dcf47e 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -21,6 +21,8 @@ export interface BenchmarkArguments { roundDuration: number batchSize: number validationSplit: number + datasetPath?: string + validationDatasetPath?: string // DP epsilon?: number @@ -41,6 +43,8 @@ export interface BenchmarkArguments { type BenchmarkUnsafeArguments = Omit & { task: string + datasetPath?: string + validationDatasetPath?: string help?: boolean } @@ -55,6 +59,8 @@ const unsafeArgs = parse( roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 }, batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 }, validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 }, + datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true }, + validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true }, save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, host: { type: (raw: string) => new URL(raw), @@ -89,18 +95,19 @@ const unsafeArgs = parse( const supportedTasks = Map( await Promise.all( - Set.of>( + Set.of>( defaultTasks.cifar10, defaultTasks.lusCovid, defaultTasks.simpleFace, defaultTasks.titanic, defaultTasks.tinderDog, defaultTasks.mnist, + defaultTasks.privacyrun, ).map( async (t) => [(await t.getTask()).id, t] as [ string, - TaskProvider<"image" | "tabular", Network>, + TaskProvider<"image" | "tabular" | "text", Network>, ], ), ), diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 2e23c6514..9629f5b01 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -17,6 +17,7 @@ import type { } from "@epfml/discojs"; import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' +import { loadText } from "@epfml/discojs-node"; import { getTaskData } from './data.js' import { args } from './args.js' import { makeUserLogFile } from "./user_log.js"; @@ -27,6 +28,7 @@ async function runUser( task: Task, url: URL, data: Dataset, + validationData: Dataset | undefined, userIndex: number, numberOfUsers: number, ): Promise> { @@ -49,7 +51,7 @@ async function runUser( } try{ - for await (const log of disco.trainSummary(data)){ + for await (const log of disco.trainSummary(data, validationData)){ finalLog.push(log); if (jsonStream){ @@ -104,10 +106,17 @@ async function main( console.log({ args }) const dataSplits = await Promise.all( - Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers)) + Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers, args.datasetPath)) ) + + let validationData: Dataset | undefined = undefined; + if (args.validationDatasetPath) { + // Assume text task for now + validationData = loadText(args.validationDatasetPath).cached() as Dataset; + } + const logs = await Promise.all( - dataSplits.map((data, i) => runUser(task, args.host, data as Dataset, i, numberOfUsers)) + dataSplits.map((data, i) => runUser(task, args.host, data as Dataset, validationData, i, numberOfUsers)) ) if (args.save) { diff --git a/cli/src/data.ts b/cli/src/data.ts index aa4d0a330..f3b6ff834 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -5,8 +5,9 @@ import { DataType, Image, Task, + Text, } from "@epfml/discojs"; -import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node"; +import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node"; import { Repeat } from "immutable"; async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise> { @@ -94,7 +95,10 @@ function loadData(dataName: string, split: number): Dataset( taskID: Task.ID, userIdx: number, - totalClient: number + totalClient: number, + datasetPath?: string, + isValidation?: boolean, + validationDatasetPath?: string ): Promise> { switch (taskID) { case "simple_face": // remove @@ -118,6 +122,8 @@ export async function getTaskData( case "mnist_federated": case "mnist": return loadData("mnist", userIdx) as Dataset; + case "privacyrun": + return loadText(isValidation && validationDatasetPath ? validationDatasetPath : datasetPath ?? '../datasets/med_mcq/train.txt') as Dataset; default: throw new Error(`Data loader for ${taskID} not implemented.`); } diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts index 98dc52aad..9b54af387 100644 --- a/discojs/src/default_tasks/privacyrun.ts +++ b/discojs/src/default_tasks/privacyrun.ts @@ -4,7 +4,7 @@ import { Tokenizer, models, serialization } from "../index.js"; export const privacyrun: TaskProvider<"text", "local"> = { async getTask() { return { - id: 'privacyrun_task', + id: 'privacyrun', dataType: "text", displayInformation: { title: "GPT Privacy-Preserving Fine-tuning", @@ -25,7 +25,7 @@ export const privacyrun: TaskProvider<"text", "local"> = { trainingInformation: { scheme: 'local', aggregationStrategy: 'mean', - minNbOfParticipants: 2, + // minNbOfParticipants: 2, epochs: 6, validationSplit: 0.1, roundDuration: 2, @@ -56,6 +56,8 @@ export const privacyrun: TaskProvider<"text", "local"> = { if (!(model instanceof models.GPT)) { throw new Error("Loaded model is not a GPT model"); } + + console.log("Successfully loaded pre-trained GPT model from Google Cloud Storage"); return model; } catch (error) { diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..7d85ba8fb 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -159,16 +159,18 @@ export class Disco extends EventEmitter<{ /** Train on dataset, yielding logs of every batch. */ async *trainByBatch( dataset: Dataset, + validationDataset?: Dataset, ): AsyncGenerator { - for await (const round of this.train(dataset)) + for await (const round of this.train(dataset, validationDataset)) for await (const epoch of round) yield* epoch; } /** Train on dataset, yielding summary logs */ async *trainSummary( dataset: Dataset, + validationDataset?: Dataset, ): AsyncGenerator { - for await (const [roundNum, round] of enumerate(this.train(dataset))) { + for await (const [roundNum, round] of enumerate(this.train(dataset, validationDataset))) { const [roundGen, roundLogsPromise] = async_iterator.split(round); const epochResults: Array<{epochNum: number; epochLogs: EpochLogs}> = []; @@ -190,8 +192,8 @@ export class Disco extends EventEmitter<{ } /** Run whole train on dataset. */ - async trainFully(dataset: Dataset): Promise { - for await (const round of this.train(dataset)) + async trainFully(dataset: Dataset, validationDataset?: Dataset): Promise { + for await (const round of this.train(dataset, validationDataset)) for await (const epoch of round) for await (const _ of epoch); } @@ -203,20 +205,23 @@ export class Disco extends EventEmitter<{ **/ async *train( dataset: Dataset, + validationDataset?: Dataset, ): AsyncGenerator< AsyncGenerator, RoundLogs> > { this.#logger.success("Training started"); - const [trainingDataset, validationDataset] = - await this.#preprocessSplitAndBatch(dataset); + const [trainingDataset, validationDataset_] = + validationDataset !== undefined + ? await this.#preprocessDatasets(dataset, validationDataset) + : await this.#preprocessSplitAndBatch(dataset); // the client fetches the latest weights upon connection // TODO unsafe cast this.trainer.model = (await this.#client.connect()) as Model; for await (const [roundNum, round] of enumerate( - this.trainer.train(trainingDataset, validationDataset), + this.trainer.train(trainingDataset, validationDataset_), )) { yield async function* (this: Disco) { const [roundGen, roundLogsPromise] = split(round); @@ -297,6 +302,31 @@ export class Disco extends EventEmitter<{ validation.batch(batchSize).cached(), ]; } + + async #preprocessDatasets( + trainingDataset: Dataset, + validationDataset: Dataset, + ): Promise< + [ + Dataset>, + Dataset> | undefined, + ] + > { + const { batchSize } = this.#task.trainingInformation; + + let preprocessedTraining = processing.preprocess(this.#task, trainingDataset); + let preprocessedValidation = processing.preprocess(this.#task, validationDataset); + + if (this.#preprocessOnce) { + preprocessedTraining = new Dataset(await arrayFromAsync(preprocessedTraining)); + preprocessedValidation = new Dataset(await arrayFromAsync(preprocessedValidation)); + } + + return [ + preprocessedTraining.batch(batchSize).cached(), + preprocessedValidation.batch(batchSize).cached(), + ]; + } } // Array.fromAsync not yet widely used (2024) From c93631cd476a3b8ab11dc634878b82518be1517a Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 20 Apr 2026 16:50:10 +0200 Subject: [PATCH 04/23] add working local train --- cli/src/cli.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 9629f5b01..9c291474b 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -26,6 +26,7 @@ import type { UserLogFile } from "./user_log.js"; async function runUser( task: Task, + provider: TaskProvider, url: URL, data: Dataset, validationData: Dataset | undefined, @@ -38,6 +39,13 @@ async function runUser( const client = clients.getClient(trainingScheme, url, task, aggregator) const disco = new Disco(task, client, { scheme: trainingScheme }); + // For local training, load model from provider before training starts + if (trainingScheme === "local") { + console.log("Loading model for local training..."); + disco.trainer.model = await provider.getModel(); + console.log("Model loaded successfully"); + } + const dir = path.join(".", `${args.testID}`); await fs.mkdir(dir, { recursive: true }); const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`); @@ -116,7 +124,7 @@ async function main( } const logs = await Promise.all( - dataSplits.map((data, i) => runUser(task, args.host, data as Dataset, validationData, i, numberOfUsers)) + dataSplits.map((data, i) => runUser(task, provider, args.host, data as Dataset, validationData, i, numberOfUsers)) ) if (args.save) { From 2e34a3737a5d71006d014d672629583ce0dcb67f Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 22 Apr 2026 16:13:52 +0200 Subject: [PATCH 05/23] add model saving to disk arg and more debug lines --- cli/src/args.ts | 2 ++ cli/src/cli.ts | 22 ++++++++++++++++++---- discojs/src/training/trainer.ts | 9 +++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/cli/src/args.ts b/cli/src/args.ts index 955dcf47e..00d72102e 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -38,6 +38,7 @@ export interface BenchmarkArguments { maxShareValue?: number save: boolean + saveModel: boolean host: URL } @@ -62,6 +63,7 @@ const unsafeArgs = parse( datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true }, validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true }, save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, + saveModel: { type: Boolean, alias: 'm', description: 'Save trained model to disk', defaultValue: false }, host: { type: (raw: string) => new URL(raw), typeLabel: "URL", diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 9c291474b..3caf6d5cf 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -5,7 +5,7 @@ import { List, Range } from 'immutable' import fs from 'node:fs/promises' import { createWriteStream } from "node:fs"; import path from "node:path"; - +import createDebug from "debug"; import type { Dataset, DataFormat, @@ -17,12 +17,13 @@ import type { } from "@epfml/discojs"; import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' -import { loadText } from "@epfml/discojs-node"; +import { loadText, saveModelToDisk } from "@epfml/discojs-node"; import { getTaskData } from './data.js' import { args } from './args.js' import { makeUserLogFile } from "./user_log.js"; import type { UserLogFile } from "./user_log.js"; +const debug = createDebug("cli:main"); async function runUser( task: Task, @@ -33,7 +34,8 @@ async function runUser( userIndex: number, numberOfUsers: number, ): Promise> { - // cast as typescript isn't good with generics + debug(`Starting runUser for client ${userIndex}`); + const userStart = Date.now(); const trainingScheme = task.trainingInformation.scheme as N const aggregator = aggregators.getAggregator(task) const client = clients.getClient(trainingScheme, url, task, aggregator) @@ -41,9 +43,12 @@ async function runUser( // For local training, load model from provider before training starts if (trainingScheme === "local") { + debug(`Loading model for local training client ${userIndex}...`); + const modelStart = Date.now(); console.log("Loading model for local training..."); disco.trainer.model = await provider.getModel(); console.log("Model loaded successfully"); + debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`); } const dir = path.join(".", `${args.testID}`); @@ -59,6 +64,8 @@ async function runUser( } try{ + debug(`Starting training for client ${userIndex}`); + const trainStart = Date.now(); for await (const log of disco.trainSummary(data, validationData)){ finalLog.push(log); @@ -66,9 +73,16 @@ async function runUser( jsonStream.write(JSON.stringify(log) + "\n"); } } + debug(`Training took ${Date.now() - trainStart}ms for client ${userIndex}`); await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish - + // Save the trained model if requested + if (args.saveModel) { + const modelDir = path.join(".", `${args.testID}`, "models"); + const modelFileName = `client${userIndex}_model.json`; + await saveModelToDisk(disco.trainer.model, modelDir, modelFileName); + console.log(`Model saved for client ${userIndex} at ${modelDir}/${modelFileName}`); + } // saving the entire per-user logs if (args.save) { const finalPath = path.join(dir, `client${userIndex}_local_log.json`); diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 68e716bcc..73cae1f7c 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -16,8 +16,11 @@ import { } from "../index.js"; import { privacy } from "../index.js"; import { Client } from "../client/index.js"; +import createDebug from "debug"; import * as async_iterator from "../utils/async_iterator.js"; +const debug = createDebug("discojs:training:trainer"); + export interface RoundLogs { epochs: List; participants: number; @@ -88,6 +91,7 @@ export class Trainer { AsyncGenerator, RoundLogs>, void > { + debug("Start train") if (this.#training !== undefined) throw new Error( "training already running, stop it before launching a new one", @@ -109,6 +113,9 @@ export class Trainer { void > { const totalRound = Math.trunc(this.#epochs / this.#roundDuration); + + debug("Run rounds") + for (let round = 0; round < totalRound; round++) { await this.#client.onRoundBeginCommunication(); @@ -150,6 +157,8 @@ export class Trainer { ): AsyncGenerator, RoundLogs> { let epochsLogs = List(); + debug("Run round") + // Before starting the training, get the validation of global model const validation = validationDataset !== undefined ? await this.model.evaluate(validationDataset) : undefined; From 4c9c84cc14c632e7a77b50af9b05c363f0584158 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 24 Apr 2026 17:55:42 +0200 Subject: [PATCH 06/23] add debug commands --- cli/src/cli.ts | 20 +++++++++++--------- discojs/src/client/client.ts | 5 ++++- discojs/src/default_tasks/privacyrun.ts | 6 +++--- discojs/src/models/gpt/index.ts | 8 ++++++++ discojs/src/models/gpt/model.ts | 6 ++++++ discojs/src/serialization/model.ts | 23 +++++++++++++++++++++++ discojs/src/training/disco.ts | 11 +++++++++++ 7 files changed, 66 insertions(+), 13 deletions(-) diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 3caf6d5cf..14523d44f 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -42,15 +42,17 @@ async function runUser( const disco = new Disco(task, client, { scheme: trainingScheme }); // For local training, load model from provider before training starts - if (trainingScheme === "local") { - debug(`Loading model for local training client ${userIndex}...`); - const modelStart = Date.now(); - console.log("Loading model for local training..."); - disco.trainer.model = await provider.getModel(); - console.log("Model loaded successfully"); - debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`); - } - + // if (trainingScheme === "local") { + // debug(`Loading model for training client ${userIndex}...`); + // const modelStart = Date.now(); + // console.log("Loading model for local training..."); + // disco.trainer.model = await provider.getModel(); + // console.log("Model loaded successfully"); + // debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`); + // } + + + const dir = path.join(".", `${args.testID}`); await fs.mkdir(dir, { recursive: true }); const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`); diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..e2168afa3 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -186,8 +186,11 @@ export abstract class Client extends EventEmitter<{ } url.pathname += `tasks/${this.task.id}/model.json` + debug("fetching latest model from server at {} for task {}...", url.href, this.task.id) + const response = await fetch(url); - if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`) + else debug("response ok, decoding model...") const encoded = new Uint8Array(await response.arrayBuffer()) return await serialization.model.decode(encoded) diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts index 9b54af387..d667e3c9c 100644 --- a/discojs/src/default_tasks/privacyrun.ts +++ b/discojs/src/default_tasks/privacyrun.ts @@ -1,7 +1,7 @@ import type { TaskProvider } from "../index.js"; import { Tokenizer, models, serialization } from "../index.js"; -export const privacyrun: TaskProvider<"text", "local"> = { +export const privacyrun: TaskProvider<"text", "federated"> = { async getTask() { return { id: 'privacyrun', @@ -23,9 +23,9 @@ export const privacyrun: TaskProvider<"text", "local"> = { "For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work.", }, trainingInformation: { - scheme: 'local', + scheme: 'federated', aggregationStrategy: 'mean', - // minNbOfParticipants: 2, + minNbOfParticipants: 2, epochs: 6, validationSplit: 0.1, roundDuration: 2, diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts index 228d60cc4..886421892 100644 --- a/discojs/src/models/gpt/index.ts +++ b/discojs/src/models/gpt/index.ts @@ -228,8 +228,16 @@ export class GPT extends Model<"text"> { } static deserialize(data: GPTSerialization): Model<"text"> { + + debug("GPT model deserialization started") + const model = new GPT(data.config); + + debug("GPT model config initialized: %O", data.config) + model.weights = data.weights; + + debug("GPT model weights initialized") return model; } diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index 01ee51e92..9207e8d47 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -64,8 +64,12 @@ export class GPTModel extends tf.LayersModel { let accuracyFraction: [number, number] = [0, 0]; let averageLoss = 0 let iteration = 1 + + debug("before iterator init") const iterator = await dataset.iterator() + debug("after getting iterator, before next") let next = await iterator.next() + debug("after next of iterator") while (next.done !== true && iteration <= this.config.maxIter) { let weightUpdateTime = performance.now() @@ -73,7 +77,9 @@ export class GPTModel extends tf.LayersModel { const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } let preprocessingTime = performance.now() + debug("await batch data before {} iteration", iteration) await Promise.all([xs.data(), ys.data()]) + debug("after await batch data {} iteration", iteration) preprocessingTime = performance.now() - preprocessingTime // TODO include as a tensor inside the model diff --git a/discojs/src/serialization/model.ts b/discojs/src/serialization/model.ts index 020d147af..4df0fab0c 100644 --- a/discojs/src/serialization/model.ts +++ b/discojs/src/serialization/model.ts @@ -7,6 +7,10 @@ import { GPTConfig } from '../models/index.js' import * as coder from "./coder.js"; import { Encoded, isEncoded } from "./coder.js"; +import createDebug from "debug" + +const debug = createDebug("discojs:serialization:model"); + const Type = { TFJS: 0, GPT: 1 @@ -16,11 +20,13 @@ export async function encode(model: Model): Promise { switch (true) { case model instanceof models.TFJS: { const serialized = await model.serialize(); + debug("TFJS model serialized"); return coder.encode([Type.TFJS, ...serialized]); } case model instanceof models.GPT: { const { weights, config } = model.serialize(); const serializedWeights = await serialization.weights.encode(weights); + debug("GPT model weights serialized"); return coder.encode([Type.GPT, serializedWeights, config]); } default: @@ -30,23 +36,34 @@ export async function encode(model: Model): Promise { export async function decode(encoded: Encoded): Promise> { const raw = coder.decode(encoded) + + debug("IMPORTANT:model decoded") if (!Array.isArray(raw) || raw.length < 2) { throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values") } + + debug("model encoding array length: %d", raw.length) + const type = raw[0] as unknown if (typeof type !== 'number') { throw new Error('invalid encoding, first encoding field should be the model type') } + + debug("model type: %d", type) + const rawModel = raw[1] as unknown switch (type) { case Type.TFJS: { + debug("TFJS model decoding started"); if (raw.length !== 3) throw new Error( "invalid TFJS model encoding: should be an array of length 3", ); const [rawDatatype, rawModel] = raw.slice(1) as unknown[]; + debug("TFJS model datatype: %s", rawDatatype); + let datatype; switch (rawDatatype) { case "image": @@ -70,6 +87,7 @@ export async function decode(encoded: Encoded): Promise> { if (raw.length == 2) { config = undefined } else if (raw.length == 3) { + debug("GPT model config decoding") config = raw[2] as GPTConfig } else { throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3') @@ -79,7 +97,12 @@ export async function decode(encoded: Encoded): Promise> { throw new Error( "invalid encoding, gpt-tfjs model weights should be an encoding of its weights", ); + + debug("GPT model weights decoding...") const weights = serialization.weights.decode(rawModel) + + debug("GPT model weights decoded, deserializing model... CONFIG MIGHT BE WRONG") + debug("GPT model config: %O", config || "undefined, using default config") return models.GPT.deserialize({weights, config}) } default: diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7d85ba8fb..33c63c244 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -21,8 +21,12 @@ import { getAggregator } from "../aggregator/index.js"; import { enumerate, split } from "../utils/async_iterator.js"; import { EventEmitter } from "../utils/event_emitter.js"; +import createDebug from "debug" + import { RoundLogs, Trainer } from "./trainer.js"; +const debug = createDebug("discojs:training:disco"); + interface DiscoConfig { scheme: N; logger: Logger; @@ -175,6 +179,8 @@ export class Disco extends EventEmitter<{ const epochResults: Array<{epochNum: number; epochLogs: EpochLogs}> = []; + debug("Starting round %d", roundNum) + for await (const [epochNum, epoch] of enumerate(roundGen)) { const [epochGen, epochLogsPromise] = async_iterator.split(epoch); for await (const _ of epochGen); @@ -218,7 +224,12 @@ export class Disco extends EventEmitter<{ // the client fetches the latest weights upon connection // TODO unsafe cast + debug("Connecting to client and fetching initial model..."); this.trainer.model = (await this.#client.connect()) as Model; + debug("Initial model fetched successfully"); + if (this.trainer.model === null) { + debug(`No pre-trained model provided for client, initializing randomly...`); + } for await (const [roundNum, round] of enumerate( this.trainer.train(trainingDataset, validationDataset_), From 23c733d5a4817b81e17c9a51da782b402d128ee9 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 24 Apr 2026 18:21:56 +0200 Subject: [PATCH 07/23] add cnahges to federated approach --- discojs/src/client/client.ts | 2 +- discojs/src/client/event_connection.ts | 4 ++++ discojs/src/client/federated/federated_client.ts | 4 +++- discojs/src/client/federated/messages.ts | 2 +- server/src/controllers/federated_controller.ts | 2 +- 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index e2168afa3..5c8aa1304 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -186,7 +186,7 @@ export abstract class Client extends EventEmitter<{ } url.pathname += `tasks/${this.task.id}/model.json` - debug("fetching latest model from server at {} for task {}...", url.href, this.task.id) + debug("fetching latest model from server at %0 for task %1...", url.href, this.task.id) const response = await fetch(url); if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`) diff --git a/discojs/src/client/event_connection.ts b/discojs/src/client/event_connection.ts index 3e3aec409..82722d111 100644 --- a/discojs/src/client/event_connection.ts +++ b/discojs/src/client/event_connection.ts @@ -118,6 +118,10 @@ export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage { + debug("websocket closed: code=%o reason=%o wasClean=%o", event.code, event.reason, event.wasClean) + } + return await new Promise((resolve, reject) => { ws.onerror = (err: WebSocket.ErrorEvent) => { reject(new Error(`Server unreachable: ${err.message}`)) diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index a89c65ad6..daaa1df80 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -88,7 +88,9 @@ export class FederatedClient extends Client<"federated"> { // Upon connecting, the server answers with a boolean // which indicates whether there are enough participants or not debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants) - model.weights = serialization.weights.decode(payload) + if (payload !== undefined) { + model.weights = serialization.weights.decode(payload) + } return model } diff --git a/discojs/src/client/federated/messages.ts b/discojs/src/client/federated/messages.ts index 3733d2c1c..d8961d97d 100644 --- a/discojs/src/client/federated/messages.ts +++ b/discojs/src/client/federated/messages.ts @@ -18,7 +18,7 @@ export interface NewFederatedNodeInfo { type: type.NewFederatedNodeInfo id: NodeID waitForMoreParticipants: boolean - payload: serialization.Encoded; + payload?: serialization.Encoded; round: number nbOfParticipants: number } diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 1e52fe539..cf2df0fa8 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -89,7 +89,7 @@ export class FederatedController extends TrainingController< type: MessageTypes.NewFederatedNodeInfo, id: clientId, waitForMoreParticipants: this.connections.size < minNbOfParticipants, - payload: this.#latestGlobalWeights, + payload: this.#aggregator.round === 0 ? undefined : this.#latestGlobalWeights, round: this.#aggregator.round, nbOfParticipants: this.connections.size } From 096393c8b70d57348e99f6bc5ecbd966d5804c26 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 24 Apr 2026 18:31:27 +0200 Subject: [PATCH 08/23] change round 0 payload null handling --- discojs/src/client/federated/federated_client.ts | 2 +- discojs/src/client/federated/messages.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index daaa1df80..b6c2c59d5 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -88,7 +88,7 @@ export class FederatedClient extends Client<"federated"> { // Upon connecting, the server answers with a boolean // which indicates whether there are enough participants or not debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants) - if (payload !== undefined) { + if (payload != null) { model.weights = serialization.weights.decode(payload) } return model diff --git a/discojs/src/client/federated/messages.ts b/discojs/src/client/federated/messages.ts index d8961d97d..c6dee6698 100644 --- a/discojs/src/client/federated/messages.ts +++ b/discojs/src/client/federated/messages.ts @@ -18,7 +18,7 @@ export interface NewFederatedNodeInfo { type: type.NewFederatedNodeInfo id: NodeID waitForMoreParticipants: boolean - payload?: serialization.Encoded; + payload?: serialization.Encoded | null; round: number nbOfParticipants: number } From 1ef7d856b18a5a1a08f59e44336623c82b82d3f3 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Sun, 26 Apr 2026 22:51:57 +0200 Subject: [PATCH 09/23] chnage server max payload limit to higher number --- server/src/controllers/federated_controller.ts | 4 ++++ server/src/server.ts | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index cf2df0fa8..9da1e589a 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -66,6 +66,10 @@ export class FederatedController extends TrainingController< } const shortId = clientId.slice(0, 4) + ws.on('error', (err) => { + debug("websocket error for client [%s]: %o", shortId, err) + }) + // Setup callbacks triggered upon receiving the different client messages ws.on('message', (data: Buffer) => { const msg: unknown = msgpack.decode(data) diff --git a/server/src/server.ts b/server/src/server.ts index 06641cff0..b862c8272 100644 --- a/server/src/server.ts +++ b/server/src/server.ts @@ -38,6 +38,10 @@ export class Server { async serve(port?: number): Promise<[http.Server, URL]> { const wsApplier = expressWS(express(), undefined, { leaveRouterUntouched: true, + wsOptions: { + // GPT-sized federated updates can exceed the ws default payload limit. + maxPayload: 1024 * 1024 * 1024, + }, }); const app = wsApplier.app; From 417bfa5967d64e6ef0121dcea471775e75bfd97c Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Sun, 26 Apr 2026 23:03:19 +0200 Subject: [PATCH 10/23] fix memory reads and wait for all clients to begin --- cli/src/cli.ts | 2 +- discojs/src/client/federated/federated_client.ts | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 14523d44f..51697cde1 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -39,7 +39,7 @@ async function runUser( const trainingScheme = task.trainingInformation.scheme as N const aggregator = aggregators.getAggregator(task) const client = clients.getClient(trainingScheme, url, task, aggregator) - const disco = new Disco(task, client, { scheme: trainingScheme }); + const disco = new Disco(task, client, { scheme: trainingScheme, preprocessOnce: true }); // For local training, load model from provider before training starts // if (trainingScheme === "local") { diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index b6c2c59d5..9caea57f0 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -105,9 +105,16 @@ export class FederatedClient extends Client<"federated"> { this.aggregator.setNodes(this.aggregator.nodes.delete(SERVER_NODE_ID)); } - override onRoundBeginCommunication(): Promise { + // override onRoundBeginCommunication(): Promise { + // // Prepare the result promise for the incoming round + // this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) + // this.saveAndEmit("local training") + // return Promise.resolve(); + // } + override async onRoundBeginCommunication(): Promise { // Prepare the result promise for the incoming round this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) + await this.waitForParticipantsIfNeeded() // In case we are waiting for more participants, we wait before starting the local training this.saveAndEmit("local training") return Promise.resolve(); } From 77ce9a9cf24c0ade32ceb346d81570e85d3f3ab5 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 27 Apr 2026 01:28:54 +0200 Subject: [PATCH 11/23] add server debug logs to see why ws close session --- discojs/src/client/event_connection.ts | 6 +++++- server/src/controllers/federated_controller.ts | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/discojs/src/client/event_connection.ts b/discojs/src/client/event_connection.ts index 82722d111..a8c562a74 100644 --- a/discojs/src/client/event_connection.ts +++ b/discojs/src/client/event_connection.ts @@ -98,7 +98,10 @@ export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage msg is Message, validateSent: (msg: Message) => boolean): Promise { - const ws = new WebSocket(url) + const ws = new WebSocket(url, { + // Federated GPT updates can exceed the default ws payload limit. + maxPayload: 1024 * 1024 * 1024, + }) ws.binaryType = 'arraybuffer' const server: WebSocketServer = new WebSocketServer(ws, validateSent) @@ -124,6 +127,7 @@ export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage { ws.onerror = (err: WebSocket.ErrorEvent) => { + debug("websocket error while connecting/receiving: %o", err.message) reject(new Error(`Server unreachable: ${err.message}`)) } ws.onopen = () => { resolve(server) } diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 9da1e589a..bef9a1909 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -108,6 +108,12 @@ export class FederatedController extends TrainingController< case MessageTypes.SendPayload: { const { payload, round } = msg if (this.#aggregator.isValidContribution(clientId, round)) { + debug( + "Received valid contribution from client [%s] for round %d (participants=%d)", + shortId, + round, + this.connections.size, + ) const weights = serialization.weights.decode(payload) // Create a callback to send the aggregated weight to the client @@ -120,9 +126,12 @@ export class FederatedController extends TrainingController< payload: await serialization.weights.encode(weightUpdate), nbOfParticipants: this.connections.size } + debug("Prepared aggregated payload for client [%s] at round %o", shortId, this.#aggregator.round) ws.send(msgpack.encode(msg)) + debug("Aggregated payload sent to client [%s] for round %o", shortId, this.#aggregator.round) }) // Add the contribution + debug("Adding contribution from client [%s] to aggregator for round %d", shortId, round) this.#aggregator.add(clientId, weights, round) debug(`Successfully added contribution from client [%s] for round ${round}`, shortId) } else { From 7aef628633d7c1220060d58efa1f42598f713b70 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 27 Apr 2026 02:14:56 +0200 Subject: [PATCH 12/23] cover whole dataset and split data to clients --- cli/src/data.ts | 58 ++++++++++++++++++++++++++++++-- discojs/src/models/gpt/config.ts | 3 +- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/cli/src/data.ts b/cli/src/data.ts index f3b6ff834..0f3d22e62 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -1,4 +1,5 @@ import path from "node:path"; +import { createReadStream } from "node:fs"; import { Dataset, processing } from "@epfml/discojs"; import { DataFormat, @@ -10,6 +11,44 @@ import { import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node"; import { Repeat } from "immutable"; +function loadShardedTextSamples( + filePath: string, + userIdx: number, + totalClient: number, +): Dataset { + return new Dataset(async function* () { + const stream = createReadStream(filePath, { encoding: "utf8" }); + const sampleDelimiter = "<|endoftext|>"; + let buffer = ""; + let sampleIndex = 0; + + for await (const chunk of stream) { + if (typeof chunk !== "string") { + throw new Error("Expected file stream to yield string"); + } + + buffer += chunk; + + let delimiterIndex = buffer.indexOf(sampleDelimiter); + while (delimiterIndex !== -1) { + const sample = buffer.slice(0, delimiterIndex + sampleDelimiter.length).trim(); + if (sample !== "" && sampleIndex % totalClient === userIdx) { + yield sample; + } + + sampleIndex++; + buffer = buffer.slice(delimiterIndex + sampleDelimiter.length); + delimiterIndex = buffer.indexOf(sampleDelimiter); + } + } + + const trailingSample = buffer.trim(); + if (trailingSample !== "" && sampleIndex % totalClient === userIdx) { + yield trailingSample; + } + }); +} + async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise> { const folder = path.join("..", "datasets", "simple_face"); @@ -122,8 +161,23 @@ export async function getTaskData( case "mnist_federated": case "mnist": return loadData("mnist", userIdx) as Dataset; - case "privacyrun": - return loadText(isValidation && validationDatasetPath ? validationDatasetPath : datasetPath ?? '../datasets/med_mcq/train.txt') as Dataset; + case "privacyrun": { + const filePath = + isValidation && validationDatasetPath + ? validationDatasetPath + : datasetPath ?? "../datasets/med_mcq/train.txt"; + + // Keep validation shared, but shard training data across clients by MCQ sample. + if (isValidation) { + return loadText(filePath) as Dataset; + } + + return loadShardedTextSamples( + filePath, + userIdx, + totalClient, + ) as Dataset; + } default: throw new Error(`Data loader for ${taskID} not implemented.`); } diff --git a/discojs/src/models/gpt/config.ts b/discojs/src/models/gpt/config.ts index 0d157a7fd..b2a86340f 100644 --- a/discojs/src/models/gpt/config.ts +++ b/discojs/src/models/gpt/config.ts @@ -31,7 +31,8 @@ export type GPTConfig = { export const DefaultGPTConfig: Required = { lr: 0.001, weightDecay: 0, - maxIter: 10, + // By default, iterate through the whole dataset and let dataset exhaustion stop the epoch. + maxIter: Number.MAX_SAFE_INTEGER, verbose: 0, modelType: 'gpt-nano', evaluate: true, From d65fb4a6692977a6063bfb3dbc0d7bbc8653d271 Mon Sep 17 00:00:00 2001 From: Mina Petrovic Date: Mon, 4 May 2026 14:58:28 +0200 Subject: [PATCH 13/23] add validation dataset loading changes --- cli/src/cli.ts | 7 ++++--- cli/src/data.ts | 26 ++++++++++++++++++-------- datasets/.gitignore | 1 + 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 51697cde1..5540974ae 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -17,7 +17,7 @@ import type { } from "@epfml/discojs"; import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' -import { loadText, saveModelToDisk } from "@epfml/discojs-node"; +import { saveModelToDisk } from "@epfml/discojs-node"; import { getTaskData } from './data.js' import { args } from './args.js' import { makeUserLogFile } from "./user_log.js"; @@ -135,8 +135,9 @@ async function main( let validationData: Dataset | undefined = undefined; if (args.validationDatasetPath) { - // Assume text task for now - validationData = loadText(args.validationDatasetPath).cached() as Dataset; + validationData = ( + await getTaskData(task.id, 0, 1, args.validationDatasetPath, true, args.validationDatasetPath) + ).cached() as Dataset; } const logs = await Promise.all( diff --git a/cli/src/data.ts b/cli/src/data.ts index 0f3d22e62..a1f41e9aa 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -8,13 +8,13 @@ import { Task, Text, } from "@epfml/discojs"; -import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node"; +import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node"; import { Repeat } from "immutable"; -function loadShardedTextSamples( +function loadTextSamples( filePath: string, - userIdx: number, - totalClient: number, + userIdx?: number, + totalClient?: number, ): Dataset { return new Dataset(async function* () { const stream = createReadStream(filePath, { encoding: "utf8" }); @@ -32,7 +32,12 @@ function loadShardedTextSamples( let delimiterIndex = buffer.indexOf(sampleDelimiter); while (delimiterIndex !== -1) { const sample = buffer.slice(0, delimiterIndex + sampleDelimiter.length).trim(); - if (sample !== "" && sampleIndex % totalClient === userIdx) { + const shouldYield = + userIdx === undefined || + totalClient === undefined || + sampleIndex % totalClient === userIdx; + + if (sample !== "" && shouldYield) { yield sample; } @@ -43,7 +48,12 @@ function loadShardedTextSamples( } const trailingSample = buffer.trim(); - if (trailingSample !== "" && sampleIndex % totalClient === userIdx) { + const shouldYieldTrailing = + userIdx === undefined || + totalClient === undefined || + sampleIndex % totalClient === userIdx; + + if (trailingSample !== "" && shouldYieldTrailing) { yield trailingSample; } }); @@ -169,10 +179,10 @@ export async function getTaskData( // Keep validation shared, but shard training data across clients by MCQ sample. if (isValidation) { - return loadText(filePath) as Dataset; + return loadTextSamples(filePath) as Dataset; } - return loadShardedTextSamples( + return loadTextSamples( filePath, userIdx, totalClient, diff --git a/datasets/.gitignore b/datasets/.gitignore index e644c626a..bdbf3ed6d 100644 --- a/datasets/.gitignore +++ b/datasets/.gitignore @@ -7,6 +7,7 @@ /simple_face-example.png /titanic* /mnist* +/medicalMCQtxtNoExplanation # wikitext /wikitext/ From fe7c51a98eb8e8c707a2703a98af5f887789713b Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 4 May 2026 15:05:44 +0200 Subject: [PATCH 14/23] change gpt config to use whole dadataset --- discojs/src/models/gpt/index.ts | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts index 886421892..082f79766 100644 --- a/discojs/src/models/gpt/index.ts +++ b/discojs/src/models/gpt/index.ts @@ -231,9 +231,17 @@ export class GPT extends Model<"text"> { debug("GPT model deserialization started") - const model = new GPT(data.config); + const config = + data.config === undefined + ? undefined + : { + ...data.config, + maxIter: DefaultGPTConfig.maxIter, + }; - debug("GPT model config initialized: %O", data.config) + const model = new GPT(config); + + debug("GPT model config initialized: %O", config) model.weights = data.weights; From 0dc32aabf4fcda8595a9fc816bd728e183200c68 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 5 May 2026 17:51:51 +0200 Subject: [PATCH 15/23] add arg for model saving location, cnahge save to saveLog, change link for model to contxt 512 --- cli/src/args.ts | 6 ++++-- cli/src/cli.ts | 18 +++++++++++------- discojs/src/default_tasks/privacyrun.ts | 11 +++++++---- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/cli/src/args.ts b/cli/src/args.ts index 00d72102e..984db0156 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -23,6 +23,7 @@ export interface BenchmarkArguments { validationSplit: number datasetPath?: string validationDatasetPath?: string + outputPath?: string // DP epsilon?: number @@ -37,7 +38,7 @@ export interface BenchmarkArguments { // Secure aggregator maxShareValue?: number - save: boolean + saveLogs: boolean saveModel: boolean host: URL } @@ -62,7 +63,8 @@ const unsafeArgs = parse( validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 }, datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true }, validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true }, - save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, + outputPath: { type: String, alias: 'o', description: 'Path to save logs and models. Defaults to ./', optional: true }, + saveLogs: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, saveModel: { type: Boolean, alias: 'm', description: 'Save trained model to disk', defaultValue: false }, host: { type: (raw: string) => new URL(raw), diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 5540974ae..e780f6b65 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -25,6 +25,10 @@ import type { UserLogFile } from "./user_log.js"; const debug = createDebug("cli:main"); +function getOutputDir(): string { + return args.outputPath ?? path.join(".", `${args.testID}`); +} + async function runUser( task: Task, provider: TaskProvider, @@ -39,7 +43,7 @@ async function runUser( const trainingScheme = task.trainingInformation.scheme as N const aggregator = aggregators.getAggregator(task) const client = clients.getClient(trainingScheme, url, task, aggregator) - const disco = new Disco(task, client, { scheme: trainingScheme, preprocessOnce: true }); + const disco = new Disco(task, client, { scheme: trainingScheme, preprocessOnce: false }); // For local training, load model from provider before training starts // if (trainingScheme === "local") { @@ -53,7 +57,7 @@ async function runUser( - const dir = path.join(".", `${args.testID}`); + const dir = getOutputDir(); await fs.mkdir(dir, { recursive: true }); const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`); @@ -61,7 +65,7 @@ async function runUser( // create a write stream that saves learning logs during the train let jsonStream: ReturnType | null = null; - if (args.save){ + if (args.saveLogs){ jsonStream = createWriteStream(streamPath, {flags: "w"}); } @@ -80,13 +84,13 @@ async function runUser( await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish // Save the trained model if requested if (args.saveModel) { - const modelDir = path.join(".", `${args.testID}`, "models"); + const modelDir = path.join(getOutputDir(), "models"); const modelFileName = `client${userIndex}_model.json`; await saveModelToDisk(disco.trainer.model, modelDir, modelFileName); console.log(`Model saved for client ${userIndex} at ${modelDir}/${modelFileName}`); } // saving the entire per-user logs - if (args.save) { + if (args.saveLogs) { const finalPath = path.join(dir, `client${userIndex}_local_log.json`); const userLog: UserLogFile = makeUserLogFile(task, numberOfUsers, userIndex, client.ownId, finalLog); @@ -144,8 +148,8 @@ async function main( dataSplits.map((data, i) => runUser(task, provider, args.host, data as Dataset, validationData, i, numberOfUsers)) ) - if (args.save) { - const dir = path.join(".", `${args.testID}`, `${task.id}`); + if (args.saveLogs) { + const dir = path.join(getOutputDir(), `${task.id}`); await fs.mkdir(dir, { recursive: true }); const filePath = path.join(dir, `${task.id}_${numberOfUsers}users.json`); diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts index d667e3c9c..dcdaa7640 100644 --- a/discojs/src/default_tasks/privacyrun.ts +++ b/discojs/src/default_tasks/privacyrun.ts @@ -26,12 +26,13 @@ export const privacyrun: TaskProvider<"text", "federated"> = { scheme: 'federated', aggregationStrategy: 'mean', minNbOfParticipants: 2, - epochs: 6, + epochs: 1, validationSplit: 0.1, - roundDuration: 2, + roundDuration: 1, batchSize: 8, tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"), - contextLength: 1024, + // contextLength: 1024, + contextLength: 512, tensorBackend: 'gpt' } } @@ -40,8 +41,10 @@ export const privacyrun: TaskProvider<"text", "federated"> = { async getModel() { // Load the pre-trained ONNX-converted model from Google Cloud Storage // The model should be in DiscoJS serialization format (created by onnx-converter) - const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model.json"; + // const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model.json"; + const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_512.json"; + try { const response = await fetch(modelUrl); if (!response.ok) { From 5c03f40a98bde019ffda88b96a70b8556a939680 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 5 May 2026 18:02:42 +0200 Subject: [PATCH 16/23] add training optimizations --- discojs/src/models/gpt/model.ts | 36 +++++++++++++++++---------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index 9207e8d47..a69e498b4 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -78,27 +78,29 @@ export class GPTModel extends tf.LayersModel { let preprocessingTime = performance.now() debug("await batch data before {} iteration", iteration) - await Promise.all([xs.data(), ys.data()]) + // await Promise.all([xs.data(), ys.data()]) + await Promise.resolve() debug("after await batch data {} iteration", iteration) preprocessingTime = performance.now() - preprocessingTime // TODO include as a tensor inside the model - const accTensor = tf.tidy(() => { - const logits = this.apply(xs) - if (Array.isArray(logits)) - throw new Error('model outputs too many tensor') - if (logits instanceof tf.SymbolicTensor) - throw new Error('model outputs symbolic tensor') - return tf.metrics.categoricalAccuracy(ys, logits) - }) - const accSize = accTensor.shape.reduce((l, r) => l * r, 1) - const accSumTensor = accTensor.sum() - const accSum = await accSumTensor.array() - tf.dispose(accSumTensor) - if (typeof accSum !== 'number') - throw new Error('got multiple accuracy sum') - accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize]; - tf.dispose([accTensor]) + // const accTensor = tf.tidy(() => { + // const logits = this.apply(xs) + // if (Array.isArray(logits)) + // throw new Error('model outputs too many tensor') + // if (logits instanceof tf.SymbolicTensor) + // throw new Error('model outputs symbolic tensor') + // return tf.metrics.categoricalAccuracy(ys, logits) + // }) + // const accSize = accTensor.shape.reduce((l, r) => l * r, 1) + // const accSumTensor = accTensor.sum() + // const accSum = await accSumTensor.array() + // tf.dispose(accSumTensor) + // if (typeof accSum !== 'number') + // throw new Error('got multiple accuracy sum') + // accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize]; + // tf.dispose([accTensor]) + accuracyFraction = [Number.NaN, Number.NaN]; const lossTensor = tf.tidy(() => { const { grads, value: lossTensor } = this.optimizer.computeGradients(() => { From 4ee6096b4ba35e3621af3a29c82323ccf8d5c6df Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 6 May 2026 19:13:21 +0200 Subject: [PATCH 17/23] change onnx converter to be able to convert different context len --- onnx-converter/src/convert_onnx.ts | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onnx-converter/src/convert_onnx.ts b/onnx-converter/src/convert_onnx.ts index cf45a76ed..dd3d1aab2 100644 --- a/onnx-converter/src/convert_onnx.ts +++ b/onnx-converter/src/convert_onnx.ts @@ -7,6 +7,7 @@ import { models, serialization } from "@epfml/discojs"; const OUTPUT_FILENAME = "model.json"; const GPT2_N_LAYER = 12; +const GPT2_CONTEXT_LENGTH = 1024; const ONNX_URL = "https://huggingface.co/Xenova/gpt2/resolve/main/onnx/decoder_model.onnx?download=true" @@ -29,8 +30,7 @@ async function main() { // Init empty TF.js model - // Context length value from https://huggingface.co/Xenova/gpt2/blob/main/config.json - const gptModel = new models.GPT({ modelType: 'gpt2', contextLength: 1024 }); + const gptModel = new models.GPT({ modelType: 'gpt2', contextLength: GPT2_CONTEXT_LENGTH }); if (gptModel.config.nLayer != GPT2_N_LAYER) throw new Error(`ONNX conversion only supports GPT-2 with 12 layers, instead found ${gptModel.config.nLayer}.`); const gptLayersModel = gptModel.extract(); @@ -54,7 +54,14 @@ async function main() { throw new Error(`Undefined layer dimensions for ${tensor.name}`) const dims = tensor.dims.map((d) => Number(d)); const flatData = parseTensorData(tensor); - const tfTensor = tf.tensor(flatData).reshape(dims) + let tfTensor = tf.tensor(flatData).reshape(dims) + if (tensor.name === "transformer.wpe.weight") { + if (dims.length !== 2) + throw new Error(`Expected transformer.wpe.weight to be a 2D tensor, got ${dims.length}D.`); + if (dims[0] < GPT2_CONTEXT_LENGTH) + throw new Error(`ONNX positional embeddings only support context length ${dims[0]}, requested ${GPT2_CONTEXT_LENGTH}.`); + tfTensor = tfTensor.slice([0, 0], [GPT2_CONTEXT_LENGTH, dims[1]]); + } preTrainedWeights = preTrainedWeights.set(tfjsName, tfTensor); } @@ -133,4 +140,4 @@ function createWeightNameMap(): Map { } -await main().catch(console.error); \ No newline at end of file +await main().catch(console.error); From 5b5f88cc892ad582a9dd3455c27b1daa8e82f950 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 11 May 2026 15:06:08 +0200 Subject: [PATCH 18/23] aggregate inside of an epoch for llms --- cli/package.json | 1 + cli/src/args.ts | 3 + cli/src/measure_memorization_gpt2.ts | 316 +++++++++++++++++++++++ discojs/src/models/gpt/index.ts | 36 ++- discojs/src/models/gpt/model.ts | 16 +- discojs/src/task/training_information.ts | 2 + discojs/src/training/trainer.ts | 133 +++++++++- 7 files changed, 499 insertions(+), 8 deletions(-) create mode 100644 cli/src/measure_memorization_gpt2.ts diff --git a/cli/package.json b/cli/package.json index 8ed779b3f..6c4f10bb3 100644 --- a/cli/package.json +++ b/cli/package.json @@ -10,6 +10,7 @@ "train_gpt": "npm run build && node dist/train_gpt.js", "hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js", "eval_finetuned_gpt2": "npm run build && node dist/evaluate_finetuned_gpt2.js", + "measure_memorization_gpt2": "npm run build && node dist/measure_memorization_gpt2.js", "finetune_gpt": "npm run build && node dist/finetune_gpt.js", "build": "tsc --build", "test": ": nothing" diff --git a/cli/src/args.ts b/cli/src/args.ts index 984db0156..f7755b43d 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -19,6 +19,7 @@ export interface BenchmarkArguments { numberOfUsers: number epochs: number roundDuration: number + roundIterations?: number batchSize: number validationSplit: number datasetPath?: string @@ -59,6 +60,7 @@ const unsafeArgs = parse( numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 }, epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 }, roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 }, + roundIterations: { type: Number, description: 'For GPT text tasks, aggregate every N training batches without rewinding the dataset', optional: true }, batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 }, validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 }, datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true }, @@ -133,6 +135,7 @@ export const args: BenchmarkArguments = { task.trainingInformation.roundDuration = unsafeArgs.roundDuration; task.trainingInformation.epochs = unsafeArgs.epochs; task.trainingInformation.validationSplit = unsafeArgs.validationSplit; + (task.trainingInformation as typeof task.trainingInformation & { roundIterations?: number }).roundIterations = unsafeArgs.roundIterations; const {aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs; diff --git a/cli/src/measure_memorization_gpt2.ts b/cli/src/measure_memorization_gpt2.ts new file mode 100644 index 000000000..63f02d1ac --- /dev/null +++ b/cli/src/measure_memorization_gpt2.ts @@ -0,0 +1,316 @@ +import "@tensorflow/tfjs-node"; +import * as tf from "@tensorflow/tfjs"; +import fs from "node:fs/promises"; +import { parse } from "ts-command-line-args"; + +import { models, Tokenizer } from "@epfml/discojs"; +import { loadModelFromDisk } from "@epfml/discojs-node"; + +interface Args { + modelPath: string; + dataPath: string; + maxRecords: number; + promptLengths: string; + suffixLength: number; + bleuThreshold: number; + seed: number; + savePath?: string; + help?: boolean; +} + +type PromptResult = { + recordIndex: number; + promptLength: number; + splitIndex: number; + exactMatch: boolean; + bleu: number; + memorizedByBleu: boolean; + promptText: string; + referenceText: string; + generatedText: string; +}; + +function parseIntegerList(raw: string): number[] { + const values = raw + .split(",") + .map((v) => Number.parseInt(v.trim(), 10)) + .filter((v) => !Number.isNaN(v)); + + if (values.length === 0 || values.some((v) => v <= 0)) { + throw new Error("promptLengths must be a comma-separated list of positive integers"); + } + + return values; +} + +function seededRandom(seed: number): () => number { + let state = seed >>> 0; + return () => { + state = (1664525 * state + 1013904223) >>> 0; + return state / 0x100000000; + }; +} + +function randomInt(random: () => number, minInclusive: number, maxInclusive: number): number { + if (maxInclusive < minInclusive) { + throw new Error("invalid random integer range"); + } + + return minInclusive + Math.floor(random() * (maxInclusive - minInclusive + 1)); +} + +async function loadRecords(filePath: string, limit: number): Promise { + const text = await fs.readFile(filePath, "utf8"); + const delimiter = "<|endoftext|>"; + const rawRecords = text.includes(delimiter) + ? text.split(delimiter) + : text.split(/\n\s*\n/g); + + const records = rawRecords + .map((record) => + record + .replaceAll("<|startoftext|>", "") + .replaceAll("<|endoftext|>", "") + .trim(), + ) + .filter((record) => record.length > 0); + + return limit > 0 ? records.slice(0, limit) : records; +} + +function ngrams(tokens: number[], n: number): Map { + const counts = new Map(); + if (tokens.length < n) return counts; + + for (let i = 0; i <= tokens.length - n; i++) { + const key = tokens.slice(i, i + n).join(","); + counts.set(key, (counts.get(key) ?? 0) + 1); + } + + return counts; +} + +function bleu1to4(reference: number[], candidate: number[]): number { + if (candidate.length === 0) return 0; + + const precisions: number[] = []; + for (let n = 1; n <= 4; n++) { + const referenceCounts = ngrams(reference, n); + const candidateCounts = ngrams(candidate, n); + let overlap = 0; + let total = 0; + + for (const [key, count] of candidateCounts) { + overlap += Math.min(count, referenceCounts.get(key) ?? 0); + total += count; + } + + precisions.push(total === 0 ? 0 : overlap / total); + } + + if (precisions.some((precision) => precision === 0)) return 0; + + const brevityPenalty = + candidate.length > reference.length + ? 1 + : Math.exp(1 - reference.length / candidate.length); + const geometricMean = Math.exp( + precisions.reduce((sum, precision) => sum + Math.log(precision), 0) / precisions.length, + ); + + return brevityPenalty * geometricMean; +} + +async function greedyGenerateGPT2( + model: models.GPT, + inputIds: number[], + maxNewTokens: number, + maxContextLength: number, +): Promise { + const generated = [...inputIds]; + const tfModel = model.extract(); + + for (let i = 0; i < maxNewTokens; i++) { + const modelInput = generated.slice(-maxContextLength); + const input = tf.tensor2d([modelInput], [1, modelInput.length], "int32"); + + const logits = tf.tidy(() => { + const output = tfModel.predict(input); + if (Array.isArray(output)) { + return output[0] as tf.Tensor; + } + return output as tf.Tensor; + }); + + const nextTokenTensor = tf.tidy(() => { + const last = logits.slice([0, modelInput.length - 1, 0], [1, 1, -1]); + return last.squeeze().argMax(); + }); + + const nextTokenData = await nextTokenTensor.data(); + const nextToken = nextTokenData[0]; + + input.dispose(); + logits.dispose(); + nextTokenTensor.dispose(); + + generated.push(nextToken); + } + + return generated; +} + +function summarize(results: PromptResult[]) { + const byPromptLength = new Map(); + for (const result of results) { + byPromptLength.set( + result.promptLength, + [...(byPromptLength.get(result.promptLength) ?? []), result], + ); + } + + const summarizeGroup = (group: PromptResult[]) => ({ + count: group.length, + exactMatchRate: group.filter((r) => r.exactMatch).length / group.length, + bleuMemorizationRate: group.filter((r) => r.memorizedByBleu).length / group.length, + averageBleu: group.reduce((sum, r) => sum + r.bleu, 0) / group.length, + }); + + return { + overall: summarizeGroup(results), + byPromptLength: Object.fromEntries( + [...byPromptLength.entries()].map(([promptLength, group]) => [ + promptLength, + summarizeGroup(group), + ]), + ), + }; +} + +async function main() { + const args = parse( + { + modelPath: { type: String, description: "Path to a saved Disco GPT model.json" }, + dataPath: { type: String, description: "Path to records/canaries text file" }, + maxRecords: { type: Number, description: "Maximum records to evaluate; -1 for all", defaultValue: 100 }, + promptLengths: { type: String, description: "Comma-separated prompt lengths", defaultValue: "10,50,100,200,500" }, + suffixLength: { type: Number, description: "Number of suffix tokens to generate and compare", defaultValue: 50 }, + bleuThreshold: { type: Number, description: "BLEU threshold for approximate memorization", defaultValue: 0.75 }, + seed: { type: Number, description: "Random seed for choosing record split positions", defaultValue: 42 }, + savePath: { type: String, description: "Optional JSON output path", optional: true }, + help: { type: Boolean, optional: true, alias: "h", description: "Prints this usage guide" }, + }, + { + helpArg: "help", + headerContentSections: [ + { + header: "GPT-2 Unintended Memorization", + content: "Measures extractable memorization via greedy suffix generation.", + }, + ], + }, + ); + + const promptLengths = parseIntegerList(args.promptLengths); + const maxPromptLength = Math.max(...promptLengths); + const random = seededRandom(args.seed); + + console.log("Loading tokenizer..."); + const tokenizer = await Tokenizer.from_pretrained("Xenova/gpt2"); + + console.log("Loading model..."); + const loadedModel = await loadModelFromDisk(args.modelPath); + if (!(loadedModel instanceof models.GPT)) { + throw new Error("modelPath must point to a Disco GPT model"); + } + + console.log("Loading records..."); + const records = await loadRecords(args.dataPath, args.maxRecords); + console.log(`Loaded ${records.length} records`); + + const results: PromptResult[] = []; + let skipped = 0; + + for (let recordIndex = 0; recordIndex < records.length; recordIndex++) { + const record = records[recordIndex]; + const ids = tokenizer.tokenize(record).toArray(); + + if (ids.length < maxPromptLength + args.suffixLength + 1) { + skipped++; + continue; + } + + const splitIndex = randomInt( + random, + maxPromptLength, + ids.length - args.suffixLength, + ); + const reference = ids.slice(splitIndex, splitIndex + args.suffixLength); + + for (const promptLength of promptLengths) { + const prompt = ids.slice(splitIndex - promptLength, splitIndex); + const generated = await greedyGenerateGPT2( + loadedModel, + prompt, + args.suffixLength, + loadedModel.config.contextLength, + ); + const generatedSuffix = generated.slice(prompt.length, prompt.length + args.suffixLength); + const exactMatch = + generatedSuffix.length === reference.length && + generatedSuffix.every((token, i) => token === reference[i]); + const bleu = bleu1to4(reference, generatedSuffix); + + results.push({ + recordIndex, + promptLength, + splitIndex, + exactMatch, + bleu, + memorizedByBleu: bleu > args.bleuThreshold, + promptText: tokenizer.decode(prompt), + referenceText: tokenizer.decode(reference), + generatedText: tokenizer.decode(generatedSuffix), + }); + } + + if ((recordIndex + 1) % 10 === 0) { + console.log(`Processed ${recordIndex + 1}/${records.length} records`); + } + } + + if (results.length === 0) { + throw new Error("No records were long enough to evaluate"); + } + + const summary = { + config: { + modelPath: args.modelPath, + dataPath: args.dataPath, + maxRecords: args.maxRecords, + promptLengths, + suffixLength: args.suffixLength, + bleuThreshold: args.bleuThreshold, + seed: args.seed, + modelContextLength: loadedModel.config.contextLength, + }, + skippedRecords: skipped, + ...summarize(results), + }; + + console.log("\n=== Memorization Summary ==="); + console.log(JSON.stringify(summary, null, 2)); + + if (args.savePath !== undefined) { + await fs.writeFile( + args.savePath, + JSON.stringify({ summary, results }, null, 2), + ); + console.log(`Saved detailed results to ${args.savePath}`); + } +} + +main().catch((err) => { + console.error(err); + process.exitCode = 1; +}); diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts index 082f79766..458dad7e6 100644 --- a/discojs/src/models/gpt/index.ts +++ b/discojs/src/models/gpt/index.ts @@ -30,6 +30,7 @@ export class GPT extends Model<"text"> { readonly #contextLength: number; readonly #maxBatchCount: number; readonly #vocabSize: number; + #iterationCount = 0; constructor(partialConfig?: Partial, layersModel?: tf.LayersModel) { super(); @@ -62,7 +63,7 @@ export class GPT extends Model<"text"> { for await (const [batch, _] of trainingDataset.zip( Range(0, this.#maxBatchCount), )) { - const batchLogs = await this.#runBatch(batch); + const batchLogs = await this.#runBatch(batch, ++this.#iterationCount); yield batchLogs; batchesLogs = batchesLogs.push(batchLogs); @@ -75,14 +76,47 @@ export class GPT extends Model<"text"> { return new EpochLogs(batchesLogs, epochTime, validation); } + async *trainNextBatches( + trainingIterator: AsyncIterator>, + maxBatchCount: number, + validationDataset?: Dataset>, + setDone?: (done: boolean) => void, + ): AsyncGenerator { + let batchesLogs = List(); + let epochTime = performance.now(); + let done = false; + + for (let batchCount = 0; batchCount < maxBatchCount; batchCount++) { + const next = await trainingIterator.next(); + if (next.done === true) { + done = true; + break; + } + + const batchLogs = await this.#runBatch(next.value, ++this.#iterationCount); + + yield batchLogs; + batchesLogs = batchesLogs.push(batchLogs); + } + + const validation = + validationDataset && (await this.evaluate(validationDataset)); + epochTime = performance.now() - epochTime; + setDone?.(done); + + return new EpochLogs(batchesLogs, epochTime, validation); + } + async #runBatch( batch: Batched, + iterationNumber: number, ): Promise { const tfBatch = this.#batchToTF(batch); let logs: tf.Logs | undefined; await this.model.fitDataset(tf.data.array([tfBatch]), { epochs: 1, + iterationOffset: iterationNumber - 1, verbose: 0, // don't pollute callbacks: { onEpochEnd: (_, cur) => { diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index a69e498b4..f11ffbbcf 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -55,9 +55,10 @@ export class GPTModel extends tf.LayersModel { : tf.train.adam(this.config.lr) } - override async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs): Promise { + override async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs & { iterationOffset?: number }): Promise { const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> + const iterationOffset = trainingArgs.iterationOffset ?? 0 await callbacks.onTrainBegin?.() for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) { @@ -72,15 +73,18 @@ export class GPTModel extends tf.LayersModel { debug("after next of iterator") while (next.done !== true && iteration <= this.config.maxIter) { + const reportedIteration = iterationOffset + iteration let weightUpdateTime = performance.now() await callbacks.onEpochBegin?.(epoch) const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } let preprocessingTime = performance.now() - debug("await batch data before {} iteration", iteration) + // debug("await batch data before {} iteration", iteration) + debug("await batch data before {} iteration", reportedIteration) // await Promise.all([xs.data(), ys.data()]) await Promise.resolve() - debug("after await batch data {} iteration", iteration) + // debug("after await batch data {} iteration", iteration) + debug("after await batch data {} iteration", reportedIteration) preprocessingTime = performance.now() - preprocessingTime // TODO include as a tensor inside the model @@ -125,7 +129,8 @@ export class GPTModel extends tf.LayersModel { if ( evalDataset !== undefined && this.config.evaluateEvery !== undefined && - iteration % this.config.evaluateEvery == 0 + // iteration % this.config.evaluateEvery == 0 + reportedIteration % this.config.evaluateEvery == 0 ){ const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches) debug('evaluation metrics: %O', iterationLogs); @@ -133,7 +138,8 @@ export class GPTModel extends tf.LayersModel { const memory = tf.memory().numBytes / 1024 / 1024 / 1024 debug("training metrics: %O", { epoch, - iteration, + // iteration, + iteration: reportedIteration, loss, memory, allocated: tf.memory().numTensors, diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index fa01cdf2e..aaca2f117 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -62,6 +62,8 @@ export namespace TrainingInformation { // number of epochs between each weight sharing round. // e.g.if 3 then weights are shared every 3 epochs (in the distributed setting). roundDuration: z.number().positive().int(), + // for GPT text tasks, number of training batches between each weight sharing round. + roundIterations: z.number().positive().int().optional(), // fraction of data to keep for validation, note this only works for image data validationSplit: z.number().min(0).max(1), // batch size of training data diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 73cae1f7c..f46690b69 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -30,6 +30,15 @@ export interface RoundLogs { /** List of weight update norms */ export type WeightNormHistory = List>; +type IterationTrainableTextModel = Model<"text"> & { + trainNextBatches( + trainingIterator: AsyncIterator>, + maxBatchCount: number, + validationDataset?: Dataset>, + setDone?: (done: boolean) => void, + ): AsyncGenerator; +}; + function appendWeightHistory(weightNormHistory: WeightNormHistory, wc: number[]){ return wc.reduce((hist, t, i) => { const arr = hist.get(i, List()); @@ -53,6 +62,7 @@ export class Trainer { AsyncGenerator, RoundLogs>, void >; + readonly #roundIterations?: number; // Map of weight Index and weight update #weightNormHistory : WeightNormHistory = List(); #previousRoundWeights?: WeightsContainer; @@ -71,10 +81,18 @@ export class Trainer { this.#client = client; this.#roundDuration = task.trainingInformation.roundDuration; this.#epochs = task.trainingInformation.epochs; + this.#roundIterations = task.trainingInformation.roundIterations; if ("privacy" in task.trainingInformation) this.#privacy = task.trainingInformation.privacy; - if (!Number.isInteger(this.#epochs / this.#roundDuration)) + if (this.#roundIterations !== undefined && (task.dataType !== "text" || task.trainingInformation.tensorBackend !== "gpt")) + throw new Error("roundIterations is only supported for GPT text tasks"); + + if (this.#roundIterations !== undefined && (!Number.isInteger(this.#roundIterations) || this.#roundIterations < 1)) + throw new Error("roundIterations must be a positive integer"); + + // if (!Number.isInteger(this.#epochs / this.#roundDuration)) + if (this.#roundIterations === undefined && !Number.isInteger(this.#epochs / this.#roundDuration)) throw new Error( `round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`, ); @@ -98,7 +116,11 @@ export class Trainer { ); try { - this.#training = this.#runRounds(dataset, validationDataset); + // this.#training = this.#runRounds(dataset, validationDataset); + this.#training = + this.#roundIterations === undefined + ? this.#runRounds(dataset, validationDataset) + : this.#runIterationRounds(dataset, validationDataset); yield* this.#training; } finally { this.#training = undefined; @@ -151,6 +173,77 @@ export class Trainer { } } + async *#runIterationRounds( + dataset: Dataset>, + validationDataset?: Dataset>, + ): AsyncGenerator< + AsyncGenerator, RoundLogs>, + void + > { + if (this.#roundIterations === undefined) + throw new Error("roundIterations was not set"); + + for (let epoch = 0; epoch < this.#epochs; epoch++) { + const trainingIterator = dataset[Symbol.asyncIterator](); + let next = await trainingIterator.next(); + let pendingBatch: Batched | undefined = + next.done === true ? undefined : next.value; + + while (pendingBatch !== undefined) { + await this.#client.onRoundBeginCommunication(); + + this.#previousRoundWeights = new WeightsContainer(this.model.weights.weights.map(t => t.clone())); + + let firstBatch: Batched | undefined = pendingBatch; + pendingBatch = undefined; + let done = false; + const prefixedIterator: AsyncIterator> = { + next: async () => { + if (firstBatch !== undefined) { + const value = firstBatch; + firstBatch = undefined; + return { value, done: false }; + } + + return await trainingIterator.next(); + }, + }; + + yield this.#runIterationRound( + prefixedIterator, + this.#roundIterations, + validationDataset, + (roundDone) => done = roundDone, + ); + + let roundWeights = this.model.weights; + + if (this.#privacy !== undefined){ + const roundUpdate = roundWeights.sub(this.#previousRoundWeights); + const updateNorm = await Promise.all( + roundUpdate.weights.map(privacy.frobeniusNorm) + ); + this.#weightNormHistory = appendWeightHistory(this.#weightNormHistory, updateNorm); + + roundWeights = await applyOptimalPrivacy( + this.#previousRoundWeights, + roundWeights, + this.#privacy, + this.#weightNormHistory, + Number.MAX_SAFE_INTEGER, + ) + } + + const networkWeights = await this.#client.onRoundEndCommunication(roundWeights); + this.model.weights = networkWeights; + + if (done) break; + next = await trainingIterator.next(); + pendingBatch = next.done === true ? undefined : next.value; + } + } + } + async *#runRound( dataset: Dataset>, validationDataset?: Dataset>, @@ -177,6 +270,42 @@ export class Trainer { preRoundValidation: validation, }; } + + async *#runIterationRound( + datasetIterator: AsyncIterator>, + maxBatchCount: number, + validationDataset?: Dataset>, + setDone?: (done: boolean) => void, + ): AsyncGenerator, RoundLogs> { + let epochsLogs = List(); + + debug("Run iteration-based round") + + const validation = validationDataset !== undefined ? await this.model.evaluate(validationDataset) : undefined; + + const model = this.model as unknown as IterationTrainableTextModel; + if (typeof model.trainNextBatches !== "function") + throw new Error("model does not support iteration-based training"); + + const [gen, result] = async_iterator.split( + model.trainNextBatches( + datasetIterator as AsyncIterator>, + maxBatchCount, + validationDataset as Dataset> | undefined, + setDone, + ), + ); + + yield gen; + const epochLogs = await result; + epochsLogs = epochsLogs.push(epochLogs); + + return { + epochs: epochsLogs, + participants: this.#client.nbOfParticipants, + preRoundValidation: validation, + }; + } } /** ALDP-FL implementation */ From 8d409b9cfef22d8631b51ab827f96a3a5b0f76d2 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 11 May 2026 16:53:36 +0200 Subject: [PATCH 19/23] fix mem leak --- discojs/src/models/gpt/model.ts | 1 + discojs/src/training/trainer.ts | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index f11ffbbcf..c0ec63f2b 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -117,6 +117,7 @@ export class GPTModel extends tf.LayersModel { }) const gradsClipped = clipByGlobalNormObj(grads, 1) this.optimizer.applyGradients(gradsClipped) + tf.dispose(Object.values(gradsClipped)) return lossTensor }) diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index f46690b69..1a74cecbd 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -170,6 +170,9 @@ export class Trainer { // Update the local weights this.model.weights = networkWeights; + networkWeights.dispose(); + this.#previousRoundWeights.dispose(); + this.#previousRoundWeights = undefined; } } @@ -236,6 +239,9 @@ export class Trainer { const networkWeights = await this.#client.onRoundEndCommunication(roundWeights); this.model.weights = networkWeights; + networkWeights.dispose(); + this.#previousRoundWeights.dispose(); + this.#previousRoundWeights = undefined; if (done) break; next = await trainingIterator.next(); From 144b751f0bc6c09559156214468475c1cbd1421c Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Mon, 11 May 2026 17:17:18 +0200 Subject: [PATCH 20/23] change ligs --- discojs/src/models/gpt/model.ts | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index c0ec63f2b..8be9d7379 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -79,12 +79,8 @@ export class GPTModel extends tf.LayersModel { const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } let preprocessingTime = performance.now() - // debug("await batch data before {} iteration", iteration) - debug("await batch data before {} iteration", reportedIteration) - // await Promise.all([xs.data(), ys.data()]) - await Promise.resolve() - // debug("after await batch data {} iteration", iteration) - debug("after await batch data {} iteration", reportedIteration) + await Promise.all([xs.data(), ys.data()]) + // await Promise.resolve() preprocessingTime = performance.now() - preprocessingTime // TODO include as a tensor inside the model From c046e9edecae68ea0437e4738399e74e7d73b372 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 12 May 2026 15:50:21 +0200 Subject: [PATCH 21/23] change model to 256 --- discojs/src/default_tasks/privacyrun.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts index dcdaa7640..2b743b98b 100644 --- a/discojs/src/default_tasks/privacyrun.ts +++ b/discojs/src/default_tasks/privacyrun.ts @@ -32,7 +32,7 @@ export const privacyrun: TaskProvider<"text", "federated"> = { batchSize: 8, tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"), // contextLength: 1024, - contextLength: 512, + contextLength: 256, tensorBackend: 'gpt' } } @@ -43,7 +43,9 @@ export const privacyrun: TaskProvider<"text", "federated"> = { // The model should be in DiscoJS serialization format (created by onnx-converter) // const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model.json"; - const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_512.json"; + // const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_512.json"; + + const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_256.json"; try { const response = await fetch(modelUrl); From e89b9e0c020aab242b80298f92971d95fefe1c82 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 12 May 2026 17:47:51 +0200 Subject: [PATCH 22/23] back to 512 --- discojs/src/default_tasks/privacyrun.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts index 2b743b98b..37b6b951c 100644 --- a/discojs/src/default_tasks/privacyrun.ts +++ b/discojs/src/default_tasks/privacyrun.ts @@ -32,7 +32,7 @@ export const privacyrun: TaskProvider<"text", "federated"> = { batchSize: 8, tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"), // contextLength: 1024, - contextLength: 256, + contextLength: 512, tensorBackend: 'gpt' } } @@ -45,7 +45,7 @@ export const privacyrun: TaskProvider<"text", "federated"> = { // const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_512.json"; - const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_256.json"; + const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model_ctx_512.json"; try { const response = await fetch(modelUrl); From fbcca59330a9ad240731990dd34c333b2d0913a0 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 13 May 2026 01:31:27 +0200 Subject: [PATCH 23/23] fix end of training --- discojs/src/processing/index.spec.ts | 35 ++++++++++++++++++++++++++++ discojs/src/processing/index.ts | 1 + 2 files changed, 36 insertions(+) diff --git a/discojs/src/processing/index.spec.ts b/discojs/src/processing/index.spec.ts index 19aa2b47a..4bb3eec48 100644 --- a/discojs/src/processing/index.spec.ts +++ b/discojs/src/processing/index.spec.ts @@ -5,6 +5,12 @@ import { Dataset } from "../index.js"; import { preprocess } from "./index.js"; +async function arrayFromAsync(iter: AsyncIterable): Promise { + const ret: T[] = []; + for await (const e of iter) ret.push(e); + return ret; +} + describe("preprocess", () => { it("throws on missing column in tabular", async () => { const task: Task<"tabular", "local"> = { @@ -41,4 +47,33 @@ describe("preprocess", () => { expect(false, "should have thrown").to.be.true; }); + + it("drops incomplete text windows", async () => { + const task = { + id: "task", + dataType: "text", + displayInformation: { + title: "", + summary: { preview: "", overview: "" }, + }, + trainingInformation: { + tensorBackend: "gpt", + scheme: "local", + aggregationStrategy: "mean", + epochs: 1, + roundDuration: 1, + batchSize: 2, + validationSplit: 0, + contextLength: 4, + tokenizer: { + tokenize: () => [0, 1, 2, 3, 4, 5, 6], + }, + }, + } as unknown as Task<"text", "local">; + + const dataset = new Dataset(["ignored"]); + const preprocessed = await arrayFromAsync(preprocess(task, dataset)); + + expect(preprocessed.map(([tokens]) => tokens.size)).to.deep.equal([4]); + }); }); diff --git a/discojs/src/processing/index.ts b/discojs/src/processing/index.ts index 4011f40d1..fa73618d8 100644 --- a/discojs/src/processing/index.ts +++ b/discojs/src/processing/index.ts @@ -58,6 +58,7 @@ export function preprocess( .map((text) => tokenizer.tokenize(text)) .flatten() .batch(contextLength + 1, 1) + .filter((tokens) => tokens.size === contextLength + 1) .map((tokens) => [tokens.pop(), tokens.last()]) as Dataset< DataFormat.ModelEncoded[D] >;