Skip to content

Commit

Permalink
Update pytorch.js (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jan 15, 2025
1 parent a49e537 commit ad8865c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1,799 deletions.
186 changes: 72 additions & 114 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -5350,7 +5350,7 @@ python.Execution = class {
return false;
}
sweep(block, recurse) {
const nodes = block.nodes().reverse();
const nodes = Array.from(block.nodes()).reverse();
for (const node of nodes) {
this.removeDeadBlockOutputs(node);
this.removeDeadLoopOutputs(node);
Expand Down Expand Up @@ -5380,7 +5380,7 @@ python.Execution = class {
return it;
}
const has_side_effects = node.hasSideEffects() ||
node.blocks().some((b) => b.nodes().some((n) => this.hasSideEffects(n))) ||
node.blocks().some((b) => Array.from(b.nodes()).some((n) => this.hasSideEffects(n))) ||
this.hasUntrackedMutation(node);
this._memo.set(node, has_side_effects);
return has_side_effects;
Expand Down Expand Up @@ -9871,9 +9871,34 @@ python.Execution = class {
return [op, overload_names];
}
});
this.registerType('torch._C.graph_node_list', class {
constructor(head) {
this.head = head;
}
front() {
return this.head.next;
}
end() {
return this.head.prev;
}
[Symbol.iterator]() {
let current = this.head.next;
const prev = this.head.prev;
return {
next() {
if (current !== prev) {
const value = current;
current = current.next;
return { value, done: false };
}
return { done: true };
}
};
}
});
this.registerType('torch.Graph', class {
constructor() {
this._next_unique = 1;
this._next_unique = 0;
this._unique_names = new Map();
this._name_base_suffix = new Map();
this.all_nodes = new Set();
Expand Down Expand Up @@ -10075,7 +10100,9 @@ python.Execution = class {
return torch._C.insertConstant(this, val, loc, scope);
}
insertMethodCall(method_name, matched) {
const result = this.insertNode(this.create('prim::CallMethod', matched.inputs)).s_('name', method_name).output().setType(matched.return_types[0]);
const result = this.insertNode(this.create('prim::CallMethod', matched.inputs))
.s_('name', method_name)
.output().setType(matched.return_types[0]);
return result;
}
insertUncheckedCast(v, type) {
Expand Down Expand Up @@ -10201,13 +10228,7 @@ python.Execution = class {
return this._output.inputs();
}
nodes() {
const nodes = [];
let current = this._input.next;
while (current !== this._input.prev) {
nodes.push(current);
current = current.next;
}
return nodes;
return new torch._C.graph_node_list(this._input);
}
return_node() {
return this._output;
Expand Down Expand Up @@ -10717,17 +10738,33 @@ python.Execution = class {
const names = this.attributeNames();
for (let i = 0; i < names.length; i++) {
const name = names[i];
if (ignore_subgraph && name === 'subgraph') {
if (ignore_subgraph && name === 'Subgraph') {
continue;
}
if (i > 0) {
out.write(', ');
}
out.write(`${name}=`);
out.write(this._values.get(name)[0]); // this.printAttrValue(out, name);
this.printAttrValue(out, name);
}
out.write(']');
}
printAttrValue(out, name) {
const kind = this.kindOf(name);
switch (kind) {
case 'c': case 'cs': case 'f': case 'fs': case 'i': case 'is':
case 'ss': case 't': case 'ival': case 'ty':
out.write(this[kind](name));
break;
case 's':
out.write(`"${this.s(name)}"`);
break;
case 'ts': out.write('[<Tensors>]'); break;
case 'g': out.write('[<Graph>]'); break;
case 'gs': out.write('[<Graphs>]'); break;
default: throw new python.Error(`Unknown attribute kind '${kind}'.`);
}
}
print(out, level, groups, print_source_locations, print_attributes, print_scopes, print_body) {
print_source_locations = print_source_locations === false ? false : true;
print_attributes = print_attributes === false ? false : true;
Expand Down Expand Up @@ -10999,7 +11036,7 @@ python.Execution = class {
}
torch._C.printValueRef(out, n);
out.write(' : ');
out.write(n.type().toString());
out.write(n.type().str());
}
});
this.register('torch.jit._script');
Expand Down Expand Up @@ -11283,9 +11320,7 @@ python.Execution = class {
const type_parser = new torch._C.ScriptTypeParser(this);
for (const assign of attributes) {
const name = assign.name;
const annotation = this._cu.execution.to_ir ?
type_parser.parseTypeFromExpr(assign.annotation) :
this._cu.execution.type(assign.annotation, null);
const annotation = type_parser.parseTypeFromExpr(assign.annotation);
const is_parameter = parameter_names.has(name);
const is_buffer = buffer_names.has(name);
class_type.addAttribute(name, annotation, is_parameter, is_buffer);
Expand Down Expand Up @@ -11981,16 +12016,20 @@ python.Execution = class {
if (name.length === 0 && name[0] === '$') {
return false;
}
return name[0] !== '_' && !/[0-9]/.test(name.slice(1));
if (name[0] !== '_') {
return true;
}
return !/\d+/.test(name.slice(1));
});
this.registerFunction('torch._C.materializeConstant', (val, graph, r, map) => {
const existing_constant = map.get(val);
const key = `${val.value}:${val.tag}`;
const existing_constant = map.get(key);
if (existing_constant) {
return existing_constant;
}
const guard = new torch._C.WithInsertPoint(graph.block().nodes()[0]);
const guard = new torch._C.WithInsertPoint(graph.block().nodes().front());
const new_constant = graph.insertConstant(val, r);
map.set(val, new_constant);
map.set(key, new_constant);
guard.dispose();
return new_constant;
});
Expand Down Expand Up @@ -13806,90 +13845,8 @@ python.Execution = class {
}
get graph() {
if (!this._graph) {
if (execution.to_ir) {
const fn = this._typ.getMethod('forward');
this._graph = fn.graph();
} else {
const isTensor = (obj) => {
const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
switch (name) {
case 'torch':
case 'torch.cuda':
return obj.__class__.__name__.endsWith('Tensor');
case 'torch.nn.parameter':
return obj.__class__.__name__ === 'Parameter';
default:
return false;
}
};
if (!this.forward) {
return null;
}
const args = [];
if (this.forward.__code__ && this.forward.__code__.args) {
const params = this.forward.__code__.args.args;
for (let i = 0; i < params.length; i++) {
const arg = params[i];
const value = execution.graph.addInput(arg.arg);
if (i === 0 && arg.arg === 'self' && !arg.annotation) {
value.setType(this.type());
} else {
value.setType(execution.type(arg.annotation));
}
if (isTensor(value)) {
value.__variable__ = arg.name;
value.__origin__ = 'graph-input';
}
args.push(value);
}
}
execution.purge = new Set();
const result = this.forward.__call__(args);
const queue = Array.from(execution.purge);
const visited = new Set();
while (queue.length > 0) {
const node = queue.shift();
if (visited.has(node)) {
continue;
}
visited.add(node);
if (node.outputs().every((output) => output.uses().length === 0)) {
for (const input of node.inputs()) {
queue.push(input.node());
}
node.destroy();
}
}
if (Array.isArray(result)) {
for (const output of result) {
if (isTensor(output)) {
const value = execution.variable(output);
execution.graph.return_node().addInput(value);
}
}
} else if (isTensor(result)) {
const value = execution.variable(result);
execution.graph.return_node().addInput(value);
} else if (result instanceof torch.Value) {
execution.graph.return_node().addInput(result);
} else if (Object(result) === result) {
for (const key of Object.keys(result)) {
const item = result[key];
if (Array.isArray(item)) {
for (const output of item) {
if (isTensor(output)) {
const value = execution.variable(output);
execution.graph.return_node().addInput(value);
}
}
} else if (isTensor(item)) {
const value = execution.variable(item);
execution.graph.return_node().addInput(value);
}
}
}
this._graph = execution.graph;
}
const fn = this._typ.getMethod('forward');
this._graph = fn.graph();
}
return this._graph;
}
Expand Down Expand Up @@ -14025,10 +13982,8 @@ python.Execution = class {
this._def_stack[this._def_stack.length - 1]._declared_return_type = schema.returns[0].type;
}
const args = this.emitFormalArguments(def, self, schema, block);
if (execution.to_ir) {
this.emitStatements(def.body);
this.handleMaybeNoReturn(def, block);
}
this.emitStatements(def.body);
this.handleMaybeNoReturn(def, block);
const returns = [this.emitOutput(def, schema, block)];
return new torch.FunctionSchema(def.name, '', args, returns);
}
Expand Down Expand Up @@ -14370,6 +14325,9 @@ python.Execution = class {
}
} break;
*/
if (expr instanceof ast.UnaryOp) {
throw new python.Error('Not implemented.');
}
if (expr instanceof ast.Call) {
const apply = expr;
const callee = expr.func;
Expand Down Expand Up @@ -15612,7 +15570,7 @@ python.Execution = class {
const [b] = args;
{
this._graph = b.owningGraph();
const guard = new torch._C.WithInsertPoint(b.nodes()[0]);
const guard = new torch._C.WithInsertPoint(b.nodes().front());
this._false_val = this._graph.insertConstant(false);
guard.dispose();
}
Expand Down Expand Up @@ -15766,7 +15724,7 @@ python.Execution = class {
this._graph = graph;
this._target_block = null;
this._unit_values = new Map();
const guard = new torch._C.WithInsertPoint(this._graph.block().nodes()[0]);
const guard = new torch._C.WithInsertPoint(this._graph.block().nodes().front());
this._true_val = this._graph.insertConstant(true);
this._false_val = this._graph.insertConstant(false);
this._throws_val = this.getUnitValue(torch.BoolType.get());
Expand Down Expand Up @@ -15799,11 +15757,11 @@ python.Execution = class {
}
deleteAfterExitNodes(block, iter) {
const nodes = block.nodes();
if (iter === nodes[nodes.length - 1].next) {
if (iter === nodes.end()) {
return;
}
const insert = new torch._C.WithInsertPoint(block.nodes()[0]);
for (const it of nodes.reverse()) {
const insert = new torch._C.WithInsertPoint(block.nodes().front());
for (const it of Array.from(nodes).reverse()) {
if (it === iter) {
break;
}
Expand Down
Loading

0 comments on commit ad8865c

Please sign in to comment.