Add documentation for using Milkdown with various frameworks

- Created a new document for using components in Milkdown.
- Added a guide for using plugins in Milkdown, including toggling plugins programmatically and listing official plugins.
- Introduced a recipe for integrating Milkdown with Angular, including installation steps and component creation.
- Added a recipe for using Milkdown with Next.js, detailing installation and component setup.
- Created a guide for integrating Milkdown with NuxtJS, including installation and component creation.
- Added a comprehensive guide for using Milkdown with React, covering both Crepe and core Milkdown usage.
- Introduced a recipe for SolidJS integration with Milkdown, including installation and component creation.
- Added a guide for using Milkdown with Svelte, detailing installation and component setup.
- Created a comprehensive guide for integrating Milkdown with Vue, covering both Crepe and core Milkdown usage.
- Added a recipe for using Milkdown with Vue2, including installation and component creation.
This commit is contained in:
2026-01-17 14:18:08 +08:00
parent 4de3dfdd8d
commit d9ab341223
381 changed files with 125356 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import {
FunctionComponent,
PromptComponentChild,
PromptElement,
PromptElementProps,
PromptFragment,
} from '../src/components/components';
/**
* JSX factory function called for any JSX element.
*
* @param type Type of the element: `type` is the function that instantiate a prompt component. We store it so that we can render the component later in the virtual prompt.
* @param props Properties of the element, with children
*/
function functionComponentFunction(
type: FunctionComponent,
props: PromptElementProps,
key?: string | number
): PromptElement {
let children: PromptComponentChild[] = [];
if (Array.isArray(props.children)) {
children = props.children;
} else if (props.children) {
children = [props.children];
}
const componentProps = { ...props, children };
if (key) {
componentProps.key = key;
}
return { type, props: componentProps };
}
/**
* JSX factory function called for any JSX fragment.
* It is used as the function when the jsx element is a fragment. It gets invoked from the reconciler when it encounters a fragment.
*/
function fragmentFunction(children: PromptComponentChild[]): PromptFragment {
return { type: 'f', children };
}
fragmentFunction.isFragmentFunction = true;
/* JSX namespace is used by TypeScript to type JSX:
* https://www.typescriptlang.org/docs/handbook/jsx.html#the-jsx-namespace
*/
export namespace JSX {
export interface IntrinsicElements {
[s: string]: unknown;
}
export interface IntrinsicAttributes {
key?: string | number;
weight?: number;
source?: unknown;
}
/* any type necessary for component prop types */
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ElementType<P = any> = FunctionComponent<P>;
export type Element = PromptElement;
export interface ElementAttributesProperty {
props: unknown;
}
export interface ElementChildrenAttribute {
children: unknown;
}
}
export { fragmentFunction as Fragment, functionComponentFunction as jsx, functionComponentFunction as jsxs };

View File

@@ -0,0 +1,163 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { DataConsumer, Dispatch, StateUpdater, TypePredicate } from './hooks';
import { TokenizerName } from '../tokenization';
import { CancellationToken } from 'vscode-languageserver-protocol';
// --------- Prompt component types
export type PromptComponentChild = PromptElement | string | number | undefined;
type PromptComponentChildren = PromptComponentChild[] | PromptComponentChild;
interface PromptAttributes {
[key: string]: unknown;
key?: string | number;
weight?: number;
source?: unknown;
}
export type PromptElementProps<P = object> = P & Readonly<PromptAttributes & { children?: PromptComponentChildren }>;
export interface ComponentContext {
/**
* Hook to manage component state that can change over time.
* @param initialState - Initial state value or function that returns initial state
* @returns A tuple containing current state and setter function
* @example
* function Counter(props: PromptElementProps, context: ComponentContext) {
* const [count, setCount] = context.useState(0);
* return <Text>Count: {count}</Text>;
* }
*/
useState<S = undefined>(): [S | undefined, Dispatch<StateUpdater<S | undefined>>];
useState<S>(initialState: S | (() => S)): [S, Dispatch<StateUpdater<S>>];
/**
* Hook to subscribe to typed external data streams with type checking.
* @param typePredicate - TypeScript type predicate function for runtime type checking
* @param consumer - Callback function that receives type-checked data
* @example
* function DataViewer(props: PromptElementProps, context: ComponentContext) {
* interface MessageData {
* message: string;
* }
*
* function isMessageData(data: unknown): data is MessageData {
* return typeof data === 'object' && data !== null &&
* 'message' in data && typeof (data as any).message === 'string';
* }
*
* context.useData(
* isMessageData,
* (data) => console.log(data.message)
* );
* }
*/
useData<T>(typePredicate: TypePredicate<T>, consumer: DataConsumer<T>): void;
}
export interface PromptFragment {
type: 'f';
children: PromptComponentChild[];
}
export interface FragmentFunction {
(children: PromptComponentChildren): PromptFragment;
}
export interface FunctionComponent<P = PromptAttributes> {
(props: PromptElementProps<P>, context: ComponentContext): PromptComponentChildren;
}
/**
* Data structure returned by prompt component functions and used by the `virtualize` function to construct a virtual prompt.
*/
export interface PromptElement<P = PromptAttributes> {
type: FunctionComponent<P> | FragmentFunction;
props: P & { children: PromptComponentChildren };
}
// --------- Prompt snapshot and rendering types
export interface PromptSnapshotNodeStatistics {
updateDataTimeMs?: number;
}
/**
* A prompt snapshot node is a node in the virtual prompt tree in its immutable form.
*/
export interface PromptSnapshotNode {
name: string;
path: string;
value?: string;
props?: PromptElementProps;
children?: PromptSnapshotNode[];
statistics: PromptSnapshotNodeStatistics;
}
export interface PromptRenderer<T extends Prompt, P extends PromptRenderOptions> {
render(snapshot: PromptSnapshotNode, options: P, cancellationToken?: CancellationToken): T;
}
export type PromptMetadata = {
renderId: number;
rendererName?: string;
tokenizer: string;
elisionTimeMs: number;
renderTimeMs: number;
updateDataTimeMs: number;
componentStatistics: ComponentStatistics[];
};
export type ComponentStatistics = {
componentPath: string;
expectedTokens?: number;
actualTokens?: number;
updateDataTimeMs?: number;
// This field is only used internally, and even tho we send it to CTS it's not telemetrized
source?: unknown;
};
type StatusOk = { status: 'ok' };
export type StatusNotOk = { status: 'cancelled' } | { status: 'error'; error: Error };
export type Status = StatusOk | StatusNotOk;
export type PromptOk = StatusOk & {
metadata: PromptMetadata;
};
type Prompt = PromptOk | StatusNotOk;
export interface PromptRenderOptions {
tokenizer?: TokenizerName;
delimiter?: string;
}
// --------- Components
type TextPromptComponentChild = string | number | undefined;
interface TextPromptElementProps extends PromptElementProps {
children?: TextPromptComponentChild[] | TextPromptComponentChild;
}
/**
* Basic component to represent text in a prompt.
*/
export function Text(props: TextPromptElementProps) {
if (props.children) {
if (Array.isArray(props.children)) {
return props.children.join('');
}
return props.children;
}
return;
}
/**
* Basic component to represent a group of components that gets elided all together or not at all.
*/
export function Chunk(props: PromptElementProps) {
return props.children;
}

View File

@@ -0,0 +1,67 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export type Dispatch<A> = (value: A) => void;
export type StateUpdater<S> = S | ((prevState: S) => S);
export class UseState {
private currentIndex: number = 0;
private stateChanged: boolean = false;
constructor(private readonly states: unknown[]) { }
useState<S = undefined>(): [S | undefined, Dispatch<StateUpdater<S | undefined>>];
useState<S>(initialState: S | (() => S)): [S, Dispatch<StateUpdater<S>>];
useState<S>(initialState?: S | (() => S)): [S | undefined, Dispatch<StateUpdater<S | undefined>>] {
const index = this.currentIndex;
// Initialize state if not exists
if (this.states[index] === undefined) {
const initial = typeof initialState === 'function' ? (initialState as () => S)() : initialState;
this.states[index] = initial;
}
const setState = (newState: StateUpdater<S | undefined>) => {
const nextState =
typeof newState === 'function' ? (newState as (prevState: S) => S)(this.states[index] as S) : newState;
this.states[index] = nextState;
this.stateChanged = true;
};
this.currentIndex++;
return [this.states[index] as S, setState];
}
hasChanged(): boolean {
return this.stateChanged;
}
}
export type TypePredicate<T> = (data: unknown) => data is T;
export type DataConsumer<T> = (data: T) => void | Promise<void>;
export class UseData {
private consumers: DataConsumer<unknown>[] = [];
constructor(private readonly measureUpdateTime: (updateTimeMs: number) => void) { }
useData<T>(typePredicate: TypePredicate<T>, consumer: DataConsumer<T>): void {
this.consumers.push((data: unknown) => {
if (typePredicate(data)) {
return consumer(data);
}
});
}
async updateData(data: unknown) {
if (this.consumers.length > 0) {
const start = performance.now();
for (const consumer of this.consumers) {
await consumer(data);
}
this.measureUpdateTime(performance.now() - start);
}
}
}

View File

@@ -0,0 +1,270 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CancellationToken } from 'vscode-languageserver-protocol';
import {
FragmentFunction,
FunctionComponent,
type ComponentContext,
type PromptComponentChild,
type PromptElement,
type PromptElementProps,
} from './components';
import { DataConsumer, Dispatch, StateUpdater, TypePredicate, UseData, UseState } from './hooks';
import { DataPipe } from './virtualPrompt';
/**
* A virtual prompt node is an in-memory representation of a prompt component in its rendered form.
* It is constructed from a `PromptElement` and contains the name of the component that it was constructed from, and resolved external context and state.
*/
export type VirtualPromptNode = {
name: string;
path: string;
props?: PromptElementProps;
children?: VirtualPromptNode[];
component?: PromptComponentChild;
lifecycle?: PromptElementLifecycle;
};
type VirtualPromptNodeChild = VirtualPromptNode | undefined;
/**
* Translate a `PromptComponentChild` object into a virtual prompt node.
*/
export class VirtualPromptReconciler {
private lifecycleData: Map<string, PromptElementLifecycleData> = new Map();
private vTree: VirtualPromptNode | undefined;
constructor(prompt: PromptElement) {
// Initial virtualization
this.vTree = this.virtualizeElement(prompt, '$', 0);
}
reconcile(cancellationToken?: CancellationToken): VirtualPromptNode | undefined {
if (!this.vTree) {
throw new Error('No tree to reconcile, make sure to pass a valid prompt');
}
if (cancellationToken?.isCancellationRequested) {
return this.vTree;
}
this.vTree = this.reconcileNode(this.vTree, '$', 0, cancellationToken);
return this.vTree;
}
private reconcileNode(
node: VirtualPromptNode,
parentNodePath: string,
nodeIndex: number,
cancellationToken?: CancellationToken
): VirtualPromptNodeChild {
// If the node has no children or does not have a lifecycle, return it as is (primitive nodes)
if (!node.children && !node.lifecycle) { return node; }
let newNode: VirtualPromptNodeChild = node;
const needsReconciliation = node.lifecycle?.isRemountRequired();
// If the node needs reconciliation, virtualize it again
if (needsReconciliation) {
const oldChildrenPaths = this.collectChildPaths(node);
newNode = this.virtualizeElement(node.component, parentNodePath, nodeIndex);
const newChildrenPaths = this.collectChildPaths(newNode);
this.cleanupState(oldChildrenPaths, newChildrenPaths);
// Otherwise, check if the children need reconciliation
} else if (node.children) {
const children: VirtualPromptNode[] = [];
for (let i = 0; i < node.children.length; i++) {
const child = node.children[i];
if (child) {
const reconciledChild = this.reconcileNode(child, node.path, i, cancellationToken);
if (reconciledChild !== undefined) {
children.push(reconciledChild);
}
}
}
newNode.children = children;
}
return newNode;
}
private virtualizeElement(
component: PromptComponentChild,
parentNodePath: string,
nodeIndex: number
): VirtualPromptNodeChild {
if (typeof component === 'undefined') {
return undefined;
}
if (typeof component === 'string' || typeof component === 'number') {
return {
name: typeof component,
path: `${parentNodePath}[${nodeIndex}]`,
props: { value: component },
component,
};
}
if (isFragmentFunction(component.type)) {
const fragment = component.type(component.props.children);
const indexIndicator = parentNodePath !== '$' ? `[${nodeIndex}]` : ``;
const componentPath = `${parentNodePath}${indexIndicator}.${fragment.type}`;
const children = fragment.children.map((c, i) => this.virtualizeElement(c, componentPath, i));
this.ensureUniqueKeys(children);
return {
name: fragment.type,
path: componentPath,
children: children.flat().filter(c => c !== undefined),
component,
};
}
return this.virtualizeFunctionComponent(parentNodePath, nodeIndex, component, component.type);
}
private virtualizeFunctionComponent(
parentNodePath: string,
nodeIndex: number,
component: PromptElement,
functionComponent: FunctionComponent
) {
const indexIndicator = component.props.key ? `["${component.props.key}"]` : `[${nodeIndex}]`;
const componentPath = `${parentNodePath}${indexIndicator}.${functionComponent.name}`;
const lifecycle = new PromptElementLifecycle(this.getOrCreateLifecycleData(componentPath));
const element = functionComponent(component.props, lifecycle);
const elementToVirtualize = Array.isArray(element) ? element : [element];
const virtualizedChildren = elementToVirtualize.map((e, i) => this.virtualizeElement(e, componentPath, i));
const children = virtualizedChildren.flat().filter(e => e !== undefined);
this.ensureUniqueKeys(children);
return {
name: functionComponent.name,
path: componentPath,
props: component.props,
children,
component,
lifecycle,
};
}
private ensureUniqueKeys(nodes: VirtualPromptNodeChild[]) {
const keyCount = new Map<string | number, number>();
for (const node of nodes) {
if (!node) { continue; }
const key = node.props?.key;
if (key) {
keyCount.set(key, (keyCount.get(key) || 0) + 1);
}
}
// Find all duplicates
const duplicates = Array.from(keyCount.entries())
.filter(([_, count]) => count > 1)
.map(([key]) => key);
if (duplicates.length > 0) {
throw new Error(`Duplicate keys found: ${duplicates.join(', ')}`);
}
}
private collectChildPaths(node: VirtualPromptNode | undefined) {
const paths: string[] = [];
if (node?.children) {
for (const child of node.children) {
if (child) {
paths.push(child.path);
paths.push(...this.collectChildPaths(child));
}
}
}
return paths;
}
private cleanupState(oldChildrenPaths: string[], newChildrenPaths: string[]) {
for (const path of oldChildrenPaths) {
if (!newChildrenPaths.includes(path)) {
this.lifecycleData.delete(path);
}
}
}
private getOrCreateLifecycleData(path: string) {
if (!this.lifecycleData.has(path)) {
this.lifecycleData.set(path, new PromptElementLifecycleData([]));
}
return this.lifecycleData.get(path)!;
}
createPipe(): DataPipe {
return {
pump: async (data: unknown) => {
await this.pumpData(data);
},
};
}
private async pumpData<T>(data: T) {
if (!this.vTree) {
throw new Error('No tree to pump data into. Pumping data before initializing?');
}
await this.recursivelyPumpData(data, this.vTree);
}
private async recursivelyPumpData<T>(data: T, node: VirtualPromptNode) {
if (!node) {
throw new Error(`Can't pump data into undefined node.`);
}
await node.lifecycle?.dataHook.updateData(data);
for (const child of node.children || []) {
await this.recursivelyPumpData(data, child);
}
}
}
class PromptElementLifecycleData {
state: unknown[];
_updateTimeMs: number;
constructor(state: unknown[]) {
this.state = state;
this._updateTimeMs = 0;
}
getUpdateTimeMsAndReset() {
const value = this._updateTimeMs;
this._updateTimeMs = 0;
return value;
}
}
class PromptElementLifecycle implements ComponentContext {
private readonly stateHook: UseState;
readonly dataHook: UseData;
constructor(readonly lifecycleData: PromptElementLifecycleData) {
this.stateHook = new UseState(lifecycleData.state);
this.dataHook = new UseData((updateTimeMs: number) => {
lifecycleData._updateTimeMs = updateTimeMs;
});
}
useState<S = undefined>(): [S | undefined, Dispatch<StateUpdater<S | undefined>>];
useState<S>(initialState: S | (() => S)): [S, Dispatch<StateUpdater<S>>];
useState<S>(initialState?: S | (() => S)): [S | undefined, Dispatch<StateUpdater<S | undefined>>] {
return this.stateHook.useState(initialState);
}
useData<T>(typePredicate: TypePredicate<T>, consumer: DataConsumer<T>): void {
this.dataHook.useData(typePredicate, consumer);
}
isRemountRequired(): boolean {
return this.stateHook.hasChanged();
}
}
function isFragmentFunction(element: FragmentFunction | FunctionComponent): element is FragmentFunction {
return typeof element === 'function' && 'isFragmentFunction' in element;
}

View File

@@ -0,0 +1,90 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import type { PromptElement, PromptSnapshotNode, Status } from './components';
import { VirtualPromptNode, VirtualPromptReconciler } from './reconciler';
import { CancellationToken } from 'vscode-languageserver-protocol';
type PromptSnapshot = Status & { snapshot: PromptSnapshotNode | undefined };
/**
* The `VirtualPrompt` class holds the in-memory representation of the prompt, and is responsible for updating it with context, and generating immutable snapshots which can be passed to a prompt renderer.
*/
export class VirtualPrompt {
private reconciler: VirtualPromptReconciler;
constructor(prompt: PromptElement) {
this.reconciler = new VirtualPromptReconciler(prompt);
}
private snapshotNode(
node: VirtualPromptNode,
cancellationToken?: CancellationToken
): PromptSnapshotNode | 'cancelled' | undefined {
if (!node) {
return;
}
if (cancellationToken?.isCancellationRequested) {
return 'cancelled';
}
const children = [];
for (const child of node.children ?? []) {
const result = this.snapshotNode(child, cancellationToken);
if (result === 'cancelled') {
return 'cancelled';
}
if (result !== undefined) {
children.push(result);
}
}
return {
value: node.props?.value?.toString(),
name: node.name,
path: node.path,
props: node.props,
children,
statistics: {
updateDataTimeMs: node.lifecycle?.lifecycleData.getUpdateTimeMsAndReset(),
},
};
}
snapshot(cancellationToken?: CancellationToken): PromptSnapshot {
try {
const vTree = this.reconciler.reconcile(cancellationToken);
if (cancellationToken?.isCancellationRequested) {
return { snapshot: undefined, status: 'cancelled' };
}
if (!vTree) {
throw new Error('Invalid virtual prompt tree');
}
const snapshotNode = this.snapshotNode(vTree, cancellationToken);
if (snapshotNode === 'cancelled' || cancellationToken?.isCancellationRequested) {
return { snapshot: undefined, status: 'cancelled' };
}
return { snapshot: snapshotNode, status: 'ok' };
} catch (e) {
return { snapshot: undefined, status: 'error', error: e as Error };
}
}
createPipe(): DataPipe {
return this.reconciler.createPipe();
}
}
/**
* A data pipe is a one-way channel to get external data into the prompt. Pumping unsupported data types into the pipe will result in no-op.
*/
export interface DataPipe {
pump(data: unknown): Promise<void>;
}

View File

@@ -0,0 +1,115 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { Chunk, PromptSnapshotNode } from './components';
/**
* Represents the context during the traversal of a prompt snapshot tree.
* This context is passed to every node and can be modified by transformers.
*/
interface WalkContext {
/**
* Context properties that can be added by custom transformers.
*/
[key: string]: unknown;
}
/**
* A function that transforms the walking context as the tree is traversed.
* Transformers are applied in sequence before visiting each node.
*
* @param node - The current node being visited
* @param parent - The parent of the current node (undefined for root)
* @param context - The current context
* @returns A new context to be used for this node and its children
*/
export type WalkContextTransformer = (
node: PromptSnapshotNode,
parent: PromptSnapshotNode | undefined,
context: WalkContext
) => WalkContext;
/**
* A utility class for traversing a prompt snapshot tree.
* The walker applies transformers to modify the context at each node
* and calls a visitor function with the transformed context.
*/
export class SnapshotWalker {
/**
* Creates a new SnapshotWalker.
*
* @param snapshot - The root node of the snapshot tree to walk
* @param transformers - Optional array of context transformers to apply during traversal
*/
constructor(
private readonly snapshot: PromptSnapshotNode,
private readonly transformers: WalkContextTransformer[] = defaultTransformers()
) { }
/**
* Walks the snapshot tree and applies the visitor function to each node.
*
* @param visitor - Function called for each node during traversal. Return false to skip traversing children.
* @param options - Optional configuration for the walk
*/
walkSnapshot(
visitor: (n: PromptSnapshotNode, parent: PromptSnapshotNode | undefined, context: WalkContext) => boolean
) {
this.walkSnapshotNode(this.snapshot, undefined, visitor, {});
}
private walkSnapshotNode(
node: PromptSnapshotNode,
parent: PromptSnapshotNode | undefined,
visitor: (n: PromptSnapshotNode, parent: PromptSnapshotNode | undefined, context: WalkContext) => boolean,
context: WalkContext
) {
// Apply all transformers to create the new context for this node
const newContext = this.transformers.reduce((ctx, transformer) => transformer(node, parent, ctx), { ...context });
// Visit the node with the transformed context
const accept = visitor(node, parent, newContext);
if (!accept) {
return;
}
// Process children with the new context
for (const child of node.children ?? []) {
this.walkSnapshotNode(child, node, visitor, newContext);
}
}
}
export function defaultTransformers(): WalkContextTransformer[] {
return [
// Weight transformer - computes the weight of the current relative to the parent
(node, _, context) => {
if (context.weight === undefined) {
context.weight = 1;
}
const weight = node.props?.weight ?? 1;
const clampedWeight = typeof weight === 'number' ? Math.max(0, Math.min(1, weight)) : 1;
return { ...context, weight: clampedWeight * (context.weight as number) };
},
// Chunk transformer
(node, _, context) => {
if (node.name === Chunk.name) {
// Initialize chunk set if it doesn't exist
const chunks = context.chunks ? new Set<string>(context.chunks as Set<string>) : new Set<string>();
// Add current node path to the set
chunks.add(node.path);
return { ...context, chunks };
}
return context;
},
// Source transformer
(node, _, context) => {
if (node.props?.source !== undefined) {
return { ...context, source: node.props.source };
}
return context;
},
];
}

View File

@@ -0,0 +1,10 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export class CopilotPromptLoadFailure extends Error {
readonly code = 'CopilotPromptLoadFailure';
constructor(message: string, cause?: unknown) {
super(message, { cause });
}
}

View File

@@ -0,0 +1,31 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as fs from 'node:fs/promises';
import path from 'node:path';
export async function readFile(filename: string): Promise<Uint8Array> {
return await fs.readFile(locateFile(filename));
}
export function locateFile(filename: string): string {
// construct a path that works both for the TypeScript source, which lives under `/src`, and for
// the transpiled JavaScript, which lives under `/dist`
return path.resolve(
path.extname(__filename) === '.ts' ? path.join(locationInPath(path.dirname(__dirname), 'src'), '..', 'dist') : locationInPath(__dirname, 'dist'),
filename
);
}
function locationInPath(filePath: string, directoryName: string): string {
let p = filePath;
while (path.basename(p) !== directoryName) {
if (path.dirname(p) === p) {
return filePath;
}
p = path.dirname(p);
}
return p;
}

View File

@@ -0,0 +1,132 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export type IndentationTree<L> = TopNode<L> | VirtualNode<L> | LineNode<L> | BlankNode<L>;
export type IndentationSubTree<L> = Exclude<IndentationTree<L>, TopNode<L>>;
interface NodeBase<L> {
label?: L;
subs: IndentationSubTree<L>[];
}
/**
* Virtual nodes represent groupings are not directly visible in indentation.
**/
export interface VirtualNode<L> extends NodeBase<L> {
type: 'virtual';
indentation: number;
}
export interface TopNode<L> extends NodeBase<L> {
type: 'top';
indentation: -1;
}
/**
* A line of source code and its sub-nodes
* */
export interface LineNode<L> extends NodeBase<L> {
type: 'line';
indentation: number;
lineNumber: number;
sourceLine: string;
}
/**
* A blank line
*/
interface BlankNode<L> extends NodeBase<L> {
type: 'blank';
lineNumber: number;
subs: never[]; // Type trick to make it easier to code
}
/** Construct a virtual node */
export function virtualNode<L>(indentation: number, subs: IndentationSubTree<L>[], label?: L): VirtualNode<L> {
return { type: 'virtual', indentation, subs, label };
}
/** Construct a line node */
export function lineNode<L>(
indentation: number,
lineNumber: number,
sourceLine: string,
subs: IndentationSubTree<L>[],
label?: L
): LineNode<L> {
if (sourceLine === '') {
throw new Error('Cannot create a line node with an empty source line');
}
return { type: 'line', indentation, lineNumber, sourceLine, subs, label };
}
/** Return a blank node */
export function blankNode(line: number): BlankNode<never> {
return { type: 'blank', lineNumber: line, subs: [] };
}
/** Return a node representing the top node */
export function topNode<L>(subs?: IndentationSubTree<L>[]): TopNode<L> {
return {
type: 'top',
indentation: -1,
subs: subs ?? [],
};
}
export function isBlank<L>(tree: IndentationTree<L>): tree is BlankNode<L> {
return tree.type === 'blank';
}
export function isLine<L>(tree: IndentationTree<L>): tree is LineNode<L> {
return tree.type === 'line';
}
export function isVirtual<L>(tree: IndentationTree<L>): tree is VirtualNode<L> {
return tree.type === 'virtual';
}
export function isTop<L>(tree: IndentationTree<L>): tree is TopNode<L> {
return tree.type === 'top';
}
/**
* Return the tree which consists of everything up to the line node with the
* given number. All later siblings of that line node, recursively, are removed.
*
* This function does not assume the line numbers appear contiguously, but will
* return anything before the numbered line, whether its line number is greater
* or not.
*
* This is destructive and modifies the tree.
*/
export function cutTreeAfterLine(tree: IndentationTree<unknown>, lineNumber: number) {
function cut(tree: IndentationTree<unknown>): boolean {
if (!isVirtual(tree) && !isTop(tree) && tree.lineNumber === lineNumber) {
tree.subs = [];
return true;
}
for (let i = 0; i < tree.subs.length; i++) {
if (cut(tree.subs[i])) {
tree.subs = tree.subs.slice(0, i + 1);
return true;
}
}
return false;
}
cut(tree);
}
/**
* A type expressing that JSON.parse(JSON.stringify(x)) === x.
*/
export type JsonStable = string | number | JsonStable[] | { [key: string]: JsonStable };
/**
* Return a deep duplicate of the tree -- this will only work if the labels can be stringified to parseable JSON.
*/
export function duplicateTree<L extends JsonStable>(tree: IndentationTree<L>): IndentationTree<L> {
return <IndentationTree<L>>JSON.parse(JSON.stringify(tree));
}

View File

@@ -0,0 +1,162 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { IndentationTree, isBlank, isLine, isTop, isVirtual, JsonStable, LineNode } from './classes';
import { foldTree } from './manipulation';
/**
* Format only the given line node, and *NOT* its subnodes.
* This essentially comprise indentation and a trailing newline.
*/
export function deparseLine<T>(node: LineNode<T>): string {
return ' '.repeat(node.indentation) + node.sourceLine + '\n';
}
/**
* Return a flat string representation of the indentation tree.
*/
export function deparseTree<L>(tree: IndentationTree<L>): string {
function accumulator(tree: IndentationTree<L>, accum: string): string {
let str = '';
if (isLine(tree)) {
str = deparseLine(tree);
} else if (isBlank(tree)) {
str = '\n';
}
return accum + str;
}
return foldTree(tree, '', accumulator, 'topDown');
}
/**
* Return a list of flat strings whose concatenation equals `deparseTree`.
* The source is cut at the lines whose labels appear in `cutAt`. In other
* words, if a node has a labelled `A` that appears in `cutAt`, then there will
* be at least three strings in the result: the concatenation of lines before
* the node `A`, the lines covered by node `A`, and lines after the node `A`.
*
* FIXME: The cuts are *not* applied recursively: If e.g. node `A` has a
* sub-node labelled `B` which is also in `cutAt`, then the result will still
* contain only a single string for node `A`.
*
*/
export function deparseAndCutTree<L>(tree: IndentationTree<L>, cutAt: L[]): { label: L | undefined; source: string }[] {
const cutAtSet = new Set(cutAt);
const cuts: { label: L | undefined; source: string }[] = [];
let curUndef = '';
// Reimplement visitTree to avoid descending into cut nodes.
function visit(tree: IndentationTree<L>) {
if (tree.label !== undefined && cutAtSet.has(tree.label)) {
if (curUndef !== '') {
cuts.push({ label: undefined, source: curUndef });
}
cuts.push({
label: tree.label,
source: deparseTree(tree),
});
curUndef = '';
} else {
if (isLine(tree)) {
curUndef += deparseLine(tree);
}
tree.subs.forEach(visit);
}
}
visit(tree);
if (curUndef !== '') {
cuts.push({ label: undefined, source: curUndef });
}
return cuts;
}
/**
* Return a readable string representation of the tree.
*
* The output is closely related to building trees using the helper functions in
* `indentation.test.ts`.
*/
export function describeTree<L>(tree: IndentationTree<L>, indent = 0): string {
const ind = ' '.repeat(indent);
if (tree === undefined) {
return 'UNDEFINED NODE';
}
let children: string;
if (tree.subs === undefined) {
children = 'UNDEFINED SUBS';
} else {
children = tree.subs.map(child => describeTree(child, indent + 2)).join(',\n');
}
if (children === '') {
children = '[]';
} else {
children = `[\n${children}\n ${ind}]`;
}
const prefix = (isVirtual(tree) || isTop(tree) ? ' ' : String(tree.lineNumber).padStart(3, ' ')) + `: ${ind}`;
const labelString = tree.label === undefined ? '' : JSON.stringify(tree.label);
if (isVirtual(tree) || isTop(tree)) {
return `${prefix}vnode(${tree.indentation}, ${labelString}, ${children})`;
} else if (isBlank(tree)) {
return `${prefix}blank(${labelString ?? ''})`;
} else {
return `${prefix}lnode(${tree.indentation}, ${labelString}, ${JSON.stringify(tree.sourceLine)}, ${children})`;
}
}
/**
* Return a string that mimics the call that would construct the tree
* This is less readable than describeTree, but useful to write code.
*/
export function encodeTree<T extends JsonStable>(tree: IndentationTree<T>, indent = ''): string {
const labelString = tree.label === undefined ? '' : `, ${JSON.stringify(tree.label)}`;
const subString =
!isBlank(tree) && tree.subs.length > 0
? `[\n${tree.subs.map(node => encodeTree(node, indent + ' ')).join(', \n')}\n${indent}]`
: '[]';
switch (tree.type) {
case 'blank':
return `${indent}blankNode(${tree.lineNumber}${labelString})`;
case 'top':
return `topNode(${subString}${labelString})`;
case 'virtual':
return `${indent}virtualNode(${tree.indentation}, ${subString}${labelString})`;
case 'line':
return `${indent}lineNode(${tree.indentation}, ${tree.lineNumber}, "${tree.sourceLine}", ${subString}${labelString})`;
}
}
/**
* Return the first line number of the given tree.
*/
export function firstLineOf<L>(tree: IndentationTree<L>): number | undefined {
if (isLine(tree) || isBlank(tree)) {
return tree.lineNumber;
}
for (const sub of tree.subs) {
const firstLine = firstLineOf(sub);
if (firstLine !== undefined) {
return firstLine;
}
}
return undefined;
}
/**
* Return the last line number of the given tree.
*/
export function lastLineOf<L>(tree: IndentationTree<L>): number | undefined {
let lastLine: number | undefined = undefined;
let i = tree.subs.length - 1;
while (i >= 0 && lastLine === undefined) {
lastLine = lastLineOf(tree.subs[i]);
i--;
}
if (lastLine === undefined && !isVirtual(tree) && !isTop(tree)) {
return tree.lineNumber;
} else {
return lastLine;
}
}

View File

@@ -0,0 +1,17 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { processJava } from './java';
import { processMarkdown } from './markdown';
import { registerLanguageSpecificParser } from './parsing';
registerLanguageSpecificParser('markdown', processMarkdown);
registerLanguageSpecificParser('java', processJava);
export * from './classes';
export * from './description';
export * from './manipulation';
export * from './parsing';

View File

@@ -0,0 +1,72 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { IndentationTree, isBlank } from './classes';
import { visitTree } from './manipulation';
import {
LabelRule,
buildLabelRules,
combineClosersAndOpeners,
flattenVirtual,
labelLines,
labelVirtualInherited,
} from './parsing';
/**
* Java labels.
*
* * package: A package declaration;
* * import: An import stament
* * comment_single: Single-line comments starting with //
* * comment_multi: Multi-line comments starting with /*, or a vnode of
* multiple single-line comments.
* * annotation: A line starting with "@". Note that fields are habitually
* declared on one line, even if they have an annotation. In this case, the
* field will have the label "annotation" rather than "member".
* * closeBrace: A closing brace alone on a line.
* * member: Anything inside a class or interface that does not have a more
* specific label.
*/
const _javaLabelRules = {
package: /^package /,
import: /^import /,
class: /\bclass /,
interface: /\binterface /,
javadoc: /^\/\*\*/,
comment_multi: /^\/\*[^*]/,
comment_single: /^\/\//,
annotation: /^@/,
opener: /^[[({]/,
closer: /^[\])}]/,
} as const;
const javaLabelRules: LabelRule<string>[] = buildLabelRules(_javaLabelRules);
/**
* processJava(parseRaw(text)) is supposed to serve as superior alternative to alternative parseTree(text, "generic")
*/
export function processJava<L>(originalTree: IndentationTree<L>): IndentationTree<L | string> {
let tree = originalTree as IndentationTree<L | string>;
labelLines(tree, javaLabelRules);
tree = combineClosersAndOpeners(tree);
tree = flattenVirtual(tree);
labelVirtualInherited(tree);
// Label all non-labelled subs of class and interface as member.
// We also relabel annotations that are direct subs of class or interface as
// member.
visitTree(
tree,
(tree: IndentationTree<L | string>) => {
if (tree.label === 'class' || tree.label === 'interface') {
for (const sub of tree.subs) {
if (!isBlank(sub) && (sub.label === undefined || sub.label === 'annotation')) {
sub.label = 'member';
}
}
}
},
'bottomUp'
);
return tree;
}

View File

@@ -0,0 +1,190 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { IndentationSubTree, IndentationTree, TopNode, isTop, isVirtual, topNode } from './classes';
/**
* Clear all labels (and their types) from the tree.
* This will modify the tree in place, or return a retyped tree.
*/
export function clearLabels<L>(tree: IndentationTree<L>): IndentationTree<never> {
visitTree(
tree,
(tree: IndentationTree<L>) => {
tree.label = undefined;
},
'bottomUp'
);
return tree as IndentationTree<never>;
}
/** clear labels if condition is true */
export function clearLabelsIf<L, S>(
tree: IndentationTree<L | S>,
condition: (arg: L | S) => arg is S
): IndentationTree<L> {
visitTree(
tree,
(tree: IndentationTree<L | S>) => {
tree.label = tree.label ? (condition(tree.label) ? undefined : tree.label) : undefined;
},
'bottomUp'
);
return tree as IndentationTree<L>;
}
export function mapLabels<L1, L2>(
tree: IndentationSubTree<L1>,
map: (arg: L1) => L2 | undefined
): IndentationSubTree<L2>;
export function mapLabels<L1, L2>(tree: TopNode<L1>, map: (arg: L1) => L2 | undefined): TopNode<L2>;
export function mapLabels<L1, L2>(tree: IndentationTree<L1>, map: (arg: L1) => L2 | undefined): IndentationTree<L2>;
/**
* Apply a type changing function to all labels.
* This will return a new, retyped tree.
* (For applying a type keeping function to a tree
* that modifies it in place, use `visitTree`.)
*/
export function mapLabels<L1, L2>(tree: IndentationTree<L1>, map: (arg: L1) => L2 | undefined): IndentationTree<L2> {
switch (tree.type) {
case 'line':
case 'virtual': {
const newSubs = tree.subs.map(sub => mapLabels(sub, map));
return { ...tree, subs: newSubs, label: tree.label ? map(tree.label) : undefined };
}
case 'blank':
return { ...tree, label: tree.label ? map(tree.label) : undefined };
case 'top':
return {
...tree,
subs: tree.subs.map(sub => mapLabels(sub, map)),
label: tree.label ? map(tree.label) : undefined,
};
}
}
/**
* Renumber the line numbers of the tree contiguously from 0 and up.
*/
export function resetLineNumbers<L>(tree: IndentationTree<L>): void {
let lineNumber = 0;
function visitor(tree: IndentationTree<L>) {
if (!isVirtual(tree) && !isTop(tree)) {
tree.lineNumber = lineNumber;
lineNumber++;
}
}
visitTree(tree, visitor, 'topDown');
}
/**
* Visit the tree with a function that is called on each node.
*
* If direction is topDown, then parents are visited before their children.
* If direction is bottomUp, children are visited in order before their parents,
* so that leaf nodes are visited first.
*/
export function visitTree<L>(
tree: IndentationTree<L>,
visitor: (tree: IndentationTree<L>) => void,
direction: 'topDown' | 'bottomUp'
): void {
function _visit(tree: IndentationTree<L>) {
if (direction === 'topDown') {
visitor(tree);
}
tree.subs.forEach(subtree => {
_visit(subtree);
});
if (direction === 'bottomUp') {
visitor(tree);
}
}
_visit(tree);
}
/**
* Visit the tree with a function that is called on each node --
* if it returns false, children are not visited (in case of topDown),
* or the parent is not visited anymore (in case of bottomUp).
*
* If direction is topDown, then parents are visited before their children.
* If direction is bottomUp, children are visited in order before their parents,
* so that leaf nodes are visited first.
*/
export function visitTreeConditionally<L>(
tree: IndentationTree<L>,
visitor: (tree: IndentationTree<L>) => boolean,
direction: 'topDown' | 'bottomUp'
): void {
// IDEA: rewrite visitTree to reuse this code
function _visit(tree: IndentationTree<L>): boolean {
if (direction === 'topDown') {
if (!visitor(tree)) {
return false;
}
}
let shouldContinue = true;
tree.subs.forEach(subtree => {
shouldContinue = shouldContinue && _visit(subtree);
});
if (direction === 'bottomUp') {
shouldContinue = shouldContinue && visitor(tree);
}
return shouldContinue;
}
_visit(tree);
}
/**
* Fold an accumulator function over the tree.
*
* If direction is topDown, then parents are visited before their children.
* If direction is bottomUp, children are visited in order before their parents,
* so that leaf nodes are visited first.
*/
export function foldTree<T, L>(
tree: IndentationTree<L>,
init: T,
accumulator: (tree: IndentationTree<L>, acc: T) => T,
direction: 'topDown' | 'bottomUp'
): T {
let acc = init;
function visitor(tree: IndentationTree<L>) {
acc = accumulator(tree, acc);
}
visitTree(tree, visitor, direction);
return acc;
}
export type Rebuilder<L> = (tree: IndentationTree<L>) => IndentationTree<L> | undefined;
/**
* Rebuild the tree from the bottom up by applying a function to each node.
* The visitor function takes a node whose children have already been rebuilt,
* and returns a new node to replace it (or undefined if it should be deleted).
* Optionally, a function can be provided to skip nodes that should just be kept
* without visiting them or their sub-nodes.
*/
export function rebuildTree<L>(
tree: IndentationTree<L>,
visitor: Rebuilder<L>,
skip?: (tree: IndentationTree<L>) => boolean
): IndentationTree<L> {
const rebuild: Rebuilder<L> = (tree: IndentationTree<L>) => {
if (skip !== undefined && skip(tree)) {
return tree;
} else {
const newSubs = tree.subs.map(rebuild).filter(sub => sub !== undefined) as IndentationSubTree<L>[];
tree.subs = newSubs;
return visitor(tree);
}
};
const rebuilt = rebuild(tree);
if (rebuilt !== undefined) {
return rebuilt;
} else {
return topNode();
}
}

View File

@@ -0,0 +1,75 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { IndentationTree, isBlank, LineNode, TopNode, VirtualNode } from './classes';
import {
buildLabelRules,
flattenVirtual,
groupBlocks,
labelLines,
LabelRule,
labelVirtualInherited,
} from './parsing';
/**
*/
const _MarkdownLabelRules = {
heading: /^# /,
subheading: /^## /,
subsubheading: /### /,
} as const;
const MarkdownLabelRules: LabelRule<string>[] = buildLabelRules(_MarkdownLabelRules);
/**
* processMarkdown(parseRaw(text)) is supposed to serve as a superior alternative to parseTree(text, "generic")
*/
export function processMarkdown<L>(originalTree: IndentationTree<L>): IndentationTree<L | string> {
let tree = originalTree as IndentationTree<L | string>;
labelLines(tree, MarkdownLabelRules);
// We'll want to refer to the tree's subs, so let the type checker know it won't be blank
if (isBlank(tree)) {
return tree;
}
// the top level is ordered according to headings / subheadings / subsubheadings
function headingLevel(sub: IndentationTree<L | string>): number | undefined {
// 0 is the tree itself, so we start at 1
if (sub.label === 'heading') { return 1; }
if (sub.label === 'subheading') { return 2; }
if (sub.label === 'subsubheading') { return 3; }
return undefined;
}
const currentHierarchy: (TopNode<L | string> | LineNode<L | string> | VirtualNode<L | string>)[] = [tree];
const oldTreeSubs = [...tree.subs];
tree.subs = [];
for (const sub of oldTreeSubs) {
const level = headingLevel(sub);
if (level === undefined || isBlank(sub)) {
currentHierarchy[currentHierarchy.length - 1].subs.push(sub);
} else {
// take care of "dangling" levels, e.g. if we have a subsubheading after a heading
while (currentHierarchy.length < level) {
currentHierarchy.push(currentHierarchy[currentHierarchy.length - 1]);
}
// add this to the parent
currentHierarchy[level - 1].subs.push(sub);
// make this the tip of the hierarchy
currentHierarchy[level] = sub;
// delete all higher levels
while (currentHierarchy.length > level + 1) {
currentHierarchy.pop();
}
}
}
// now group paragraphs
tree = groupBlocks(tree);
tree = flattenVirtual(tree);
labelVirtualInherited(tree);
return tree;
}

View File

@@ -0,0 +1,332 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import {
blankNode,
IndentationSubTree,
IndentationTree,
isBlank,
isLine,
isVirtual,
lineNode,
LineNode,
TopNode,
topNode,
virtualNode,
VirtualNode,
} from './classes';
import { clearLabelsIf, Rebuilder, rebuildTree, visitTree } from './manipulation';
/**
* Perform a raw indentation-tree parse of a string. This is completely
* language-agnostic and the returned tree is unlabeled.
*
* - Blank lines pertain to the top-most node that they may, as restricted
* by next non-blank line. So e.g.
*
* E
* e1
* e2
*
* e3
*
* Then e1.subs = [e2], and E.subs = [ e1, blank, e3 ].
*
*/
export function parseRaw(source: string): IndentationTree<never> {
const rawLines = source.split('\n');
// TODO: How to handle mix of tabs and spaces?
const indentations = rawLines.map(line => line.match(/^\s*/)![0].length);
const lines = rawLines.map(line => line.trimLeft());
function parseNode(line: number): [LineNode<never>, number] {
const [subs, nextLine] = parseSubs(line + 1, indentations[line]);
const node: LineNode<never> = lineNode(indentations[line], line, lines[line], subs);
return [node, nextLine];
}
function parseSubs(initialLine: number, parentIndentation: number): [IndentationSubTree<never>[], number] {
let sub: IndentationTree<never> | undefined;
const subs: IndentationSubTree<never>[] = [];
let line = initialLine;
let lastBlank: number | undefined = undefined;
while (line < lines.length && (lines[line] === '' || indentations[line] > parentIndentation)) {
if (lines[line] === '') {
if (lastBlank === undefined) {
lastBlank = line;
}
line += 1;
} else {
if (lastBlank !== undefined) {
for (let i = lastBlank; i < line; i++) {
subs.push(blankNode(i));
}
lastBlank = undefined;
}
[sub, line] = parseNode(line);
subs.push(sub);
}
}
// Trailing blanks are left for the grandparent
if (lastBlank !== undefined) {
line = lastBlank;
}
return [subs, line];
}
const [subs, parsedLine] = parseSubs(0, -1);
let line = parsedLine;
// Special case: trailing blank lines at end of file
while (line < lines.length && lines[line] === '') {
subs.push(blankNode(line));
line += 1;
}
if (line < lines.length) {
throw new Error(`Parsing did not go to end of file. Ended at ${line} out of ${lines.length}`);
}
return topNode(subs);
}
type LineMatcher = (sourceLine: string) => boolean;
export interface LabelRule<L> {
matches: LineMatcher;
label: L | undefined;
}
/** Labels the line elements of the tree in-place according to rules */
export function labelLines<L>(tree: IndentationTree<L>, labelRules: LabelRule<L>[]): void {
function visitor(tree: IndentationTree<L>): void {
if (isLine(tree)) {
const rule = labelRules.find(rule => rule.matches(tree.sourceLine));
if (rule) {
tree.label = rule.label;
}
}
}
visitTree(tree, visitor, 'bottomUp');
}
/**
* For each virtual node, if the node has only one non-blank sub, then label
* the virtual node as that sub.
*/
export function labelVirtualInherited<L>(tree: IndentationTree<L>): void {
function visitor(tree: IndentationTree<L>): void {
if (isVirtual(tree) && tree.label === undefined) {
const subs = tree.subs.filter(sub => !isBlank(sub));
if (subs.length === 1) {
tree.label = subs[0].label;
}
}
}
visitTree(tree, visitor, 'bottomUp');
}
/**
* Function to convert a mapped object to a list of rules.
* This allows some type magic for extracting a label type from a mapping of rules.
*/
export function buildLabelRules<L extends { [key: string]: RegExp | LineMatcher }>(ruleMap: L): LabelRule<keyof L>[] {
return (Object.keys(ruleMap) as (keyof L)[]).map(key => {
let matches: (sourceLine: string) => boolean;
if ((ruleMap[key] as RegExp).test) {
matches = sourceLine => (ruleMap[key] as RegExp).test(sourceLine);
} else {
matches = ruleMap[key] as LineMatcher;
}
return {
matches,
label: key,
};
});
}
/**
* Fills the opener and closer indentation spec
* 1. Openers alone in a line whose older sibling is a line are moved to be the first of that sibling's children,
* and their children integrated as subsequent children of their new parent.
* 2. Closers following an older sibling (maybe with blanks in between) are moved to be the last of that sibling.
* 3. If the closer in 2 has children themselves, their older siblings are wrapped in a virtual node
*/
export function combineClosersAndOpeners<L>(
tree: IndentationTree<L | 'opener' | 'closer'>
): IndentationTree<L | 'opener' | 'closer'> {
// We'll make new virtual nodes, which comprise older siblings of a closer and get a temporary label
type S = L | 'opener' | 'closer' | 'newVirtual';
const rebuilder: Rebuilder<S> = function (tree: IndentationTree<S>) {
if (
tree.subs.length === 0 ||
tree.subs.findIndex(sub => sub.label === 'closer' || sub.label === 'opener') === -1
) {
return tree;
}
const newSubs: IndentationSubTree<S>[] = [];
let lastNew: TopNode<S> | VirtualNode<S> | LineNode<S> | undefined;
for (let i = 0; i < tree.subs.length; i++) {
const sub = tree.subs[i];
const directOlderSibling = tree.subs[i - 1];
// 1. if opener whose older sibling is a line, move to first of that sibling's children
if (sub.label === 'opener' && directOlderSibling !== undefined && isLine(directOlderSibling)) {
// Move the bracket to be the last child of it
directOlderSibling.subs.push(sub);
sub.subs.forEach(sub => directOlderSibling.subs.push(sub));
sub.subs = [];
}
// 2. if a closer following an older sibling
else if (
sub.label === 'closer' &&
lastNew !== undefined &&
(isLine(sub) || isVirtual(sub)) &&
sub.indentation >= lastNew.indentation
) {
// Move intervening blanks from newSubs to lastNew.subs
let j = newSubs.length - 1;
while (j > 0 && isBlank(newSubs[j])) {
j -= 1;
}
lastNew.subs.push(...newSubs.splice(j + 1));
// 3.if the closer in 2 has children themselves, their older siblings are wrapped in a virtual node to distinguish them
// Except for leading blocks of virtual nodes which have already been wrapped that way
// i.e. take the longest initial subsequence of lastNew.subs that are all labeled 'virtual' and don't wrap those again
if (sub.subs.length > 0) {
const firstNonVirtual = lastNew.subs.findIndex(sub => sub.label !== 'newVirtual');
const subsToKeep = lastNew.subs.slice(0, firstNonVirtual);
const subsToWrap = lastNew.subs.slice(firstNonVirtual);
const wrappedSubs =
subsToWrap.length > 0 ? [virtualNode(sub.indentation, subsToWrap, 'newVirtual')] : [];
lastNew.subs = [...subsToKeep, ...wrappedSubs, sub];
} else {
lastNew.subs.push(sub);
}
} else {
// nothing to do here, just add it normally
newSubs.push(sub);
if (!isBlank(sub)) {
lastNew = sub;
}
}
}
tree.subs = newSubs;
return tree;
};
const returnTree = rebuildTree(tree, rebuilder);
clearLabelsIf<S, 'newVirtual'>(tree, (arg: S): arg is 'newVirtual' => arg === 'newVirtual');
// now returnTree does not have the helper label 'newVirtual' anymore
return returnTree as IndentationTree<L | 'opener' | 'closer'>;
}
/**
* If there are more than 1 consecutive sibling separated from others by delimiters,
* combine them into a virtual node.
* The possibly several consecutive delimiters will be put with the preceding siblings into the virtual node.
* Note that offside groupings should be done before this.
*/
export function groupBlocks<L>(
tree: IndentationTree<L>,
isDelimiter: (node: IndentationTree<L>) => boolean = isBlank,
label?: L
): IndentationTree<L> {
const rebuilder: Rebuilder<L> = function (tree: IndentationTree<L>) {
if (tree.subs.length <= 1) {
return tree;
}
const newSubs: IndentationSubTree<L>[] = [];
let nodesSinceLastFlush: IndentationSubTree<L>[] = [];
let currentBlockIndentation: number | undefined;
let lastNodeWasDelimiter = false;
// we write to nodesSinceLastDelimiter as cache
// if we have a non-delimiter after a delimiter, we flush
// to a new virtual node appended to the newSubs array
function flushBlockIntoNewSubs(
final: boolean = false // if final, only wrap in virtual if there are newSubs already
): void {
if (currentBlockIndentation !== undefined && (newSubs.length > 0 || !final)) {
const virtual = virtualNode(currentBlockIndentation, nodesSinceLastFlush, label);
newSubs.push(virtual);
} else {
nodesSinceLastFlush.forEach(node => newSubs.push(node));
}
}
for (let i = 0; i < tree.subs.length; i++) {
const sub = tree.subs[i];
const subIsDelimiter = isDelimiter(sub);
if (!subIsDelimiter && lastNodeWasDelimiter) {
flushBlockIntoNewSubs();
nodesSinceLastFlush = [];
}
lastNodeWasDelimiter = subIsDelimiter;
nodesSinceLastFlush.push(sub);
if (!isBlank(sub)) {
currentBlockIndentation = currentBlockIndentation ?? sub.indentation;
}
}
// treat the end of node like a block end, and make the virtual block if it wouldn't be a singleton
flushBlockIntoNewSubs(true);
tree.subs = newSubs;
return tree;
};
return rebuildTree(tree, rebuilder);
}
/**
* Remove unlabeled virtual nodes which either:
* - Have one or no children
* - Are the only child of their parent
* In either case, it is replaced by their children.
*/
export function flattenVirtual<L>(tree: IndentationTree<L>): IndentationTree<L> {
const rebuilder: Rebuilder<L> = function (tree) {
if (isVirtual(tree) && tree.label === undefined && tree.subs.length <= 1) {
if (tree.subs.length === 0) {
return undefined;
} else {
//tree.subs.length === 1
return tree.subs[0];
}
} else if (tree.subs.length === 1 && isVirtual(tree.subs[0]) && tree.subs[0].label === undefined) {
tree.subs = tree.subs[0].subs;
}
return tree;
};
return rebuildTree(tree, rebuilder);
}
/**
* Generic labels.
*
* * opener: A line starting with an opening parens, square bracket, or curly brace
* * closer: A line starting with a closing parens, square bracket, or curly brace
*/
const _genericLabelRules = {
opener: /^[[({]/,
closer: /^[\])}]/,
} as const;
const genericLabelRules: LabelRule<'opener' | 'closer'>[] = buildLabelRules(_genericLabelRules);
const LANGUAGE_SPECIFIC_PARSERS: { [key: string]: (raw: IndentationTree<never>) => IndentationTree<string> } = {};
/**
* Register a language-specific parser for a language.
* This should normally be called in index.ts.
*/
export function registerLanguageSpecificParser(
language: string,
parser: (raw: IndentationTree<never>) => IndentationTree<string>
): void {
LANGUAGE_SPECIFIC_PARSERS[language] = parser;
}
export function parseTree(source: string, languageId?: string): IndentationTree<string> {
const raw = parseRaw(source);
const languageSpecificParser = LANGUAGE_SPECIFIC_PARSERS[languageId ?? ''];
if (languageSpecificParser) {
return languageSpecificParser(raw);
} else {
labelLines(raw, genericLabelRules);
const processedTree = combineClosersAndOpeners(raw);
return processedTree;
}
}

View File

@@ -0,0 +1,468 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { DocumentInfo } from './prompt';
/**
* Interface for writing single-line comments in a given language.
* Does not include the terminal new-line character (i.e. for many languages,
* `end` will just be the empty string).
*/
interface CommentMarker {
start: string;
end: string;
}
interface ILanguageInfo {
readonly lineComment: CommentMarker;
/**
* if not set, defaults to the language id
*/
readonly markdownLanguageIds?: string[];
}
interface ILanguage extends ILanguageInfo {
readonly languageId: string;
}
/**
* Language files in VSCode:
* https://code.visualstudio.com/docs/languages/identifiers#_known-language-identifiers
*
* Missing below from this list are:
* Diff diff
* Git git-commit and git-rebase
* JSON json
* ShaderLab shaderlab
* Additional to that list are:
* Erlang
* Haskell
* Kotlin
* QL
* Scala
* Verilog
*
* Markdown ids from https://raw.githubusercontent.com/highlightjs/highlight.js/refs/heads/main/SUPPORTED_LANGUAGES.md
* Also refer to [vscode-copilot-chat](https://github.com/microsoft/vscode-copilot-chat/blob/main/src/util/common/languages.ts)
*/
export const languageMarkers: { [language: string]: ILanguageInfo } = {
abap: {
lineComment: { start: '"', end: '' },
markdownLanguageIds: ['abap', 'sap-abap'],
},
aspdotnet: {
lineComment: { start: '<%--', end: '--%>' },
},
bat: {
lineComment: { start: 'REM', end: '' },
},
bibtex: {
lineComment: { start: '%', end: '' },
markdownLanguageIds: ['bibtex'],
},
blade: {
lineComment: { start: '#', end: '' },
},
BluespecSystemVerilog: {
lineComment: { start: '//', end: '' },
},
c: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['c', 'h'],
},
clojure: {
lineComment: { start: ';', end: '' },
markdownLanguageIds: ['clojure', 'clj'],
},
coffeescript: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['coffeescript', 'coffee', 'cson', 'iced'],
},
cpp: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['cpp', 'hpp', 'cc', 'hh', 'c++', 'h++', 'cxx', 'hxx'],
},
csharp: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['csharp', 'cs'],
},
css: {
lineComment: { start: '/*', end: '*/' },
},
cuda: {
lineComment: { start: '//', end: '' },
},
dart: {
lineComment: { start: '//', end: '' },
},
dockerfile: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['dockerfile', 'docker'],
},
dotenv: {
lineComment: { start: '#', end: '' },
},
elixir: {
lineComment: { start: '#', end: '' },
},
erb: {
lineComment: { start: '<%#', end: '%>' },
},
erlang: {
lineComment: { start: '%', end: '' },
markdownLanguageIds: ['erlang', 'erl'],
},
fsharp: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['fsharp', 'fs', 'fsx', 'fsi', 'fsscript'],
},
go: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['go', 'golang'],
},
graphql: {
lineComment: { start: '#', end: '' },
},
groovy: {
lineComment: { start: '//', end: '' },
},
haml: {
lineComment: { start: '-#', end: '' },
},
handlebars: {
lineComment: { start: '{{!', end: '}}' },
markdownLanguageIds: ['handlebars', 'hbs', 'html.hbs', 'html.handlebars'],
},
haskell: {
lineComment: { start: '--', end: '' },
markdownLanguageIds: ['haskell', 'hs'],
},
hlsl: {
lineComment: { start: '//', end: '' },
},
html: {
lineComment: { start: '<!--', end: '-->' },
markdownLanguageIds: ['html', 'xhtml'],
},
ini: {
lineComment: { start: ';', end: '' },
},
java: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['java', 'jsp'],
},
javascript: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['javascript', 'js'],
},
javascriptreact: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['jsx'],
},
jsonc: {
lineComment: { start: '//', end: '' },
},
jsx: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['jsx'],
},
julia: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['julia', 'jl'],
},
kotlin: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['kotlin', 'kt'],
},
latex: {
lineComment: { start: '%', end: '' },
markdownLanguageIds: ['tex'],
},
legend: {
lineComment: { start: '//', end: '' },
},
less: {
lineComment: { start: '//', end: '' },
},
lua: {
lineComment: { start: '--', end: '' },
markdownLanguageIds: ['lua', 'pluto'],
},
makefile: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['makefile', 'mk', 'mak', 'make'],
},
markdown: {
lineComment: { start: '[]: #', end: '' },
markdownLanguageIds: ['markdown', 'md', 'mkdown', 'mkd'],
},
'objective-c': {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['objectivec', 'mm', 'objc', 'obj-c'],
},
'objective-cpp': {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['objectivec++', 'objc+'],
},
perl: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['perl', 'pl', 'pm'],
},
php: {
lineComment: { start: '//', end: '' },
},
powershell: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['powershell', 'ps', 'ps1'],
},
pug: {
lineComment: { start: '//', end: '' },
},
python: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['python', 'py', 'gyp'],
},
ql: {
lineComment: { start: '//', end: '' },
}, // QL is a query language for CodeQL
r: {
lineComment: { start: '#', end: '' },
},
razor: {
lineComment: { start: '<!--', end: '-->' },
markdownLanguageIds: ['cshtml', 'razor', 'razor-cshtml'],
},
ruby: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['ruby', 'rb', 'gemspec', 'podspec', 'thor', 'irb'],
},
rust: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['rust', 'rs'],
},
sass: {
lineComment: { start: '//', end: '' },
},
scala: {
lineComment: { start: '//', end: '' },
},
scss: {
lineComment: { start: '//', end: '' },
},
shellscript: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['bash', 'sh', 'zsh'],
},
slang: {
lineComment: { start: '//', end: '' },
},
slim: {
lineComment: { start: '/', end: '' },
},
solidity: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['solidity', 'sol'],
},
sql: {
lineComment: { start: '--', end: '' },
},
stylus: {
lineComment: { start: '//', end: '' },
},
svelte: {
lineComment: { start: '<!--', end: '-->' },
},
swift: {
lineComment: { start: '//', end: '' },
},
systemverilog: {
lineComment: { start: '//', end: '' },
},
terraform: {
lineComment: { start: '#', end: '' },
},
tex: {
lineComment: { start: '%', end: '' },
},
typescript: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['typescript', 'ts'],
},
typescriptreact: {
lineComment: { start: '//', end: '' },
markdownLanguageIds: ['tsx'],
},
vb: {
lineComment: { start: `'`, end: '' },
markdownLanguageIds: ['vb', 'vbscript'],
},
verilog: {
lineComment: { start: '//', end: '' },
},
'vue-html': {
lineComment: { start: '<!--', end: '-->' },
},
vue: {
lineComment: { start: '//', end: '' },
},
xml: {
lineComment: { start: '<!--', end: '-->' },
},
xsl: {
lineComment: { start: '<!--', end: '-->' },
},
yaml: {
lineComment: { start: '#', end: '' },
markdownLanguageIds: ['yaml', 'yml'],
},
};
const mdLanguageIdToLanguageId: { [markdownLanguageId: string]: string } = {};
for (const [languageId, info] of Object.entries(languageMarkers)) {
if (info.markdownLanguageIds) {
for (const mdLanguageId of info.markdownLanguageIds) {
mdLanguageIdToLanguageId[mdLanguageId] = languageId;
}
} else {
mdLanguageIdToLanguageId[languageId] = languageId;
}
}
export function mdCodeBlockLangToLanguageId(mdLanguageId: string): string | undefined {
return mdLanguageIdToLanguageId[mdLanguageId];
}
const defaultCommentMarker: CommentMarker = { start: '//', end: '' };
const dontAddLanguageMarker: string[] = [
'php', // We don't know if the file starts with `<?php` or not
'plaintext', // Doesn't admit comments
];
// prettier-ignore
const shebangLines: { [language: string]: string } = {
'html': '<!DOCTYPE html>',
'python': '#!/usr/bin/env python3',
'ruby': '#!/usr/bin/env ruby',
'shellscript': '#!/bin/sh',
'yaml': '# YAML data'
};
/**
* Determine if a line is a shebang line for a known language
* @param line The line to check
* @returns The language if it is a known shebang line, otherwise undefined
*/
export function isShebangLine(line: string): boolean {
return Object.values(shebangLines).includes(line.trim());
}
/**
* Best-effort determining whether the top of the source already contains a
* discernible language marker, in particular a shebang line
* @param languageId The string name of the language
* @returns True iff we determined a recognisable language marker
*/
// prettier-ignore
export function hasLanguageMarker({ source }: DocumentInfo): boolean {
return source.startsWith('#!') || source.startsWith('<!DOCTYPE');
}
/**
* Comment a single line of text in a given language.
* E.g. for python, turn "hello there" into "# hello there"
*
* Note: This will not behave as you expect if `text` has multiple lines. In
* that case, use {@link commentBlockAsSingles} instead.
*/
export function comment(text: string, languageId: string) {
const markers = languageMarkers[languageId] ? languageMarkers[languageId].lineComment : defaultCommentMarker;
if (markers) {
const end = markers.end === '' ? '' : ' ' + markers.end;
return `${markers.start} ${text}${end}`;
}
return '';
}
/**
* Comment a block of text using single-line comments.
*
* The returned comment block will have a trailing newline exactly when the
* input does.
*/
export function commentBlockAsSingles(text: string, languageId: string) {
if (text === '') {
// Avoid spewing out a long list of blank lines
return '';
}
const trailingNewline = text.endsWith('\n');
const lines = (trailingNewline ? text.slice(0, -1) : text).split('\n');
const commented = lines.map(line => comment(line, languageId)).join('\n');
return trailingNewline ? commented + '\n' : commented;
}
/**
* Return a one-line comment or text which describes the language of a
* document, e.g. a shebang line or a comment.
*
* @param doc The document we want the marker for.
* @returns A one-line string that describes the language.
*/
export function getLanguageMarker(doc: DocumentInfo): string {
const { languageId } = doc;
if (dontAddLanguageMarker.indexOf(languageId) === -1 && !hasLanguageMarker(doc)) {
if (languageId in shebangLines) {
return shebangLines[languageId];
} else {
return `Language: ${languageId}`;
}
}
return '';
}
/**
* Return a one-line comment containing the relative path of the document, if known.
*
* @param doc The document we want the marker for.
* @param defaultCommentMarker The comment marker to use if the language does not have one.
* @returns A one-line comment that contains the relative path of the document.
*/
export function getPathMarker(doc: DocumentInfo): string {
if (doc.relativePath) {
return `Path: ${doc.relativePath}`;
}
return '';
}
/**
* Appends a new line to a string if it does not already end with one.
*
* @param str String to append
*
* @returns A string with a new line escape character at the end.
*/
export function newLineEnded(str: string): string {
return str === '' || str.endsWith('\n') ? str : str + '\n';
}
/**
* Retrieves the language for a given language identifier.
*
* @param languageId - The identifier of the language. If undefined, defaults to 'plaintext'.
* @returns The language associated with the specified language identifier.
*/
export function getLanguage(languageId: string | undefined): ILanguage {
if (typeof languageId === 'string') {
return _getLanguage(languageId);
}
return _getLanguage('plaintext');
}
function _getLanguage(languageId: string): ILanguage {
if (languageMarkers[languageId] !== undefined) {
return { languageId, ...languageMarkers[languageId] };
} else {
return { languageId, lineComment: { start: '//', end: '' } };
}
}

View File

@@ -0,0 +1,172 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import Parser from 'web-tree-sitter';
import { CopilotPromptLoadFailure } from './error';
import { locateFile, readFile } from './fileLoader';
export enum WASMLanguage {
Python = 'python',
JavaScript = 'javascript',
TypeScript = 'typescript',
TSX = 'tsx',
Go = 'go',
Ruby = 'ruby',
CSharp = 'c-sharp',
Java = 'java',
Php = 'php',
Cpp = 'cpp',
}
const languageIdToWasmLanguageMapping: { [language: string]: WASMLanguage } = {
python: WASMLanguage.Python,
javascript: WASMLanguage.JavaScript,
javascriptreact: WASMLanguage.JavaScript,
jsx: WASMLanguage.JavaScript,
typescript: WASMLanguage.TypeScript,
typescriptreact: WASMLanguage.TSX,
go: WASMLanguage.Go,
ruby: WASMLanguage.Ruby,
csharp: WASMLanguage.CSharp,
java: WASMLanguage.Java,
php: WASMLanguage.Php,
c: WASMLanguage.Cpp,
cpp: WASMLanguage.Cpp,
};
export function isSupportedLanguageId(languageId: string): boolean {
// Temporarily disable C# support until the tree-sitter parser for it is
// fully spec-ed.
return (
languageId in languageIdToWasmLanguageMapping &&
languageId !== 'csharp' &&
languageId !== 'java' &&
languageId !== 'php' &&
languageId !== 'c' &&
languageId !== 'cpp'
);
}
export function languageIdToWasmLanguage(languageId: string): WASMLanguage {
if (!(languageId in languageIdToWasmLanguageMapping)) {
throw new Error(`Unrecognized language: ${languageId}`);
}
return languageIdToWasmLanguageMapping[languageId];
}
const languageLoadPromises = new Map<WASMLanguage, Promise<Parser.Language>>();
async function loadWasmLanguage(language: WASMLanguage): Promise<Parser.Language> {
// construct a path that works both for the TypeScript source, which lives under `/src`, and for
// the transpiled JavaScript, which lives under `/dist`
let wasmBytes;
try {
wasmBytes = await readFile(`tree-sitter-${language}.wasm`);
} catch (e: unknown) {
if (e instanceof Error && 'code' in e && typeof e.code === 'string' && e.name === 'Error') {
throw new CopilotPromptLoadFailure(`Could not load tree-sitter-${language}.wasm`, e);
}
throw e;
}
return Parser.Language.load(wasmBytes);
}
export function getLanguage(language: string): Promise<Parser.Language> {
const wasmLanguage = languageIdToWasmLanguage(language);
if (!languageLoadPromises.has(wasmLanguage)) {
// IMPORTANT: This function does not have an async signature to prevent interleaved execution
// that can cause duplicate loading of the same language during yields/awaits prior to them
// being added to the cache.
const loadedLang = loadWasmLanguage(wasmLanguage);
languageLoadPromises.set(wasmLanguage, loadedLang);
}
return languageLoadPromises.get(wasmLanguage)!;
}
class WrappedError extends Error {
constructor(message: string, cause: unknown) {
super(message, { cause });
}
}
// This method returns a tree that the user needs to call `.delete()` before going out of scope.
export async function parseTreeSitter(language: string, source: string): Promise<Parser.Tree> {
return (await parseTreeSitterIncludingVersion(language, source))[0];
}
// This method returns a tree that the user needs to call `.delete()` before going out of scope.
export async function parseTreeSitterIncludingVersion(language: string, source: string): Promise<[Parser.Tree, number]> {
// `Parser.init` needs to be called before `new Parser()` below
await Parser.init({
locateFile: (filename: string) => locateFile(filename),
});
let parser;
try {
parser = new Parser();
} catch (e: unknown) {
if (
e &&
typeof e === 'object' &&
'message' in e &&
typeof e.message === 'string' &&
e.message.includes('table index is out of bounds')
) {
throw new WrappedError(`Could not init Parse for language <${language}>`, e);
}
throw e;
}
const treeSitterLanguage = await getLanguage(language);
parser.setLanguage(treeSitterLanguage);
const parsedTree = parser.parse(source);
// Need to delete parser objects directly
parser.delete();
return [parsedTree, treeSitterLanguage.version];
}
export function getBlockCloseToken(language: string): string | null {
const wasmLanguage = languageIdToWasmLanguage(language);
switch (wasmLanguage) {
case WASMLanguage.Python:
return null;
case WASMLanguage.JavaScript:
case WASMLanguage.TypeScript:
case WASMLanguage.TSX:
case WASMLanguage.Go:
case WASMLanguage.CSharp:
case WASMLanguage.Java:
case WASMLanguage.Php:
case WASMLanguage.Cpp:
return '}';
case WASMLanguage.Ruby:
return 'end';
}
}
function innerQuery(queries: [string, Parser.Query?][], root: Parser.SyntaxNode): Parser.QueryMatch[] {
const matches = [];
for (const query of queries) {
// parse and cache query if this is the first time we've used it
if (!query[1]) {
const lang = root.tree.getLanguage();
// cache parsed query object
query[1] = lang.query(query[0]);
}
matches.push(...query[1].matches(root));
}
return matches;
}
const docstringQuery: [string, Parser.Query?] = [
`[
(class_definition (block (expression_statement (string))))
(function_definition (block (expression_statement (string))))
]`,
];
export function queryPythonIsDocstring(blockNode: Parser.SyntaxNode): boolean {
return innerQuery([docstringQuery], blockNode).length === 1;
}

View File

@@ -0,0 +1,975 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as Parser from 'web-tree-sitter';
import {
WASMLanguage,
isSupportedLanguageId,
languageIdToWasmLanguage,
parseTreeSitter,
parseTreeSitterIncludingVersion,
queryPythonIsDocstring,
} from './parse';
interface BlockParser {
isEmptyBlockStart: (text: string, offset: number) => Promise<boolean>;
/**
* Given a document prefix, offset, and a proposed completion, determines how much of the
* completion to keep in order to "finish" the following block when the completion is appended
* to the document prefix.
*
* If there is no such block, or the completion doesn't close the block, returns undefined.
*/
isBlockBodyFinished: (prefix: string, completion: string, offset: number) => Promise<number | undefined>;
/**
* Given a document text and offset, determines the beginning of current matching node.
*
* If there is no such block, returns undefined.
*/
getNodeStart: (text: string, offset: number) => Promise<number | undefined>;
}
abstract class BaseBlockParser implements BlockParser {
abstract isEmptyBlockStart(text: string, offset: number): Promise<boolean>;
constructor(
protected readonly languageId: string,
protected readonly nodeMatch: { [parent: string]: string },
/**
* A map from node types that have a block or an statement as a child
* to the field label of the child node that is a block or statement.
* For example, an if statement in a braced language.
*/
protected readonly nodeTypesWithBlockOrStmtChild: Map<string, string>
) { }
protected async getNodeMatchAtPosition<T>(
text: string,
offset: number,
cb: (nd: Parser.SyntaxNode) => T
): Promise<T | undefined> {
const tree = await parseTreeSitter(this.languageId, text);
try {
// TODO:(hponde) It seems that we have an issue if it's at the end of the block:
// https://github.com/tree-sitter/tree-sitter/issues/407
const nodeAtPos = tree.rootNode.descendantForIndex(offset);
let nodeToComplete: Parser.SyntaxNode | null = nodeAtPos;
// find target element by looking at parent of cursor node
// don't stop at node types that may have a block child, but don't actually in this
// parse tree
while (nodeToComplete) {
const blockNodeType = this.nodeMatch[nodeToComplete.type];
if (blockNodeType) {
if (!this.nodeTypesWithBlockOrStmtChild.has(nodeToComplete.type)) {
break;
}
const fieldLabel = this.nodeTypesWithBlockOrStmtChild.get(nodeToComplete.type)!;
const childToCheck =
fieldLabel === ''
? nodeToComplete.namedChildren[0]
: nodeToComplete.childForFieldName(fieldLabel);
if (childToCheck?.type === blockNodeType) {
break;
}
}
nodeToComplete = nodeToComplete.parent;
}
if (!nodeToComplete) {
// No nodes we're interested in
return;
}
return cb(nodeToComplete);
} finally {
tree.delete();
}
}
protected getNextBlockAtPosition<T>(
text: string,
offset: number,
cb: (nd: Parser.SyntaxNode) => T
): Promise<T | undefined> {
return this.getNodeMatchAtPosition(text, offset, nodeToComplete => {
// FIXME: childForFieldName always returns null
// const block = nodeToComplete.childForFieldName(fieldToComplete);
// Instead, find child nodes of the langauge's nodeMatch type for
// nodeToComplete.
// Look in reverse order, in case of nodes with multiple blocks defined,
// such as try/catch/finally.
let block = nodeToComplete.children.reverse().find(x => x.type === this.nodeMatch[nodeToComplete.type]);
if (!block) {
// child of matching type isn't defined yet
return;
}
if (this.languageId === 'python' && block.parent) {
// handle empty block's parent being the colon (!)
const parent = block.parent.type === ':' ? block.parent.parent : block.parent;
// tree-sitter handles comments in a weird way, so we need to
// consume them.
let nextComment = parent?.nextSibling;
while (nextComment && nextComment.type === 'comment') {
// next comment is inline at the end of the block
// see issue: https://github.com/tree-sitter/tree-sitter-python/issues/113
const commentInline =
nextComment.startPosition.row === block.endPosition.row &&
nextComment.startPosition.column >= block.endPosition.column;
// next comment is on subsequent line and indented > parent's indentation
// see issue: https://github.com/tree-sitter/tree-sitter-python/issues/112
const commentAtEnd =
nextComment.startPosition.row > parent!.endPosition.row &&
nextComment.startPosition.column > parent!.startPosition.column;
if (commentInline || commentAtEnd) {
block = nextComment;
nextComment = nextComment.nextSibling;
} else {
break;
}
}
}
if (block.endIndex >= block.tree.rootNode.endIndex - 1 && (block.hasError || block.parent!.hasError)) {
// TODO:(hponde) improve this logic
// block is the whole document, and has errors, most likely doc has
// preceding errors.
return;
}
// Return first block if not empty
return cb(block);
});
}
async isBlockBodyFinished(prefix: string, completion: string, offset: number): Promise<number | undefined> {
const solution = (prefix + completion).trimEnd();
const endIndex = await this.getNextBlockAtPosition(solution, offset, block => block.endIndex);
if (endIndex === undefined) {
// no block, not finished yet
return;
}
if (endIndex < solution.length) {
// descendant block is finished, stop at end of block
const lengthOfBlock = endIndex - prefix.length;
return lengthOfBlock > 0 ? lengthOfBlock : undefined;
}
}
getNodeStart(text: string, offset: number): Promise<number | undefined> {
const solution = text.trimEnd();
return this.getNodeMatchAtPosition(solution, offset, block => block.startIndex);
}
}
class RegexBasedBlockParser extends BaseBlockParser {
constructor(
languageId: string,
protected readonly blockEmptyMatch: string,
private readonly lineMatch: RegExp,
nodeMatch: { [parent: string]: string },
nodeTypesWithBlockOrStmtChild: Map<string, string>
) {
super(languageId, nodeMatch, nodeTypesWithBlockOrStmtChild);
}
private isBlockStart(line: string): boolean {
return this.lineMatch.test(line.trimStart());
}
private async isBlockBodyEmpty(text: string, offset: number): Promise<boolean> {
const res = await this.getNextBlockAtPosition(text, offset, block => {
// strip whitespace and compare with language-defined empty block
// Note that for Ruby, `block` is the closing `end` token, while for other
// languages it is the whole block, so we consider the text from the earlier of
// block.startIndex and offset, all the way up to block.endIndex.
if (block.startIndex < offset) { offset = block.startIndex; }
const blockText = text.substring(offset, block.endIndex).trim();
if (blockText === '' || blockText.replace(/\s/g, '') === this.blockEmptyMatch) {
// block is empty
return true;
}
return false;
});
return res === undefined || res;
}
async isEmptyBlockStart(text: string, offset: number): Promise<boolean> {
offset = rewindToNearestNonWs(text, offset);
return this.isBlockStart(getLineAtOffset(text, offset)) && this.isBlockBodyEmpty(text, offset);
}
}
function getLineAtOffset(text: string, offset: number): string {
const prevNewline = text.lastIndexOf('\n', offset - 1);
let nextNewline = text.indexOf('\n', offset);
if (nextNewline < 0) {
nextNewline = text.length;
}
return text.slice(prevNewline + 1, nextNewline);
}
/**
* Returns the cursor position immediately after the nearest non-whitespace
* character. If every character before offset is whitespace, returns 0.
*/
function rewindToNearestNonWs(text: string, offset: number): number {
let result = offset;
while (result > 0 && /\s/.test(text.charAt(result - 1))) {
result--;
}
return result;
}
/**
* If `nd` is only preceded by whitespace on the line where it starts, return that whitespace;
* otherwise, return undefined. The parameter `source` is the source text from which `nd` was
* parsed.
*/
function indent(nd: Parser.SyntaxNode, source: string): string | undefined {
const startIndex = nd.startIndex;
const lineStart = nd.startIndex - nd.startPosition.column;
const prefix = source.substring(lineStart, startIndex);
if (/^\s*$/.test(prefix)) {
return prefix;
}
return undefined;
}
/**
* Check if `snd` is "outdented" with respect to `fst`, that is, it starts on a later line, and
* its indentation is no greater than that of `fst`.
*/
function outdented(fst: Parser.SyntaxNode, snd: Parser.SyntaxNode, source: string): boolean {
if (snd.startPosition.row <= fst.startPosition.row) {
return false;
}
const fstIndent = indent(fst, source);
const sndIndent = indent(snd, source);
return fstIndent !== undefined && sndIndent !== undefined && fstIndent.startsWith(sndIndent);
}
class TreeSitterBasedBlockParser extends BaseBlockParser {
constructor(
languageId: string,
nodeMatch: { [parent: string]: string },
nodeTypesWithBlockOrStmtChild: Map<string, string>,
private readonly startKeywords: string[],
private readonly blockNodeType: string,
/**
* The langauge-specific node type of an empty statement, that is,
* a statement with no text except possibly the statement terminator.
* For example, `;` is an empty statement in a braced language, but
* `pass` is not in Python.
*/
private readonly emptyStatementType: string | null,
private readonly curlyBraceLanguage: boolean
) {
super(languageId, nodeMatch, nodeTypesWithBlockOrStmtChild);
}
private isBlockEmpty(block: Parser.SyntaxNode, offset: number): boolean {
let trimmed = block.text.trim();
if (this.curlyBraceLanguage) {
if (trimmed.startsWith('{')) {
trimmed = trimmed.slice(1);
}
if (trimmed.endsWith('}')) {
trimmed = trimmed.slice(0, -1);
}
trimmed = trimmed.trim();
}
if (trimmed.length === 0) {
return true;
}
// Python: Consider a block that contains only a docstring empty.
if (
this.languageId === 'python' &&
(block.parent?.type === 'class_definition' || block.parent?.type === 'function_definition') &&
block.children.length === 1 &&
queryPythonIsDocstring(block.parent)
) {
return true;
}
return false;
}
async isEmptyBlockStart(text: string, offset: number): Promise<boolean> {
if (offset > text.length) {
throw new RangeError('Invalid offset');
}
// Ensure that the cursor is at the end of a line, ignoring trailing whitespace.
for (let i = offset; i < text.length; i++) {
if (text.charAt(i) === '\n') {
break;
} else if (/\S/.test(text.charAt(i))) {
return false;
}
}
// This lets e.g. "def foo():\nâˆ" give a multiline suggestion.
offset = rewindToNearestNonWs(text, offset);
const [tree, version] = await parseTreeSitterIncludingVersion(this.languageId, text);
try {
// offset here is the cursor position immediately after a whitespace
// character, but tree-sitter expects the index of the node to search for.
// Therefore we adjust the offset when we call into tree-sitter.
const nodeAtPos = tree.rootNode.descendantForIndex(offset - 1);
if (nodeAtPos === null) {
return false;
}
// Because of rewinding to the previous non-whitespace character, nodeAtPos may be
// "}". That's not a good place to show multline ghost text.
if (this.curlyBraceLanguage && nodeAtPos.type === '}') {
return false;
}
// JS/TS: half open, empty blocks are sometimes parsed as objects
if (
(this.languageId === 'javascript' || this.languageId === 'typescript') &&
nodeAtPos.parent &&
nodeAtPos.parent.type === 'object' &&
nodeAtPos.parent.text.trim() === '{'
) {
return true;
}
// TS: a function_signature/method_signature is a prefix of a
// function_declaration/method_declaration, so if nodeAtPos is a descendant of one of
// those node types and the signature looks incomplete, return true
if (this.languageId === 'typescript') {
let currNode = nodeAtPos;
while (currNode.parent) {
if (currNode.type === 'function_signature' || currNode.type === 'method_signature') {
// if the next node is outdented, the signature is probably incomplete and
// TreeSitter may just have done some fanciful error correction, so we'll
// assume that this is really meant to be an incomplete function
const next = nodeAtPos.nextSibling;
if (next && currNode.hasError && outdented(currNode, next, text)) {
return true;
}
// if, on the other hand, there is a semicolon, then the signature is
// probably complete, and we should not show a multiline suggestion
const semicolon = currNode.children.find(c => c.type === ';');
return !semicolon && currNode.endIndex <= offset;
}
currNode = currNode.parent;
}
}
// Ignoring special cases, there are three situations when we want to return true:
//
// 1. nodeAtPos is in a block or a descendant of a block, the parent of the block is one of the node types
// in this.nodeMatch, and the block is empty.
// 2. nodeAtPos is somewhere below an ERROR node, and that ERROR node has an anonymous child
// matching one of the keywords we care about. If that ERROR node also has a block child, the
// block must be empty.
// 3. nodeAtPos is somewhere below a node type that we know can contain a block, and the block is either
// not present or empty.
let errorNode = null;
let blockNode = null;
let blockParentNode = null;
let currNode: Parser.SyntaxNode | null = nodeAtPos;
while (currNode !== null) {
if (currNode.type === this.blockNodeType) {
blockNode = currNode;
break;
}
if (this.nodeMatch[currNode.type]) {
blockParentNode = currNode;
break;
}
if (currNode.type === 'ERROR') {
errorNode = currNode;
break;
}
currNode = currNode.parent;
}
if (blockNode !== null) {
if (!blockNode.parent || !this.nodeMatch[blockNode.parent.type]) {
return false;
}
// Python: hack for unclosed docstrings. There's no rhyme or reason to how the actual
// docstring comments are parsed, but overall the parse tree looks like:
// function_definition
// - def
// - identifier
// - parameters
// - :
// - ERROR with text that starts with """ or '''
// - block
// - junk
//
// We do best effort here to detect that we're in an unclosed docstring and return true.
// Note that this won't work (we won't give a multline suggestion) if the docstring uses single
// quotes, which is allowed by the language standard but not idiomatic (see PEP 257,
// Docstring Conventions).
if (this.languageId === 'python') {
const prevSibling = blockNode.previousSibling;
if (
prevSibling !== null &&
prevSibling.hasError &&
(prevSibling.text.startsWith('"""') || prevSibling.text.startsWith(`'''`))
) {
return true;
}
}
return this.isBlockEmpty(blockNode, offset);
}
if (errorNode !== null) {
// TS: In a module such as "module 'foo' {" or internal_module such as "namespace 'foo' {"
// the open brace is parsed as an error node, like so:
// - expression_statement
// - [internal_]module
// - string
// - ERROR
if (
errorNode.previousSibling?.type === 'module' ||
errorNode.previousSibling?.type === 'internal_module' ||
errorNode.previousSibling?.type === 'def'
) {
return true;
}
// @dbaeumer The way how unfinished docstrings are handled changed in version 14 for Python.
if (this.languageId === 'python' && version >= 14) {
// In version 14 and later, we need to account for the possibility of
// an unfinished docstring being represented as an ERROR node.
if (errorNode.hasError && (errorNode.text.startsWith('"') || errorNode.text.startsWith(`'`))) {
const parentType = errorNode.parent?.type;
if (
parentType === 'function_definition' ||
parentType === 'class_definition' ||
parentType === 'module'
) {
return true;
}
}
}
// Search in reverse order so we get the latest block or keyword node.
const children = [...errorNode.children].reverse();
const keyword = children.find(child => this.startKeywords.includes(child.type));
let block = children.find(child => child.type === this.blockNodeType);
if (keyword) {
switch (this.languageId) {
case 'python': {
// Python: try-except-finally
// If the cursor is in either "except" or "finally," but the try-except-finally isn't finished,
// nodeAtPos will be parsed as an identifier. If > 4 characters of "except" or "finally" have been
// typed, it will be parsed as:
// ERROR
// - try
// - :
// - ERROR
// - block
// - expression_statement
// - identifier
//
// In this case, we have to special-case finding the right block to check whether it's empty.
if (keyword.type === 'try' && nodeAtPos.type === 'identifier' && nodeAtPos.text.length > 4) {
block = children
.find(child => child.hasError)
?.children.find(child => child.type === 'block');
}
// Python: sometimes nodes that are morally part of a block are parsed as statements
// that are all children of an ERROR node. Detect this by looking for ":" and inspecting
// its nextSibling. Skip over ":" inside parentheses because those could be part of a
// typed parameter.
let colonNode;
let parenCount = 0;
for (const child of errorNode.children) {
if (child.type === ':' && parenCount === 0) {
colonNode = child;
break;
}
if (child.type === '(') {
parenCount += 1;
}
if (child.type === ')') {
parenCount -= 1;
}
}
if (colonNode && keyword.endIndex <= colonNode.startIndex && colonNode.nextSibling) {
// horrible hack to handle unfinished docstrings :(
if (keyword.type === 'def') {
const sibling = colonNode.nextSibling;
if (sibling.type === '"' || sibling.type === `'`) {
return true;
}
if (sibling.type === 'ERROR' && (sibling.text === '"""' || sibling.text === `'''`)) {
return true;
}
}
return false;
}
break;
}
case 'javascript': {
// JS: method definition within a class, e.g. "class C { foo()"
if (keyword.type === 'class') {
if (version <= 13) {
const formalParameters = children.find(child => child.type === 'formal_parameters');
if (formalParameters) {
return true;
}
} else {
const children = errorNode.children;
for (let i = 0; i < children.length; i++) {
const child = children[i];
if (child.type === 'formal_parameters') {
return (
i + 1 === children.length ||
(children[i + 1]?.type === '{' && i + 2 === children.length)
);
}
}
}
}
// JS: Don't mistake a half-open curly brace after a keyword under an error node for an empty
// block. If it has a nextSibling, then it's not empty. e.g. in "do {\n\t;â–ˆ", the ";" is an
// empty_statement and the nextSibling of the "{".
const leftCurlyBrace = children.find(child => child.type === '{');
if (
leftCurlyBrace &&
leftCurlyBrace.startIndex > keyword.endIndex &&
leftCurlyBrace.nextSibling !== null
) {
return false;
}
// JS: do-while: don't give a multline suggestion after the "while" keyword
const doNode = children.find(child => child.type === 'do');
if (doNode && keyword.type === 'while') {
return false;
}
// JS: In an arrow function, if there is a next sibling of the arrow and it's not an open brace, we're not in a
// block context and we should return false.
if (keyword.type === '=>' && keyword.nextSibling && keyword.nextSibling.type !== '{') {
return false;
}
break;
}
case 'typescript': {
// TS: Don't mistake a half-open curly brace after a keyword under an error node for an empty
// block. If it has a nextSibling, then it's not empty. e.g. in "do {\n\t;â–ˆ", the ";" is an
// empty_statement and the nextSibling of the "{".
const leftCurlyBrace = children.find(child => child.type === '{');
if (
leftCurlyBrace &&
leftCurlyBrace.startIndex > keyword.endIndex &&
leftCurlyBrace.nextSibling !== null
) {
return false;
}
// TS: do-while: don't give a multline suggestion after the "while" keyword
const doNode = children.find(child => child.type === 'do');
if (doNode && keyword.type === 'while') {
return false;
}
// TS: In an arrow function, if there is a next sibling of the arrow and it's not an open brace, we're not in a
// block context and we should return false.
if (keyword.type === '=>' && keyword.nextSibling && keyword.nextSibling.type !== '{') {
return false;
}
break;
}
}
if (block && block.startIndex > keyword.endIndex) {
return this.isBlockEmpty(block, offset);
}
return true;
}
}
if (blockParentNode !== null) {
const expectedType = this.nodeMatch[blockParentNode.type];
const block = blockParentNode.children
.slice()
.reverse()
.find(x => x.type === expectedType);
if (!block) {
// Some node types have a child that is either a block or a statement, e.g. "if (foo)".
// If the user has started typing a non-block statement, then this is not the start of an
// empty block.
if (this.nodeTypesWithBlockOrStmtChild.has(blockParentNode.type)) {
const fieldLabel = this.nodeTypesWithBlockOrStmtChild.get(blockParentNode.type)!;
const child =
fieldLabel === ''
? blockParentNode.children[0]
: blockParentNode.childForFieldName(fieldLabel);
if (child && child.type !== this.blockNodeType && child.type !== this.emptyStatementType) {
return false;
}
}
return true;
} else {
return this.isBlockEmpty(block, offset);
}
}
return false;
} finally {
tree.delete();
}
}
}
const wasmLanguageToBlockParser: { [languageId in WASMLanguage]: BlockParser } = {
python: new TreeSitterBasedBlockParser(
/* languageId */ 'python',
/* nodeMatch */ {
// Generated with script/tree-sitter-super-types tree-sitter-python block
class_definition: 'block',
elif_clause: 'block',
else_clause: 'block',
except_clause: 'block',
finally_clause: 'block',
for_statement: 'block',
function_definition: 'block',
if_statement: 'block',
try_statement: 'block',
while_statement: 'block',
with_statement: 'block',
},
/* nodeTypesWithBlockOrStmtChild */ new Map(),
/* startKeywords */['def', 'class', 'if', 'elif', 'else', 'for', 'while', 'try', 'except', 'finally', 'with'],
/* blockNodeType */ 'block',
/* emptyStatementType */ null,
/* curlyBraceLanguage */ false
),
javascript: new TreeSitterBasedBlockParser(
/* languageId */ 'javascript',
/* nodeMatch */ {
// Generated with script/tree-sitter-super-types tree-sitter-javascript statement_block
arrow_function: 'statement_block',
catch_clause: 'statement_block',
do_statement: 'statement_block',
else_clause: 'statement_block',
finally_clause: 'statement_block',
for_in_statement: 'statement_block',
for_statement: 'statement_block',
function: 'statement_block',
function_expression: 'statement_block',
function_declaration: 'statement_block',
generator_function: 'statement_block',
generator_function_declaration: 'statement_block',
if_statement: 'statement_block',
method_definition: 'statement_block',
try_statement: 'statement_block',
while_statement: 'statement_block',
with_statement: 'statement_block',
// Generated with script/tree-sitter-super-types tree-sitter-javascript class_body
class: 'class_body',
class_declaration: 'class_body',
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
['arrow_function', 'body'],
['do_statement', 'body'],
['else_clause', ''],
['for_in_statement', 'body'],
['for_statement', 'body'],
['if_statement', 'consequence'],
['while_statement', 'body'],
['with_statement', 'body'],
]),
/* startKeywords */[
'=>',
'try',
'catch',
'finally',
'do',
'for',
'if',
'else',
'while',
'with',
'function',
'function*',
'class',
],
/* blockNodeType */ 'statement_block',
/* emptyStatementType */ 'empty_statement',
/* curlyBraceLanguage */ true
),
typescript: new TreeSitterBasedBlockParser(
/* languageId */ 'typescript',
/* nodeMatch */ {
// Generated with script/tree-sitter-super-types tree-sitter-typescript/typescript statement_block
ambient_declaration: 'statement_block',
arrow_function: 'statement_block',
catch_clause: 'statement_block',
do_statement: 'statement_block',
else_clause: 'statement_block',
finally_clause: 'statement_block',
for_in_statement: 'statement_block',
for_statement: 'statement_block',
function: 'statement_block',
function_expression: 'statement_block',
function_declaration: 'statement_block',
generator_function: 'statement_block',
generator_function_declaration: 'statement_block',
if_statement: 'statement_block',
internal_module: 'statement_block',
method_definition: 'statement_block',
module: 'statement_block',
try_statement: 'statement_block',
while_statement: 'statement_block',
// Generated with script/tree-sitter-super-types tree-sitter-typescript/typescript class_body
abstract_class_declaration: 'class_body',
class: 'class_body',
class_declaration: 'class_body',
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
['arrow_function', 'body'],
['do_statement', 'body'],
['else_clause', ''],
['for_in_statement', 'body'],
['for_statement', 'body'],
['if_statement', 'consequence'],
['while_statement', 'body'],
['with_statement', 'body'],
]),
/* startKeywords */[
'declare',
'=>',
'try',
'catch',
'finally',
'do',
'for',
'if',
'else',
'while',
'with',
'function',
'function*',
'class',
],
/* blockNodeType */ 'statement_block',
/* emptyStatementType */ 'empty_statement',
/* curlyBraceLanguage */ true
),
tsx: new TreeSitterBasedBlockParser(
/* languageId */ 'typescriptreact',
/* nodeMatch */ {
// Generated with script/tree-sitter-super-types tree-sitter-typescript/typescript statement_block
ambient_declaration: 'statement_block',
arrow_function: 'statement_block',
catch_clause: 'statement_block',
do_statement: 'statement_block',
else_clause: 'statement_block',
finally_clause: 'statement_block',
for_in_statement: 'statement_block',
for_statement: 'statement_block',
function: 'statement_block',
function_expression: 'statement_block',
function_declaration: 'statement_block',
generator_function: 'statement_block',
generator_function_declaration: 'statement_block',
if_statement: 'statement_block',
internal_module: 'statement_block',
method_definition: 'statement_block',
module: 'statement_block',
try_statement: 'statement_block',
while_statement: 'statement_block',
// Generated with script/tree-sitter-super-types tree-sitter-typescript/typescript class_body
abstract_class_declaration: 'class_body',
class: 'class_body',
class_declaration: 'class_body',
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
['arrow_function', 'body'],
['do_statement', 'body'],
['else_clause', ''],
['for_in_statement', 'body'],
['for_statement', 'body'],
['if_statement', 'consequence'],
['while_statement', 'body'],
['with_statement', 'body'],
]),
/* startKeywords */[
'declare',
'=>',
'try',
'catch',
'finally',
'do',
'for',
'if',
'else',
'while',
'with',
'function',
'function*',
'class',
],
/* blockNodeType */ 'statement_block',
/* emptyStatementType */ 'empty_statement',
/* curlyBraceLanguage */ true
),
go: new RegexBasedBlockParser(
/* languageId */ 'go',
/* blockEmptyMatch */ '{}',
/* lineMatch */ /\b(func|if|else|for)\b/,
/* nodeMatch */ {
// Generated with script/tree-sitter-super-types tree-sitter-go block
communication_case: 'block',
default_case: 'block',
expression_case: 'block',
for_statement: 'block',
func_literal: 'block',
function_declaration: 'block',
if_statement: 'block',
labeled_statement: 'block',
method_declaration: 'block',
type_case: 'block',
},
/* nodeTypesWithBlockOrStmtChild */ new Map() // Go always requires braces
),
ruby: new RegexBasedBlockParser(
/* languageId */ 'ruby',
/* blockEmptyMatch */ 'end',
// Regex \b matches word boundaries - `->{}` has no word boundary.
/* lineMatch */ /\b(BEGIN|END|case|class|def|do|else|elsif|for|if|module|unless|until|while)\b|->/,
/* nodeMatch */ {
// Ruby works differently from other languages because there is no
// block-level node, instead we use the literal 'end' node to
// represent the end of a block.
begin_block: '}',
block: '}',
end_block: '}',
lambda: 'block',
for: 'do',
until: 'do',
while: 'do',
case: 'end',
do: 'end',
if: 'end',
method: 'end',
module: 'end',
unless: 'end',
do_block: 'end',
},
// TODO(eaftan): Scour Ruby grammar for these
/* nodeTypesWithBlockOrStmtChild */ new Map()
),
'c-sharp': new TreeSitterBasedBlockParser(
/* languageId */ 'csharp',
/* nodeMatch */ {
// TODO -- unused in the current usage.
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
// TODO -- unused in the current usage.
]),
/* startKeywords */[
// TODO -- unused in the current usage.
],
/* blockNodeType */ 'block',
/* emptyStatementType */ null,
/* curlyBraceLanguage */ true
),
java: new TreeSitterBasedBlockParser(
/* languageId */ 'java',
/* nodeMatch */ {
// TODO -- unused in the current usage.
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
// TODO -- unused in the current usage.
]),
/* startKeywords */[
// TODO -- unused in the current usage.
],
/* blockNodeType */ 'block',
/* emptyStatementType */ null,
/* curlyBraceLanguage */ true
),
php: new TreeSitterBasedBlockParser(
/* languageId */ 'php',
/* nodeMatch */ {
// TODO -- unused in the current usage.
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
// TODO -- unused in the current usage.
]),
/* startKeywords */[
// TODO -- unused in the current usage.
],
/* blockNodeType */ 'block',
/* emptyStatementType */ null,
/* curlyBraceLanguage */ true
),
cpp: new TreeSitterBasedBlockParser(
/* languageId */ 'cpp',
/* nodeMatch */ {
// TODO -- unused in the current usage.
},
/* nodeTypesWithBlockOrStmtChild */ new Map([
// TODO -- unused in the current usage.
]),
/* startKeywords */[
// TODO -- unused in the current usage.
],
/* blockNodeType */ 'block',
/* emptyStatementType */ null,
/* curlyBraceLanguage */ true
),
};
export function getBlockParser(languageId: string): BlockParser {
if (!isSupportedLanguageId(languageId)) {
throw new Error(`Language ${languageId} is not supported`);
}
return wasmLanguageToBlockParser[languageIdToWasmLanguage(languageId)];
}
export async function isEmptyBlockStart(languageId: string, text: string, offset: number) {
if (!isSupportedLanguageId(languageId)) {
return false;
}
return getBlockParser(languageId).isEmptyBlockStart(text, offset);
}
export async function isBlockBodyFinished(languageId: string, prefix: string, completion: string, offset: number) {
if (!isSupportedLanguageId(languageId)) {
return undefined;
}
return getBlockParser(languageId).isBlockBodyFinished(prefix, completion, offset);
}
export async function getNodeStart(languageId: string, text: string, offset: number) {
if (!isSupportedLanguageId(languageId)) {
return;
}
return getBlockParser(languageId).getNodeStart(text, offset);
}

View File

@@ -0,0 +1,99 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { SimilarFilesOptions } from './snippetInclusion/similarFiles';
/**
* Default PromptOptions are defined as constants to ensure the same values are shared
* between:
* - the class constructor
* - the EXP default flags
*/
/** The maximum number of tokens in a completion. */
export const DEFAULT_MAX_COMPLETION_LENGTH = 500;
/** The maximum number of tokens in a prompt. */
export const DEFAULT_MAX_PROMPT_LENGTH = 8192 - DEFAULT_MAX_COMPLETION_LENGTH;
/** The maximal number of the final snippets to return. */
export const DEFAULT_NUM_SNIPPETS = 4;
/**
* The default threshold for choosing a cached suffix.
*/
export const DEFAULT_SUFFIX_MATCH_THRESHOLD = 10;
/* The default allocation of the prompt to different components */
export const DEFAULT_PROMPT_ALLOCATION_PERCENT = {
prefix: 35,
suffix: 15,
stableContext: 35,
volatileContext: 15,
} as const;
export type PromptComponentId = keyof typeof DEFAULT_PROMPT_ALLOCATION_PERCENT;
export type PromptComponentAllocation = Record<PromptComponentId, number>;
/**
* Information about a document, not including the offset.
*/
export interface DocumentInfo {
/** The file path of the document relative to its containing project or folder, if known. */
relativePath?: string;
/** The URI of the document. We can't pass URI class instances directly due to limitations of passing objects to the worker thread. */
uri: string;
/** The source text of the document. */
source: string;
/** The language identifier of the document. */
languageId: string;
}
/**
* Information about a document, including an offset corresponding to
* the cursor position.
*/
export interface DocumentInfoWithOffset extends DocumentInfo {
/** The offset in the document where we want the completion (0-indexed, between characters). */
offset: number;
}
/**
* Information about a similar file.
*/
export type SimilarFileInfo = Omit<DocumentInfo, 'languageId'>;
export type PromptOptions = {
/** The maximum prompt length in tokens */
maxPromptLength: number;
/** The number of snippets to include */
numberOfSnippets: number;
/** The percent of `maxPromptLength` to reserve for the suffix */
suffixPercent: number;
/** The threshold (in percent) for declaring match of new suffix with existing suffix */
suffixMatchThreshold: number;
/** The default parameters for the similar-files provider, for any language. */
similarFilesOptions: SimilarFilesOptions;
};
/**
* A map that normalises common aliases of languageIds.
*/
const languageNormalizationMap: { [language: string]: string } = {
javascriptreact: 'javascript',
jsx: 'javascript',
typescriptreact: 'typescript',
jade: 'pug',
cshtml: 'razor',
c: 'cpp',
};
/**
* Return a normalized form of a language id, by lower casing and combining
* certain languageId's that are not considered distinct by promptlib.
*/
export function normalizeLanguageId(languageId: string): string {
languageId = languageId.toLowerCase();
return languageNormalizationMap[languageId] ?? languageId;
}

View File

@@ -0,0 +1,94 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
/**
* Cursor contexts used by snippet providers, e.g. similar files.
*
* A 'cursor context' is quite similar to a prompt, but it is meant as a more
* basic, lightweight and ultimately myopic look at what the user is currently doing.
*/
import { DocumentInfoWithOffset } from '../prompt';
import { getTokenizer, TokenizerName } from '../tokenization';
/**
* Options for cursor context generation.
*/
type CursorContextOptions = {
/** The maximum cursor context length in tokens */
maxTokenLength?: number;
/** The maximum number of lines in a cursor context */
maxLineCount?: number;
/** TokenizerName for the tokenization */
tokenizerName: TokenizerName;
};
const defaultCursorContextOptions: CursorContextOptions = {
tokenizerName: TokenizerName.o200k,
};
function cursorContextOptions(options?: Partial<CursorContextOptions>): CursorContextOptions {
return { ...defaultCursorContextOptions, ...options };
}
export interface CursorContextInfo {
/** The compiled context as a string */
context: string;
/** The number of tokens in the context */
tokenLength: number;
/** The number of lines in the context */
lineCount: number;
/** TokenizerName for the tokenization */
tokenizerName: TokenizerName;
}
/**
* Return a cursor context corresponding to this document info.
* This is essentially a trimmed-down version of a prompt.
*
* If maxLineCount or maxTokenLength are 0, an empty context is returned
* If exactly one of `maxLineCount` or `maxTokenLength` is defined, the limit is applied for that one only
* If both are defined, we apply both conditions so end up using the shorter of the two constraints
* If both are undefined, the entire document up to the cursor is returned
*/
export function getCursorContext(
doc: DocumentInfoWithOffset,
options: Partial<CursorContextOptions> = {}
): CursorContextInfo {
const completeOptions = cursorContextOptions(options);
const tokenizer = getTokenizer(completeOptions.tokenizerName);
if (completeOptions.maxLineCount !== undefined && completeOptions.maxLineCount < 0) {
throw new Error('maxLineCount must be non-negative if defined');
}
if (completeOptions.maxTokenLength !== undefined && completeOptions.maxTokenLength < 0) {
throw new Error('maxTokenLength must be non-negative if defined');
}
if (completeOptions.maxLineCount === 0 || completeOptions.maxTokenLength === 0) {
return {
context: '',
lineCount: 0,
tokenLength: 0,
tokenizerName: completeOptions.tokenizerName,
};
}
let context = doc.source.slice(0, doc.offset); // Trim to cursor location, offset is a character location
if (completeOptions.maxLineCount !== undefined) {
context = context.split('\n').slice(-completeOptions.maxLineCount).join('\n');
}
if (completeOptions.maxTokenLength !== undefined) {
context = tokenizer.takeLastLinesTokens(context, completeOptions.maxTokenLength);
}
return {
context,
lineCount: context.split('\n').length,
tokenLength: tokenizer.tokenLength(context),
tokenizerName: completeOptions.tokenizerName,
};
}

View File

@@ -0,0 +1,56 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { DocumentInfoWithOffset } from '../prompt';
import { CursorContextInfo, getCursorContext } from './cursorContext';
import { WindowedMatcher } from './selectRelevance';
import { getBasicWindowDelineations } from './windowDelineations';
export class FixedWindowSizeJaccardMatcher extends WindowedMatcher {
private windowLength: number;
private constructor(referenceDoc: DocumentInfoWithOffset, windowLength: number) {
super(referenceDoc);
this.windowLength = windowLength;
}
static FACTORY = (windowLength: number) => {
return {
to: (referenceDoc: DocumentInfoWithOffset) => new FixedWindowSizeJaccardMatcher(referenceDoc, windowLength),
};
};
protected id(): string {
return 'fixed:' + this.windowLength;
}
protected getWindowsDelineations(lines: string[]): [number, number][] {
return getBasicWindowDelineations(this.windowLength, lines);
}
protected _getCursorContextInfo(referenceDoc: DocumentInfoWithOffset): CursorContextInfo {
return getCursorContext(referenceDoc, {
maxLineCount: this.windowLength,
});
}
protected similarityScore(a: Set<string>, b: Set<string>): number {
return computeScore(a, b);
}
}
/**
* Compute the Jaccard metric of number of elements in the intersection
* divided by number of elements in the union
*/
export function computeScore(a: Set<string>, b: Set<string>) {
const intersection = new Set();
a.forEach(x => {
if (b.has(x)) {
intersection.add(x);
}
});
return intersection.size / (a.size + b.size - intersection.size);
}

View File

@@ -0,0 +1,397 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { DocumentInfo, DocumentInfoWithOffset, SimilarFileInfo } from '../prompt';
import { CursorContextInfo } from './cursorContext';
import { SnippetProviderType, SnippetSemantics, SnippetWithProviderInfo } from './snippets';
class FifoCache<T> {
private keys: string[] = [];
private cache: { [key: string]: T } = {};
private size: number;
constructor(size: number) {
this.size = size;
}
put(key: string, value: T) {
this.cache[key] = value;
if (this.keys.length > this.size) {
this.keys.push(key);
const leavingKey = this.keys.shift() ?? '';
delete this.cache[leavingKey];
}
}
get(key: string): T | undefined {
return this.cache[key];
}
}
export interface ScoredSnippetMarker {
score: number;
startLine: number;
endLine: number;
}
/**
* A snippet of code together with a relevance score.
*
* The scoring system assumes that a snippet with a **bigger** score is **more** relevant.
*/
export interface ScoredSnippet extends ScoredSnippetMarker {
snippet: string;
relativePath?: string;
}
export enum SortOptions {
Ascending = 'ascending',
Descending = 'descending',
None = 'none',
}
class Tokenizer {
private readonly stopsForLanguage: Set<string>;
constructor(doc: DocumentInfo) {
this.stopsForLanguage = SPECIFIC_STOPS.get(doc.languageId) ?? GENERIC_STOPS;
}
tokenize(a: string): Set<string> {
return new Set<string>(splitIntoWords(a).filter(x => !this.stopsForLanguage.has(x)));
}
}
/**
* For a number of documents (the similar files),
* associate to each document and its kind of window computation (as key)
* the sequence b_1, ..., b_n, where
* b_i is the set of tokens in the ith window --
* e.g. for window length 10,
* WINDOWED_TOKEN_SET_CACHE(doc)[0]
* holds the tokens in the first 10 lines of the document.
*/
const WINDOWED_TOKEN_SET_CACHE = new FifoCache<Set<string>[]>(20);
/**
* For a given document, extracts the best matching snippets from other documents
* by comparing all of a set of windows in the object doc.
*/
export abstract class WindowedMatcher {
protected referenceDoc: DocumentInfoWithOffset;
protected tokenizer: Tokenizer;
protected abstract id(): string;
protected abstract similarityScore(a: Set<string>, b: Set<string>): number;
/**
* Given an array of lines, returns an array of pairs <startLine, endLine> of indices,
* such that each pair is a window of lines to consider adding.
* startLine is inclusive, endLine is exclusive.
* @param lines Lines of a source text, in order
*/
protected abstract getWindowsDelineations(lines: string[]): [number, number][];
/**
* Subclasses should implement this method to return the desired context info for tokenization
* from the reference document. Will only be called after constructor is finished.
* The tokenizer used in WindowedMatcher is a simple tokenizer for Jaccard similarity, NOT an
* OpenAI model tokenizer.
*/
protected abstract _getCursorContextInfo(referenceDoc: DocumentInfoWithOffset): CursorContextInfo;
protected constructor(referenceDoc: DocumentInfoWithOffset) {
this.referenceDoc = referenceDoc;
this.tokenizer = new Tokenizer(referenceDoc); // Just uses language info from referenceDoc
}
/**
* Lazy getter for referenceTokens since it relies on properties
* that are not initialized in the constructor of WindowedMatcher
* but in the constructor of its subclasses.
*/
protected referenceTokensCache: Set<string> | undefined;
get referenceTokens(): Promise<Set<string>> {
return Promise.resolve(this.createReferenceTokens());
}
private createReferenceTokens(): Set<string> {
return (this.referenceTokensCache ??= this.tokenizer.tokenize(
this._getCursorContextInfo(this.referenceDoc).context
));
}
/**
* Returns a sorted array of snippets with their scores according to the sort option.
* @param snippets ScoredSnippet[]
*
*/
sortScoredSnippets(snippets: ScoredSnippetMarker[], sortOption = SortOptions.Descending): ScoredSnippetMarker[] {
return sortOption === SortOptions.Ascending
? snippets.sort((snippetA, snippetB) => (snippetA.score > snippetB.score ? 1 : -1))
: sortOption === SortOptions.Descending
? snippets.sort((snippetA, snippetB) => (snippetA.score > snippetB.score ? -1 : 1))
: snippets;
}
/**
* Returns all snippet markers with their scores.
* @param objectDoc
*
*/
async retrieveAllSnippets(
objectDoc: SimilarFileInfo,
sortOption = SortOptions.Descending
): Promise<ScoredSnippetMarker[]> {
const snippets: ScoredSnippetMarker[] = [];
if (objectDoc.source.length === 0 || (await this.referenceTokens).size === 0) {
return snippets;
}
const lines = objectDoc.source.split('\n');
const key = this.id() + ':' + objectDoc.source;
const tokensInWindows = WINDOWED_TOKEN_SET_CACHE.get(key) ?? [];
// if the tokens are not cached, we need to compute them
const needToComputeTokens = tokensInWindows.length === 0;
const tokenizedLines = needToComputeTokens ? lines.map(l => this.tokenizer.tokenize(l), this.tokenizer) : [];
// Compute the windows with the score
for (const [index, [startLine, endLine]] of this.getWindowsDelineations(lines).entries()) {
if (needToComputeTokens) {
const tokensInWindow = new Set<string>();
tokenizedLines
.slice(startLine, endLine)
.forEach(x => x.forEach(s => tokensInWindow.add(s), tokensInWindow));
tokensInWindows.push(tokensInWindow);
}
// Now tokensInWindows[index] contains the tokens in the window, whether we just computed them or not
const tokensInWindow = tokensInWindows[index];
const score = this.similarityScore(tokensInWindow, await this.referenceTokens);
// If snippets overlap, keep the one with highest score.
// Note: Assuming the getWindowsDelineations function returns windows in sorted ascending line ranges.
if (snippets.length && startLine > 0 && snippets[snippets.length - 1].endLine > startLine) {
if (snippets[snippets.length - 1].score < score) {
snippets[snippets.length - 1].score = score;
snippets[snippets.length - 1].startLine = startLine;
snippets[snippets.length - 1].endLine = endLine;
}
continue;
}
snippets.push({
score,
startLine,
endLine,
});
}
// If we didn't get the token sets from the cache, time to put them there!
if (needToComputeTokens) {
WINDOWED_TOKEN_SET_CACHE.put(key, tokensInWindows);
}
return this.sortScoredSnippets(snippets, sortOption);
}
findMatches(objectDoc: SimilarFileInfo, maxSnippetsPerFile: number): Promise<SnippetWithProviderInfo[]> {
const snippet = this.findBestMatch(objectDoc, maxSnippetsPerFile);
return snippet;
}
/**
* Returns the snippet from the object document
* that is most similar to the reference Document
* together with its Jaccard score
*
* @param objectDoc
*/
async findBestMatch(objectDoc: SimilarFileInfo, maxSnippetsPerFile: number): Promise<SnippetWithProviderInfo[]> {
if (objectDoc.source.length === 0 || (await this.referenceTokens).size === 0) {
return [];
}
const lines = objectDoc.source.split('\n');
const snippets = await this.retrieveAllSnippets(objectDoc, SortOptions.Descending);
// safe guard against empty lists
if (snippets.length === 0) {
return [];
}
const bestSnippets: SnippetWithProviderInfo[] = [];
for (let i = 0; i < snippets.length && i < maxSnippetsPerFile; i++) {
// Skip null scored snippets.
if (snippets[i].score !== 0) {
// Get the snippet's text.
const snippetCode = lines.slice(snippets[i].startLine, snippets[i].endLine).join('\n');
bestSnippets.push({
snippet: snippetCode,
semantics: SnippetSemantics.Snippet,
provider: SnippetProviderType.SimilarFiles,
...snippets[i],
});
}
}
return bestSnippets;
}
}
/**
* Split by non-alphanumeric characters
*/
export function splitIntoWords(a: string): string[] {
return a.split(/[^a-zA-Z0-9]/).filter(x => x.length > 0);
}
const ENGLISH_STOPS = new Set([
// - pronouns
'we',
'our',
'you',
'it',
'its',
'they',
'them',
'their',
'this',
'that',
'these',
'those',
// - verbs
'is',
'are',
'was',
'were',
'be',
'been',
'being',
'have',
'has',
'had',
'having',
'do',
'does',
'did',
'doing',
'can',
'don',
't',
's',
'will',
'would',
'should',
// - wh-words
'what',
'which',
'who',
'when',
'where',
'why',
'how',
// - articles
'a',
'an',
'the',
// - prepositions
'and',
'or',
'not',
'no',
'but',
'because',
'as',
'until',
'again',
'further',
'then',
'once',
'here',
'there',
'all',
'any',
'both',
'each',
'few',
'more',
'most',
'other',
'some',
'such',
'above',
'below',
'to',
'during',
'before',
'after',
'of',
'at',
'by',
'about',
'between',
'into',
'through',
'from',
'up',
'down',
'in',
'out',
'on',
'off',
'over',
'under',
'only',
'own',
'same',
'so',
'than',
'too',
'very',
'just',
'now',
]);
/**
* A generic set of stops for any programming language
*/
const GENERIC_STOPS = new Set([
// words that are common in programming languages
'if',
'then',
'else',
'for',
'while',
'with',
'def',
'function',
'return',
'TODO',
'import',
'try',
'catch',
'raise',
'finally',
'repeat',
'switch',
'case',
'match',
'assert',
'continue',
'break',
'const',
'class',
'enum',
'struct',
'static',
'new',
'super',
'this',
'var',
// words that are common in English comments:
...ENGLISH_STOPS,
]);
/**
* Specific stops for certain languages
* Note that ENGLISH_STOPS need to be added to this set if they are to be included
*/
const SPECIFIC_STOPS: Map<string, Set<string>> = new Map([
// none yet
]);

View File

@@ -0,0 +1,119 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { DocumentInfoWithOffset, SimilarFileInfo } from '../prompt';
import { FixedWindowSizeJaccardMatcher } from './jaccardMatching';
import { SnippetWithProviderInfo } from './snippets';
import { BlockTokenSubsetMatcher } from './subsetMatching';
const DEFAULT_SNIPPET_THRESHOLD = 0.0;
const DEFAULT_SNIPPET_WINDOW_SIZE = 60;
const DEFAULT_MAX_TOP_SNIPPETS = 4;
const DEFAULT_MAX_SNIPPETS_PER_FILE = 1;
const DEFAULT_MAX_NUMBER_OF_FILES = 20;
const DEFAULT_MAX_CHARACTERS_PER_FILE = 10000;
export interface SimilarFilesOptions {
snippetLength: number;
threshold: number;
maxTopSnippets: number;
maxCharPerFile: number;
maxNumberOfFiles: number;
maxSnippetsPerFile: number;
useSubsetMatching?: boolean;
}
export const defaultSimilarFilesOptions: SimilarFilesOptions = {
snippetLength: DEFAULT_SNIPPET_WINDOW_SIZE,
threshold: DEFAULT_SNIPPET_THRESHOLD,
maxTopSnippets: DEFAULT_MAX_TOP_SNIPPETS,
maxCharPerFile: DEFAULT_MAX_CHARACTERS_PER_FILE,
maxNumberOfFiles: DEFAULT_MAX_NUMBER_OF_FILES,
maxSnippetsPerFile: DEFAULT_MAX_SNIPPETS_PER_FILE,
useSubsetMatching: false,
};
export const conservativeFilesOptions: SimilarFilesOptions = {
snippetLength: 10,
threshold: 0.3,
maxTopSnippets: 1,
maxCharPerFile: DEFAULT_MAX_CHARACTERS_PER_FILE,
maxNumberOfFiles: DEFAULT_MAX_NUMBER_OF_FILES,
maxSnippetsPerFile: 1,
};
export const nullSimilarFilesOptions: SimilarFilesOptions = {
snippetLength: 0,
threshold: 1,
maxTopSnippets: 0,
maxCharPerFile: 0,
maxNumberOfFiles: 0,
maxSnippetsPerFile: 0,
};
// Default similarity parameters for languageId === 'cpp'.
export const defaultCppSimilarFilesOptions: SimilarFilesOptions = {
snippetLength: 60,
threshold: 0.0,
maxTopSnippets: 16,
maxCharPerFile: 100000,
maxNumberOfFiles: 200,
maxSnippetsPerFile: 4,
};
function getMatcher(doc: DocumentInfoWithOffset, selection: SimilarFilesOptions) {
const matcherFactory = selection.useSubsetMatching
? BlockTokenSubsetMatcher.FACTORY(selection.snippetLength)
: FixedWindowSizeJaccardMatcher.FACTORY(selection.snippetLength);
return matcherFactory.to(doc);
}
/**
* @returns A SnippetWithProviderInfo describing the best matches from similar files.
*/
export async function getSimilarSnippets(
doc: DocumentInfoWithOffset,
similarFiles: SimilarFileInfo[],
options: SimilarFilesOptions
): Promise<SnippetWithProviderInfo[]> {
const matcher = getMatcher(doc, options);
if (options.maxTopSnippets === 0) {
return [];
}
const snippets = (
await similarFiles
// filter out absurdly long or absurdly many open files
.filter(similarFile => similarFile.source.length < options.maxCharPerFile && similarFile.source.length > 0)
// slice(0) duplicates an array
.slice(0, options.maxNumberOfFiles)
.reduce(
async (
acc,
similarFile // accumulator of all snippets from all similarFiles
) =>
(await acc).concat(
(await matcher.findMatches(similarFile, options.maxSnippetsPerFile)).map(snippet => ({
relativePath: similarFile.relativePath,
...snippet,
}))
),
Promise.resolve([] as SnippetWithProviderInfo[])
)
)
.filter(
similarFile =>
// remove files that had no match at all
similarFile.score &&
similarFile.snippet &&
// remove files that had a low score
similarFile.score > options.threshold
)
// order them with best (highest scores) last
.sort((a, b) => a.score - b.score)
// take the best options from the end
.slice(-options.maxTopSnippets);
return snippets;
}

View File

@@ -0,0 +1,76 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ScoredSnippet } from './selectRelevance';
/** Indicates what provider produced a given snippet. */
export enum SnippetProviderType {
SimilarFiles = 'similar-files',
Path = 'path',
}
/**
* The semantics of a snippet. For example, some providers
* might always produce a snippet that is a complete function
* whereas others might produce a snippet that are inherhently
* partial.
*/
export enum SnippetSemantics {
/** The contents of the snippet is a function. */
Function = 'function',
/** The contents of the snippet is an unspecified snippet. */
Snippet = 'snippet',
/** Contains multiple snippets of type snippet */
Snippets = 'snippets',
/** The following are from hover text */
Variable = 'variable',
Parameter = 'parameter',
Method = 'method',
Class = 'class',
Module = 'module',
Alias = 'alias',
Enum = 'enum member',
Interface = 'interface',
}
/** Extends a ScoredSnippet with information about its provider. */
export interface SnippetWithProviderInfo extends ScoredSnippet {
/** The provider that created this snippet. */
provider: SnippetProviderType;
/** The semantical meaning of the snippet's contents. */
semantics: SnippetSemantics;
}
type SnippetToAnnounce = Pick<SnippetWithProviderInfo, 'snippet' | 'semantics' | 'relativePath'>;
/**
* A map from semantics enum to a human / LLM-readable label that we
* include when announcing a snippet.
*/
const snippetSemanticsToString: { [key in SnippetSemantics]: string } = {
[SnippetSemantics.Function]: 'function',
[SnippetSemantics.Snippet]: 'snippet',
[SnippetSemantics.Snippets]: 'snippets',
[SnippetSemantics.Variable]: 'variable',
[SnippetSemantics.Parameter]: 'parameter',
[SnippetSemantics.Method]: 'method',
[SnippetSemantics.Class]: 'class',
[SnippetSemantics.Module]: 'module',
[SnippetSemantics.Alias]: 'alias',
[SnippetSemantics.Enum]: 'enum member',
[SnippetSemantics.Interface]: 'interface',
};
/**
* Formats a snippet for inclusion in the prompt.
*/
export function announceSnippet(snippet: SnippetToAnnounce) {
const semantics = snippetSemanticsToString[snippet.semantics];
const pluralizedSemantics = [SnippetSemantics.Snippets].includes(snippet.semantics) ? 'these' : 'this';
const headline = snippet.relativePath
? `Compare ${pluralizedSemantics} ${semantics} from ${snippet.relativePath}:`
: `Compare ${pluralizedSemantics} ${semantics}:`;
return { headline, snippet: snippet.snippet };
}

View File

@@ -0,0 +1,159 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { parseTreeSitter } from '../parse';
import { DocumentInfoWithOffset } from '../prompt';
import { CursorContextInfo, getCursorContext } from './cursorContext';
import { WindowedMatcher } from './selectRelevance';
import { getBasicWindowDelineations } from './windowDelineations';
import Parser from 'web-tree-sitter';
/**
* Implements an evolution of the FixedWindowSizeJaccardMatcher that is different in two ways.
* 1. The source tokens window is the enclosing class member, as determined by Tree-Sitter.
* 2. The scoring algorithm is a unidirectional set membership check (count of items from A that exist in B)
* rather than a set difference.
*/
export class BlockTokenSubsetMatcher extends WindowedMatcher {
private windowLength: number;
private constructor(referenceDoc: DocumentInfoWithOffset, windowLength: number) {
super(referenceDoc);
this.windowLength = windowLength;
}
static FACTORY = (windowLength: number) => {
return {
to: (referenceDoc: DocumentInfoWithOffset) => new BlockTokenSubsetMatcher(referenceDoc, windowLength),
};
};
protected id(): string {
return 'fixed:' + this.windowLength;
}
protected getWindowsDelineations(lines: string[]): [number, number][] {
return getBasicWindowDelineations(this.windowLength, lines);
}
protected _getCursorContextInfo(referenceDoc: DocumentInfoWithOffset): CursorContextInfo {
return getCursorContext(referenceDoc, {
maxLineCount: this.windowLength,
});
}
override get referenceTokens(): Promise<Set<string>> {
return this.createReferenceTokensForLanguage();
}
private async createReferenceTokensForLanguage(): Promise<Set<string>> {
if (this.referenceTokensCache) {
return this.referenceTokensCache;
}
// Syntax aware reference tokens uses tree-sitter based parsing to identify the bounds of the current
// method and extracts tokens from just that span for use as the reference set.
this.referenceTokensCache = BlockTokenSubsetMatcher.syntaxAwareSupportsLanguage(this.referenceDoc.languageId)
? await this.syntaxAwareReferenceTokens()
: await super.referenceTokens;
return this.referenceTokensCache;
}
private async syntaxAwareReferenceTokens(): Promise<Set<string>> {
// See if there is an enclosing class or type member.
const start = (await this.getEnclosingMemberStart(this.referenceDoc.source, this.referenceDoc.offset))
?.startIndex;
const end = this.referenceDoc.offset;
// If not, fallback to the 60-line chunk behavior.
const text = start
? this.referenceDoc.source.slice(start, end)
: getCursorContext(this.referenceDoc, {
maxLineCount: this.windowLength,
}).context;
// Extract the tokens.
return this.tokenizer.tokenize(text);
}
private static syntaxAwareSupportsLanguage(languageId: string): boolean {
switch (languageId) {
case 'csharp':
return true;
default:
return false;
}
}
protected similarityScore(a: Set<string>, b: Set<string>): number {
return computeScore(a, b);
}
async getEnclosingMemberStart(text: string, offset: number): Promise<Parser.SyntaxNode | undefined> {
let tree: Parser.Tree | undefined;
try {
tree = await parseTreeSitter(this.referenceDoc.languageId, text);
let nodeAtPos: Parser.SyntaxNode | undefined = tree.rootNode.namedDescendantForIndex(offset);
while (nodeAtPos) {
// For now, hard code for C#.
if (BlockTokenSubsetMatcher.isMember(nodeAtPos) || BlockTokenSubsetMatcher.isBlock(nodeAtPos)) {
break;
}
nodeAtPos = nodeAtPos.parent ?? undefined;
}
return nodeAtPos;
} finally {
tree?.delete();
}
}
static isMember(node: Parser.SyntaxNode | undefined): boolean {
// For now, hard code for C#.
switch (node?.type) {
case 'method_declaration':
case 'property_declaration':
case 'field_declaration':
case 'constructor_declaration':
return true;
default:
return false;
}
}
static isBlock(node: Parser.SyntaxNode | undefined): boolean {
// For now, hard code for C#.
switch (node?.type) {
case 'class_declaration':
case 'struct_declaration':
case 'record_declaration':
case 'enum_declaration':
case 'interface_declaration':
return true;
default:
return false;
}
}
}
/**
* Count the number of unique tokens from B that are also in A.
*/
function computeScore(a: Set<string>, b: Set<string>) {
const subsetOverlap = new Set();
b.forEach(x => {
if (a.has(x)) {
subsetOverlap.add(x);
}
});
return subsetOverlap.size;
}

View File

@@ -0,0 +1,145 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { IndentationTree } from '../indentation/classes';
import { clearLabels, visitTree } from '../indentation/manipulation';
import { parseTree } from '../indentation/parsing';
/**
* Returns a list of (startline, endline) pairs representing fixed size windows
*
* @param windowLength length of fixed size window
* @param lines lines to extract fixed size windows from
* @returns list of (startline, endline) pairs
*/
export function getBasicWindowDelineations(windowLength: number, lines: string[]): [number, number][] {
const windows: [number, number][] = [];
const length = lines.length;
if (length === 0) {
return [];
}
if (length < windowLength) {
// if not long enough to reach a single window length, return full document
return [[0, length]];
}
for (let startLine = 0; startLine < length - windowLength + 1; startLine++) {
windows.push([startLine, startLine + windowLength]);
}
return windows;
}
/**
* Calculate all windows like with the following properties:
* - they are all of length <= maxLength
* - they are all of length >= minLength
* - except if they are followed by enough blank lines to reach length >= minLength
* - they are a contiguous subsequence from [parentline, child1, child2, ..., childn]
* - which neither starts nor ends with a blank line
* Note that windows of the form "parent with all its children" could
* appear in different ways with that definition,
* e.g. as "childi" of its parent, and as "parent, child1, ..., childn" where the parent is itself.
* Nevertheless, it will only be listed once.
* @param lines
*/
export function getIndentationWindowsDelineations(
lines: string[],
languageId: string,
minLength: number,
maxLength: number
): [number, number][] {
// Deal with degenerate cases
if (lines.length < minLength || maxLength === 0) {
return [];
}
const windows: [number, number][] = [];
// For each node, keep track of how long its children extend, or whether it can't be included in a window anyhow
type TreeLabel = { totalLength: number; firstLineAfter: number };
// Todo: add groupBlocks here as well
const labeledTree = clearLabels(parseTree(lines.join('\n'), languageId)) as IndentationTree<TreeLabel>;
visitTree(
labeledTree,
node => {
if (node.type === 'blank') {
node.label = { totalLength: 1, firstLineAfter: node.lineNumber + 1 };
return;
}
// Statistics to gather on the way, to be consumed by parents
let totalLength = node.type === 'line' ? 1 : 0;
let firstLineAfter = node.type === 'line' ? node.lineNumber + 1 : NaN;
// we consider intervals [a, b] which correspond to including children number a (-1 means parent) through b exclusive.
// the window start and end lines are computed here, such that startLine (inclusive) to endLine (exclusive) covers the window
function getStartLine(a: number) {
return a === -1
? firstLineAfter - totalLength
: node.subs[a].label!.firstLineAfter - node.subs[a].label!.totalLength;
}
function getEndLine(b: number, startLine: number) {
return b === 0 ? startLine + 1 : node.subs[b - 1].label!.firstLineAfter;
}
// iteratively go through candidates for [a, b[:
// if from a to including b would be too long, add the window a to b exclusive and increase a as far as necessary, otherwise increase b
// a = -1 will mean: include the parent
let a = node.type === 'line' ? -1 : 0; // if the parent is a line, consider using it
let lengthFromAToBInclusive = node.type === 'line' ? 1 : 0; // if so, the length is 1, otherwise 0
let lastBThatWasntABlank = 0;
for (let b = 0; b < node.subs.length; b++) {
// don't let the window start with blank lines
while (a >= 0 && a < node.subs.length && node.subs[a].type === 'blank') {
lengthFromAToBInclusive -= node.subs[a].label!.totalLength;
a++;
}
if (node.subs[b].type !== 'blank') {
lastBThatWasntABlank = b;
}
// add subs[b] to the window
firstLineAfter = node.subs[b].label!.firstLineAfter;
totalLength += node.subs[b].label!.totalLength;
lengthFromAToBInclusive += node.subs[b].label!.totalLength;
if (lengthFromAToBInclusive > maxLength) {
const startLine = getStartLine(a);
const endLine = getEndLine(b, startLine);
const endLineTrimmedForBlanks =
lastBThatWasntABlank === b ? endLine : getEndLine(lastBThatWasntABlank, startLine);
// for the test, note that blanks count for getting us over the minLength:
if (minLength <= endLine - startLine) {
windows.push([startLine, endLineTrimmedForBlanks]);
}
while (lengthFromAToBInclusive > maxLength) {
// remove subs[a] from the window
lengthFromAToBInclusive -=
a === -1
? node.type === 'line'
? 1
: // this cannot happen: if not a line, we start with a = 0 unless it's a line
0
: node.subs[a].label!.totalLength;
a++;
}
}
}
// if there's anything left to add (a < b), do it
if (a < node.subs.length) {
const startLine = getStartLine(a);
const endLine = firstLineAfter;
const endLineTrimmedForBlanks =
a === -1 ? endLine : node.subs[lastBThatWasntABlank].label!.firstLineAfter;
// note: even if fillUpWindowWithPartOfNextNeighbor is true,
// there is no next similar file here, so nothing to extend the window to
if (minLength <= endLine - startLine) {
windows.push([startLine, endLineTrimmedForBlanks]);
}
// Set the node's label
}
node.label = { totalLength, firstLineAfter };
},
'bottomUp'
);
// windows is an array of [start, end] pairs,
// but some may appear twice, and should be removed
return windows
.sort((a, b) => a[0] - b[0] || a[1] - b[1])
.filter((a, i, arr) => i === 0 || a[0] !== arr[i - 1][0] || a[1] !== arr[i - 1][1]);
}

View File

@@ -0,0 +1,34 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
interface ScoredSuffix {
score: number;
}
export function findEditDistanceScore(a: number[], b: number[]): ScoredSuffix {
if (a.length === 0 || b.length === 0) {
return { score: a.length + b.length };
}
const matrix = Array.from({ length: a.length }).map(() => Array.from({ length: b.length }).map(() => 0));
for (let i = 0; i < a.length; i++) {
matrix[i][0] = i;
}
for (let i = 0; i < b.length; i++) {
matrix[0][i] = i;
}
for (let j = 0; j < b.length; j++) {
for (let i = 0; i < a.length; i++) {
matrix[i][j] = Math.min(
(i === 0 ? j : matrix[i - 1][j]) + 1,
(j === 0 ? i : matrix[i][j - 1]) + 1,
(i === 0 || j === 0 ? Math.max(i, j) : matrix[i - 1][j - 1]) + (a[i] === b[j] ? 0 : 1)
);
}
}
return { score: matrix[a.length - 1][b.length - 1] };
}

View File

@@ -0,0 +1,226 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { UseData, UseState } from '../../components/hooks';
import * as assert from 'assert';
import { isNumber, isString } from './testHelpers';
suite('Hooks', function () {
suite('Use State', function () {
test('stores state', function () {
const state = new UseState([]);
const [value] = state.useState(0);
assert.deepStrictEqual(value, 0);
});
test('accepts undefined as initial state', function () {
const state = new UseState([]);
const [value] = state.useState(undefined);
assert.deepStrictEqual(value, undefined);
});
test('accepts no value as initial state', function () {
const state = new UseState([]);
const [value] = state.useState();
assert.deepStrictEqual(value, undefined);
});
test('marks state as changed when updating state', function () {
const state = new UseState([]);
const [_, setValue] = state.useState(0);
setValue(1);
assert.strictEqual(state.hasChanged(), true);
});
test('stores state across use state instances', function () {
const rawState: number[] = [];
const [value, setValue] = new UseState(rawState).useState(0);
setValue(1);
const [newValue] = new UseState(rawState).useState(0);
assert.deepStrictEqual(value, 0);
assert.deepStrictEqual(newValue, 1);
});
test('multiple use state invocations produce separate state', function () {
const rawState: number[] = [];
const state = new UseState(rawState);
const [value1] = state.useState(0);
const [value2] = state.useState('test');
assert.deepStrictEqual(value1, 0);
assert.deepStrictEqual(value2, 'test');
});
test('accepts function as initial state', function () {
const state = new UseState([]);
const initializer = () => 42;
const [value] = state.useState(initializer);
assert.deepStrictEqual(value, 42);
});
test('setState accepts function to update state', function () {
const rawState: number[] = [];
const state1 = new UseState(rawState);
const [value, setValue] = state1.useState(1);
const state2 = new UseState(rawState);
setValue(prev => prev + 1);
const [updatedValue] = state2.useState(0);
assert.deepStrictEqual(value, 1);
assert.deepStrictEqual(updatedValue, 2);
assert.strictEqual(state1.hasChanged(), true);
});
test('maintains separate states when multiple instances share raw state', function () {
const rawState: number[] = [];
const state1 = new UseState(rawState);
const state2 = new UseState(rawState);
const [count1, setCount1] = state1.useState(0);
setCount1(5);
const [count2] = state2.useState(0);
assert.strictEqual(count1, 0);
assert.strictEqual(count2, 5);
});
test('hasChanged returns false before any setState calls', function () {
const state = new UseState([]);
state.useState(0);
assert.strictEqual(state.hasChanged(), false);
});
});
suite('Use Data', function () {
test('stores data callback for type', async function () {
const useData = new UseData(() => { });
let data = '';
useData.useData(isString, (value: string) => {
data = value;
});
await useData.updateData('test');
assert.deepStrictEqual(data, 'test');
});
test('stores async data callback for type', async function () {
const useData = new UseData(() => { });
let data = '';
useData.useData(isString, async (value: string) => {
await Promise.resolve();
data = value;
});
await useData.updateData('test');
assert.deepStrictEqual(data, 'test');
});
test('stores multiple data callbacks for type', async function () {
const useData = new UseData(() => { });
let data1 = '';
let data2 = '';
useData.useData(isString, (value: string) => {
data1 = value;
});
useData.useData(isString, (value: string) => {
data2 = value;
});
await useData.updateData('test');
assert.deepStrictEqual(data1, 'test');
assert.deepStrictEqual(data2, 'test');
});
test('stores multiple async data callbacks for type', async function () {
const useData = new UseData(() => { });
let data1 = '';
let data2 = '';
useData.useData(isString, async (value: string) => {
await Promise.resolve();
data1 = value;
});
useData.useData(isString, async (value: string) => {
await Promise.resolve();
data2 = value;
});
await useData.updateData('test');
assert.deepStrictEqual(data1, 'test');
assert.deepStrictEqual(data2, 'test');
});
test('stores multiple data callbacks for different types', async function () {
const useData = new UseData(() => { });
let data1 = '';
let data2 = 0;
useData.useData(isString, (value: string) => {
data1 = value;
});
useData.useData(isNumber, (value: number) => {
data2 = value;
});
await useData.updateData('test');
await useData.updateData(23);
assert.deepStrictEqual(data1, 'test');
assert.deepStrictEqual(data2, 23);
});
test('updates data for subscribed types only', async function () {
const useData = new UseData(() => { });
let data = '';
useData.useData(isString, (value: string) => {
data = value;
});
await useData.updateData(23);
assert.deepStrictEqual(data, '');
});
test('updates data measures time to update', async function () {
let time = 0;
const useData = new UseData(updateTimeMs => {
time = updateTimeMs;
});
let data = '';
useData.useData(isString, (value: string) => {
data = value;
});
await useData.updateData(23);
assert.deepStrictEqual(data, '');
assert.ok(time > 0);
});
test('updates data measures time only if data hooks are present', async function () {
const useData = new UseData(updateTimeMs => {
throw new Error('Should not be called');
});
await useData.updateData(23);
});
});
});

View File

@@ -0,0 +1,33 @@
import {Fragment, jsx} from '#jsx/jsx-runtime';
import {ComponentContext} from '#prompt/components/components';
import * as assert from 'assert';
suite('JSX/TSX', function () {
test('Produces element from functional component', function () {
const fn = (props: unknown, context: ComponentContext) => [];
const element = jsx(fn, {children: ['Hello world']});
assert.deepStrictEqual(element.type, fn);
assert.deepStrictEqual(element.props, {
children: ['Hello world'],
});
});
test('Produces element from functional component with key', function () {
const fn = (props: unknown, context: ComponentContext) => [];
const element = jsx(fn, {children: ['Hello world']}, 'key');
assert.deepStrictEqual(element.type, fn);
assert.deepStrictEqual(element.props, {
children: ['Hello world'],
key: 'key',
});
});
test('Produces fragment function', function () {
const element = Fragment(['Hello world']);
assert.deepStrictEqual(element.type, 'f');
assert.deepStrictEqual(element.children, ['Hello world']);
});
});

View File

@@ -0,0 +1,762 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
/** @jsxRuntime automatic */
/** @jsxImportSource ../../../jsx-runtime */
import { Chunk, ComponentContext, PromptElement, PromptElementProps, Text } from '../../components/components';
import { Dispatch, StateUpdater } from '../../components/hooks';
import { VirtualPromptReconciler } from '../../components/reconciler';
import * as assert from 'assert';
import { CancellationTokenSource } from 'vscode-languageserver-protocol';
import { extractNodesWitPath, isNumber, isString } from './testHelpers';
suite('Virtual prompt reconciler', function () {
test('computes paths for virtual prompt nodes', function () {
const MyNestedComponent = () => {
return (
<>
<Text>Hola</Text>
<Text>Adios</Text>
</>
);
};
const prompt = (
<>
<MyNestedComponent />
<Text>Intermediate</Text>
<MyNestedComponent />
</>
);
const reconciler = new VirtualPromptReconciler(prompt);
const result = reconciler.reconcile();
const orderedPaths = extractNodesWitPath(result!);
// Assert expected paths
assert.deepStrictEqual(orderedPaths, [
'$.f',
'$.f[0].MyNestedComponent',
'$.f[0].MyNestedComponent[0].f',
'$.f[0].MyNestedComponent[0].f[0].Text',
'$.f[0].MyNestedComponent[0].f[0].Text[0]',
'$.f[0].MyNestedComponent[0].f[1].Text',
'$.f[0].MyNestedComponent[0].f[1].Text[0]',
'$.f[1].Text',
'$.f[1].Text[0]',
'$.f[2].MyNestedComponent',
'$.f[2].MyNestedComponent[0].f',
'$.f[2].MyNestedComponent[0].f[0].Text',
'$.f[2].MyNestedComponent[0].f[0].Text[0]',
'$.f[2].MyNestedComponent[0].f[1].Text',
'$.f[2].MyNestedComponent[0].f[1].Text[0]',
]);
// Assert uniqueness of paths
assert.deepStrictEqual([...new Set(orderedPaths)], orderedPaths);
});
test('computes paths for virtual prompt nodes with keys', function () {
const MyNestedComponent = () => {
return (
<>
<Text>Hola</Text>
<Text key={23}>Adios</Text>
</>
);
};
const prompt = (
<>
<MyNestedComponent />
<Chunk>
<Text key={'key-1'}>Text with key</Text>
</Chunk>
<MyNestedComponent />
</>
);
const reconciler = new VirtualPromptReconciler(prompt);
const result = reconciler.reconcile();
const orderedPaths = extractNodesWitPath(result!);
assert.deepStrictEqual(orderedPaths, [
'$.f',
'$.f[0].MyNestedComponent',
'$.f[0].MyNestedComponent[0].f',
'$.f[0].MyNestedComponent[0].f[0].Text',
'$.f[0].MyNestedComponent[0].f[0].Text[0]',
'$.f[0].MyNestedComponent[0].f["23"].Text',
'$.f[0].MyNestedComponent[0].f["23"].Text[0]',
'$.f[1].Chunk',
'$.f[1].Chunk["key-1"].Text',
'$.f[1].Chunk["key-1"].Text[0]',
'$.f[2].MyNestedComponent',
'$.f[2].MyNestedComponent[0].f',
'$.f[2].MyNestedComponent[0].f[0].Text',
'$.f[2].MyNestedComponent[0].f[0].Text[0]',
'$.f[2].MyNestedComponent[0].f["23"].Text',
'$.f[2].MyNestedComponent[0].f["23"].Text[0]',
]);
// Assert uniqueness of paths
assert.deepStrictEqual([...new Set(orderedPaths)], orderedPaths);
});
test('rejects duplicate keys on same level in initial prompt', function () {
const prompt = (
<>
<Text key={23}>Hola</Text>
<Text key={23}>Adios</Text>
</>
);
try {
new VirtualPromptReconciler(prompt);
assert.fail('Should have thrown an error');
} catch (e) {
assert.equal((e as Error).message, 'Duplicate keys found: 23');
}
});
test('rejects multiple duplicate keys on same level in initial prompt', function () {
const prompt = (
<>
<Text key={23}>Hola</Text>
<Text key={23}>Adios</Text>
<Text key={'aKey'}>Hola</Text>
<Text key={'aKey'}>Adios</Text>
</>
);
try {
new VirtualPromptReconciler(prompt);
assert.fail('Should have thrown an error');
} catch (e) {
assert.equal((e as Error).message, 'Duplicate keys found: 23, aKey');
}
});
test('rejects duplicate keys on same level during reconciliation', function () {
let outerSetCount: Dispatch<StateUpdater<number>>;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [count, setCount] = context.useState(1);
outerSetCount = setCount;
return (
<>
{Array.from({ length: count }).map((_, i) => (
<Text key={23}>Text {i}</Text>
))}
</>
);
};
const reconciler = new VirtualPromptReconciler(<MyTestComponent />);
outerSetCount!(2);
try {
reconciler.reconcile();
assert.fail('Should have thrown an error');
} catch (e) {
assert.equal((e as Error).message, 'Duplicate keys found: 23');
}
});
test('accepts same keys on different level', function () {
const prompt = (
<>
<>
<Text key={23}>Hola</Text>
</>
<>
<Text key={23}>Adios</Text>
</>
</>
);
const reconciler = new VirtualPromptReconciler(prompt);
const result = reconciler.reconcile();
const orderedPaths = extractNodesWitPath(result!);
assert.deepStrictEqual(orderedPaths, [
'$.f',
'$.f[0].f',
'$.f[0].f["23"].Text',
'$.f[0].f["23"].Text[0]',
'$.f[1].f',
'$.f[1].f["23"].Text',
'$.f[1].f["23"].Text[0]',
]);
// Assert uniqueness of paths
assert.deepStrictEqual([...new Set(orderedPaths)], orderedPaths);
});
test('Should re-render if the state of the component changed', function () {
let outerShouldRenderChildren: Dispatch<StateUpdater<boolean>>;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [shouldRenderChildren, setShouldRenderChildren] = context.useState(false);
outerShouldRenderChildren = setShouldRenderChildren;
if (shouldRenderChildren) {
return <Text>This is my child</Text>;
}
};
const reconciler = new VirtualPromptReconciler(<MyTestComponent />);
const resultOne = reconciler.reconcile();
assert.deepStrictEqual(resultOne!.children?.length, 0);
outerShouldRenderChildren!(true);
// Should re-render since the state changed
const resultTwo = reconciler.reconcile();
assert.deepStrictEqual(resultTwo!.children?.length, 1);
});
test('Should re-render if the state of a nested component changed', function () {
let outerSetShouldRenderChildren: Dispatch<StateUpdater<boolean>>;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [shouldRenderChildren, setShouldRenderChildren] = context.useState(false);
outerSetShouldRenderChildren = setShouldRenderChildren;
if (shouldRenderChildren) {
return <Text>This is my child</Text>;
}
};
const reconciler = new VirtualPromptReconciler(
(
<>
<MyTestComponent />
</>
)
);
const resultOne = reconciler.reconcile();
assert.deepStrictEqual(resultOne!.children?.length, 1);
assert.deepStrictEqual(resultOne!.children[0].children?.length, 0);
outerSetShouldRenderChildren!(true);
// Should re-render since the state changed
const resultTwo = reconciler.reconcile();
assert.deepStrictEqual(resultTwo!.children?.length, 1);
assert.deepStrictEqual(resultTwo!.children[0].children?.length, 1);
});
test('Should not re-render if the state did not change', function () {
let created = false;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [count, _] = context.useState(0);
if (created) {
throw new Error('Component was created more than once');
}
created = true;
return <Text>This is my component {count}</Text>;
};
const reconciler = new VirtualPromptReconciler(<MyTestComponent />);
try {
reconciler.reconcile();
reconciler.reconcile();
} catch (e) {
assert.fail('Component was created more than once, which should not happen');
}
});
test('Should preserve child state if position and type within parent are the same', function () {
let outerSetParentState: Dispatch<StateUpdater<string>>;
const ParentComponent = (props: PromptElementProps, context: ComponentContext) => {
const [parentState, setParentState] = context.useState('BEFORE');
outerSetParentState = setParentState;
return (
<>
<Text>This is the parent count: {parentState}</Text>
<ChildComponent parentState={parentState} />
</>
);
};
type ChildComponentProps = { parentState: string };
let childState = 'UNINITIALIZED';
const ChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childState = childComponentState;
return <Text>This is the child state {childComponentState}</Text>;
};
const reconciler = new VirtualPromptReconciler(<ParentComponent />);
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
outerSetParentState!('AFTER');
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
});
test('Should not preserve child state if position and type change and switch back', function () {
let outerSetParentState: Dispatch<StateUpdater<string>>;
const ParentComponent = (props: PromptElementProps, context: ComponentContext) => {
const [parentState, setParentState] = context.useState('BEFORE');
outerSetParentState = setParentState;
if (parentState === 'BEFORE') {
return (
<>
<Text>This is the parent count: {parentState}</Text>
<ChildComponent parentState={parentState} />
</>
);
}
return (
<>
<ChildComponent parentState={parentState} />
<Text>This is the parent count: {parentState}</Text>
</>
);
};
type ChildComponentProps = { parentState: string };
let childState = 'UNINITIALIZED';
const ChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childState = childComponentState;
return <Text>This is the child state {childComponentState}</Text>;
};
const reconciler = new VirtualPromptReconciler(<ParentComponent />);
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
outerSetParentState!('AFTER');
reconciler.reconcile();
assert.strictEqual(childState, 'AFTER');
outerSetParentState!('BEFORE');
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
});
test('Should preserve child state if position changes but key stays the same', function () {
let outerSetParentState: Dispatch<StateUpdater<string>>;
const ParentComponent = (props: PromptElementProps, context: ComponentContext) => {
const [parentState, setParentState] = context.useState('BEFORE');
outerSetParentState = setParentState;
if (parentState === 'BEFORE') {
return (
<>
<Text>This is the parent count: {parentState}</Text>
<ChildComponent key='child' parentState={parentState} />
</>
);
}
return (
<>
<ChildComponent key='child' parentState={parentState} />
<Text>This is the parent count: {parentState}</Text>
</>
);
};
type ChildComponentProps = { parentState: string };
let childState = 'UNINITIALIZED';
const ChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childState = childComponentState;
return <Text>This is the child state {childComponentState}</Text>;
};
const reconciler = new VirtualPromptReconciler(<ParentComponent />);
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
outerSetParentState!('AFTER');
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
outerSetParentState!('BEFORE');
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
});
test('Should preserve child state if position and type within parent are the same with deep nesting', function () {
let outerSetParentState: Dispatch<StateUpdater<string>>;
const ParentComponent = (props: PromptElementProps, context: ComponentContext) => {
const [parentState, setParentState] = context.useState('BEFORE');
outerSetParentState = setParentState;
return (
<>
<Text>This is the parent count: {parentState}</Text>
<ChildComponent parentState={parentState} />
</>
);
};
type ChildComponentProps = { parentState: string };
let childState = 'UNINITIALIZED';
const ChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childState = childComponentState;
return (
<>
<Text>This is the child state {childComponentState}</Text>
<ChildChildComponent parentState={childComponentState} />
</>
);
};
let childChildState = 'UNINITIALIZED';
const ChildChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childChildState = childComponentState;
return <Text>This is the child state {childComponentState}</Text>;
};
const reconciler = new VirtualPromptReconciler(<ParentComponent />);
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
assert.strictEqual(childChildState, 'BEFORE');
outerSetParentState!('AFTER');
reconciler.reconcile();
assert.strictEqual(childState, 'BEFORE');
assert.strictEqual(childChildState, 'BEFORE');
});
test('Should preserve child state if position and type within parent are the same with multiple children of same type', function () {
let outerSetParentState: Dispatch<StateUpdater<string>>;
const ParentComponent = (props: PromptElementProps, context: ComponentContext) => {
const [parentState, setParentState] = context.useState('BEFORE');
outerSetParentState = setParentState;
return (
<>
<Text>This is the parent count: {parentState}</Text>
<ChildComponent parentState={parentState + '_A'} />
<ChildComponent parentState={parentState + '_B'} />
</>
);
};
type ChildComponentProps = { parentState: string };
let childState: string[] = [];
const ChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childState.push(childComponentState);
return <Text>This is the child state {childComponentState}</Text>;
};
const reconciler = new VirtualPromptReconciler(<ParentComponent />);
reconciler.reconcile();
assert.deepStrictEqual(childState, ['BEFORE_A', 'BEFORE_B']);
childState = [];
outerSetParentState!('AFTER');
reconciler.reconcile();
assert.deepStrictEqual(childState, ['BEFORE_A', 'BEFORE_B']);
});
test('Should initialize child state if position changes on reconciliation', function () {
let outerSetParentCount: Dispatch<StateUpdater<number>>;
let outerSetParentState: Dispatch<StateUpdater<string>>;
const ParentComponent = (props: PromptElementProps, context: ComponentContext) => {
const [parentState, setParentState] = context.useState('FIRST');
const [count, setCount] = context.useState(0);
outerSetParentCount = setCount;
outerSetParentState = setParentState;
const renderChildren = () => {
const children = [];
for (let i = 0; i < count; i++) {
children.push(<Text>This is the parent count: {parentState}</Text>);
}
children.push(<ChildComponent parentState={parentState} />);
return children;
};
return <>{renderChildren()}</>;
};
type ChildComponentProps = { parentState: string };
let childState = 'UNINITIALIZED';
const ChildComponent = (props: ChildComponentProps, context: ComponentContext) => {
const [childComponentState, _] = context.useState(props.parentState);
childState = childComponentState;
return <Text>This is the child state {childComponentState}</Text>;
};
const reconciler = new VirtualPromptReconciler(<ParentComponent />);
reconciler.reconcile();
assert.strictEqual(childState, 'FIRST');
outerSetParentCount!(1);
outerSetParentState!('SECOND');
reconciler.reconcile();
assert.strictEqual(childState, 'SECOND');
});
test('Should support cancellation', function () {
const cts = new CancellationTokenSource();
let outerSetCount: Dispatch<StateUpdater<number>> = () => 0;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [count, setCount] = context.useState(0);
outerSetCount = setCount;
return <Text>This is my component {count}</Text>;
};
const reconciler = new VirtualPromptReconciler(<MyTestComponent />);
const result = reconciler.reconcile(cts.token);
outerSetCount(1);
cts.cancel();
const resultAfterCancellation = reconciler.reconcile(cts.token);
assert.deepStrictEqual(result, resultAfterCancellation);
});
test('Creates a pipe to route data to a component', async function () {
let componentData = '';
const DataComponent = (props: PromptElementProps, context: ComponentContext) => {
context.useData(isString, (data: string) => {
componentData = data;
});
return <></>;
};
const reconciler = new VirtualPromptReconciler(<DataComponent />);
const pipe = reconciler.createPipe();
await pipe.pump('test');
assert.deepStrictEqual(componentData, 'test');
});
test('Fails to pump data before initialization', async function () {
const reconciler = new VirtualPromptReconciler(undefined as unknown as PromptElement);
const pipe = reconciler.createPipe();
try {
await pipe.pump('test');
assert.fail('Should have thrown an error');
} catch (e) {
assert.equal((e as Error).message, 'No tree to pump data into. Pumping data before initializing?');
}
});
test('Creates a pipe to route data to a component after previous reconciliation has been cancelled', async function () {
const cts = new CancellationTokenSource();
let componentData = '';
const DataComponent = (props: PromptElementProps, context: ComponentContext) => {
context.useData(isString, (data: string) => {
componentData = data;
});
return <></>;
};
const reconciler = new VirtualPromptReconciler(<DataComponent />);
const pipe = reconciler.createPipe();
cts.cancel();
reconciler.reconcile(cts.token);
await pipe.pump('test');
assert.deepStrictEqual(componentData, 'test');
});
test('Computes node statistics on reconcile', async function () {
const DataComponent = (props: PromptElementProps, context: ComponentContext) => {
const [state, setState] = context.useState('');
context.useData(isString, (data: string) => {
setState(data);
});
return <>{state}</>;
};
const reconciler = new VirtualPromptReconciler(<DataComponent />);
const pipe = reconciler.createPipe();
await pipe.pump('test');
const tree = reconciler.reconcile();
const updateTime = tree?.lifecycle?.lifecycleData.getUpdateTimeMsAndReset();
assert.ok(updateTime);
assert.ok(updateTime > 0);
});
test('Computes node statistics on reconcile with measurements from data pumping', async function () {
const DataComponent = (props: PromptElementProps, context: ComponentContext) => {
const [state, setState] = context.useState('');
context.useData(isString, (data: string) => {
setState(data);
});
return <>{state}</>;
};
const reconciler = new VirtualPromptReconciler(<DataComponent />);
const pipe = reconciler.createPipe();
await pipe.pump('test');
let tree = reconciler.reconcile();
let updateTime = tree?.lifecycle?.lifecycleData.getUpdateTimeMsAndReset();
assert.ok(updateTime);
assert.ok(updateTime > 0);
tree = reconciler.reconcile();
updateTime = tree?.lifecycle?.lifecycleData.getUpdateTimeMsAndReset();
assert.ok(updateTime === 0);
});
test('Updates data time is updated on every data update', async function () {
const DataComponent = (props: PromptElementProps, context: ComponentContext) => {
const [count, setCount] = context.useState(0);
context.useData(isNumber, async (newCount: number) => {
await new Promise(resolve => setTimeout(resolve, count));
setCount(newCount);
});
return <>{count}</>;
};
const reconciler = new VirtualPromptReconciler(<DataComponent />);
const pipe = reconciler.createPipe();
await pipe.pump(1);
const tree = reconciler.reconcile();
const lifeCycleData = tree?.lifecycle?.lifecycleData;
assert.ok(lifeCycleData);
const timeFirstPump = lifeCycleData?.getUpdateTimeMsAndReset();
assert.ok(timeFirstPump > 0);
await pipe.pump(2);
const timeSecondPump = lifeCycleData?.getUpdateTimeMsAndReset();
assert.ok(timeSecondPump > 0);
assert.notDeepStrictEqual(timeFirstPump, timeSecondPump);
});
test('Creates a pipe to route data to many components', async function () {
let componentDataA = '';
const DataComponentA = (props: PromptElementProps, context: ComponentContext) => {
context.useData(isString, (data: string) => {
componentDataA = data;
});
return <></>;
};
let componentDataB = '';
const DataComponentB = (props: PromptElementProps, context: ComponentContext) => {
context.useData(isString, (data: string) => {
componentDataB = data;
});
return <></>;
};
const reconciler = new VirtualPromptReconciler(
(
<>
<DataComponentA />
<DataComponentB />
</>
)
);
const pipe = reconciler.createPipe();
await pipe.pump('test');
assert.deepStrictEqual(componentDataA, 'test');
assert.deepStrictEqual(componentDataB, 'test');
});
test('Creates a pipe to route data async to many components', async function () {
let componentDataA = '';
const DataComponentA = (props: PromptElementProps, context: ComponentContext) => {
context.useData(isString, async (data: string) => {
await Promise.resolve();
componentDataA = data;
});
return <></>;
};
let componentDataB = '';
const DataComponentB = (props: PromptElementProps, context: ComponentContext) => {
context.useData(isString, async (data: string) => {
await Promise.resolve();
componentDataB = data;
});
return <></>;
};
const reconciler = new VirtualPromptReconciler(
(
<>
<DataComponentA />
<DataComponentB />
</>
)
);
const pipe = reconciler.createPipe();
await pipe.pump('test');
assert.deepStrictEqual(componentDataA, 'test');
assert.deepStrictEqual(componentDataB, 'test');
});
test('Pumps data to components with any pipe independently', async function () {
const componentDataA: string[] = [];
const DataComponentA = (props: unknown, context: ComponentContext) => {
context.useData(isString, (data: string) => {
componentDataA.push(data);
});
return <></>;
};
const componentDataB: string[] = [];
const DataComponentB = (props: unknown, context: ComponentContext) => {
context.useData(isString, (data: string) => {
componentDataB.push(data);
});
return <></>;
};
const reconciler = new VirtualPromptReconciler(
(
<>
<DataComponentA />
<DataComponentB />
</>
)
);
const pipe1 = reconciler.createPipe();
await pipe1.pump('test');
const pipe2 = reconciler.createPipe();
await pipe2.pump('test2');
assert.deepStrictEqual(componentDataA, ['test', 'test2']);
assert.deepStrictEqual(componentDataB, ['test', 'test2']);
});
});

View File

@@ -0,0 +1,22 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { PromptSnapshotNode } from '../../components/components';
import { VirtualPromptNode } from '../../components/reconciler';
export function extractNodesWitPath(node: VirtualPromptNode | PromptSnapshotNode): string[] {
if (node.children === undefined || node.children.length === 0) {
return [node.path];
}
return [node.path, ...(node.children?.flatMap(extractNodesWitPath) ?? [])];
}
export function isString(value: unknown): value is string {
return typeof value === 'string';
}
export function isNumber(value: unknown): value is number {
return typeof value === 'number';
}

View File

@@ -0,0 +1,175 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
/** @jsxRuntime automatic */
/** @jsxImportSource ../../../jsx-runtime */
import {
ComponentContext,
PromptElement,
PromptElementProps,
PromptSnapshotNode,
Text,
} from '../../components/components';
import { Dispatch, StateUpdater } from '../../components/hooks';
import { VirtualPrompt } from '../../components/virtualPrompt';
import * as assert from 'assert';
import { CancellationTokenSource } from 'vscode-languageserver-protocol';
suite('Virtual prompt', function () {
test('The virtual prompt should return a snapshot tree of a prompt', function () {
const prompt = (
<>
<Text>This is text</Text>
<Text>This is more text</Text>
</>
);
const virtualPrompt = new VirtualPrompt(prompt);
const { snapshot } = virtualPrompt.snapshot();
const nodeNames = getNodeNames(snapshot!);
const expected = {
name: 'f',
children: [
{
name: 'Text',
children: [
{
name: 'string',
children: [],
},
],
},
{
name: 'Text',
children: [
{
name: 'string',
children: [],
},
],
},
],
};
assert.deepStrictEqual(nodeNames, expected);
});
test('The virtual prompt should return an updated snapshot if the inner state changed', function () {
let outerSetCount: Dispatch<StateUpdater<number>>;
let renderCount = 0;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [count, setCount] = context.useState(0);
outerSetCount = setCount;
renderCount++;
return <Text>This is my component {count}</Text>;
};
const virtualPrompt = new VirtualPrompt(<MyTestComponent />);
const { snapshot: snapshotOne } = virtualPrompt.snapshot();
outerSetCount!(1);
const { snapshot: snapshotTwo } = virtualPrompt.snapshot();
assert.strictEqual(renderCount, 2);
assert.notDeepStrictEqual(snapshotOne, snapshotTwo);
});
test('Should cancel while snapshotting', function () {
let shouldCancel = false;
let outerCancelCount: Dispatch<StateUpdater<number>>;
const cts = new CancellationTokenSource();
const CancellingComponent = (props: PromptElementProps, context: ComponentContext) => {
const [_, setCount] = context.useState(0);
outerCancelCount = setCount;
// Cancel on second rendering
if (shouldCancel) {
cts.cancel();
}
shouldCancel = true;
return <Text>CancellingComponent</Text>;
};
const prompt = (
<>
<CancellingComponent />
</>
);
const virtualPrompt = new VirtualPrompt(prompt);
outerCancelCount!(1);
const result = virtualPrompt.snapshot(cts.token);
assert.deepStrictEqual(result, { snapshot: undefined, status: 'cancelled' });
});
test('Should return an error if there was an error during snapshot', function () {
const virtualPrompt = new VirtualPrompt(undefined as unknown as PromptElement);
const result = virtualPrompt.snapshot();
assert.deepStrictEqual(result.snapshot, undefined);
assert.deepStrictEqual(result.status, 'error');
assert.deepStrictEqual(result.error?.message, 'No tree to reconcile, make sure to pass a valid prompt');
});
test('Should return an error if there was an error during reconciliation', function () {
let outerSetCount: Dispatch<StateUpdater<number>>;
let created = false;
const MyTestComponent = (props: PromptElementProps, context: ComponentContext) => {
const [count, setCount] = context.useState(0);
if (created) {
throw new Error('Component was recreated');
}
created = true;
outerSetCount = setCount;
return <Text>This is my component {count}</Text>;
};
const prompt = (
<>
<MyTestComponent />
</>
);
const virtualPrompt = new VirtualPrompt(prompt);
outerSetCount!(1);
const result = virtualPrompt.snapshot();
assert.deepStrictEqual(result.snapshot, undefined);
assert.deepStrictEqual(result.status, 'error');
assert.deepStrictEqual(result.error?.message, 'Component was recreated');
});
test('Should create a pipe', function () {
const virtualPrompt = new VirtualPrompt(<>test</>);
const pipe = virtualPrompt.createPipe();
assert.ok(pipe);
});
});
type NodeName = { name: string; children: NodeName[] };
function getNodeNames(node: PromptSnapshotNode): NodeName {
return {
name: node.name,
children: node.children?.map(getNodeNames) ?? [],
};
}

View File

@@ -0,0 +1,233 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { PromptSnapshotNode } from '../../components/components';
import { SnapshotWalker } from '../../components/walker';
import * as assert from 'assert';
suite('Snapshot Walker', function () {
test('walks snapshot recursively', function () {
const snapshot = createTestSnapshot(1, 1);
const walker = new SnapshotWalker(snapshot);
const visitedValues: string[] = [];
walker.walkSnapshot((node, parent, context) => {
visitedValues.push(node.path ?? 'undefined');
return true;
});
assert.deepStrictEqual(visitedValues, ['0', '0.0']);
});
test('stops walking after visitor returns false', function () {
const snapshot = createTestSnapshot(2, 2);
const walker = new SnapshotWalker(snapshot);
const visitedPaths: string[] = [];
walker.walkSnapshot((node, parent, context) => {
visitedPaths.push(node.path);
return false;
});
assert.deepStrictEqual(visitedPaths, ['0']);
});
test('walks deeper nested snapshot', function () {
const snapshot = createTestSnapshot(3, 2);
const walker = new SnapshotWalker(snapshot);
const paths: string[] = [];
walker.walkSnapshot((node, parent, context) => {
paths.push(node.path);
return true;
});
assert.deepStrictEqual(paths, [
'0',
'0.0',
'0.0.0',
'0.0.0.0',
'0.0.0.1',
'0.0.1',
'0.0.1.0',
'0.0.1.1',
'0.1',
'0.1.0',
'0.1.0.0',
'0.1.0.1',
'0.1.1',
'0.1.1.0',
'0.1.1.1',
]);
});
test('carries weight relative to parent weight', function () {
const snapshot: PromptSnapshotNode = {
name: 'root',
path: '0',
value: '0',
props: { weight: 0.5 },
children: [
{
name: 'child',
path: '0.0',
value: '1',
props: { weight: 0.5 },
statistics: {},
},
],
statistics: {},
};
const walker = new SnapshotWalker(snapshot);
const weights: number[] = [];
walker.walkSnapshot((node, parent, context) => {
weights.push(context.weight as number);
return true;
});
assert.deepStrictEqual(weights, [0.5, 0.25]); // root: 0.5, child: 0.5 * 0.5
});
test('propagates chunks to children', function () {
const snapshot: PromptSnapshotNode = {
name: 'Chunk',
path: '0',
value: 'chunk1',
statistics: {},
children: [
{
name: 'child',
path: '0.0',
value: 'child1',
statistics: {},
},
],
};
const walker = new SnapshotWalker(snapshot);
const chunks: Set<string>[] = [];
walker.walkSnapshot((node, parent, context) => {
chunks.push(context.chunks as Set<string>);
return true;
});
assert.deepStrictEqual(chunks.length, 2);
const chunk = new Set<string>(['0']);
assert.deepStrictEqual(chunks[0], chunk);
assert.deepStrictEqual(chunks[1], chunk);
});
test('propagates nested chunks', function () {
const snapshot: PromptSnapshotNode = {
name: 'Chunk',
path: '0',
value: 'chunk1',
statistics: {},
children: [
{
name: 'child',
path: '0.0',
value: 'child1',
statistics: {},
},
{
name: 'Chunk',
path: '0.1',
value: 'chunk2',
statistics: {},
children: [
{
name: 'child',
path: '0.1.0',
value: 'child2',
statistics: {},
},
],
},
],
};
const walker = new SnapshotWalker(snapshot);
const chunks: Set<string>[] = [];
walker.walkSnapshot((node, parent, context) => {
chunks.push(context.chunks as Set<string>);
return true;
});
assert.deepStrictEqual(chunks.length, 4);
const chunk = new Set<string>(['0']);
const nestedChunk = new Set<string>(['0', '0.1']);
assert.deepStrictEqual(chunks[0], chunk);
assert.deepStrictEqual(chunks[1], chunk);
assert.deepStrictEqual(chunks[2], nestedChunk);
assert.deepStrictEqual(chunks[3], nestedChunk);
});
test('propagates source to children', function () {
const snapshot: PromptSnapshotNode = {
name: 'root',
path: '0',
value: 'root',
props: { source: 'source1' },
statistics: {},
children: [
{
name: 'child',
path: '0.0',
value: 'child',
statistics: {},
},
],
};
const walker = new SnapshotWalker(snapshot);
const sources: unknown[] = [];
walker.walkSnapshot((node, parent, context) => {
sources.push(context.source);
return true;
});
assert.deepStrictEqual(sources, ['source1', 'source1']);
});
function createTestSnapshot(
depth: number,
childrenCount: number = 3,
currentPath: string = ''
): PromptSnapshotNode {
if (depth <= 0) {
return {
name: 'leaf',
path: currentPath || '0',
value: currentPath || '0',
statistics: {},
};
}
const children: PromptSnapshotNode[] = [];
const nodeIndex = currentPath || '0';
// Create configurable number of children at each level
for (let i = 0; i < childrenCount; i++) {
const childPath = `${nodeIndex}.${i}`;
children.push(createTestSnapshot(depth - 1, childrenCount, childPath));
}
return {
name: `node-${nodeIndex}`,
path: nodeIndex,
value: nodeIndex,
children,
statistics: {},
};
}
});

View File

@@ -0,0 +1,453 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as assert from 'assert';
import dedent from 'ts-dedent';
import {
blankNode,
clearLabels,
clearLabelsIf,
cutTreeAfterLine,
deparseAndCutTree,
deparseLine,
deparseTree,
duplicateTree,
firstLineOf,
foldTree,
IndentationTree,
isBlank,
isLine,
lastLineOf,
lineNode,
LineNode,
mapLabels,
parseRaw,
parseTree,
resetLineNumbers,
topNode,
virtualNode,
visitTree,
visitTreeConditionally,
} from '../indentation';
import { compareTreeWithSpec } from './testHelpers';
function doParseTest<T>(source: string, expectedTree: IndentationTree<T>) {
const tree = clearLabels(parseTree(source, 'python'));
compareTreeWithSpec(tree, expectedTree);
}
const SOURCE = {
source: dedent`
f1:
a1
f2:
a2
a3
`,
name: '',
};
suite('Test compareTreeWithSpec', function () {
const SOURCE_MISSING_CHILD = {
source: dedent`
f1:
a1
f2:
a2
`,
name: 'missing child',
};
const SOURCE_EXTRA_CHILD = {
source: dedent`
f1:
a1
f2:
a2
a3
a4
`,
name: 'extra_child',
};
const SOURCE_MISSING_SIBLING = {
source: dedent`
f1:
a1
`,
name: 'missing sibling',
};
const SOURCE_EXTRA_SIBLING = {
source: dedent`
f1:
a1
f2:
a2
a3
f3:
a4
`,
name: 'extra_sibling',
};
const SOURCE_EXTRA_MIDDLE_BLANK_LINE = {
source: dedent`
f1:
a1
f2:
a2
a3
`,
name: 'extra middle blank line',
};
const SOURCE_EXTRA_TRAILING_BLANK_LINE = {
source: dedent`
f1:
a1
f2:
a2
a3
`,
name: 'extra trailing blank line',
};
const SOURCE_EXTRA_INDENTATION = {
source: dedent`
f1:
a1
f2:
a2
a3
`,
name: 'extra indentation',
};
const expected = topNode([
lineNode(0, 0, 'f1:', [lineNode(4, 1, 'a1', [])]),
lineNode(0, 2, 'f2:', [lineNode(4, 3, 'a2', []), lineNode(4, 4, 'a3', [])]),
]);
test('Test compareTreeWithSpec with good input', function () {
doParseTest(SOURCE.source, expected);
});
// Loop over all bad inputs where we expect a failure from compareTreeWithSpec
for (const badInput of [
SOURCE_MISSING_CHILD,
SOURCE_EXTRA_CHILD,
SOURCE_MISSING_SIBLING,
SOURCE_EXTRA_SIBLING,
SOURCE_EXTRA_INDENTATION,
SOURCE_EXTRA_TRAILING_BLANK_LINE,
]) {
test(`Test compareTreeWithSpec with bad input ${badInput.name}`, function () {
assert.throws(
() => doParseTest(badInput.source, expected),
assert.AssertionError,
`Expected to fail with ${JSON.stringify(badInput)}`
);
});
}
// Do we want extra blank lines to be children?
test('Test compareTreeWithSpec with extra blank line input', function () {
assert.throws(
() => doParseTest(SOURCE_EXTRA_MIDDLE_BLANK_LINE.source, expected),
assert.AssertionError,
'Expected to fail with extra blank line, actually fails with extra child'
);
});
});
suite('Tree core functions: label manipulation', function () {
function setOfLabels<L>(tree: IndentationTree<L>): Set<L | 'undefined'> {
const labels = new Set<L | 'undefined'>();
visitTree(
tree,
node => {
labels.add(node.label ?? 'undefined');
},
'topDown'
);
return labels;
}
test('Remove labels from tree', function () {
const tree = parseTree(SOURCE.source, 'python');
setOfLabels(tree);
visitTree(
tree,
node => {
node.label = node.type === 'line' && node.lineNumber % 2 === 0 ? 'foo' : 'bar';
},
'topDown'
);
setOfLabels(tree);
assert.notDeepStrictEqual([...setOfLabels(tree)], ['undefined'], 'Tree never had labels');
clearLabels(tree);
assert.deepStrictEqual([...setOfLabels(tree)], ['undefined'], 'Tree still has labels');
});
test('Remove certain labels from tree', function () {
const tree = parseRaw(SOURCE.source) as IndentationTree<string>;
visitTree(
tree,
node => {
node.label = node.type === 'line' && node.lineNumber % 2 === 0 ? 'foo' : 'bar';
},
'topDown'
);
assert.deepStrictEqual([...setOfLabels(tree)], ['bar', 'foo'], 'Did not prepare tree as expected');
clearLabelsIf<'foo', 'bar'>(
tree as IndentationTree<'foo' | 'bar'>,
// type predicate of form arg is 'bar':
(arg: 'foo' | 'bar'): arg is 'bar' => arg === 'bar'
);
assert.deepStrictEqual([...setOfLabels(tree)], ['undefined', 'foo'], 'Did not remove bar labels');
});
test('Test mapLabels', function () {
const tree = parseTree(SOURCE.source + '\n\nprint("bye")', 'python');
visitTree(
tree,
node => {
node.label = node.type === 'line' && node.lineNumber % 2 === 0 ? 'foo' : 'bar';
},
'topDown'
);
assert.deepStrictEqual([...setOfLabels(tree)], ['bar', 'foo'], 'Did not prepare tree as expected');
const labelsBefore = foldTree(tree, [] as string[], (node, acc) => [...acc, node.label ?? ''], 'topDown');
const mapfct = (label: string) => (label === 'foo' ? 1 : 2);
const treeWithNumbers = mapLabels(tree as IndentationTree<'foo' | 'bar'>, mapfct);
const labelsAfter = foldTree(
treeWithNumbers,
[] as Array<string | number>,
(node, acc) => [...acc, node.label ?? ''],
'topDown'
);
assert.deepStrictEqual([...setOfLabels(treeWithNumbers)], [2, 1], 'Did not map labels');
assert.deepStrictEqual(labelsBefore.map(mapfct), labelsAfter, 'Did not map labels right');
});
});
suite('Tree core functions: line numbers', function () {
const tree = parseTree(SOURCE.source, 'python');
test('First line of source tree is 0', function () {
assert.strictEqual(firstLineOf(tree), 0);
});
test('First line of source tree + two newlines is 2', function () {
const offsetTree = parseTree(`\n\n${SOURCE.source}`, 'python');
const originalTree = offsetTree.subs[2];
assert.strictEqual(firstLineOf(originalTree), 2);
});
test('Last line of source tree is 4', function () {
assert.strictEqual(lastLineOf(tree), 4);
});
test('firstLineOf', function () {
const firstLine = firstLineOf(
topNode([virtualNode(0, []), virtualNode(0, [lineNode(0, 5, 'zero', [])]), lineNode(0, 6, 'one', [])])
);
assert.ok(firstLine !== undefined);
assert.strictEqual(firstLine, 5);
});
test('firstLineOf undefined', function () {
const firstLine = firstLineOf(topNode([virtualNode(0, []), virtualNode(0, [virtualNode(0, [])])]));
assert.ok(firstLine === undefined);
});
test('firstLineOf blank', function () {
const firstLine = firstLineOf(topNode([blankNode(1), lineNode(0, 2, 'line', [])]));
assert.ok(firstLine === 1);
});
test('lastLineOf', function () {
const line = lastLineOf(
topNode([
virtualNode(0, []),
virtualNode(0, [lineNode(0, 1, 'first', [])]),
lineNode(0, 2, 'second', [lineNode(0, 3, 'third', []), lineNode(0, 4, 'fourth', [])]),
])
);
assert.ok(line !== undefined);
assert.strictEqual(line, 4);
});
test('lastLineOf take by tree order, not registered line numbers', function () {
const line = lastLineOf(
topNode([
lineNode(
0,
5,
'parent',
[lineNode(0, 4, 'child 1', []), lineNode(0, 3, 'child 2', []), lineNode(0, 2, 'child 3', [])],
5
),
])
);
assert.ok(line !== undefined);
assert.strictEqual(line, 2);
});
test('lastLineOf undefined', function () {
const line = lastLineOf(topNode([virtualNode(0, []), virtualNode(0, [virtualNode(0, [])])]));
assert.ok(line === undefined);
});
test('lastLineOf blank', function () {
const line = lastLineOf(topNode([lineNode(0, 1, 'line', []), blankNode(2)]));
assert.ok(line === 2);
});
test('Reset line numbers for tree', function () {
const duplicatedTree = duplicateTree(tree);
visitTree(
duplicatedTree,
node => {
if (isLine(node)) { node.lineNumber = -1; }
},
'topDown'
);
assert.strictEqual(firstLineOf(duplicatedTree), -1);
assert.strictEqual(lastLineOf(duplicatedTree), -1);
resetLineNumbers(duplicatedTree);
let counter = 0;
visitTree(
duplicatedTree,
node => {
if (isLine(node) || isBlank(node)) {
assert.strictEqual(node.lineNumber, counter);
counter++;
}
},
'topDown'
);
});
});
suite('Test core functions: other', function () {
const tree = parseTree(SOURCE.source, 'python');
test('deparseTree should give same output as source input', function () {
// Assert that the tree is the same as the source, ignoring trailing newlines
assert.strictEqual(deparseTree(tree).replace(/\n*$/, ''), SOURCE.source.replace(/\n*$/, ''));
});
test('deparseTree should give same output as source input with an extra blank line', function () {
const treeLonger = parseTree(`${SOURCE.source}\n`, 'python');
// Assert that the tree is the same as the source, ignoring trailing newlines
assert.strictEqual(deparseTree(treeLonger).replace(/\n*$/, ''), SOURCE.source.replace(/\n*$/, ''));
});
test('deparseAndCutTree cuts at labels', function () {
const source = dedent`
1
2
3
4
5
6
7
8
9`;
const tree = parseRaw(source) as IndentationTree<string>;
tree.subs[0].subs[1].label = 'cut';
tree.subs[1].subs[0].label = 'cut';
const cuts = deparseAndCutTree(tree, ['cut']);
// since there were two cuts, it's cut in 5 bits:
assert.strictEqual(cuts.length, 5);
// it's cut at the lines labeled 'cut'
assert.strictEqual(cuts[1].source, deparseLine(tree.subs[0].subs[1] as LineNode<string>));
assert.strictEqual(cuts[3].source, deparseLine(tree.subs[1].subs[0] as LineNode<string>));
// all together give the original source (ignoring trailing newlines -- _all_ cuts are newline ended)
assert.strictEqual(cuts.map(x => x.source).join(''), source + '\n');
});
/* test('encodeTree should give an expression coding the tree', function () {
const source = dedent`
1
2
3
4 (
5
6
)
7
8
9
)`;
const tree = groupBlocks(parseTree(source));
// to eval, need to make several imports explicit
const functions = [topNode, virtualNode, lineNode, blankNode];
assert.notStrictEqual(functions, []); // make functions used
const treeAfterRoundTrip = <IndentationTree<string>>eval(`
const topNode = functions[0];
const virtualNode = functions[1];
const lineNode = functions[2];
const blankNode = functions[3];
${encodeTree(tree)}`);
compareTreeWithSpec(treeAfterRoundTrip, tree);
}); */
test('Cutting tree correctly', function () {
const cutTree = parseTree(SOURCE.source, 'python');
cutTreeAfterLine(cutTree, 2);
assert.strictEqual(lastLineOf(cutTree), 2);
});
test('VisitTreeConditionally', function () {
const tree = parseRaw(dedent`
1
2
3
4
5
6
7
8
9`);
const traceTopDownAll: string[] = [];
visitTree(
tree,
node => {
if (node.type === 'line') { traceTopDownAll.push(node.sourceLine.trim()); }
return node.type === 'top';
},
'topDown'
);
assert.deepStrictEqual(
traceTopDownAll,
['1', '2', '3', '4', '5', '6', '7', '8', '9'],
'visit all in order: top to down'
);
const traceButtonUpAll: string[] = [];
visitTree(
tree,
node => {
if (node.type === 'line') { traceButtonUpAll.push(node.sourceLine.trim()); }
return node.type === 'top';
},
'bottomUp'
);
assert.deepStrictEqual(
traceButtonUpAll,
['2', '3', '1', '5', '6', '4', '8', '9', '7'],
'visit all in order: first leaves, then parents'
);
const traceTopDown: string[] = [];
visitTreeConditionally(
tree,
node => {
if (node.type === 'line') { traceTopDown.push(node.sourceLine.trim()); }
return traceTopDown.length < 4;
},
'topDown'
);
assert.deepStrictEqual(traceTopDown, ['1', '2', '3', '4'], 'should stop after four lines');
const traceButtomUp: string[] = [];
visitTreeConditionally(
tree,
node => {
if (node.type === 'line') { traceButtomUp.push(node.sourceLine.trim()); }
return traceButtomUp.length < 4;
},
'bottomUp'
);
assert.deepStrictEqual(traceButtomUp, ['2', '3', '1', '5'], 'should stop after four nodes');
});
});

View File

@@ -0,0 +1,265 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as assert from 'assert';
import dedent from 'ts-dedent';
import { blankNode, isLine, lineNode, parseTree, topNode, virtualNode, visitTree } from '../indentation';
import { compareTreeWithSpec } from './testHelpers';
/** Test some language specific parsing techniques */
suite('Java', function () {
test('method detection in Java', function () {
const source = dedent`
// first an import
import java.util.List;
@Override
public class Test {
public static void main(String[] args) {
System.out.println("Hello World!");
}
@Override
private List<String> list;
}`;
const javaParsedTree = parseTree(source, 'java');
// we should have picked up the correct labels
const lineLabels: string[] = [];
visitTree(
javaParsedTree,
node => {
if (isLine(node) && node.label) {
lineLabels.push(node.label);
}
},
'topDown'
);
assert.deepStrictEqual(lineLabels, [
'comment_single',
'import',
// blank
'annotation',
'class',
'member',
// not labelled
'closer',
// blank
'member', // as per explicit comment, the annotations within a class are relabeled 'member,
'member',
'closer',
]);
});
test('labelLines java', function () {
const tree = parseTree(
dedent`
package com.example;
import java.awt.*;
@annotation
final public class A {
/** A javadoc
* Second line
*/
public static void main(String[] args) {
// single-line comment
/* Multiline
* comment
*/
System.out.println("Hello, world!");
}
}
public interface I { }
`,
'java'
);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'pa...', [], 'package'),
lineNode(0, 1, 'imp..', [], 'import'),
lineNode(0, 2, '@ann...', [], 'annotation'),
lineNode(
0,
3,
'cla...',
[
lineNode(4, 4, '/**...', [lineNode(5, 5, '* ...', []), lineNode(5, 6, '* ...', [])], 'javadoc'),
lineNode(4, 7, 'public...', [
lineNode(8, 8, '//...', [], 'comment_single'),
lineNode(
8,
9,
'/*...',
[lineNode(9, 10, '* ...', []), lineNode(9, 11, '*/', [])],
'comment_multi'
),
lineNode(8, 12, 'System ...', []),
lineNode(4, 13, '}', [], 'closer'),
]),
lineNode(0, 14, '}', [], 'closer'),
],
'class'
),
lineNode(0, 15, 'public...', [], 'interface'),
])
);
});
test('parse Java fields', function () {
//TODO: Add a field with annotation on separate line
const tree = parseTree(
dedent`
class A {
int a;
/** Javadoc */
int b;
// Comment
@Native int c;
}
`,
'java'
);
compareTreeWithSpec(
tree,
topNode([
lineNode(
0,
0,
'class...',
[
lineNode(4, 1, 'int a;', [], 'member'),
lineNode(4, 2, '/**...', [], 'javadoc'),
lineNode(4, 3, 'int b;', [], 'member'),
lineNode(4, 4, '//...', [], 'comment_single'),
lineNode(4, 5, '@Native int c;', [], 'member'),
lineNode(0, 6, '}', [], 'closer'),
],
'class'
),
])
);
});
test('parse Java inner class', function () {
const tree = parseTree(
dedent`
class A {
int a;
class Inner {
int b;
}
interface InnerInterface {
int myMethod();
}
}
`,
'java'
);
compareTreeWithSpec(
tree,
topNode([
lineNode(
0,
0,
'class A {',
[
lineNode(4, 1, 'int a;', [], 'member'),
blankNode(2),
lineNode(
4,
3,
'class Inner ...',
[lineNode(8, 4, 'int b;', [], 'member'), lineNode(4, 5, '}', [], 'closer')],
'class'
),
blankNode(6),
lineNode(
4,
7,
'interface InnerInterface ...',
[lineNode(8, 8, 'int myMethod();', [], 'member'), lineNode(4, 9, '}', [], 'closer')],
'interface'
),
lineNode(0, 10, '}', [], 'closer'),
],
'class'
),
])
);
});
});
suite('Markdown', function () {
test('header processing in markdown', function () {
const source = dedent`
A
# B
C
D
## E
F
G
# H
I
### J
K
L
M
`;
const mdParsedTree = parseTree(source, 'markdown');
compareTreeWithSpec(
mdParsedTree,
topNode([
virtualNode(0, [lineNode(0, 0, 'A', []), blankNode(1)]),
virtualNode(0, [
lineNode(
0,
2,
'# B',
[
virtualNode(0, [lineNode(0, 3, 'C', []), lineNode(0, 4, 'D', []), blankNode(5)]),
lineNode(
0,
6,
'## E',
[lineNode(0, 7, 'F', []), lineNode(0, 8, 'G', []), blankNode(9)],
'subheading'
),
],
'heading'
),
lineNode(
0,
10,
'# H',
[
virtualNode(0, [lineNode(0, 11, 'I', []), blankNode(12)]),
lineNode(
0,
13,
'### J',
[
virtualNode(0, [lineNode(0, 14, 'K', []), blankNode(15)]),
virtualNode(0, [lineNode(0, 16, 'L', []), lineNode(0, 17, 'M', [])]),
],
'subsubheading'
),
],
'heading'
),
]),
])
);
});
});

View File

@@ -0,0 +1,656 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as assert from 'assert';
import dedent from 'ts-dedent';
import {
blankNode,
buildLabelRules,
combineClosersAndOpeners,
flattenVirtual,
groupBlocks,
IndentationSubTree,
IndentationTree,
isLine,
isVirtual,
labelLines,
lineNode,
parseRaw,
parseTree,
topNode,
VirtualNode,
virtualNode,
visitTree,
} from '../indentation';
import { compareTreeWithSpec } from './testHelpers';
/**
* Parse a tree according to indentation, where lines
* with content "-> virtual" are translated into virtual nodes
* E.g.
* A
* -> virtual
* B
* C
* Will be parsed as: A having a virtual child, whose children are B and C
* @param sourceParsedAsIf
* @returns
*/
function parseAsIfVirtual(sourceParsedAsIf: string) {
const treeExpected = parseRaw(sourceParsedAsIf);
visitTree(
treeExpected,
node => {
if (isLine(node) && node.sourceLine.trim() === '-> virtual') {
node = node as unknown as VirtualNode<never>;
node.type = 'virtual';
}
},
'topDown'
);
return treeExpected;
}
suite('Test core parsing elements', function () {
test('flattenVirtual 1', function () {
const before = topNode([virtualNode(0, []), virtualNode(0, [lineNode(0, 0, 'lonely node', [])])]);
const after = topNode([lineNode(0, 0, 'lonely node', [])]);
compareTreeWithSpec(flattenVirtual(before), after);
});
test('flattenVirtual 2', function () {
const before = topNode([lineNode(0, 0, 'A', [virtualNode(2, [lineNode(2, 1, 'lonely node', [])])])]);
const after = topNode([lineNode(0, 0, 'A', [lineNode(2, 1, 'lonely node', [])])]);
compareTreeWithSpec(flattenVirtual(before), after);
});
test('groupBlocks basic cases', function () {
const source = dedent`
A
B
C
D
E
F
G
H`;
const tree = parseRaw(source);
const blockTree = groupBlocks(tree);
function assertChildrenAreTheFollowingLines(
tree: IndentationTree<never>,
children: (string | number)[],
message: string = ''
) {
assert.deepStrictEqual(
tree.subs.map((node: IndentationSubTree<string>) => (isVirtual(node) ? 'v' : node.lineNumber)),
children,
message
);
}
assertChildrenAreTheFollowingLines(blockTree, ['v', 'v', 'v', 'v'], 'wrong topline blocks');
assertChildrenAreTheFollowingLines(blockTree.subs[0], [0, 1], 'wrong zeroth block');
assertChildrenAreTheFollowingLines(blockTree.subs[1], [2, 3, 4, 5], 'wrong first block');
assertChildrenAreTheFollowingLines(blockTree.subs[2], [6, 7, 8], 'wrong second block');
assertChildrenAreTheFollowingLines(blockTree.subs[3], [9, 10], 'wrong fourth block');
});
test('groupBlocks advanced cases', function () {
// tests consecutive blank lines, first child blank lines,
// blank lines after last child, lone blank lines,
// consecutive lone blank lines, offside blocks
let tree = parseRaw(dedent`
A
B
C
D
E
F
G
H
I
J
K
`);
tree = groupBlocks(tree);
compareTreeWithSpec(
tree,
topNode([
virtualNode(0, [
lineNode(0, 0, 'A', [
blankNode(1),
virtualNode(2, [
lineNode(2, 2, 'B', []),
lineNode(2, 3, 'C', [lineNode(4, 4, 'D', [])]),
blankNode(5),
]),
virtualNode(2, [lineNode(2, 6, 'E', []), blankNode(7), blankNode(8)]),
virtualNode(2, [lineNode(2, 9, 'F', [])]),
]),
blankNode(10),
]),
virtualNode(0, [
lineNode(0, 11, 'G', [
virtualNode(4, [
lineNode(4, 12, 'H', []),
lineNode(4, 13, 'I', []),
lineNode(2, 14, 'J', []),
blankNode(15),
]),
virtualNode(4, [lineNode(2, 16, 'K', [])]),
]),
]),
])
);
});
test('groupBlocks consecutive blanks as oldest children', function () {
let tree = parseRaw(dedent`
A
B1
B2
C
`);
tree = groupBlocks(tree);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'A', [
blankNode(1),
blankNode(2),
virtualNode(4, [lineNode(4, 3, 'B1', []), lineNode(4, 4, 'B2', [])]),
]),
lineNode(0, 5, 'C', []),
])
);
});
test('groupBlocks subs ending with a blank line', function () {
const baseTree = topNode([
lineNode(0, 0, 'A', [blankNode(1)]),
lineNode(0, 2, 'B', [blankNode(3), blankNode(4)]),
blankNode(5),
lineNode(0, 6, 'C', []),
]);
const tree = groupBlocks(baseTree);
compareTreeWithSpec(
tree,
topNode([
virtualNode(0, [
lineNode(0, 0, 'A', [blankNode(1)]),
lineNode(0, 2, 'B', [blankNode(3), blankNode(4)]),
blankNode(5),
]),
virtualNode(0, [lineNode(0, 6, 'C', [])]),
])
);
});
test('groupBlocks with different delimiter', function () {
let tree = parseRaw(dedent`
A
B
C
D
E
`) as IndentationTree<string>;
const isDelimiter = (node: IndentationTree<string>) =>
isLine(node) && (node.sourceLine.trim() === 'B' || node.sourceLine.trim() === 'D');
tree = groupBlocks(tree, isDelimiter);
compareTreeWithSpec(
tree,
topNode([
virtualNode(0, [lineNode(0, 0, 'A', []), lineNode(0, 1, 'B', [])]),
virtualNode(0, [lineNode(0, 2, 'C', []), lineNode(0, 3, 'D', [])]),
virtualNode(0, [lineNode(0, 4, 'E ', [])]),
])
);
});
});
suite('Raw parsing', function () {
test('parseRaw', function () {
compareTreeWithSpec(
parseRaw(dedent`
A
a
B
b1
b2
C
c1
c2
c3
D
d1
d2
`),
topNode([
lineNode(0, 0, 'A', [lineNode(2, 1, 'a', [])]),
lineNode(0, 2, 'B', [lineNode(2, 3, 'b1', []), lineNode(2, 4, 'b2', [])]),
lineNode(0, 5, 'C', [lineNode(4, 6, 'c1', []), lineNode(4, 7, 'c2', []), lineNode(2, 8, 'c3', [])]),
lineNode(0, 9, 'D', [lineNode(2, 10, 'd1', [lineNode(4, 11, 'd2', [])])]),
])
);
});
test('parseRaw blanks', function () {
compareTreeWithSpec(
parseRaw(dedent`
E
e1
e2
F
f1
G
g1
H
`),
topNode([
lineNode(0, 0, 'E', [lineNode(2, 1, 'e1', []), blankNode(2), lineNode(2, 3, 'e2', [])]),
lineNode(0, 4, 'F', [blankNode(5), lineNode(2, 6, 'f1', [])]),
lineNode(0, 7, 'G', [lineNode(2, 8, 'g1', [])]),
blankNode(9),
lineNode(0, 10, 'H', []),
blankNode(11),
])
);
});
test('combineBraces', function () {
const tree = parseTree(dedent`
A {
}
B
b1 {
bb1
}
b2 {
bb2
}
}
C {
c1
c2
c3
c4
}
`);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'A {', [lineNode(0, 1, '}', [], 'closer')]),
lineNode(0, 2, 'B', [
lineNode(2, 3, 'b1 {', [lineNode(4, 4, 'bb1', []), lineNode(2, 5, '}', [], 'closer')]),
lineNode(2, 6, 'b2 {', [
lineNode(4, 7, 'bb2', []),
blankNode(8),
lineNode(2, 9, '}', [], 'closer'),
]),
lineNode(0, 10, '}', [], 'closer'),
]),
lineNode(0, 11, 'C {', [
lineNode(4, 12, 'c1', []),
lineNode(4, 13, 'c2', []),
lineNode(2, 14, 'c3', []),
lineNode(2, 15, 'c4', []),
lineNode(0, 16, '}', [], 'closer'),
]),
])
);
// Running the optimisation twice doesn't change the result
let newTree = <IndentationTree<string>>JSON.parse(JSON.stringify(tree));
newTree = combineClosersAndOpeners(newTree);
compareTreeWithSpec(newTree, tree);
});
});
/**
* Many examples in this suite are taken from
* https://docs.google.com/document/d/1WxjTDzx8Qbf4Bklrp9KwiQsB4-kTOloAR5h86np3_OM/edit#
*/
suite('Test bracket indentation spec', function () {
test('Opener merged to older sibling', function () {
const source = dedent`
A
(
B
C`;
const treeRaw = parseRaw(source);
const treeCode = parseTree(source, '');
// the raw indentation indicates line 1 is the parent of the following lines
compareTreeWithSpec(
treeRaw,
topNode([lineNode(0, 0, 'A', []), lineNode(0, 1, '(', [lineNode(4, 2, 'B', []), lineNode(4, 3, 'C', [])])])
);
// the bracket parsing indicates line 0 is the parent
compareTreeWithSpec(
treeCode,
topNode([
lineNode(0, 0, 'A', [
lineNode(0, 1, '(', [], 'opener'),
lineNode(4, 2, 'B', []),
lineNode(4, 3, 'C', []),
]),
])
);
});
test('Closer merged, simplest case', function () {
const source = dedent`
A
B
)`;
const treeRaw = parseRaw(source);
const treeCode = parseTree(source, '');
// the raw indentation indicates line 2 is the sibling of 0
compareTreeWithSpec(
treeRaw,
topNode([lineNode(0, 0, 'A', [lineNode(4, 1, 'B', [])]), lineNode(0, 2, ')', [])])
);
// the bracket parsing indicates line 2 actually another child
compareTreeWithSpec(
treeCode,
topNode([lineNode(0, 0, 'A', [lineNode(4, 1, 'B', []), lineNode(0, 2, ')', [], 'closer')])])
);
});
test('Closer merged, multi-body case', function () {
const source = dedent`
A
B
C
) + (
D
E
)`;
const treeRaw = parseRaw(source);
const treeCode = parseTree(source, '');
// before bracket parsing, A had two children, B and C
assert.strictEqual(
treeRaw.subs[0].subs.map(x => (x.type === 'line' ? x.sourceLine.trim() : 'v')).join(),
'B,C'
);
// after, it had three children, a virtual node, line node 3 and the closer 6
assert.strictEqual(
treeCode.subs[0].subs.map(x => (x.type === 'line' ? x.sourceLine.trim() : 'v')).join(),
'v,) + (,)'
);
});
test('closer starting their next subblock, ifelse', function () {
const source = dedent`
if (new) {
print(“hello”)
print(“world”)
} else {
print(“goodbye”)
}`;
const sourceParsedAsIf = dedent`
if (new) {
-> virtual
print(“hello”)
print(“world”)
} else {
print(“goodbye”)
}`;
const treeRaw = parseRaw(source);
const treeCode = parseTree(source, '');
const treeExpected = parseAsIfVirtual(sourceParsedAsIf);
compareTreeWithSpec(
treeRaw,
topNode([
lineNode(0, 0, 'if (new) {', [
lineNode(4, 1, 'print(“hello”)', []),
lineNode(4, 2, 'print(“world”)', []),
]),
lineNode(0, 3, '} else {', [lineNode(4, 4, 'print(“goodbye”)', [])]),
lineNode(0, 5, '}', []),
])
);
compareTreeWithSpec(
treeCode,
topNode([
lineNode(0, 0, 'if (new) {', [
virtualNode(0, [lineNode(4, 1, 'print(“hello”)', []), lineNode(4, 2, 'print(“world”)', [])]),
lineNode(0, 3, '} else {', [lineNode(4, 4, 'print(“goodbye”)', [])]),
lineNode(0, 5, '}', []),
]),
])
);
compareTreeWithSpec(treeCode, treeExpected, 'structure');
});
});
suite('Special indentation styles', function () {
test('Allman style example (function)', function () {
const source = dedent`
function test()
{
print(“hello”)
print(“world”)
}`;
const treeRaw = parseRaw(source);
const treeCode = parseTree(source, '');
// the bracket parsing indicates line 0 is the parent
compareTreeWithSpec(
treeCode,
topNode([
lineNode(0, 0, 'function test()', [
lineNode(0, 1, '{', [], 'opener'),
lineNode(4, 2, 'print(“hello”)', []),
lineNode(4, 3, 'print(“world”)', []),
lineNode(0, 4, '}', [], 'closer'),
]),
])
);
// the next line is also moved, but by the closing partof the spec, so not tested here
compareTreeWithSpec(
treeRaw,
topNode([
lineNode(0, 0, 'function test()', []),
lineNode(0, 1, '{', [lineNode(4, 2, 'print(“hello”)', []), lineNode(4, 3, 'print(“world”)', [])]),
lineNode(0, 4, '}', []),
])
);
});
/** This test is a case where our parsing isn't yet optimal */
test('Allman style example (if-then-else)', function () {
const source = dedent`
if (condition)
{
print(“hello”)
print(“world”)
}
else
{
print(“goodbye”)
print(“phone”)
}
`;
const treeCode = parseTree(source, '');
// Currently, this is parsed the same as two consecutive if-statements,
// Because generic languages do not understand `else` should continue.
compareTreeWithSpec(
treeCode,
topNode([
lineNode(0, 0, 'if (condition)', [
lineNode(0, 1, '{', [], 'opener'),
lineNode(4, 2, 'print(“hello”)', []),
lineNode(4, 3, 'print(“world”)', []),
lineNode(0, 4, '}', [], 'closer'),
]),
lineNode(0, 5, 'else ', [
lineNode(0, 6, '{', [], 'opener'),
lineNode(4, 7, 'print(“goodbye”)', []),
lineNode(4, 8, 'print(“phone”)', []),
lineNode(0, 9, '}', [], 'closer'),
]),
])
);
});
test('K&R style example (if-then-else)', function () {
const source = dedent`
if (condition) {
print(“hello”)
print(“world”)
} else {
print(“goodbye”)
print(“phone”)
}
`;
const treeCode = parseTree(source, '');
// Currently, this is parsed the same as two consecutive if-statements,
// Because generic languages do not understand `else` should continue.
compareTreeWithSpec(
treeCode,
topNode([
lineNode(0, 0, 'if (condition) {', [
virtualNode(0, [lineNode(4, 2, 'print(“hello”)', []), lineNode(4, 3, 'print(“world”)', [])]),
lineNode(
0,
4,
'} else {',
[lineNode(4, 5, 'print(“goodbye”)', []), lineNode(4, 6, 'print(“phone”)', [])],
'closer'
),
lineNode(0, 7, '}', [], 'closer'),
]),
])
);
});
test('combineBraces GNU style indentation 1', function () {
let tree: IndentationTree<string> = parseRaw(dedent`
A
{
stmt
}
`);
labelLines(tree, buildLabelRules({ opener: /^{$/, closer: /^}$/ }));
tree = combineClosersAndOpeners(tree);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'A', [
lineNode(2, 1, '{', [lineNode(4, 2, 'stmt', []), lineNode(2, 3, '}', [], 'closer')], 'opener'),
]),
])
);
});
test('combineBraces GNU style indentation 2', function () {
let tree: IndentationTree<string> = parseRaw(dedent`
B
{
stmt
}
end
`);
labelLines(tree, buildLabelRules({ opener: /^{$/, closer: /^}$/ }));
tree = combineClosersAndOpeners(tree);
tree = flattenVirtual(tree);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'B', [
lineNode(0, 1, '{', [], 'opener'),
lineNode(4, 2, 'stmt', []),
blankNode(3),
lineNode(0, 4, '}', [], 'closer'),
]),
blankNode(5),
blankNode(6),
lineNode(0, 7, 'end', []),
])
);
});
test('combineBraces GNU style indentation 3', function () {
let tree: IndentationTree<string> = parseRaw(dedent`
C
{
}
`);
labelLines(tree, buildLabelRules({ opener: /^{$/, closer: /^}$/ }));
tree = combineClosersAndOpeners(tree);
tree = flattenVirtual(tree);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'C', [
lineNode(0, 1, '{', [], 'opener'),
blankNode(2),
lineNode(0, 3, '}', [], 'closer'),
]),
])
);
});
test('combineBraces GNU style indentation 4', function () {
let tree: IndentationTree<string> = parseRaw(dedent`
D
{
d
{
stmt
}
}
`);
labelLines(tree, buildLabelRules({ opener: /^{$/, closer: /^}$/ }));
tree = combineClosersAndOpeners(tree);
tree = flattenVirtual(tree);
compareTreeWithSpec(
tree,
topNode([
lineNode(0, 0, 'D', [
lineNode(0, 1, '{', [], 'opener'),
lineNode(4, 2, 'd', [
lineNode(4, 3, '{', [], 'opener'),
lineNode(8, 4, 'stmt', []),
blankNode(5),
lineNode(4, 6, '}', [], 'closer'),
]),
lineNode(0, 7, '}', [], 'closer'),
]),
])
);
});
});

View File

@@ -0,0 +1,178 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
// we need useless escapes before `!` or some tooling breaks; contact @johanrosenkilde for details
import {
comment,
commentBlockAsSingles,
getLanguage,
getLanguageMarker,
hasLanguageMarker,
mdCodeBlockLangToLanguageId,
} from '../languageMarker';
import { DocumentInfoWithOffset } from '../prompt';
import * as assert from 'assert';
import * as fs from 'fs';
import { resolve } from 'path';
suite('LanguageMarker Test Suite', function () {
let doc: DocumentInfoWithOffset;
setup(function () {
const source = fs.readFileSync(resolve(__dirname, 'testdata/example.py'), 'utf8');
const languageId = 'python';
doc = {
uri: 'file:///home/user/test.py',
source,
languageId,
offset: 0,
};
});
test('getLanguageMarker', function () {
doc.languageId = 'python';
assert.strictEqual(getLanguageMarker(doc), '#!/usr/bin/env python3');
doc.languageId = 'cpp';
assert.strictEqual(getLanguageMarker(doc), 'Language: cpp');
doc.languageId = 'css';
assert.strictEqual(getLanguageMarker(doc), 'Language: css');
doc.languageId = 'html';
assert.strictEqual(getLanguageMarker(doc), '<!DOCTYPE html>');
doc.languageId = 'php';
assert.strictEqual(getLanguageMarker(doc), '');
doc.languageId = 'yaml';
assert.strictEqual(getLanguageMarker(doc), '# YAML data');
doc.languageId = 'unknown';
assert.strictEqual(getLanguageMarker(doc), 'Language: unknown');
});
test('hasLanguageMarker', function () {
doc.languageId = 'python';
doc.source = 'import mypants\ndef my_socks():\n pass';
assert.ok(!hasLanguageMarker(doc));
doc.source = '#!/bin/python\n' + doc.source; //Note: not the shebang we add ourselves
assert.ok(hasLanguageMarker(doc));
doc.languageId = 'html';
doc.source = '<html><body><p>My favourite web page</p></body></html>';
assert.ok(!hasLanguageMarker(doc));
doc.source = '<!DOCTYPE html>' + doc.source;
assert.ok(hasLanguageMarker(doc));
doc.languageId = 'shellscript';
doc.source = 'echo Wonderful script';
assert.ok(!hasLanguageMarker(doc));
doc.source = '#!/bin/bash\n' + doc.source;
assert.ok(hasLanguageMarker(doc));
});
test('comment normal', function () {
assert.strictEqual(comment('', 'python'), '# ');
assert.strictEqual(comment('hello', 'python'), '# hello');
assert.strictEqual(comment('hello', 'typescript'), '// hello');
});
test('comment demonstrate multiple lines gives unintuitive result', function () {
assert.strictEqual(comment('hello\nworld', 'typescript'), '// hello\nworld');
});
test('comment non-existing language', function () {
assert.strictEqual(comment('hello', 'nonexistent'), '// hello');
});
test('comment normal with default', function () {
assert.strictEqual(comment('', 'python'), '# ');
assert.strictEqual(comment('', 'nonexistent'), '// ');
assert.strictEqual(comment('hello', 'nonexistent'), '// hello');
});
test('commentBlockAsSingles normal', function () {
assert.strictEqual(commentBlockAsSingles('', 'python'), '');
assert.strictEqual(commentBlockAsSingles('hello', 'python'), '# hello');
assert.strictEqual(commentBlockAsSingles('hello\nworld', 'python'), '# hello\n# world');
assert.strictEqual(commentBlockAsSingles('hello\nworld', 'typescript'), '// hello\n// world');
});
test('commentBlockAsSingles trailing newline', function () {
assert.strictEqual(commentBlockAsSingles('hello\nworld\n', 'python'), '# hello\n# world\n');
assert.strictEqual(commentBlockAsSingles('\n', 'python'), '# \n');
});
test('commentBlockAsSingles nonexistent language', function () {
assert.strictEqual(commentBlockAsSingles('hello\nworld', 'nonexistent'), '// hello\n// world');
});
test('commentBlockAsSingles with default', function () {
assert.strictEqual(commentBlockAsSingles('hello\nworld', 'python'), '# hello\n# world');
assert.strictEqual(commentBlockAsSingles('hello\nworld', 'nonexistent'), '// hello\n// world');
});
const markdownLanguageIdsTestCases = [
{ input: 'h', expected: 'c' },
{ input: 'py', expected: 'python' },
{ input: 'js', expected: 'javascript' },
{ input: 'ts', expected: 'typescript' },
{ input: 'cpp', expected: 'cpp' },
{ input: 'java', expected: 'java' },
{ input: 'cs', expected: 'csharp' },
{ input: 'rb', expected: 'ruby' },
{ input: 'php', expected: 'php' },
{ input: 'html', expected: 'html' },
{ input: 'css', expected: 'css' },
{ input: 'xml', expected: 'xml' },
{ input: 'sh', expected: 'shellscript' },
{ input: 'go', expected: 'go' },
{ input: 'rs', expected: 'rust' },
{ input: 'swift', expected: 'swift' },
{ input: 'kt', expected: 'kotlin' },
{ input: 'lua', expected: 'lua' },
{ input: 'sql', expected: 'sql' },
{ input: 'yaml', expected: 'yaml' },
{ input: 'md', expected: 'markdown' },
{ input: 'plaintext', expected: undefined },
];
markdownLanguageIdsTestCases.forEach(({ input, expected }) => {
test(`test markdownLanguageId ${input} to language id ${expected}`, function () {
const languageId = mdCodeBlockLangToLanguageId(input);
assert.strictEqual(languageId, expected);
});
});
const getLanguageTestCases = [
{ input: 'python', expected: 'python', expCommentStart: '#', expCommentEnd: '' },
{ input: 'javascript', expected: 'javascript', expCommentStart: '//', expCommentEnd: '' },
{ input: 'typescript', expected: 'typescript', expCommentStart: '//', expCommentEnd: '' },
{ input: 'cpp', expected: 'cpp', expCommentStart: '//', expCommentEnd: '' },
{ input: 'java', expected: 'java', expCommentStart: '//', expCommentEnd: '' },
{ input: 'csharp', expected: 'csharp', expCommentStart: '//', expCommentEnd: '' },
{ input: 'ruby', expected: 'ruby', expCommentStart: '#', expCommentEnd: '' },
{ input: 'php', expected: 'php', expCommentStart: '//', expCommentEnd: '' },
{ input: 'html', expected: 'html', expCommentStart: '<!--', expCommentEnd: '-->' },
{ input: 'css', expected: 'css', expCommentStart: '/*', expCommentEnd: '*/' },
{ input: 'xml', expected: 'xml', expCommentStart: '<!--', expCommentEnd: '-->' },
{ input: 'shellscript', expected: 'shellscript', expCommentStart: '#', expCommentEnd: '' },
{ input: 'go', expected: 'go', expCommentStart: '//', expCommentEnd: '' },
{ input: 'rust', expected: 'rust', expCommentStart: '//', expCommentEnd: '' },
{ input: 'swift', expected: 'swift', expCommentStart: '//', expCommentEnd: '' },
{ input: 'kotlin', expected: 'kotlin', expCommentStart: '//', expCommentEnd: '' },
{ input: 'lua', expected: 'lua', expCommentStart: '--', expCommentEnd: '' },
{ input: 'sql', expected: 'sql', expCommentStart: '--', expCommentEnd: '' },
{ input: 'yaml', expected: 'yaml', expCommentStart: '#', expCommentEnd: '' },
{ input: 'markdown', expected: 'markdown', expCommentStart: '[]: #', expCommentEnd: '' },
{ input: 'plaintext', expected: 'plaintext', expCommentStart: '//', expCommentEnd: '' },
{ input: 'not-existed', expected: 'not-existed', expCommentStart: '//', expCommentEnd: '' },
{ input: undefined, expected: 'plaintext', expCommentStart: '//', expCommentEnd: '' },
];
getLanguageTestCases.forEach(({ input, expected, expCommentStart, expCommentEnd }) => {
test(`test getLanguage for language id ${input} to language id ${expected}`, function () {
const language = getLanguage(input);
assert.strictEqual(language.languageId, expected);
assert.strictEqual(language.lineComment.start, expCommentStart);
assert.strictEqual(language.lineComment.end, expCommentEnd);
});
});
});

View File

@@ -0,0 +1,98 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { DocumentInfoWithOffset, SimilarFileInfo } from '../prompt';
import {
SimilarFilesOptions,
defaultSimilarFilesOptions,
getSimilarSnippets,
nullSimilarFilesOptions,
} from '../snippetInclusion/similarFiles';
import * as assert from 'assert';
import dedent from 'ts-dedent';
suite('Test Multiple Snippet Selection', function () {
const docSource: string = dedent`
A
B
C
D|
E
F
G`;
const doc: DocumentInfoWithOffset = {
relativePath: 'source1',
uri: 'source1',
source: docSource,
languageId: 'python',
offset: docSource.indexOf('|'), // reference snippet will be A B C D
};
const similarFiles: SimilarFileInfo[] = [
{
relativePath: 'similarFile1',
uri: 'similarFile1',
source: dedent`
A
B
C
H
X
Y
Z
`,
},
{
relativePath: 'similarFile2',
uri: 'similarFile2',
source: dedent`
D
H
`,
},
];
const fixedWinDocSrc =
'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'
.split('')
.join('\n');
const fixedWinDoc: DocumentInfoWithOffset = {
relativePath: 'source1',
uri: 'source1',
source: fixedWinDocSrc,
languageId: 'python',
offset: fixedWinDocSrc.length, // Reference doc qrstuvqxyz with conservative option (10 characters), stuv...abc...xyz with eager (60 characters)
};
const fixedWinSimilarFiles: SimilarFileInfo[] = [
{
relativePath: 'similarFile1',
uri: 'similarFile1',
source: 'abcdefghijklmno1234567890abcdefghijklmnopqrstuvwxyzabcdefghijklmno1234567890abcdefghijklmnopqrstuvwxyzabcdefghijklmno1234567890abcdefghijklmnopqrstuvwxyz'
.split('')
.join('\n'),
},
];
test('FixedWindow Matcher None', async function () {
/** Test under FixedWindow matcher no match gets picked up */
const options: SimilarFilesOptions = nullSimilarFilesOptions;
const snippets = await getSimilarSnippets(doc, similarFiles, options);
assert.deepStrictEqual(snippets, []);
});
test('FixedWindow Matcher Eager No Selection Option', async function () {
/** This is to test Multisnippet selection with FixedWindow Matcher and Eager Neibhbortab
* option. windows size for Eager option is 60 and minimum score threshold for inclusion is 0.0.
* We expect only 1 match from line 0 to 60. WIth no selection option, we expect the best match to be returned.
*/
const options: SimilarFilesOptions = defaultSimilarFilesOptions;
const snippetLocationsTop1 = (await getSimilarSnippets(fixedWinDoc, fixedWinSimilarFiles, options)).map(
snippet => [snippet.startLine, snippet.endLine]
);
const correctSnippetLocations: number[][] = [[0, 60]];
assert.deepStrictEqual(snippetLocationsTop1.sort(), correctSnippetLocations.sort());
});
});

View File

@@ -0,0 +1,30 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as parse from '../parse';
import * as assert from 'assert';
import Parser from 'web-tree-sitter';
suite('Tree-sitter Parsing Tests', function () {
test('language wasm loading', async function () {
await Parser.init();
await parse.getLanguage('python');
await parse.getLanguage('javascript');
await parse.getLanguage('go');
await parse.getLanguage('php');
await parse.getLanguage('c');
await parse.getLanguage('cpp');
await assert.rejects(async () => await parse.getLanguage('xxx'));
});
suite('getBlockCloseToken', function () {
test('all', function () {
assert.strictEqual(parse.getBlockCloseToken('javascript'), '}');
assert.strictEqual(parse.getBlockCloseToken('typescript'), '}');
assert.strictEqual(parse.getBlockCloseToken('python'), null);
assert.strictEqual(parse.getBlockCloseToken('ruby'), 'end');
assert.strictEqual(parse.getBlockCloseToken('go'), '}');
});
});
});

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,478 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as assert from 'assert';
import dedent from 'ts-dedent';
import { DocumentInfoWithOffset, SimilarFileInfo } from '../prompt';
import { FixedWindowSizeJaccardMatcher, computeScore } from '../snippetInclusion/jaccardMatching';
import { ScoredSnippetMarker, SortOptions, splitIntoWords } from '../snippetInclusion/selectRelevance';
import {
SimilarFilesOptions,
conservativeFilesOptions,
defaultCppSimilarFilesOptions,
defaultSimilarFilesOptions,
getSimilarSnippets,
nullSimilarFilesOptions,
} from '../snippetInclusion/similarFiles';
import { SnippetWithProviderInfo } from '../snippetInclusion/snippets';
import { initializeTokenizers } from '../tokenization';
async function retrieveAllSnippetsWithJaccardScore(
objectDoc: SimilarFileInfo,
referenceDoc: SimilarFileInfo,
windowLength: number,
sortOption: SortOptions
): Promise<ScoredSnippetMarker[]> {
const referenceDocWithOffset: DocumentInfoWithOffset = {
...referenceDoc,
languageId: '',
offset: referenceDoc.source.length,
};
const matcher = FixedWindowSizeJaccardMatcher.FACTORY(windowLength).to(referenceDocWithOffset);
const match = await matcher.retrieveAllSnippets(objectDoc, sortOption);
return match;
}
async function findBestJaccardMatch(
objectDoc: SimilarFileInfo,
referenceDoc: SimilarFileInfo,
windowLength: number
): Promise<SnippetWithProviderInfo[]> {
const referenceDocWithOffset: DocumentInfoWithOffset = {
...referenceDoc,
languageId: '',
offset: referenceDoc.source.length,
};
const matcher = FixedWindowSizeJaccardMatcher.FACTORY(windowLength).to(referenceDocWithOffset);
const match = await matcher.findBestMatch(objectDoc, defaultCppSimilarFilesOptions.maxSnippetsPerFile);
return match;
}
suite('selectRelevance Test Suite', function () {
setup(async function () {
await initializeTokenizers;
});
test('findBestJaccardMatch computes correct score of two single lines', async function () {
// 100% match if equal
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good morning', uri: 'file:///home/user/test.js' },
{ source: 'good morning', uri: 'file:///home/user/test.js' },
1
)
)[0].score,
1
);
// no match if different
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good morning', uri: 'file:///home/user/test.js' },
{ source: 'bad night', uri: 'file:///home/user/test.js' },
1
)
).length,
0
);
// 33% match if 1 same, 1 different (because it's 1 overlap of 3 tokens in total)
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good morning', uri: 'file:///home/user/test.js' },
{ source: 'good night', uri: 'file:///home/user/test.js' },
1
)
)[0].score,
1 / 3
);
// 50% match if half the tokens are missing (because it's 1 overlap of 2 tokens in total)
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good morning', uri: 'file:///home/user/test.js' },
{ source: 'good', uri: 'file:///home/user/test.js' },
1
)
)[0].score,
0.5
);
// order is ignored
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good morning', uri: 'file:///home/user/test.js' },
{ source: 'morning good', uri: 'file:///home/user/test.js' },
1
)
)[0].score,
1
);
// so are stop words
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good morning', uri: 'file:///home/user/test.js' },
{ source: 'morning is good', uri: 'file:///home/user/test.js' },
1
)
)[0].score,
1
);
// and non alphanumeric_ characters
assert.strictEqual(
(
await findBestJaccardMatch(
{ source: 'good !morning sunshine', uri: 'file:///home/user/test.js' },
{ source: 'goodâ¬morning,sunshine', uri: 'file:///home/user/test.js' },
1
)
)[0].score,
1
);
});
/**
* When requesting matches with a certain length,
* the returns have that length
*/
test('findBestJaccardMatch respects windowLength', async function () {
// no window no match
assert.strictEqual(
(
await findBestJaccardMatch(
{
source: 'good morning\ngood night\nthe day\nis bright',
uri: 'file:///home/user/test.js',
},
{
source: 'good morning\ngood night\nthe day\nis bright',
uri: 'file:///home/user/test.js',
},
0
)
).length,
0
);
// for identical object and reference docs
for (const n of [1, 2]) {
assert.strictEqual(
(
await findBestJaccardMatch(
{
source: 'good morning\ngood night\nthe day\nis bright',
uri: 'file:///home/user/test.js',
},
{
source: 'good morning\ngood night\nthe day\nis bright',
uri: 'file:///home/user/test.js',
},
n
)
)[0].snippet.split('\n').length,
n
);
}
// if the ref doc is shorter
for (const n of [1, 2]) {
assert.strictEqual(
(
await findBestJaccardMatch(
{
source: 'good morning\ngood night\nthe day\nis bright',
uri: 'file:///home/user/test.js',
},
{ source: 'good night', uri: 'file:///home/user/test.js' },
n
)
)[0].snippet.split('\n').length,
n
);
}
// if the ref doc is longer
for (const n of [1, 2]) {
const matches = await findBestJaccardMatch(
{
source: 'good morning\ngood night\nthe day\nis bright',
uri: 'file:///home/user/test.js',
},
{
source: 'good morning\ngood night\nthe day\nis bright\nthe sun',
uri: 'file:///home/user/test.js',
},
n
);
if (n === 1) { assert.strictEqual(matches.length, 0); }
else if (n === 2) {
assert.strictEqual(matches.length, 1);
assert.strictEqual(matches[0].snippet.split('\n').length, n > 1 ? n : []);
} else {
throw new Error('Unexpected value for `n`');
}
}
});
test('findBestJaccardMatch returns the best match', async function () {
assert.strictEqual(
(
await findBestJaccardMatch(
{
source: ['abcd', 'efgh', 'ijkl', 'mnop', 'qrst', 'uvwx', 'yz'].join('\n'),
uri: 'file:///home/user/test.js',
},
{ source: ['ijkl', 'qrst'].join('\n'), uri: 'file:///home/user/test.js' },
3
)
)[0].snippet,
['ijkl', 'mnop', 'qrst'].join('\n')
);
});
test('findBestJaccardMatch works on strings with or without a newline at the end', async function () {
assert.strictEqual(
(
await findBestJaccardMatch(
{
source: ['abcd', 'efgh', 'ijkl', 'mnop', 'qrst', 'uvwx', 'yz'].join('\n'),
uri: 'file:///home/user/test.js',
},
{ source: ['ijkl', 'qrst'].join('\n'), uri: 'file:///home/user/test.js' },
3
)
)[0].snippet,
['ijkl', 'mnop', 'qrst'].join('\n')
);
});
test('Tokenization splits words on whitespace', function () {
assert.deepStrictEqual(splitIntoWords('def hello'), ['def', 'hello']);
assert.deepStrictEqual(splitIntoWords('def hello'), ['def', 'hello']);
assert.deepStrictEqual(splitIntoWords('def \n\t hello'), ['def', 'hello']);
});
test('Tokenization keeps numbers attached to words', function () {
assert.deepStrictEqual(splitIntoWords('def hello1:\n\treturn world49'), ['def', 'hello1', 'return', 'world49']);
});
test('Tokenization splits words on special characters', function () {
assert.deepStrictEqual(splitIntoWords('def hello(world):\n\treturn a.b+1'), [
'def',
'hello',
'world',
'return',
'a',
'b',
'1',
]);
});
test('Tokenization splits words on underscores', function () {
assert.deepStrictEqual(splitIntoWords(`def hello_world:\n\treturn 'I_am_a_sentence!'`), [
'def',
'hello',
'world',
'return',
'I',
'am',
'a',
'sentence',
]);
});
test('Find all snippets.', async function () {
const windowLength = 2;
const doc1 = {
source: 'or not\ngood morning\ngood night\nthe day\nis bright\nthe morning sun\nis hot',
uri: 'file:///home/user/test.js',
};
const refDoc = {
source: 'good morning good night the day is bright',
languageId: '',
uri: 'file:///home/user/test.js',
};
assert.deepStrictEqual(
await retrieveAllSnippetsWithJaccardScore(doc1, refDoc, windowLength, SortOptions.None),
[
{ score: 0.6, startLine: 1, endLine: 3 },
{ score: 0.4, startLine: 3, endLine: 5 },
{ score: 0.14285714285714285, startLine: 5, endLine: 7 },
]
);
assert.deepStrictEqual(
await retrieveAllSnippetsWithJaccardScore(doc1, refDoc, windowLength, SortOptions.Ascending),
[
{ score: 0.14285714285714285, startLine: 5, endLine: 7 },
{ score: 0.4, startLine: 3, endLine: 5 },
{ score: 0.6, startLine: 1, endLine: 3 },
]
);
assert.deepStrictEqual(
await retrieveAllSnippetsWithJaccardScore(doc1, refDoc, windowLength, SortOptions.Descending),
[
{ score: 0.6, startLine: 1, endLine: 3 },
{ score: 0.4, startLine: 3, endLine: 5 },
{ score: 0.14285714285714285, startLine: 5, endLine: 7 },
]
);
});
test('Test Jaccard similarity.', function () {
const bagOfWords1 = 'one two three four five';
const bagOfWords2 = 'zone ztwo zthree zfour zfive';
const bagOfWords3 = 'one two three four five six'; // single word difference with bagOfWords1
const bagOfWords4 = 'one ztwo zthree zfour zfive'; // single word intersection with bagOfWords1
const bagOfWords5 = 'one ztwo ztwo zthree zfour zfive'; // repeated words
assert.strictEqual(computeScore(new Set(splitIntoWords(bagOfWords1)), new Set(splitIntoWords(bagOfWords2))), 0);
assert.strictEqual(computeScore(new Set(splitIntoWords(bagOfWords1)), new Set(splitIntoWords(bagOfWords1))), 1);
assert.strictEqual(
computeScore(new Set(splitIntoWords(bagOfWords1)), new Set(splitIntoWords(bagOfWords3))),
5 / 6
);
assert.strictEqual(
computeScore(new Set(splitIntoWords(bagOfWords1)), new Set(splitIntoWords(bagOfWords4))),
1 / 9
);
assert.strictEqual(
computeScore(new Set(splitIntoWords(bagOfWords1)), new Set(splitIntoWords(bagOfWords5))),
1 / 9
);
});
test('Snippets never overlap, the highest score wins.', async function () {
// When overlapping snippets are found, the snippet with the highest score wins and the others are dropped, e.g.:
// given the ref doc of "the speed of light is incredibly fast", the doc "the light is incredibly fast" matches
// with score 0.75, but the next "The speed of light is incredibly fast" matches with score 1, so the previous overlapping
// snippet is dropped.
const windowLength = 2;
const doc1 = {
source: 'the light\nis incredibly fast\nthe speed of light\nis incredibly fast\nexcessively bright, the morning sun\n was hot casting elongated shadows',
uri: 'file:///home/user/test.js',
};
const refDoc = {
source: 'the speed of light\nis incredibly fast',
languageId: '',
uri: 'file:///home/user/test2.js',
};
assert.deepStrictEqual(
await retrieveAllSnippetsWithJaccardScore(doc1, refDoc, windowLength, SortOptions.None),
[
{ score: 1, startLine: 1, endLine: 3 },
{ score: 0.25, startLine: 3, endLine: 5 },
]
);
});
});
suite('Test getSimilarSnippets function', function () {
const docSource: string = dedent`
A
B
C
D|
E
F
G`;
const doc: DocumentInfoWithOffset = {
relativePath: 'source1',
uri: 'source1',
source: docSource,
languageId: 'python',
offset: docSource.indexOf('|'), // reference snippet will be A B C D
};
const similarFiles: SimilarFileInfo[] = [
{
relativePath: 'similarFile1',
uri: 'similarFile1',
source: dedent`
A
B
C
H
X
Y
Z
`,
},
{
relativePath: 'similarFile2',
uri: 'similarFile2',
source: dedent`
D
H
`,
},
];
setup(async function () {
await initializeTokenizers;
});
test('Returns correct snippet in conservative mode', async function () {
const options: SimilarFilesOptions = conservativeFilesOptions;
const snippetLocations = (await getSimilarSnippets(doc, similarFiles, options)).map(snippet => [
snippet.startLine,
snippet.endLine,
]);
const correctSnippetLocations: number[][] = [
[0, 7], // A B C H X Y Z
];
assert.deepStrictEqual(snippetLocations, correctSnippetLocations);
});
test('Returns correct snippets in eager mode', async function () {
const options: SimilarFilesOptions = defaultSimilarFilesOptions;
const snippetLocations = (await getSimilarSnippets(doc, similarFiles, options)).map(snippet => [
snippet.startLine,
snippet.endLine,
]);
const correctSnippetLocations: number[][] = [
[0, 7], // A B C H X Y Z
[0, 2], // D H - included as get up to 4 similar docs
];
assert.deepStrictEqual(snippetLocations.sort(), correctSnippetLocations.sort());
});
test('Returns no snippet in None mode', async function () {
const options: SimilarFilesOptions = nullSimilarFilesOptions;
const snippetLocations = (await getSimilarSnippets(doc, similarFiles, options)).map(snippet => [
snippet.startLine,
snippet.endLine,
]);
const correctSnippetLocations: number[][] = [];
assert.deepStrictEqual(snippetLocations, correctSnippetLocations);
});
});
suite('Test trimming reference document', function () {
const docSource: string = dedent`
1
2
3
4
5
6|
7`;
const doc: DocumentInfoWithOffset = {
relativePath: 'source1',
uri: 'source1',
source: docSource,
languageId: 'python',
offset: docSource.indexOf('|'),
};
test('FixedWindowSizeJaccardMatcher trims reference document correctly', async function () {
for (let windowLength = 1; windowLength < 7; windowLength++) {
const matcherFactory = FixedWindowSizeJaccardMatcher.FACTORY(windowLength);
const matcher = matcherFactory.to(doc);
const referenceTokens = [...(await matcher.referenceTokens)];
// Don't get 7 because it's after the cursor
const correctReferenceTokens: string[] = ['1', '2', '3', '4', '5', '6'].slice(-windowLength);
assert.deepStrictEqual(referenceTokens, correctReferenceTokens);
}
});
});

View File

@@ -0,0 +1,32 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { SnippetProviderType, SnippetSemantics, announceSnippet } from '../snippetInclusion/snippets';
import * as assert from 'assert';
import dedent from 'ts-dedent';
suite('Unit tests for snippet.ts', () => {
const bogusSnippet = {
relativePath: 'snippet1.ts',
score: 1.0,
startLine: 1,
endLine: 3,
provider: SnippetProviderType.Path,
semantics: SnippetSemantics.Snippet,
snippet: dedent`
A
B
C`,
};
test('announceSnippet', function () {
assert.deepStrictEqual(announceSnippet(bogusSnippet), {
headline: 'Compare this snippet from snippet1.ts:',
snippet: dedent`
A
B
C`,
});
});
});

View File

@@ -0,0 +1,367 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import assert from 'assert';
import { DocumentInfoWithOffset, SimilarFileInfo } from '../prompt';
import {
defaultSimilarFilesOptions,
getSimilarSnippets,
SimilarFilesOptions,
} from '../snippetInclusion/similarFiles';
import { SnippetWithProviderInfo } from '../snippetInclusion/snippets';
async function findAndScoreBlocks(
referenceDoc: DocumentInfoWithOffset,
relatedFiles: SimilarFileInfo[],
useSubsetMatching: boolean
): Promise<SnippetWithProviderInfo[]> {
const options: SimilarFilesOptions = {
snippetLength: defaultSimilarFilesOptions.snippetLength,
threshold: defaultSimilarFilesOptions.threshold,
maxTopSnippets: defaultSimilarFilesOptions.maxTopSnippets,
maxCharPerFile: defaultSimilarFilesOptions.maxCharPerFile,
maxNumberOfFiles: defaultSimilarFilesOptions.maxNumberOfFiles,
maxSnippetsPerFile: defaultSimilarFilesOptions.maxSnippetsPerFile,
useSubsetMatching,
};
return getSimilarSnippets(referenceDoc, relatedFiles, options);
}
function fileScore(snippets: SnippetWithProviderInfo[], partialFileName: string): number {
for (const snippet of snippets) {
if (snippet.relativePath?.indexOf(partialFileName) !== -1) {
return snippet.score;
}
}
assert(false, 'Expected valid file name');
}
suite('Similar files with subset matching Test Suite', function () {
/**
* This test ensures that only tokens from the current method are used when
* computing the score for a chunk.
*
* Compare this to @see FixedWindowSizeJaccardMatcher
* which would use any tokens in the same 60-line-delimited chunk of code that
* the caret fell in.
*
* Scenarios where the caret is in a sub-60-line methods or near a 60 line chunk
* seam would end up getting results that are more related to the neighboring
* methods.
*/
test('Only current method is considered as part of reference tokens', async function () {
const file0 = `
public static class TestClass
{
public static void UnrelatedMethod(IBar bar)
{
var thing = UnrelatedThing();
thing.DistractingMethodName();
}
public static void Foo(IBar bar)
{
var service = bar.GetService(typeof(IBaz));
service.DoTheThing();
|
}
public static void UnrelatedMethod2(IBar bar)
{
// This method is unrelated but can DoTheThing to a service
}
}
`;
const file1 = `
public interface IBar
{
public object GetService(Type type);
}
`;
const file2 = `
public interface IBaz
{
public static void DoTheThing();
}
`;
const file3 = `
public static class DistractionClass
{
public DistractionClass UnrelatedThing()
{
TestClass.UnrelatedMethod(null);
UnrelatedMethod(null);
}
public void DistractingMethodName()
{
TestClass.UnrelatedMethod2(null);
}
}
`;
// **********************************************************
// Score with the old 60-line-delimited reference token chunk.
const oldScores = await findAndScoreBlocks(
{ source: file0, uri: 'file:///home/user/file0.js', languageId: 'csharp', offset: file0.indexOf('|') },
[
{ source: file1, uri: 'file:///home/user/file1.js', relativePath: 'file1' },
{ source: file2, uri: 'file:///home/user/file2.js', relativePath: 'file2' },
{ source: file3, uri: 'file:///home/user/file3.js', relativePath: 'file3' },
],
false
);
// We expect the old way to prefer the distraction class, which has lots of terms that look like stuff from
// the neighboring methods.
assert(
fileScore(oldScores, 'file3') > fileScore(oldScores, 'file2') &&
fileScore(oldScores, 'file2') > fileScore(oldScores, 'file1'),
'Expected 60-line-delimited reference chunks to prefer the distraction class because it resembles neighboring methods'
);
// **********************************************************
// **********************************************************
// Score with the new subset matching mechanism.
const newScores = await findAndScoreBlocks(
{ source: file0, uri: 'file:///home/user/file0.js', languageId: 'csharp', offset: file0.indexOf('|') },
[
{ source: file1, uri: 'file:///home/user/file1.js', relativePath: 'file1' },
{ source: file2, uri: 'file:///home/user/file2.js', relativePath: 'file2' },
{ source: file3, uri: 'file:///home/user/file3.js', relativePath: 'file3' },
],
true
);
// We expect the new way to prefer the second file because it contains the most tokens that match
// the method enclosing the caret.
assert(
fileScore(newScores, 'file2') > fileScore(newScores, 'file1') &&
fileScore(newScores, 'file1') > fileScore(newScores, 'file3'),
'Expected that the file containing IBaz interface would be the best match'
);
// **********************************************************
});
/**
* This test ensures that methods are matched only based on the tokens from the reference
* chunk that they do contain and are not penalized for containing additional tokens that
* don't appear in the reference set.
*
* Compare this to @see FixedWindowSizeJaccardMatcher which would use Jaccard similarity to
* score. Jaccard similarity gives preferences to chunks with sets of identical tokens.
*
* Intuitively, scenarios where a token is a type or method reference get penalized because
* they have tokens in common for the name of the method but have divergent content.
*/
test('Methods are not penalized for being supersets of the reference chunk', async function () {
const file0 = `
public static class TestClass
{
public static void Foo(Bar bar)
{
bar.Baz();
|
}
}
`;
const file1 = `
public class Bar
{
public void Baz()
{
// This method has a bunch of extra tokens that don't match file0 and collectively
// reduce its score relative to the other files.
}
}
`;
const file2 = `
public class Bar2
{
public void Baz()
{
}
}
`;
const file3 = `
public class Bar3
{
public void Baz3()
{
}
}
`;
// **********************************************************
// Score with the old 60-line-delimited reference token chunk.
const oldScores = await findAndScoreBlocks(
{ source: file0, uri: 'file:///home/user/file0.js', languageId: 'csharp', offset: file0.indexOf('|') },
[
{ source: file1, uri: 'file:///home/user/file1.js', relativePath: 'file1' },
{ source: file2, uri: 'file:///home/user/file2.js', relativePath: 'file2' },
{ source: file3, uri: 'file:///home/user/file3.js', relativePath: 'file3' },
],
false
);
// We expect the old way to prefer the simpler code samples, even when they match fewer tokens,
// because there are fewer non-matching additional tokens.
assert(
fileScore(oldScores, 'file2') > fileScore(oldScores, 'file1') &&
fileScore(oldScores, 'file1') === fileScore(oldScores, 'file3'),
'Expected 60-line-delimited reference chunks to prefer the distraction class because it resembles neighboring methods'
);
// **********************************************************
// **********************************************************
// Score with the new method.
const newScores = await findAndScoreBlocks(
{ source: file0, uri: 'file:///home/user/file0.js', languageId: 'csharp', offset: file0.indexOf('|') },
[
{ source: file1, uri: 'file:///home/user/file1.js', relativePath: 'file1' },
{ source: file2, uri: 'file:///home/user/file2.js', relativePath: 'file2' },
{ source: file3, uri: 'file:///home/user/file3.js', relativePath: 'file3' },
],
true
);
// We expect the new way to prefer the file with matching class and method names because we're no longer
// penalizing samples for having different tokens.
assert(
fileScore(newScores, 'file1') > fileScore(newScores, 'file2') &&
fileScore(newScores, 'file2') > fileScore(newScores, 'file3'),
'Expected subset matching method to prefer the file with the most token matches'
);
// **********************************************************
});
/**
* This test ensures that only tokens from the current class are used when
* computing the score for a chunk.
*
* Compare this to @see FixedWindowSizeJaccardMatcher
* which would use any tokens in the same 60-line-delimited chunk of code that
* the caret fell in.
*
* Scenarios where the caret is in a sub-60-line method or near a 60 line chunk
* seam would end up getting results that are more related to the neighboring
* methods.
*/
test('Only current class is considered as part of reference tokens', async function () {
const file0 = `
public static class TestClass2
{
public static void UnrelatedMethod(IBar bar)
{
var thing = UnrelatedThing();
thing.DistractingMethodName();
}
}
public static class TestClass
{
public static void Foo(IBar bar)
{
var service = bar.GetService(typeof(IBaz));
service.DoTheThing();
}
|
}
public static class TestClass3
{
public static void UnrelatedMethod2(IBar bar)
{
// This method is unrelated but can DoTheThing to a service
}
}
`;
const file1 = `
public interface IBar
{
public object GetService(Type type);
}
`;
const file2 = `
public interface IBaz
{
public static void DoTheThing();
}
`;
const file3 = `
public static class DistractionClass
{
public DistractionClass UnrelatedThing()
{
TestClass.UnrelatedMethod(null);
UnrelatedMethod(null);
}
public void DistractingMethodName()
{
TestClass.UnrelatedMethod2(null);
}
}
`;
// **********************************************************
// Score with the old 60-line-delimited reference token chunk.
const oldScores = await findAndScoreBlocks(
{ source: file0, uri: 'file:///home/user/file0.js', languageId: 'csharp', offset: file0.indexOf('|') },
[
{ source: file1, uri: 'file:///home/user/file1.js', relativePath: 'file1' },
{ source: file2, uri: 'file:///home/user/file2.js', relativePath: 'file2' },
{ source: file3, uri: 'file:///home/user/file3.js', relativePath: 'file3' },
],
false
);
// We expect the old way to prefer the distraction class, which has lots of terms that look like stuff from
// the neighboring methods.
assert(
fileScore(oldScores, 'file3') > fileScore(oldScores, 'file2') &&
fileScore(oldScores, 'file2') > fileScore(oldScores, 'file1'),
'Expected 60-line-delimited reference chunks to prefer the distraction class because it resembles neighboring methods'
);
// **********************************************************
// **********************************************************
// Score with the new subset matching mechanism.
const newScores = await findAndScoreBlocks(
{ source: file0, uri: 'file:///home/user/file0.js', languageId: 'csharp', offset: file0.indexOf('|') },
[
{ source: file1, uri: 'file:///home/user/file1.js', relativePath: 'file1' },
{ source: file2, uri: 'file:///home/user/file2.js', relativePath: 'file2' },
{ source: file3, uri: 'file:///home/user/file3.js', relativePath: 'file3' },
],
true
);
// We expect the new way to prefer the second file because it contains the most tokens that match
// the method enclosing the caret.
assert(
fileScore(newScores, 'file2') > fileScore(newScores, 'file3') &&
fileScore(newScores, 'file3') === fileScore(newScores, 'file1'),
'Expected that the file containing IBaz interface would be the best match'
);
// **********************************************************
});
});

View File

@@ -0,0 +1,17 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { findEditDistanceScore } from '../suffixMatchCriteria';
import * as assert from 'assert';
suite('EditDistanceScore Test Suite', function () {
test('findEditDistanceScore computes correct score of two number[]', function () {
assert.strictEqual(findEditDistanceScore([], [])?.score, 0);
assert.strictEqual(findEditDistanceScore([1], [1])?.score, 0);
assert.strictEqual(findEditDistanceScore([1], [2])?.score, 1);
assert.strictEqual(findEditDistanceScore([1], [])?.score, 1);
assert.strictEqual(findEditDistanceScore([], [1])?.score, 1);
assert.strictEqual(findEditDistanceScore([1, 2, 3], [3, 2, 1])?.score, 2);
});
});

View File

@@ -0,0 +1,59 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { describeTree, IndentationTree, isLine, VirtualNode } from '../indentation';
import * as assert from 'assert';
/**
* Asserts that two trees are isomorphic.
* @param actual The tree to test.
* @param expected The tree expected to be equal (source lines can be abbreviated with '...').
* @param strictness Should the tree be deeply equal (including indentation and line numbers),
* or is in enough for the children and types of each node match?
* @param treeParent The tree's parent for context (optional)
* @param parentIndex The index for the tree in its parent's subs (optional)
*/
export function compareTreeWithSpec<T>(
actual: IndentationTree<T>,
expected: IndentationTree<T>,
strictness: 'strict' | 'structure' = 'strict',
treeParent?: IndentationTree<T>,
parentIndex?: number
) {
if (actual.type !== expected.type) {
failCompare(
actual,
expected,
`type of tree doesn't match, ${actual.type} ${expected.type}`,
treeParent,
parentIndex
);
}
if (actual.subs.length !== expected.subs.length) {
failCompare(actual, expected, 'number of children do not match', treeParent, parentIndex);
}
if (strictness === 'strict' && isLine(actual)) {
if (actual.indentation !== (expected as VirtualNode<T>).indentation) {
failCompare(actual, expected, `virtual node indentation doesn't match`, treeParent, parentIndex);
}
}
for (let i = 0; i < actual.subs.length; ++i) {
compareTreeWithSpec(actual.subs[i], expected.subs[i], strictness, actual, i);
}
}
function failCompare<T>(
tree: IndentationTree<T>,
expected: IndentationTree<T>,
reason: string,
treeParent?: IndentationTree<T>,
parentIndex?: number
) {
assert.fail(`Reason: ${reason}
Tree: ${describeTree(tree)}
Expected: ${describeTree(expected)}`);
}

View File

@@ -0,0 +1,501 @@
"""
This is an example Python source file to use as test data. It's pulled from the synth repo
with minor edits to make it a better test case.
"""
from tree_sitter import Language, Parser
import re
import os, sys
from dataclasses import dataclass, field
import codesynthesis.synthesis as synthesis
import harness.fun_run as fun_run
from harness.utils import temporary_path_change_to, get_canonical_logger, TreeSitter
import types
import subprocess
logger = get_canonical_logger("FilesWithImport")
@dataclass
class Import:
module: str
filename: str = None
source_text: str = None
as_name: str = None
original_statement: str = None
# list of imported objects
imported: list = field(default_factory=list)
def add_source(self):
assert self.filename, f"No filename for {self.module}"
with open(self.filename, "r") as f:
self.source_text = f.read() # intentional trailing whitespace
class ImportAnalysis:
"""
Methods to discribe the effect of an import statement.
"""
def __init__(self, imp: Import, clean_environment=None, import_directly=True):
self.imp = imp
# a clean dict into which to import the module
if clean_environment is None:
clean_environment = {
"__builtins__": __builtins__,
"__name__": __name__,
"__doc__": __doc__,
"__package__": "thefuck",
"__file__": __file__
}
self._clean_environment = clean_environment
if import_directly:
self.impact = self._import(imp.original_statement, imp.filename)
self.imported_objects = self._get_imported_objects()
else:
self.impact = None
self.imported_objects = None
self.imported_directly = import_directly # if False, then can only use describe_from_outside
def _import(self, statement, filename):
"""
tries to import module from filename path, if that doesn't work parent folder, etc.,
returns a dictionary of the new objects created by the import statement.
"""
path = os.path.dirname(filename)
d = self._clean_environment.copy()
while len(path) > 1:
try:
# chdir to path, then exec statement
#with temporary_path_change_to(path):
d_before = d.copy()
exec(statement, d)
# remove from d what was already in d_before
d = {k: v for k, v in d.items() if k not in d_before}
return d
except ModuleNotFoundError:
path = os.path.dirname(path)
raise FileNotFoundError(f"Unable to import module from {filename}")
def _get_imported_objects(self, include_private=False):
"""
Returns a dict of relevant imported objects (by their names)
For `from a import b`, this should be {"b": b}
For `import a as b`, this should be {"b.x": b.x} for all imported objects x
For `from a import *`, this should be {"x": x} for all imported objects x
TODO For `import a.b`, this should be {"a.b.x": a.b.x} -- could maybe just increase depth correctly?
"""
out = {}
depth = 0
while all(isinstance(is_it_a_module, types.ModuleType) for is_it_a_module in out.values()):
out = self._unpack_dict(self.impact, include_private=include_private, depth=depth)
depth += 1
return out
@classmethod
def _unpack_dict(cls, dictionary, include_private, depth, prefix=""):
# names imported directly, except modules that will be expanded on later, and perhaps private names
out = {prefix + name: obj for name, obj in dictionary.items() if \
(depth == 0 or not isinstance(obj, types.ModuleType)) and \
(include_private or not name.startswith("_"))}
if depth > 0:
# for modules, add their objects
for name, module in dictionary.items():
if isinstance(module, types.ModuleType):
one_level_deeper = cls._unpack_dict(module.__dict__, include_private, depth=depth-1, prefix=name+".")
out.update(one_level_deeper)
return out
def get_names_of_imported_by_type(self):
"""return the imported objects as a dict, with their type name as key"""
d = dict()
for object_name, obj in self.imported_objects.items():
typ = type(obj).__name__
if typ in d:
d[typ].append(object_name)
else:
d[typ] = [object_name]
return d
def get_methods_of_type(self, class_name, include_private=False):
"""return the names of methods of the given type"""
cls = self.imported_objects[class_name]
names = [name for name in dir(cls) if include_private or not name.startswith("_")]
# return only the names which describe methods
return [name for name in names if callable(getattr(cls, name))]
def get_function_description(self, func_name):
"""return a string describing the function and its arguments"""
fun = self.imported_objects[func_name]
args = "(" + ", ".join(fun.__code__.co_varnames[:fun.__code__.co_argcount]) + ")"
return f"{func_name}{args}"
def get_class_description(self, class_name):
"""return a string describing the class and its methods"""
return f"{class_name} with functions {', '.join(self.get_methods_of_type(class_name))}"
def describe(self) -> str:
"""
return a description of the import, detailing classes, functions and properties as follows:
[as name] adds the following classes:
- [class name] with functions [function name], ..., [another function name]
...
[as name] adds the functions [function name], ..., [another function name]
[as name] also adds the objects: [property name: property type], ..., [another property name: property type]
"""
imported_names = self.get_names_of_imported_by_type()
if "type" in imported_names.keys():
classes = imported_names["type"]
answer = f"{self.imp.as_name} adds the following classes:\n" + \
"\n".join([f" - {self.get_class_description(class_name)}"
for class_name in classes]) + \
"\n"
else:
answer = ""
if "function" in imported_names.keys():
functions = imported_names["function"]
answer += f"{self.imp.as_name} adds the functions {', '.join(self.get_function_description(func_name) for func_name in functions)}\n"
# all other types are designated as properties
properties = [key for key in imported_names.keys() if key not in ["function", "type"]]
if properties:
answer += f"{self.imp.as_name} {'also ' if answer else ''}adds the objects: {', '.join([f'{imported_names[name]}: {name}' for name in properties])}\n"
return answer
def description_comment(self) -> str:
"""wrap the multiline string describe() into a comment, each of whose lines begins with '#'"""
return "\n".join(["# " + line for line in self.describe().split("\n")])
def describe_from_the_outside(self, path):
"""writes the import statement plus ask to describe to a file in path, runs that file, gathers output"""
filename = os.path.abspath(path + "/" + "tmp.py")
assert os.path.exists(filename) == False, f"{filename} already exists!"
try:
with open(filename, "w") as f:
f.write(f"""from codesynthesis.files_with_import import ImportAnalysis, ImportParser
imp_statement = "{self.imp.original_statement}"
imp = ImportParser().get_all_imports(imp_statement, "{filename}")[0]
env = dict(
__builtins__= __builtins__,
__name__= __name__,
__doc__= __doc__,
__package__= __package__,
__file__= __file__
)
analysis = ImportAnalysis(imp, env)
print(analysis.describe())"""
)
# get module name from cwd and filename:
relative_filename = os.path.relpath(filename, os.getcwd())
module_name = relative_filename.replace(".py", "").replace("/", ".")
try:
out = subprocess.check_output(["python", "-m", module_name], stderr=subprocess.STDOUT).decode("utf-8").strip()
except subprocess.CalledProcessError as e:
error = e.output.decode("utf-8")
if "FileNotFoundError: Unable to import module from" in error:
logger.error(f"Could not import module through {self.imp.original_statement} in {filename}. The import of {self.imp.module} will not be documented.")
return ""
else:
raise e
finally:
os.remove(filename)
return out
def describe_from_the_outside_as_comment(self, path):
return "\n".join(["# " + line for line in self.describe_from_the_outside(path).split("\n")])
class ImportParser:
PY_LANGUAGE = TreeSitter().language("python")
IMP_QUERY = ["(import_statement) @import",
"(future_import_statement) @import",
"(import_from_statement) @import"]
MODULE_LEVEL_IMP_QUERY = ["(module (import_statement) @import)",
"(module (future_import_statement) @import)",
"(module (import_from_statement) @import)"]
# TODO: Define MODULE_SCOPE_IMP_QUERY, where the import can't be inside a functon, but can be inside an if or try.
def __init__(self):
self.parser = TreeSitter().get_parser("python")
Parser()
self.parser.set_language(self.PY_LANGUAGE)
@staticmethod
def get_text_from(text, capture):
"""
Trim the text to the content corresponding to a single tree-sitter capture expression.
@param text: The whole text of the language document.
@param capture: The particular capture expression within the document to trim to.
@return: The text for the capture expression only.
"""
lines = text.split('\n')
relevant_lines = lines[capture.start_point[0] : capture.end_point[0]+1]
# in case the extract is just one line, the trim on the right needs to come before the trim on the left!
relevant_lines[-1] = relevant_lines[-1][:capture.end_point[1]]
relevant_lines[0] = relevant_lines[0][capture.start_point[1]:]
return '\n'.join(relevant_lines)
@staticmethod
def replace_text_from(text, capture, replacement):
"""
replaces the text from the capture with the replacement.
"""
lines = text.split('\n')
prelude = lines[0 : capture.start_point[0]+1]
if prelude:
prelude[-1] = prelude[-1][:capture.start_point[1]+1]
postlude = lines[capture.end_point[0] + 1 :]
if postlude:
postlude[0] = postlude[0][capture.end_point[1]:]
return '\n'.join(prelude + [replacement] + postlude)
def get_list_of_import_captures(self, text):
tree = self.parser.parse(bytes(text, "utf8"))
list_of_capture_lists = [self.PY_LANGUAGE.query(query).captures(tree.root_node) for query in self.IMP_QUERY]
# flatten array
return [item[0] for sublist in list_of_capture_lists for item in sublist]
def get_import_statements(self, text):
"""
returns a list of all imports (as far as found in the statement, not the background like content)
"""
captures = self.get_list_of_import_captures(text)
imports = [self.parse_single_import(self.get_text_from(text, capture)) for capture in captures]
return imports
def parse_single_import(self, relevant_text):
"""
parses a single import statement (without background like content)
TODO: deal with the case of several imports in one expression, e.g. import a, b, c
"""
# module is the XXX in from XXX import ..., or in import XXX
if re.search('from ([^ ]+) import', relevant_text):
module = re.search('from (\\.)*([^ ]+) import', relevant_text).group(2)
else:
module = re.search('import (\\.)*([^ ]+)', relevant_text).group(2)
# as_name is the XXX in import ... as XXX or from ... import ... as XXX
if re.search(' as ([^ ]+)', relevant_text):
as_name = re.search(' as ([^ ]+)', relevant_text).group(1)
else:
as_name = module
# imported are the XXX, YYY, ZZZ in from XXX import XXX, YYY, ZZZ
# but we don't need them right now, so TODO
imported = ["TODO"]
return Import(module=module, as_name=as_name, imported=imported, original_statement=relevant_text)
def get_all_imports(self, source_text, source_filename):
"""
adds all import files found in the source_text, if it were at disk as filename
skips standard packages (i.e. anything not found on disk at location filename)
"""
raw_imports = self.get_import_statements(source_text)
relevant_imports = []
for imp in raw_imports:
# check whether filename exists; if not: check whether it exist in the parent folder, recursively
base_folder = os.path.dirname(source_filename)
# Note: The following does not take '..module' imports into account differently
# (which will mostly be ok though, unless there's a name clash)
while filename:= base_folder + '/' + '/'.join(imp.module.split('.')):
if os.path.isfile(filename + '.py'):
imp.filename = filename + '.py'
imp.add_source()
relevant_imports.append(imp)
break
else:
base_folder = os.path.dirname(base_folder)
if len(base_folder) <= 1:
break
return relevant_imports
def remove_imports(self, text, imps):
"""
returns the text minus the imports
Note: This is theoretically a bit too aggressive, as it also removes the text of the import statements inside quotes etc
Note: This is theoreticallly a bit too aggressive, as it also removes captures where the import model name is only a substring
"""
captures = self.get_list_of_import_captures(text)
for capture in captures:
# only act if imp.module is in the capture expression
capture_text = self.get_text_from(text, capture)
if any(imp.module in capture_text for imp in imps):
text = text.replace(capture_text, '')
return text
def truncate_left_but_keep_module_level_imports(self, text, length_in_tokens: int, fixed_prefix: str = ""):
"""make sure to keep the module level imports, but otherwise drop lines as needed, calling truncate_left"""
# identify imports
captures = self.get_list_of_import_captures(text)
lines = text.split('\n')
idc_to_keep = set()
for capture in captures:
for lineno in range(capture.start_point[0], capture.end_point[0] + 1):
idc_to_keep.add(lineno)
result = synthesis.truncate_left_keeping_lines_with_preference(lines, idc_to_keep, length_in_tokens, fixed_prefix)
return result
def test_import_parser():
ip = ImportParser()
assert len(ip.get_all_imports("from codesynthesis.synthesis import abc\nprint(32)", "/Users/wunderalbert/openai/synth/test.py")) > 0
assert len(ip.get_all_imports("import codesynthesis.synthesis as abc\nprint(32)", "/Users/wunderalbert/openai/synth/test.py")) > 0
class FunctionWithImportsKept(fun_run.PythonFunctionInTheWild):
def make_prompt_for_fct(self, max_length_in_tokens = synthesis.CONTEXT_WINDOW_SIZE):
return ImportParser().truncate_left_but_keep_module_level_imports('# Python 3\n' + self.prelude + '\n' + self.header, max_length_in_tokens)
class FunctionWithImports(fun_run.PythonFunctionInTheWild):
def __init__(self, function_location, discriminative_model):
super().__init__(function_location, discriminative_model)
# get the absolute path
self.filename = os.path.abspath(self.function_location.path)
source_text = ''.join(self.source_lines)
self.imports = ImportParser().get_all_imports(source_text, self.filename)
importless_source_lines_without_newline_char = ImportParser().remove_imports(source_text, self.imports).split("\n")
self.importless_source_lines = [line + "\n" for line in importless_source_lines_without_newline_char]
def make_prompt_for_fct_without_imports(self):
"""call super.make_prompt_for_function, but with importless_source_lines temporarily replacing source_lines"""
complete_source_lines = self.source_lines
try:
self.source_lines = self.importless_source_lines
prompt = super().make_prompt_for_fct(fun_run.VERY_LARGE_NUMBER)
finally:
self.source_lines = complete_source_lines
return prompt
class FunctionWithImportsPastedVerbatim(FunctionWithImports):
"""
Pastes code verbatim
"""
def make_prompt_for_fct(self, max_length_in_tokens):
super_prompt = super().make_prompt_for_fct_without_imports()
desired_prompt = "\n\n".join([imp.source_text for imp in self.imports]) + "\n\n" + super_prompt
truncated_prompt = synthesis.truncate_left(desired_prompt, max_length_in_tokens)
return truncated_prompt
class FunctionWithImportsPastedWithComments(FunctionWithImports):
"""
Pastes code verbatim as a comment that this is the content of that module
"""
def make_prompt_for_fct(self, max_length_in_tokens):
prompt = ""
for imp in self.imports:
prompt = prompt + \
f"# Content of module {imp.as_name}\n# " + \
f"\n# ".join(imp.source_text.split('\n')) + \
f"\n\n"
prompt = prompt + super().make_prompt_for_fct()
truncated_prompt = synthesis.truncate_left(prompt, max_length_in_tokens)
return truncated_prompt
class FunctionWithImportsNamespacedInClasses(FunctionWithImports):
"""
Encapsulates imports within classes -- this does not create sound code, as the self argument is missing from functions, and variables are not prefixed with `class.`
"""
def make_prompt_for_fct(self, max_length_in_tokens):
prompt = ""
for imp in self.imports:
prompt = prompt + \
f"class {imp.as_name}:\n{self.STRING_FOR_INDENTATION_LEVEL_INCREASE}" + \
f"\n{self.STRING_FOR_INDENTATION_LEVEL_INCREASE}".join(imp.source_text.split('\n')) + \
f"\n\n"
prompt = prompt + super().make_prompt_for_fct_without_imports()
truncated_prompt = synthesis.truncate_left(prompt, max_length_in_tokens)
return truncated_prompt
class FunctionWithImportsReplacedOneByOne(fun_run.PythonFunctionInTheWild):
"""
Schemes where all imports that import local files are replaced or added to,
e.g. by summarizing their content, or quoting it, etc.
"""
def __init__(self, function_location, discriminative_model):
super().__init__(function_location, discriminative_model)
# get the absolute path
self.filename = os.path.abspath(self.function_location.path)
source_text = ''.join(self.source_lines)
self.imports = ImportParser().get_all_imports(source_text, self.filename)
logger.debug(f"In the function {self.function_location.name}, there are {len(self.imports)} repo imports: {[imp.module for imp in self.imports]}")
keep_imports_and_description_by_preference = True
def replace_import(self, imp: Import):
"""
Given one import, returns the replacement text pasted into where that import was.
E.g. for a text `foo\nimport bar\nbaz`, if the import is replaced by the string "bat",
the result will be "foo\nbat\nbaz"
"""
raise NotImplementedError("Needs to be implemented in subclass.")
def make_prompt_for_fct(self, max_length_in_tokens):
"""call replace import for each import"""
prompt = super().make_prompt_for_fct(max_length_in_tokens)
# imports should be sorted by start position anyways, but let's be safe
sorted_imports = self.imports.copy()
# replace them from bottom to top, so that the replacements don't change the position of other imports
sorted_imports.reverse()
import_replacements = set()
for imp in sorted_imports:
new_statement = self.replace_import(imp)
import_replacements.add(new_statement)
prompt = prompt.replace(imp.original_statement, new_statement)
# if descriptions have higher priority, extract the lines where they are and pass those lines to synthesis.truncate_left_keeping_lines_with_preference
if self.keep_imports_and_description_by_preference:
new_lines = set()
for statement in import_replacements:
match = prompt.find(statement)
lines_of_match = range(
prompt[:match].count("\n"),
prompt[:match+len(statement)].count("\n")+1)
new_lines.update(lines_of_match)
truncated_prompt = synthesis.truncate_left_keeping_lines_with_preference(prompt.split("\n"), new_lines, max_length_in_tokens)
else:
truncated_prompt = synthesis.truncate_left(prompt, max_length_in_tokens)
return truncated_prompt
class FunctionWithImportsCommentedWithTheFunctionsTheyImport(FunctionWithImportsReplacedOneByOne):
"""
adds a comment to each import with the objects it imports
"""
def replace_import(self, imp: Import):
# find all toplevel functions in imp.source_text, i.e. the `foo` from lines of the form `def foo()`
source_starting_with_newline = ("\n" + imp.source_text)
toplevel_functions = re.findall(r"\ndef\s*([a-zA-Z0-9_]+)\s*\(", source_starting_with_newline)
toplevel_classes = re.findall(r"\nclass\s*([a-zA-Z0-9_]+)", source_starting_with_newline)
toplevel_classes_with_their_functions = {}
for classname in toplevel_classes:
cls_source_lines_with_trailing = source_starting_with_newline.split("\nclass " + classname)[1].split("\n")[1:]
if not cls_source_lines_with_trailing:
continue
indices_after_class_end = [i for i, line in enumerate(cls_source_lines_with_trailing) if not line.startswith(" " * 4) and not line.startswith("#")]
if indices_after_class_end:
cls_source_lines = cls_source_lines_with_trailing[:indices_after_class_end[0]]
else:
cls_source_lines = cls_source_lines_with_trailing
cls_body = "\n".join(cls_source_lines)
# remove everything after the first line with less than 4 spaces indentation
functions_not_starting_with_underscore = re.findall(r"\ndef\s*([a-zA-Z0-9_]+)\s*\(", cls_body)
toplevel_classes_with_their_functions[classname] = functions_not_starting_with_underscore
description = imp.original_statement
if toplevel_functions:
description = description + f"\n# module {imp.as_name} declares the following functions: {', '.join(toplevel_functions)}"
for classname, functions in toplevel_classes_with_their_functions.items():
description = description + f"\n# module {imp.as_name} declares the class {classname}"
if functions:
description = description + f", which contains the following functions: {', '.join(functions)}"
return description
class FunctionWithImportsCommentedWithImportAnalysis(FunctionWithImportsReplacedOneByOne):
def replace_import(self, imp: Import):
analysis = ImportAnalysis(imp, import_directly=False)
description = imp.original_statement + "\n" + \
analysis.describe_from_the_outside_as_comment(os.path.dirname(self.filename))
logger.debug(f"To the import {imp.module} as {imp.as_name}, we add the following comment: {description}")
return description

View File

@@ -0,0 +1,5 @@
def greet(name: str) -> str:
"Does a simple greeting"
return f"Hello {name}"
greet()

View File

@@ -0,0 +1,11 @@
// This is a test file for the sake of testing actual file reads
// We had silently failing tests in the past due to improper
// file spoofing
export interface Tokenizer {
/**
* Returns the tokenization of the input string as a list of integers
* representing tokens.
*/
tokenize(text: string): Array<number>;
}

View File

@@ -0,0 +1,14 @@
// This is a test file for the sake of testing actual file reads
// We had silently failing tests in the past due to improper
// file spoofing
import { Tokenizer } from './testTokenizer';
export class PromptWishlist {
/**
* An object to keep track of a list of desired prompt elements,
* and assemble the prompt text from them.
* @param lineEndingOption The line ending option to use
*/
constructor(_tokenizer: Tokenizer) { }
}

View File

@@ -0,0 +1,595 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as assert from 'assert';
import * as fs from 'fs';
import { resolve } from 'path';
import { ApproximateTokenizer, getTokenizer, TokenizerName } from '../tokenization';
// Read the source files and normalize the line endings
const source = fs.readFileSync(resolve(__dirname, 'testdata/example.py'), 'utf8').replace(/\r\n?/g, '\n');
suite('Tokenizers can be loaded', function () {
for (const tokenizer of Object.values(TokenizerName)) {
test(`Tokenizer ${tokenizer} can be loaded`, function () {
getTokenizer(tokenizer);
});
}
});
// test suite for MockTokenizer
suite('MockTokenizer', function () {
const tokenizer = getTokenizer(TokenizerName.mock);
test('tokenize', function () {
const tokens = tokenizer.tokenize('a b c');
assert.strictEqual(tokens.length, 5);
for (const token of tokens) {
assert.strictEqual(typeof token, 'number');
}
});
test('detokenize', function () {
const tokens = tokenizer.tokenize('a b c');
const text = tokenizer.detokenize(tokens);
// unfortunately the mock tokenizer doesn't correctly round-trip the text
// because the token representation is a number. If this matters then we'll
// have to change the mock tokenizer to use a different representation.
assert.strictEqual(text, '97 32 98 32 99');
});
test('tokenLength', function () {
assert.strictEqual(tokenizer.tokenLength('a b c'), 5);
});
test('takeFirstTokens', function () {
const tokens = tokenizer.takeFirstTokens('a b c', 3);
assert.strictEqual(tokens.text, 'a b');
assert.strictEqual(tokens.tokens.length, 3);
});
test('takeLastTokens', function () {
const tokens = tokenizer.takeLastTokens('a b c', 3);
assert.strictEqual(tokens.text, 'b c');
});
test('takeLastLinesTokens', function () {
const tokens = tokenizer.takeLastLinesTokens('a b c', 3);
assert.strictEqual(tokens, 'b c');
});
});
suite('Tokenizer Test Suite - cl100k', function () {
const tokenizer = getTokenizer(TokenizerName.cl100k);
test('empty string', function () {
const str = '';
assert.deepStrictEqual(tokenizer.tokenize(str), []);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('space', function () {
const str = ' ';
assert.deepStrictEqual(tokenizer.tokenize(str), [220]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('tab', function () {
const str = '\t';
assert.deepStrictEqual(tokenizer.tokenize(str), [197]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('simple text', function () {
const str = 'This is some text';
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
assert.deepStrictEqual(tokenizer.tokenize(str), [2028, 374, 1063, 1495]);
});
test('multi-token word', function () {
const str = 'indivisible';
assert.deepStrictEqual(tokenizer.tokenize(str), [485, 344, 23936]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('emojis', function () {
const str = 'hello 👋 world 🌍';
assert.deepStrictEqual(tokenizer.tokenize(str), [15339, 62904, 233, 1917, 11410, 234, 235]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('contractions', function () {
const str = `you'll`;
assert.deepStrictEqual(tokenizer.tokenize(str), [9514, 3358]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('assert that consecutive newline is never tokenized as multiple newlines', function () {
// This is due to a regular expression change in the tokenizer.
// Loop through all possible ascii numbers and letters
for (let i = 0; i < 128; i++) {
const char = String.fromCharCode(i);
if (char !== '\n') {
assert.deepStrictEqual(tokenizer.tokenLength(`\n\n${char}`), 2);
}
}
// Test special characters
assert.deepStrictEqual(tokenizer.tokenize('\n\n👋'), [271, 9468, 239, 233]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n '), [271, 220]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n 👋'), [271, 62904, 233]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n\t'), [271, 197]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n\r'), [271, 201]);
// New lines are treated specially tho
for (let i = 1; i < 10; i++) {
assert.deepStrictEqual(tokenizer.tokenLength('\n'.repeat(i)), 1);
}
});
test('tokenizeStrings', function () {
const tokens_s = tokenizer.tokenizeStrings(source);
assert.strictEqual(tokens_s.join(''), source, 'tokenizeStrings does not join to form the input string');
const tokens = tokenizer.tokenize(source);
assert.strictEqual(tokens_s.length, tokens.length, 'tokenizeStrings should have same length as tokenize');
const half = Math.floor(tokens_s.length / 2);
assert.strictEqual(
tokens_s.slice(0, half).join(''),
tokenizer.detokenize(tokens.slice(0, half)),
'tokenizeStrings slice should represent the corresponding slice with tokenize'
);
});
test('takeLastTokens invariant of starting position', function () {
const suffix = tokenizer.takeLastTokens(source, 25);
assert.strictEqual(
suffix.text,
`"To the import {imp.module} as {imp.as_name}, we add the following comment: {description}")\n return description`
);
assert.strictEqual(suffix.tokens.length, 25);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(50), 25).text);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(100), 25).text);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(150), 25).text);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(200), 25).text);
});
test('takeLastTokens returns the desired number of tokens', function () {
assert.strictEqual(tokenizer.takeLastTokens(source, 30).tokens.length, 30);
assert.strictEqual(tokenizer.takeLastTokens(source, 29).tokens.length, 29);
assert.strictEqual(tokenizer.takeLastTokens(source, 28).tokens.length, 28);
assert.strictEqual(tokenizer.takeLastTokens(source, 5).tokens.length, 5);
assert.strictEqual(tokenizer.takeLastTokens(source, 0).tokens.length, 0);
assert.strictEqual(tokenizer.takeLastTokens(source, 1).tokens.length, 1);
assert.strictEqual(tokenizer.takeLastTokens(source, 1000).tokens.length, 1000);
assert.strictEqual(tokenizer.takeLastTokens(source, 100000).text, source);
assert.strictEqual(tokenizer.takeLastTokens('\n\n\n', 1).tokens.length, 1);
});
test('takeLastTokens returns a suffix of the sought length', function () {
function check(n: number): void {
const { text: suffix } = tokenizer.takeLastTokens(source, n);
assert.strictEqual(tokenizer.tokenLength(suffix), n);
assert.strictEqual(suffix, source.substring(source.length - suffix.length));
}
check(0);
check(1);
check(5);
check(29);
check(30);
check(100);
check(1000);
assert.strictEqual(tokenizer.takeLastTokens(source, 100000).text, source);
});
test('test takeLastLinesTokens', function () {
let example = 'a b c\nd e f\ng h i';
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 3), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 4), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 5), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 6), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 7), 'd e f\ng h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 11), example);
example = 'a b\n\n c d';
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 2), ' c d');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 3), '\n c d');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 4), '\n c d');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 5), 'a b\n\n c d');
});
test('takeFirstTokens return corresponding text and tokens', function () {
let prefix = tokenizer.takeFirstTokens(source, 30);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
prefix = tokenizer.takeFirstTokens(source, 0);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
prefix = tokenizer.takeFirstTokens('', 30);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
prefix = tokenizer.takeFirstTokens('', 0);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
});
test('takeFirstTokens invariant of ending position', function () {
const prefix = tokenizer.takeFirstTokens(source, 29).text;
const expected = `"""
This is an example Python source file to use as test data. It's pulled from the synth repo
with minor edits to make it`;
assert.strictEqual(prefix, expected);
assert.strictEqual(tokenizer.tokenLength(prefix), 29);
assert.strictEqual(prefix, tokenizer.takeFirstTokens(source.substring(0, 150), 29).text);
assert.strictEqual(prefix, tokenizer.takeFirstTokens(source.substring(0, 200), 29).text);
});
test('takeFirstTokens returns the desired number of tokens', function () {
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 30).text), 30);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 29).text), 29);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 28).text), 28);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 5).text), 5);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 0).text), 0);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 1).text), 1);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 1000).text), 1000);
assert.strictEqual(tokenizer.takeFirstTokens(source, 100000).text, source);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens('\n\n\n', 1).text), 1);
});
test('takeFirstTokens returns a prefix of the sought length', function () {
function check(n: number): void {
const prefix = tokenizer.takeFirstTokens(source, n).text;
assert.strictEqual(tokenizer.tokenLength(prefix), n);
assert.strictEqual(prefix, source.substring(0, prefix.length));
}
check(0);
check(1);
check(5);
check(29);
check(30);
check(100);
check(1000);
assert.strictEqual(tokenizer.takeFirstTokens(source, 100000).text, source);
});
/**
* Long sequences of spaces are tokenized as a sequence of 16-space tokens. This tests that
* the logic in takeFirstTokens correctly handles very long tokens.
*/
test('takeFirstTokens handles very long tokens', function () {
this.timeout(15000);
const longestSpaceToken = ' '.repeat(4000);
const tokens = tokenizer.takeFirstTokens(longestSpaceToken, 30);
assert.strictEqual(tokenizer.tokenLength(tokens.text), 30);
});
});
suite('Tokenizer Test Suite - o200k', function () {
const tokenizer = getTokenizer(TokenizerName.o200k);
test('empty string', function () {
const str = '';
assert.deepStrictEqual(tokenizer.tokenize(str), []);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('space', function () {
const str = ' ';
assert.deepStrictEqual(tokenizer.tokenize(str), [220]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('tab', function () {
const str = '\t';
assert.deepStrictEqual(tokenizer.tokenize(str), [197]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('simple text', function () {
const str = 'This is some text';
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
assert.deepStrictEqual(tokenizer.tokenize(str), [2500, 382, 1236, 2201]);
});
test('multi-token word', function () {
const str = 'indivisible';
assert.deepStrictEqual(tokenizer.tokenize(str), [521, 349, 181386]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('emojis', function () {
const str = 'hello 👋 world 🌍';
assert.deepStrictEqual(tokenizer.tokenize(str), [24912, 61138, 233, 2375, 130321, 235]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('contractions', function () {
const str = `you'll`;
assert.deepStrictEqual(tokenizer.tokenize(str), [13320, 6090]);
assert.strictEqual(tokenizer.detokenize(tokenizer.tokenize(str)), str);
});
test('assert that consecutive newline is never tokenized as multiple newlines', function () {
// This is due to a regular expression change in the tokenizer.
// Loop through all possible ascii numbers and letters
for (let i = 0; i < 128; i++) {
const char = String.fromCharCode(i);
if (char !== '\n') {
assert.deepStrictEqual(tokenizer.tokenLength(`\n\n${char}`), 2);
}
}
// Test special characters
assert.deepStrictEqual(tokenizer.tokenize('\n\n👋'), [279, 28823, 233]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n '), [279, 220]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n 👋'), [279, 61138, 233]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n\t'), [279, 197]);
assert.deepStrictEqual(tokenizer.tokenize('\n\n\r'), [279, 201]);
// New lines are treated specially tho
for (let i = 1; i < 10; i++) {
assert.deepStrictEqual(tokenizer.tokenLength('\n'.repeat(i)), 1);
}
});
test('tokenizeStrings', function () {
const tokens_s = tokenizer.tokenizeStrings(source);
assert.strictEqual(tokens_s.join(''), source, 'tokenizeStrings does not join to form the input string');
const tokens = tokenizer.tokenize(source);
assert.strictEqual(tokens_s.length, tokens.length, 'tokenizeStrings should have same length as tokenize');
const half = Math.floor(tokens_s.length / 2);
assert.strictEqual(
tokens_s.slice(0, half).join(''),
tokenizer.detokenize(tokens.slice(0, half)),
'tokenizeStrings slice should represent the corresponding slice with tokenize'
);
});
test('takeLastTokens invariant of starting position', function () {
const suffix = tokenizer.takeLastTokens(source, 25);
assert.strictEqual(
suffix.text,
`To the import {imp.module} as {imp.as_name}, we add the following comment: {description}")\n return description`
);
assert.strictEqual(suffix.tokens.length, 25);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(50), 25).text);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(100), 25).text);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(150), 25).text);
assert.strictEqual(suffix.text, tokenizer.takeLastTokens(source.substring(200), 25).text);
});
test('takeLastTokens returns the desired number of tokens', function () {
assert.strictEqual(tokenizer.takeLastTokens(source, 30).tokens.length, 30);
assert.strictEqual(tokenizer.takeLastTokens(source, 29).tokens.length, 29);
assert.strictEqual(tokenizer.takeLastTokens(source, 28).tokens.length, 28);
assert.strictEqual(tokenizer.takeLastTokens(source, 5).tokens.length, 5);
assert.strictEqual(tokenizer.takeLastTokens(source, 0).tokens.length, 0);
assert.strictEqual(tokenizer.takeLastTokens(source, 1).tokens.length, 1);
assert.strictEqual(tokenizer.takeLastTokens(source, 1000).tokens.length, 1000);
assert.strictEqual(tokenizer.takeLastTokens(source, 100000).text, source);
assert.strictEqual(tokenizer.takeLastTokens('\n\n\n', 1).tokens.length, 1);
});
test('takeLastTokens returns a suffix of the sought length', function () {
function check(n: number): void {
const { text: suffix } = tokenizer.takeLastTokens(source, n);
assert.strictEqual(tokenizer.tokenLength(suffix), n);
assert.strictEqual(suffix, source.substring(source.length - suffix.length));
}
check(0);
check(1);
check(5);
check(29);
check(30);
check(100);
check(1000);
assert.strictEqual(tokenizer.takeLastTokens(source, 100000).text, source);
});
test('test takeLastLinesTokens', function () {
let example = 'a b c\nd e f\ng h i';
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 3), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 4), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 5), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 6), 'g h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 7), 'd e f\ng h i');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 11), example);
example = 'a b\n\n c d';
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 2), ' c d');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 3), '\n c d');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 4), '\n c d');
assert.strictEqual(tokenizer.takeLastLinesTokens(example, 5), 'a b\n\n c d');
});
test('takeFirstTokens return corresponding text and tokens', function () {
let prefix = tokenizer.takeFirstTokens(source, 30);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
prefix = tokenizer.takeFirstTokens(source, 0);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
prefix = tokenizer.takeFirstTokens('', 30);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
prefix = tokenizer.takeFirstTokens('', 0);
assert.strictEqual(prefix.text, tokenizer.detokenize(prefix.tokens));
});
test('takeFirstTokens invariant of ending position', function () {
const prefix = tokenizer.takeFirstTokens(source, 29).text;
const expected = `"""
This is an example Python source file to use as test data. It's pulled from the synth repo
with minor edits to make it a`;
assert.strictEqual(prefix, expected);
assert.strictEqual(tokenizer.tokenLength(prefix), 29);
assert.strictEqual(prefix, tokenizer.takeFirstTokens(source.substring(0, 150), 29).text);
assert.strictEqual(prefix, tokenizer.takeFirstTokens(source.substring(0, 200), 29).text);
});
test('takeFirstTokens returns the desired number of tokens', function () {
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 30).text), 30);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 29).text), 29);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 28).text), 28);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 5).text), 5);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 0).text), 0);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 1).text), 1);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens(source, 1000).text), 1000);
assert.strictEqual(tokenizer.takeFirstTokens(source, 100000).text, source);
assert.strictEqual(tokenizer.tokenLength(tokenizer.takeFirstTokens('\n\n\n', 1).text), 1);
});
test('takeFirstTokens returns a prefix of the sought length', function () {
function check(n: number): void {
const prefix = tokenizer.takeFirstTokens(source, n).text;
assert.strictEqual(tokenizer.tokenLength(prefix), n);
assert.strictEqual(prefix, source.substring(0, prefix.length));
}
check(0);
check(1);
check(5);
check(29);
check(30);
check(100);
check(1000);
assert.strictEqual(tokenizer.takeFirstTokens(source, 100000).text, source);
});
/**
* Long sequences of spaces are tokenized as a sequence of 16-space tokens. This tests that
* the logic in takeFirstTokens correctly handles very long tokens.
*/
test('takeFirstTokens handles very long tokens', function () {
this.timeout(15000);
const longestSpaceToken = ' '.repeat(4000);
const tokens = tokenizer.takeFirstTokens(longestSpaceToken, 30);
assert.strictEqual(tokenizer.tokenLength(tokens.text), 30);
});
});
suite('ApproximateTokenizer', function () {
const cl100kTokenizer = new ApproximateTokenizer(TokenizerName.cl100k, 'python');
const o200kTokenizer = new ApproximateTokenizer(TokenizerName.o200k, 'python');
const defaultTokenizer = new ApproximateTokenizer(); // o200k, no language;
suite('tokenizeStrings', function () {
test('should split text into chunks of 4 characters', function () {
const result = defaultTokenizer.tokenizeStrings('abcdefgh');
assert.deepStrictEqual(result, ['abcd', 'efgh']);
});
test('should handle text not divisible by 4', function () {
const result = defaultTokenizer.tokenizeStrings('abcdefg');
assert.deepStrictEqual(result, ['abcd', 'efg']);
});
test('should handle empty string', function () {
const result = defaultTokenizer.tokenizeStrings('');
assert.deepStrictEqual(result, []);
});
test('should handle single character', function () {
const result = defaultTokenizer.tokenizeStrings('a');
assert.deepStrictEqual(result, ['a']);
});
});
suite('tokenize', function () {
test('should convert string chunks to numeric tokens', function () {
const result = defaultTokenizer.tokenize('ab');
assert.ok(Array.isArray(result));
assert.strictEqual(result.length, 1);
assert.strictEqual(typeof result[0], 'number');
});
test('should produce consistent tokens for same input', function () {
const result1 = defaultTokenizer.tokenize('test');
const result2 = defaultTokenizer.tokenize('test');
assert.deepStrictEqual(result1, result2);
});
});
suite('detokenize', function () {
test('should convert tokens back to string', function () {
const original = 'test';
const tokens = defaultTokenizer.tokenize(original);
const result = defaultTokenizer.detokenize(tokens);
assert.strictEqual(result, original);
});
test('should handle empty token array', function () {
const result = defaultTokenizer.detokenize([]);
assert.strictEqual(result, '');
});
});
test('tokenLength', function () {
assert.strictEqual(cl100kTokenizer.tokenLength('a b c'), 2);
});
test('tokenLength with language take approximated char chunks', function () {
assert.strictEqual(cl100kTokenizer.tokenLength('abc def gh'), 3);
});
test('tokenLength with no language take 4 char chunks', function () {
const str = 'w'.repeat(400);
assert.strictEqual(cl100kTokenizer.tokenLength(str), 101);
assert.strictEqual(defaultTokenizer.tokenLength(str), 100);
});
test('tokenLength approximated char chunks are correct for each approximated tokenizer', function () {
const str = 'w'.repeat(400);
assert.strictEqual(cl100kTokenizer.tokenLength(str), 101);
assert.strictEqual(o200kTokenizer.tokenLength(str), 99);
});
test('takeFirstTokens', function () {
const first2Tokens = cl100kTokenizer.takeFirstTokens('123 456 7890', 2);
assert.deepStrictEqual(first2Tokens, {
text: '123 456',
tokens: [0, 1],
});
assert.deepStrictEqual(cl100kTokenizer.tokenLength(first2Tokens.text), 2);
});
test('takeFirstTokens returns the full string if shorter', function () {
const first100Tokens = cl100kTokenizer.takeFirstTokens('123 456 7890', 100);
assert.deepStrictEqual(first100Tokens, {
text: '123 456 7890',
tokens: [0, 1, 2, 3],
});
assert.deepStrictEqual(cl100kTokenizer.tokenLength(first100Tokens.text), 4);
});
test('takeLastTokens', function () {
const last2Tokens = cl100kTokenizer.takeLastTokens('123 456 7890', 2);
assert.deepStrictEqual(last2Tokens, {
text: '56 7890',
tokens: [0, 1],
});
assert.deepStrictEqual(cl100kTokenizer.tokenLength(last2Tokens.text), 2);
});
test('takeLastTokens returns the full string if shorter', function () {
const last100Tokens = cl100kTokenizer.takeLastTokens('123 456 7890', 100);
assert.deepStrictEqual(last100Tokens, {
text: '123 456 7890',
tokens: [0, 1, 2, 3],
});
assert.deepStrictEqual(cl100kTokenizer.tokenLength(last100Tokens.text), 4);
});
suite('takeLastLinesTokens', function () {
test('should return complete lines from suffix', function () {
const text = 'line1\nline2\nline3\nline4';
const result = cl100kTokenizer.takeLastLinesTokens(text, 4);
assert.strictEqual(result, 'line3\nline4');
});
test('should handle text already within token limit', function () {
const text = 'short\ntext';
const result = cl100kTokenizer.takeLastLinesTokens(text, 100);
assert.strictEqual(result, text);
});
test('should handle text ending with newline', function () {
const text = 'line1\nline2\n';
const result = cl100kTokenizer.takeLastLinesTokens(text, 10);
assert.strictEqual(typeof result, 'string');
});
});
});

View File

@@ -0,0 +1,106 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { getIndentationWindowsDelineations } from '../snippetInclusion/windowDelineations';
import * as assert from 'assert';
import dedent from 'ts-dedent';
const SOURCE = {
source: dedent`
f1:
a1
f2:
a2
a3
`,
name: '',
};
suite('Test window delineation', function () {
test('Correct line number range, standard input', function () {
const testLineNumbers: [number, number][] = getIndentationWindowsDelineations(
SOURCE.source.split('\n'),
'python',
1,
3
);
const correctLineNumbers: [number, number][] = [
[0, 2], // f1: a1
[1, 2], // a1
[2, 5], // f2: a2 a3
[3, 4], // a2
[4, 5], // a3
];
assert.deepStrictEqual(testLineNumbers.sort(), correctLineNumbers.sort());
});
test('Correct line number range, standard input, decreased maxLength', function () {
const testLineNumbers: [number, number][] = getIndentationWindowsDelineations(
SOURCE.source.split('\n'),
'python',
1,
2
);
const correctLineNumbers: [number, number][] = [
[0, 2], // f1: a1
[1, 2], // a1
[3, 4], // a2
[4, 5], // a3
// We lose [2, 5] f2: a2 a3 as too long
// But we gain the following which were previously swallowed up by [2, 5]
[2, 4], // f2: a2
[3, 5], // a2 a3
];
assert.deepStrictEqual(testLineNumbers.sort(), correctLineNumbers.sort());
});
test('Correct line number range, standard input, increased minLength', function () {
const testLineNumbers: [number, number][] = getIndentationWindowsDelineations(
SOURCE.source.split('\n'),
'python',
2,
3
);
const correctLineNumbers: [number, number][] = [
[0, 2], // f1: a1
[2, 5], // f2: a2 a3
// We lose the following as too short
// [1, 2] a1
// [3, 4] a2
// [4, 5] a3
];
assert.deepStrictEqual(testLineNumbers.sort(), correctLineNumbers.sort());
});
test('Correct line number range, flat input', function () {
const source: string = dedent`
a1
a2
a3
`;
const testLineNumbers: [number, number][] = getIndentationWindowsDelineations(
source.split('\n'),
'python',
1,
3
);
const correctLineNumbers: [number, number][] = [
[0, 1], // a1
[1, 2], // a2
[2, 3], // a3
[0, 3], // a1 a2 a3
// Don't get [0, 2] nor [1, 3] because they not single children nor the whole tree
];
assert.deepStrictEqual(testLineNumbers.sort(), correctLineNumbers.sort());
});
test('Check degenerate case', function () {
const testLineNumbers: [number, number][] = getIndentationWindowsDelineations(
SOURCE.source.split('\n'),
'python',
0,
0
);
const correctLineNumbers: [number, number][] = [];
assert.deepStrictEqual(testLineNumbers.sort(), correctLineNumbers.sort());
});
});

View File

@@ -0,0 +1,6 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export * from './tokenizer';

View File

@@ -0,0 +1,385 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { TikTokenizer, createTokenizer, getRegexByEncoder, getSpecialTokensByEncoder } from '@microsoft/tiktokenizer';
import { parseTikTokenBinary } from '../../../../../../platform/tokenizer/node/parseTikTokens';
import { CopilotPromptLoadFailure } from '../error';
import { locateFile } from '../fileLoader';
export enum TokenizerName {
cl100k = 'cl100k_base',
o200k = 'o200k_base',
mock = 'mock',
}
const tokenizers = new Map<TokenizerName, Tokenizer>();
export function getTokenizer(name: TokenizerName = TokenizerName.o200k): Tokenizer {
let tokenizer = tokenizers.get(name);
if (tokenizer !== undefined) { return tokenizer; }
// Fallback to o200k
tokenizer = tokenizers.get(TokenizerName.o200k);
if (tokenizer !== undefined) { return tokenizer; }
// Fallback to approximate tokenizer
return new ApproximateTokenizer();
}
export interface Tokenizer {
/**
* Return the length of `text` in number of tokens.
*
* @param text - The input text
* @returns
*/
tokenLength(text: string): number;
/**
* Returns the tokens created from tokenizing `text`.
* @param text The text to tokenize
*/
tokenize(text: string): number[];
/**
* Returns the string representation of the tokens in `tokens`, given in integer
* representation.
*
* This is the functional inverse of `tokenize`.
*/
detokenize(tokens: number[]): string;
/**
* Returns the tokenization of the input string as a list of strings.
*
* The concatenation of the output of this function is equal to the input.
*/
tokenizeStrings(text: string): string[];
/**
* Return a suffix of `text` which is `n` tokens long.
* If `text` is at most `n` tokens, return `text`.
*
* Note: This implementation does not attempt to return
* the longest possible suffix, only *some* suffix of at
* most `n` tokens.
*
* @param text - The text from which to take
* @param n - How many tokens to take
* @returns A suffix of `text`, as a `{ text: string, tokens: number[] }`.
*/
takeLastTokens(text: string, n: number): { text: string; tokens: number[] };
/**
* Return a prefix of `text` which is `n` tokens long.
* If `text` is at most `n` tokens, return `text`.
*
* Note: This implementation does not attempt to return
* the longest possible prefix, only *some* prefix of at
* most `n` tokens.
*
* @param text - The text from which to take
* @param n - How many tokens to take
* @returns A prefix of `text`, as a `{ text: string, tokens: number[] }`.
*/
takeFirstTokens(text: string, n: number): { text: string; tokens: number[] };
/**
* Return the longest suffix of `text` of complete lines and is at most
* `n` tokens long.
* @param text - The text from which to take
* @param n - How many tokens to take
*/
takeLastLinesTokens(text: string, n: number): string;
}
export class TTokenizer implements Tokenizer {
constructor(private readonly _tokenizer: TikTokenizer) { }
static async create(encoder: TokenizerName): Promise<TTokenizer> {
try {
const tokenizer = createTokenizer(
parseTikTokenBinary(locateFile(`${encoder}.tiktoken`)),
getSpecialTokensByEncoder(encoder),
getRegexByEncoder(encoder),
32768
);
return new TTokenizer(tokenizer);
} catch (e: unknown) {
if (e instanceof Error) {
throw new CopilotPromptLoadFailure(`Could not load tokenizer`, e);
}
throw e;
}
}
tokenize(text: string): number[] {
return this._tokenizer.encode(text);
}
detokenize(tokens: number[]): string {
return this._tokenizer.decode(tokens);
}
tokenLength(text: string): number {
return this.tokenize(text).length;
}
tokenizeStrings(text: string): string[] {
const tokens = this.tokenize(text);
return tokens.map(token => this.detokenize([token]));
}
takeLastTokens(text: string, n: number): { text: string; tokens: number[] } {
if (n <= 0) { return { text: '', tokens: [] }; }
// Find long enough suffix of text that has >= n + 2 tokens
// We add the 2 extra tokens to avoid the edge case where
// we cut at exactly n tokens and may get an odd tokenization.
const CHARS_PER_TOKENS_START = 4;
const CHARS_PER_TOKENS_ADD = 1;
let chars = Math.min(text.length, n * CHARS_PER_TOKENS_START); //First guess
let suffix = text.slice(-chars);
let suffixT = this.tokenize(suffix);
while (suffixT.length < n + 2 && chars < text.length) {
chars = Math.min(text.length, chars + n * CHARS_PER_TOKENS_ADD);
suffix = text.slice(-chars);
suffixT = this.tokenize(suffix);
}
if (suffixT.length < n) {
// text must be <= n tokens long
return { text, tokens: suffixT };
}
// Return last n tokens
suffixT = suffixT.slice(-n);
return { text: this.detokenize(suffixT), tokens: suffixT };
}
takeFirstTokens(text: string, n: number): { text: string; tokens: number[] } {
if (n <= 0) { return { text: '', tokens: [] }; }
// Find long enough suffix of text that has >= n + 2 tokens
// We add the 2 extra tokens to avoid the edge case where
// we cut at exactly n tokens and may get an odd tokenization.
const CHARS_PER_TOKENS_START = 4;
const CHARS_PER_TOKENS_ADD = 1;
let chars = Math.min(text.length, n * CHARS_PER_TOKENS_START); //First guess
let prefix = text.slice(0, chars);
let prefix_t = this.tokenize(prefix);
while (prefix_t.length < n + 2 && chars < text.length) {
chars = Math.min(text.length, chars + n * CHARS_PER_TOKENS_ADD);
prefix = text.slice(0, chars);
prefix_t = this.tokenize(prefix);
}
if (prefix_t.length < n) {
// text must be <= n tokens long
return {
text: text,
tokens: prefix_t,
};
}
// Return first n tokens
// This implicit "truncate final tokens" text processing algorithm
// could be extracted into a generic snippet text processing function managed by the SnippetTextProcessor class.
prefix_t = prefix_t.slice(0, n);
return {
text: this.detokenize(prefix_t),
tokens: prefix_t,
};
}
takeLastLinesTokens(text: string, n: number): string {
const { text: suffix } = this.takeLastTokens(text, n);
if (suffix.length === text.length || text[text.length - suffix.length - 1] === '\n') {
// Edge case: We already took whole lines
return suffix;
}
const newline = suffix.indexOf('\n');
return suffix.substring(newline + 1);
}
}
class MockTokenizer implements Tokenizer {
private hash = (str: string) => {
let hash = 0;
for (let i = 0; i < str.length; i++) {
const char = str.charCodeAt(i);
hash = (hash << 5) - hash + char;
hash &= hash & 0xffff;
}
return hash;
};
tokenize(text: string): number[] {
return this.tokenizeStrings(text).map(this.hash);
}
detokenize(tokens: number[]): string {
// Note because this is using hashing to mock tokenization, it is not
// reversible, so detokenize will not return the original input.
return tokens.map(token => token.toString()).join(' ');
}
tokenizeStrings(text: string): string[] {
return text.split(/\b/);
}
tokenLength(text: string): number {
return this.tokenizeStrings(text).length;
}
takeLastTokens(text: string, n: number): { text: string; tokens: number[] } {
const tokens = this.tokenizeStrings(text).slice(-n);
return { text: tokens.join(''), tokens: tokens.map(this.hash) };
}
takeFirstTokens(text: string, n: number): { text: string; tokens: number[] } {
const tokens = this.tokenizeStrings(text).slice(0, n);
return { text: tokens.join(''), tokens: tokens.map(this.hash) };
}
takeLastLinesTokens(text: string, n: number): string {
const { text: suffix } = this.takeLastTokens(text, n);
if (suffix.length === text.length || text[text.length - suffix.length - 1] === '\n') {
// Edge case: We already took whole lines
return suffix;
}
const newline = suffix.indexOf('\n');
return suffix.substring(newline + 1);
}
}
// These are the effective token lengths for each language. They are based on empirical data to balance the risk of accidental overflow and overeager elision.
// Note: These may need to be recalculated in the future if typical prompt lengths are significantly changed.
const EFFECTIVE_TOKEN_LENGTH: Partial<Record<TokenizerName, Record<string, number>>> = {
[TokenizerName.cl100k]: {
python: 3.99,
typescript: 4.54,
typescriptreact: 4.58,
javascript: 4.76,
csharp: 5.13,
java: 4.86,
cpp: 3.85,
php: 4.1,
html: 4.57,
vue: 4.22,
go: 3.93,
dart: 5.66,
javascriptreact: 4.81,
css: 3.37,
},
[TokenizerName.o200k]: {
python: 4.05,
typescript: 4.12,
typescriptreact: 5.01,
javascript: 4.47,
csharp: 5.47,
java: 4.86,
cpp: 3.8,
php: 4.35,
html: 4.86,
vue: 4.3,
go: 4.21,
dart: 5.7,
javascriptreact: 4.83,
css: 3.33,
},
};
/** Max decimals per code point for ApproximateTokenizer mock tokenization. */
const MAX_CODE_POINT_SIZE = 4;
/** A best effort tokenizer computing the length of the text by dividing the
* number of characters by estimated constants near the number 4.
* It is not a real tokenizer. */
export class ApproximateTokenizer implements Tokenizer {
tokenizerName: TokenizerName;
constructor(
tokenizerName: TokenizerName = TokenizerName.o200k,
private languageId?: string
) {
this.tokenizerName = tokenizerName;
}
tokenize(text: string): number[] {
return this.tokenizeStrings(text).map(substring => {
let charCode = 0;
for (let i = 0; i < substring.length; i++) {
charCode = charCode * Math.pow(10, MAX_CODE_POINT_SIZE) + substring.charCodeAt(i);
}
return charCode;
});
}
detokenize(tokens: number[]): string {
return tokens
.map(token => {
const chars = [];
let charCodes = token.toString();
while (charCodes.length > 0) {
const charCode = charCodes.slice(-MAX_CODE_POINT_SIZE);
const char = String.fromCharCode(parseInt(charCode));
chars.unshift(char);
charCodes = charCodes.slice(0, -MAX_CODE_POINT_SIZE);
}
return chars.join('');
})
.join('');
}
tokenizeStrings(text: string): string[] {
// Mock tokenize by defaultETL
return text.match(/.{1,4}/g) ?? [];
}
private getEffectiveTokenLength(): number {
// Our default is 4, used for tail languages and error handling
const defaultETL = 4;
if (this.tokenizerName && this.languageId) {
// Use our calculated effective token length for head languages
return EFFECTIVE_TOKEN_LENGTH[this.tokenizerName]?.[this.languageId] ?? defaultETL;
}
return defaultETL;
}
tokenLength(text: string): number {
return Math.ceil(text.length / this.getEffectiveTokenLength());
}
takeLastTokens(text: string, n: number): { text: string; tokens: number[] } {
if (n <= 0) { return { text: '', tokens: [] }; }
// Return the last characters approximately. It doesn't matter what we return as token, just that it has the correct length.
const suffix = text.slice(-Math.floor(n * this.getEffectiveTokenLength()));
return { text: suffix, tokens: Array.from({ length: this.tokenLength(suffix) }, (_, i) => i) };
}
takeFirstTokens(text: string, n: number): { text: string; tokens: number[] } {
if (n <= 0) { return { text: '', tokens: [] }; }
// Return the first characters approximately.
const prefix = text.slice(0, Math.floor(n * this.getEffectiveTokenLength()));
return { text: prefix, tokens: Array.from({ length: this.tokenLength(prefix) }, (_, i) => i) };
}
takeLastLinesTokens(text: string, n: number): string {
const { text: suffix } = this.takeLastTokens(text, n);
if (suffix.length === text.length || text[text.length - suffix.length - 1] === '\n') {
// Edge case: We already took whole lines
return suffix;
}
const newline = suffix.indexOf('\n');
return suffix.substring(newline + 1);
}
}
async function setTokenizer(name: TokenizerName) {
try {
const tokenizer = await TTokenizer.create(name);
tokenizers.set(name, tokenizer);
} catch {
// Ignore errors loading tokenizer
}
}
/** Load tokenizers on start. Export promise for to be awaited by initialization. */
export const initializeTokenizers = (async () => {
tokenizers.set(TokenizerName.mock, new MockTokenizer());
await Promise.all([setTokenizer(TokenizerName.cl100k), setTokenizer(TokenizerName.o200k)]);
})();