nimxla/graph

Search:
Group by:
Source   Edit  

The graph module wraps the XLA API in order to build and compile a graph of operations.

A graph is a tree of Nodes (which each wrap an XLA operation). The Builder is used to construct new Nodes and to finalize the resulting graph to a Computation.

The nodes in the graph are typed - i.e. each has a shape which defines the data type and dimensions. Each node will derive it's shape from it's inputs or throw an XLAError exception at build time if they are not compatible. See https://www.tensorflow.org/xla/operation_semantics for more details.

If there is an error during construction of the graph an BuilderError exception is raised with the details of the current AST tree and the reason for the error.

A computation can be compiled for a specific device using the Client which is defined using the nimxla module. Shapes and host literal types are defined in the literal module.

Types

Builder = ref object
  
Builder is used to construct new operations. It holds a reference to a xla_builder. Source   Edit  
BuilderError = ref object of CatchableError
  origMsg*: string
  at*: Node
Exception raised while building the graph. origMsg is the original status message. at indicates the node in the graph where the error occured. The repr of this node is added to the msg field. Source   Edit  
Computation = ref ComputationObj
A Computation wraps the constructed graph after it has been finalised. It holds a reference to the xla_computation object. Source   Edit  
HloModule = object
  
HloModule is a serialized version of a Computation in StableHLO format. Source   Edit  
Node = ref object
  id*: int
  shape*: Shape
  args*: seq[Node]
  noGrad*: bool
  builder*: Builder
  case kind*: OpType
  of tParam:
      name*: string
    
  of tError:
    
  of tTupleElement, tConcat:
    
  of tReshape, tBroadcast, tCollapse, tTranspose, tNarrow, tRngUniform,
     tRngNormal, tReduce, tReduceSum, tReduceMin, tReduceMax, tArgmin, tArgmax,
     tReverse:
    
  of tBroadcastInDim:
    
  of tReduceWindow, tMaxPool, tSumPool:
    
  of tConv:
    
  of tPad:
    
  of tBatchNormInference, tBatchNormTraining, tBatchNormGrad:
    
  of tSoftmax:
    
  else:
      nil

  
A Node is generated from each xla_op once it is added to the graph. The id number is the index to the nodes sequence in the Computation object and is set by the builder. The shape is the output data type and dimesnions for the op. This must be fixed and known at build time. If the noGrad attribute is set then gradients are not accumulated from this node or it's inputs. Source   Edit  
OpType = enum
  tNone, tConst, tLiteral, tParam, tError, tIota, ## leaf nodes
  tNot, tNeg, tAbs, tExp, tFloor, tCeil, tRound, tLog, ## 1 arg ops
  tLog1p, tSigmoid, tRelu, tSign, tCos, tSin, tTanh, tSqrt, ## ..
  tRsqrt, tIsFinite, tCopy, tZerosLike, tTupleElement, ## ..
  tReshape, tBroadcast, tBroadcastInDim, tCollapse, ## ..
  tTranspose, tNarrow, tConvert, tReverse, tMaxPool, tSumPool, ## ..
  tReduceSum, tReduceMin, tReduceMax, tArgmin, tArgmax, ## ..
  tSoftmax,                 ## ..
  tAdd, tSub, tMul, tDiv, tRem, tMax, tMin, tPow, tDot, ## 2 arg ops
  tAnd, tOr, tEq, tNe, tGe, tGt, tLe, tLt, tRngUniform, ## ..
  tRngNormal, tReduce, tGather, tConv, tReduceWindow, ## ..
  tSelectAndScatter, tPad,  ## ..
  tSelect, tClamp, tTuple, tConcat, tScatter, ## 3 or more arg ops
  tBatchNormInference, tBatchNormTraining, tBatchNormGrad ## ..
Source   Edit  

Procs

proc `!`(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise logical not. Source   Edit  
proc `!=`(a, b: Node): Node {....raises: [Exception, BuilderError],
                              tags: [RootEffect], forbids: [].}
Elementwise not equal. Returns a Bool array. Source   Edit  
proc `$`(comp: Computation): string {....raises: [XLAError, Exception],
                                      tags: [RootEffect], forbids: [].}
Dumps out the name, parameters and info for each node added to the graph. Source   Edit  
proc `$`(hlo: HloModule): string {....raises: [XLAError], tags: [], forbids: [].}
Dump computation in StableHLO MLIR text format. Source   Edit  
proc `$`(n: Node): string {....raises: [ValueError, Exception], tags: [RootEffect],
                            forbids: [].}
Print node id, type, shape and info fields. Source   Edit  
proc `*`(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise multiply Source   Edit  
proc `+`(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise add Source   Edit  
proc `-`(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise subtract Source   Edit  
proc `-`(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise arithmetic negation Source   Edit  
proc `/`(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise divide Source   Edit  
proc `<`(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise less than. Returns a Bool array. Source   Edit  
proc `<=`(a, b: Node): Node {....raises: [Exception, BuilderError],
                              tags: [RootEffect], forbids: [].}
Elementwise less than or equal. Returns a Bool array. Source   Edit  
proc `==`(a, b: Node): Node {....raises: [Exception, BuilderError],
                              tags: [RootEffect], forbids: [].}
Elementwise equal. Returns a Bool array. Source   Edit  
proc `>`(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise greater than. Returns a Bool array. Source   Edit  
proc `>=`(a, b: Node): Node {....raises: [Exception, BuilderError],
                              tags: [RootEffect], forbids: [].}
Elementwise greater or equal. Returns a Bool array. Source   Edit  
proc `[]`(a: Node; index: int): Node {....raises: [Exception, BuilderError],
                                       tags: [RootEffect], forbids: [].}
Return the element from the input tuple at index. Source   Edit  
proc abs(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise absolute value Source   Edit  
proc addAt(a, indices, b: Node): Node {....raises: [Exception, BuilderError],
                                        tags: [RootEffect], forbids: [].}

Adds values from array b to array a at the given indices.

For example:

a = <f32 2 3>[[1 2 3]    ix = <i64 3 2>[[0 0]   b = <f32 3>[1 2 3]
              [4 5 6]]                  [1 2]
                                        [0 0]]
a.addAt(ix, b) = <f32 2 3>[[5 2 3]
                            4 5 8]]
This is implemented using the scatter op.

Source   Edit  
proc argMax(a: Node; axis: int; keepDims = false; ixType = I64): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Get the indices of the minimum or maxiumum values along the given axis for argmin and argmax respectively. By default the shape of the result will be as per the input with this axis removed. If keepDims is set the axis for the reduction is kept in the output with size of 1. If a negative axis is given then this is taken relative to the number of dimensions of the input. Source   Edit  
proc argMin(a: Node; axis: int; keepDims = false; ixType = I64): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Get the indices of the minimum or maxiumum values along the given axis for argmin and argmax respectively. By default the shape of the result will be as per the input with this axis removed. If keepDims is set the axis for the reduction is kept in the output with size of 1. If a negative axis is given then this is taken relative to the number of dimensions of the input. Source   Edit  
proc avgPool1d(a: Node; kernelSize: int; strides = 0; padding = pad(0);
               channelsFirst = false): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Average pooling over 1 dimensional input array. Stride defaults to kernelSize if left as 0. channelsFirst setting is as per conv1d. Source   Edit  
proc avgPool2d(a: Node; kernelSize: Opt2d; strides: Opt2d = 0;
               padding: Pad2d = pad(0); channelsFirst = false): Node
Average pooling over 2 dimensional input array. Stride defaults to kernelSize if left as 0. channelsFirst setting is as per conv2d. Source   Edit  
proc avgPool3d(a: Node; kernelSize: Opt3d; strides: Opt3d = 0;
               padding: Pad3d = pad(0); channelsFirst = false): Node
Average pooling over 3 dimensional input array. Stride defaults to kernelSize if left as 0. channelsFirst setting is as per conv3d. Source   Edit  
proc batchNormGrad(a, scale, mean, variance, gradOutput: Node; epsilon: float;
                   axis: int): Node {....raises: [Exception, BuilderError,
    ValueError], tags: [RootEffect], forbids: [].}
Calculates gradient of batch norm. Returns a tuple of (grad_a, grad_scale, grad_offset) See BatchNormGrad. Source   Edit  
proc batchNormInference(a, scale, offset, mean, variance: Node; epsilon: float;
                        axis: int): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Implements batch normalization in inference mode. axis should be the axis of the feature dimension - e.g. 3 or -1 for images in [N,H,W,C] format. See BatchNormInference. Source   Edit  
proc batchNormTraining(a, scale, offset: Node; epsilon: float; axis: int): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Implements batch normalization in training mode. Returns a tuple of (output, batch_mean, batch_var) See BatchNormTraining. Source   Edit  
proc broadcast(a: Node; dims: openArray[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Source   Edit  
proc broadcastInDim(a: Node; outSize, bcastDims: openArray[int]): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Expand dims at each index in bcastDimes from 1 to the corresponding value in outSize as per BroadcastInDim Source   Edit  
proc build(b: Builder; root: Node): Computation {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Build a computation from the specified root operation. Should only be called once for a given graph. Source   Edit  
proc ceil(a: Node): Node {....raises: [Exception, BuilderError],
                           tags: [RootEffect], forbids: [].}
Elementwise ceil rounding Source   Edit  
proc clamp(a, min, max: Node): Node {....raises: [Exception, BuilderError],
                                      tags: [RootEffect], forbids: [].}
Clamp values in a to be between min and max. Source   Edit  
proc clone(b: Builder; node: Node): Node {....raises: [Exception],
    tags: [RootEffect], forbids: [].}
Recursively make a copy of node and all of its inputs. This will assign a new id to each of the non-leaf nodes. Source   Edit  
proc collapse(a: Node; dims: openArray[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Collapse the given dimensions into a single dimension as per Collapse. dims should be an in-order consecutive subset of the input dims. Source   Edit  
proc concat(a: Node; nodes: openArray[Node]; axis: int): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Concatenate the given nodes with a along the given axis. Source   Edit  
proc constant(b: Builder; lit: Literal): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create new constant from the given literal Source   Edit  
proc constant(b: Builder; value: float32): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new scalar constant from the given value. Source   Edit  
proc constant(b: Builder; value: float64): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new scalar constant from the given value. Source   Edit  
proc constant(b: Builder; value: int): Node {....raises: [Exception, BuilderError],
    tags: [RootEffect], forbids: [].}
Create a new int64 scalar constant Source   Edit  
proc constant(b: Builder; value: int32): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new scalar constant from the given value. Source   Edit  
proc constant(b: Builder; value: int64): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new scalar constant from the given value. Source   Edit  
proc constant(b: Builder; value: openArray[float32]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new vector constant from the given value. Source   Edit  
proc constant(b: Builder; value: openArray[float64]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new vector constant from the given value. Source   Edit  
proc constant(b: Builder; value: openArray[int32]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new vector constant from the given value. Source   Edit  
proc constant(b: Builder; value: openArray[int64]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new vector constant from the given value. Source   Edit  
proc constant(b: Builder; value: openArray[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new int64 vector constant Source   Edit  
proc constant[T: ElemType](b: Builder; t: Tensor[T]): Node
Create new constant from the given tensor Source   Edit  
proc constant[T: float | int](b: Builder; value: T; dtype: DataType): Node
Create a new scalar constant with the given type. Source   Edit  
proc conv1d(a, kernel: Node; strides = 1; padding = pad(0); dilation = 1;
            groups = 1; channelsFirst = false): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}

One dimensional convolution with optional low and high padding and either strides or dilation.

Default layout should be

  • a: [N, W, C]
  • kernel: [W, C, K]

If channelsFirst is set then

  • a: [N, C, W]
  • kernel: [K, C, W]

Where N = number of batches, C = input channels, K = output channels and W is the spatial dimension. If groups is > 1 then performs a grouped convolution. In this case C and K should be divisible by groups.

Source   Edit  
proc conv2d(a, kernel: Node; strides: Opt2d = 1; padding: Pad2d = pad(0);
            dilation: Opt2d = 1; groups = 1; channelsFirst = false): Node

Two dimensional convolution with optional padding and either strides or dilation.

Default layout should be

  • a: [N, H, W, C]
  • kernel: [H, W, C, K]

If channelsFirst is set then

  • a: [N, C, H, W]
  • kernel: [K, C, H, W]

Where N = number of batches, C = input channels, K = output channels and H, W are spatial dimensions. If groups is > 1 then performs a grouped convolution. In this case C and K should be divisible by groups.

Source   Edit  
proc conv3d(a, kernel: Node; strides: Opt3d = 1; padding: Pad3d = pad(0);
            dilation: Opt3d = 1; groups = 1; channelsFirst = false): Node

Three dimensional convolution with optional padding and either strides or dilation.

Default layout should be

  • a: [N, D, H, W, C]
  • kernel: [D, H, W, C, K]

If channelsFirst is set then

  • a: [N, C, D, H, W]
  • kernel: [K, C, D, H, W]

Where N = number of batches, C = input channels, K = output channels and D, H, W are spatial dimensions. If groups is > 1 then performs a grouped convolution. In this case C and K should be divisible by groups.

Source   Edit  
proc convert(a: Node; dtype: DataType): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Source   Edit  
proc convolution(a, kernel: Node;
                 inputDims, outputDims, kernelDims: openArray[int];
                 strides: openArray[int] = []; padding: openArray[Padding] = [];
                 dilation: openArray[int] = [];
                 inputDilation: openArray[int] = []; groups = 1; batchGroups = 1): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}

General n dimensional convolution call. See conv1d, conv2d and conv3d for simplified version.

inputDims, outputDims and kernelDims provide the layout for the dimensions of each tensor

  • dims[0] = batch / kernel output dimension
  • dims[1] = channel / kernel input dimension
  • dims[2..] = spatial dimensions

If set then strides, padding and dilation should have same number of entries as spatial dimensions.

See ConvWithGeneralPadding.

Source   Edit  
proc copy(a: Node): Node {....raises: [Exception, BuilderError],
                           tags: [RootEffect], forbids: [].}
Returns a copy of the input. Source   Edit  
proc cos(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise cosine Source   Edit  
proc crossEntropyLoss(pred, target: Node): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Mean cross entropy loss function calculated from softmax output. Pred should be predicted values with shape [n, classes] while target is a 1d integer vector of labels each in range 0..classes. Note that the softmax function is applied to pred as part of this function to optimise the gradient calculation. Source   Edit  
proc dims(n: Node): seq[int] {....raises: [], tags: [], forbids: [].}
Dimensions of this node output. Source   Edit  
proc dot(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Vector or matrix doc product per dot Source   Edit  
proc dtype(n: Node): DataType {....raises: [], tags: [], forbids: [].}
Element type for this node Source   Edit  
proc dump(n: Node; maxDepth = -1; depth = 0): string {.
    ...raises: [ValueError, Exception], tags: [RootEffect], forbids: [].}
Source   Edit  
proc errorNode(b: Builder; message: string): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Node used to record an error e.g. due to invalid input types or shapes. Source   Edit  
proc exp(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise natural exponential Source   Edit  
proc flatten(a: Node; startDim = 0; endDim = -1): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reshapes input such that any dimensions between startDim and endDim are flattened. e.g. flatten(a, 1) returns a 2d array keeping the first dimension and collapsing all the remainder into the second. Source   Edit  
proc floor(a: Node): Node {....raises: [Exception, BuilderError],
                            tags: [RootEffect], forbids: [].}
Elementwise floor rounding Source   Edit  
proc fromHlo(hlo: HloModule): Computation {....raises: [], tags: [], forbids: [].}
Load computation from StableHLO format. Source   Edit  
proc gather(a, indices: Node): Node {....raises: [Exception, BuilderError],
                                      tags: [RootEffect], forbids: [].}

Builds a new tensor by taking individual values from the original tensor at the given indices. The last dimension in indices should have the same size as the tensor rank, i.e. you can think of indices as a 'list' of indexes which we iterate over each one of which describes a position in the source array.

For example:

a = <f32 2 2>[[1 2]    ix = <i32 3 2>[[1 1]
               3 4]]                  [0 1]
                                      [1 0]]
a.gather(ix) = <f32 3>[4 2 3]

This is a simplified version of the Gather op.

Source   Edit  
proc gradient(b: Builder; output: Node; inputs: openArray[string]): seq[Node] {.
    ...raises: [Exception, ValueError, KeyError, BuilderError], tags: [RootEffect],
    forbids: [].}

Generate the graph to calculate the gradients at each of the given input parameters for the graph given by output.

This returns a sequence of nodes, where each one calculates the gradient of the corresponding input node.

Here's an example of creating an expression and it's backward graph which calculates the gradients.

Example:

let b = newBuilder("test")
# forward expression
let x = b.parameter(F32, [], "x")
let y = b.parameter(F32, [2, 2], "y")
let fwd = x * (x + y)
# will return a slice with the expression to calculate grad(x) and grad(y)
let grads = b.gradient(fwd, ["x", "y"])
# builds a computation with 2 input parameters which will return a tuple with 3 results
let comp = b.build b.makeTuple(fwd & grads)
# will dump out details of each node for debugging
echo comp
Source   Edit  
proc iota(b: Builder; dtype: DataType; dims: openArray[int]; axis: int): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Creates an array that has specified shape and holds values starting at zero and incrementing by one along the specified axis Source   Edit  
proc iota(b: Builder; dtype: DataType; length: int): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
One dimensional vector of length with values starting from zero. Source   Edit  
proc isFinite(a: Node): Node {....raises: [Exception, BuilderError],
                               tags: [RootEffect], forbids: [].}
Elementwise is not Nan or +=Inf for each. Returns a Bool array. Source   Edit  
proc len(n: Node): int {....raises: [], tags: [], forbids: [].}
Get number of inputs to the node Source   Edit  
proc log(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise natural log Source   Edit  
proc log1p(a: Node): Node {....raises: [Exception, BuilderError],
                            tags: [RootEffect], forbids: [].}
Elementwise log(1 + a) Source   Edit  
proc logicalAnd(a, b: Node): Node {....raises: [Exception, BuilderError],
                                    tags: [RootEffect], forbids: [].}
Elementwise logical and between two Bool arrays Source   Edit  
proc logicalOr(a, b: Node): Node {....raises: [Exception, BuilderError],
                                   tags: [RootEffect], forbids: [].}
Elementwise logical or between two Bool arrays Source   Edit  
proc makeTuple(b: Builder; args: varargs[Node]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Creates a new tuple from a list of ops. Source   Edit  
proc max(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise maximum of 2 arrays Source   Edit  
proc max(a: Node; axis: int; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to maximum value of elements across the given axis in the input. See reduce for details Source   Edit  
proc max(a: Node; dims: openArray[int] = []; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to maximum value of elements across one or more dimensions in the input. See reduce for details Source   Edit  
proc maxPool1d(a: Node; kernelSize: int; strides = 0; padding = pad(0);
               channelsFirst = false): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Max pooling over 1 dimensional input array. Stride defaults to kernelSize if left as 0. channelsFirst setting is as per conv1d. Source   Edit  
proc maxPool2d(a: Node; kernelSize: Opt2d; strides: Opt2d = 0;
               padding: Pad2d = pad(0); channelsFirst = false): Node
Max pooling over 2 dimensional input array. Stride defaults to kernelSize if left as 0. channelsFirst setting is as per conv2d. Source   Edit  
proc maxPool3d(a: Node; kernelSize: Opt3d; strides: Opt3d = 0;
               padding: Pad3d = pad(0); channelsFirst = false): Node
Max pooling over 3 dimensional input array. Stride defaults to kernelSize if left as 0. channelsFirst setting is as per conv3d. Source   Edit  
proc maxValue(b: Builder; dtype = F32; dims: openArray[int] = []): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a node with the maximum value for the given datatype. i.e. +Inf for floating point types.This is broadcast to dims if provided. Source   Edit  
proc mean(a: Node; axis: int; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to mean of elements across the given axis in the input. See reduce for details Source   Edit  
proc mean(a: Node; dims: openArray[int] = []; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to mean of elements across one or more dimensions in the input. See reduce for details Source   Edit  
proc min(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise minimum of 2 arrays Source   Edit  
proc min(a: Node; axis: int; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to minimum value of elements across the given axis in the input. See reduce for details Source   Edit  
proc min(a: Node; dims: openArray[int] = []; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to minimum value of elements across one or more dimensions in the input. See reduce for details Source   Edit  
proc minValue(b: Builder; dtype = F32; dims: openArray[int] = []): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a node with the minimum value for the given datatype. i.e. -Inf for floating point types.This is broadcast to dims if provided. Source   Edit  
proc mseLoss(pred, target: Node): Node {....raises: [Exception, BuilderError],
    tags: [RootEffect], forbids: [].}
Mean square error loss function. Source   Edit  
proc name(comp: Computation): string {....raises: [], tags: [], forbids: [].}
Name of the computation specified when the builder was created + count of number of ops. Source   Edit  
proc narrow(a: Node; dim, start, stop: int; stride = 1): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Returns the data narrowed such that dimension dim ranges from start..stop-1 with step of stride. As per Slice Source   Edit  
proc newBuilder(name: string): Builder {....raises: [], tags: [], forbids: [].}
Create a new builder which is used to generate a new graph. The name is used for debug info. Source   Edit  
proc one(b: Builder; dtype = F32; dims: openArray[int] = []): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a node with the unit value for the given datatype. This is broadcast to dims if provided. Source   Edit  
proc oneHot(x: Node; classes: int; dtype: DataType): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Convert a vector into a 2d [x.len, classes] array of type dtype where result[i, x[i]] = 1 and other values are zero. Source   Edit  
proc pad(a, padValue: Node; padConfig: openArray[(int, int, int)]): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}
Add padding to an array. padValue is the scalar value which is used to fill the new elements. padConfig is a list of (pad_low, pad_high, pad_interior) tuples for each dimension of a. Source   Edit  
proc parameter(b: Builder; dtype: DataType; dims: openArray[int] = []; name = ""): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a new parameter with the given shape. The parameter index is set automatically based on number of parameters set by this builder. If the name is blank then uses p<index> format. Source   Edit  
proc parameters(comp: Computation): (seq[string], seq[Shape]) {.
    ...raises: [XLAError], tags: [], forbids: [].}
Names and shapes of the parameters which have been defined. Source   Edit  
proc pow(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise a raised to power b Source   Edit  
proc raiseError(message: string; node: Node = nil) {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Raise a BuilderError exception annotated with the repr for the given node. Source   Edit  
proc rank(n: Node): int {....raises: [], tags: [], forbids: [].}
Number of dimensions in the node output. Source   Edit  
proc rawPtr(comp: Computation): xla_computation {....raises: [], tags: [],
    forbids: [].}
Source   Edit  
proc reduce(a, initValue: Node; comp: Computation; dims: openArray[int] = [];
            nodeType = tReduce; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Apply reduction across one or more dimensions. i.e. comp is applied repeatedly with a pair of elements from the a input node. initValue defines the initial 'zero' value for the reduction. If no dims given then the reduction is applied across all of the input dimensions to reduce to a scalar. If the dimension index is negative then it is relative to the number of dimensions. If keepDims is set then the summed dimensions are kept with a size of 1, else they are removed and the numbe of dimensions in the result is reduced. Source   Edit  
proc reduceWindow(a, initValue: Node; comp: Computation;
                  windowDims, strides: openArray[int];
                  padding: openArray[Padding] = []; nodeType = tReduceWindow): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}

Apply reduction to all elements in each window of a sequence of N multi-dimensional arrays. The mumber of entries in the windowDims and strides array should equal the rank of the input array (i.e. entries should be 1 for non-spatial dimensions). If set then number of entries in the padding array should also equal the input rank, where non-spatial dimensions have padding of 0.

This can be used to implement pooling layers. See ReduceWindow for details.

Source   Edit  
proc relu(a: Node): Node {....raises: [Exception, BuilderError],
                           tags: [RootEffect], forbids: [].}
Rectified linear unit activation function: max(0, a) Source   Edit  
proc rem(a, b: Node): Node {....raises: [Exception, BuilderError],
                             tags: [RootEffect], forbids: [].}
Elementwise remainder Source   Edit  
proc repr(n: Node): string {....raises: [ValueError, Exception],
                             tags: [RootEffect], forbids: [].}
Formatted AST tree of this node and all of it's children. Source   Edit  
proc reshape(a: Node; dims: varargs[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reshape the input node to dims. Total number of elements is unchanged. If one of the dimensions is -1 then this value is inferred from the total number of elements. Source   Edit  
proc resultShape(comp: Computation): Shape {....raises: [XLAError], tags: [],
    forbids: [].}
Source   Edit  
proc reverse(a: Node; axes: varargs[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reverse the elements in the input along the given axes. If one of the dimensions is -1 then this value is inferred from the total number of elements. Source   Edit  
proc rngNormal(mean, stddev: Node; dims: openArray[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Generate a tensor with a normal random distribution described by mean, std deviation, data type and dimensions. Inputs must have the same data type. This is used the as element type for the output. Source   Edit  
proc rngUniform(min, max: Node; dims: openArray[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Generate a tensor with a uniform random distribution with values from min to max and given dimensions. Inputs must have the same data type. This is used as the element type for the output. Source   Edit  
proc round(a: Node): Node {....raises: [Exception, BuilderError],
                            tags: [RootEffect], forbids: [].}
Elementwise nearest rounding Source   Edit  
proc rsqrt(a: Node): Node {....raises: [Exception, BuilderError],
                            tags: [RootEffect], forbids: [].}
Elementwise 1/sqrt(a) Source   Edit  
proc scatter(a, indices, b: Node; comp: Computation): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}

Get the values of input array a at the specified indices and updated with values in b using comp. Indices should have shape [n, a.rank] - i.e. each row is an element to update and the columns indicate the location in the target vector (which can be repeated.)

This is the opposite of gather. It is a simplified version of the Scatter for details.

Source   Edit  
proc select(a, onTrue, onFalse: Node): Node {....raises: [Exception, BuilderError],
    tags: [RootEffect], forbids: [].}
Select values from onTrue where a is true else from onFalse. Source   Edit  
proc selectAndScatter(a, source: Node; windowDims, strides: openArray[int];
                      padding: openArray[Padding] = []): Node {.
    ...raises: [Exception, BuilderError, ValueError], tags: [RootEffect],
    forbids: [].}

Composite operation that first computes ReduceWindow on the operand array to select an element from each window, and then scatters the source array to the indices of the selected elements to construct an output array with the same shape as the operand array.

Used for gradient of the maxPool function. See SelectAndScatter for details.

Source   Edit  
proc sigmoid(a: Node): Node {....raises: [Exception, BuilderError],
                              tags: [RootEffect], forbids: [].}
Elementwise 1/(1 + exp(-a)) Source   Edit  
proc sign(a: Node): Node {....raises: [Exception, BuilderError],
                           tags: [RootEffect], forbids: [].}
Elementwise sign. Returns -1, 0, +1 or Nan Source   Edit  
proc sin(a: Node): Node {....raises: [Exception, BuilderError], tags: [RootEffect],
                          forbids: [].}
Elementwise sine Source   Edit  
proc softmax(a: Node; axis = -1): Node {....raises: [Exception, BuilderError],
    tags: [RootEffect], forbids: [].}
Softmax function along given axis, adjusted for numerical stability. Source   Edit  
proc sqrt(a: Node): Node {....raises: [Exception, BuilderError],
                           tags: [RootEffect], forbids: [].}
Elementwise square root Source   Edit  
proc sum(a: Node; axis: int; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to sum of elements across the given axis in the input. See reduce for details Source   Edit  
proc sum(a: Node; dims: openArray[int] = []; keepDims = false): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Reduce to sum of elements across one or more dimensions in the input. See reduce for details Source   Edit  
proc tanh(a: Node): Node {....raises: [Exception, BuilderError],
                           tags: [RootEffect], forbids: [].}
Elementwise hyperbolic tangent Source   Edit  
proc toHlo(comp: Computation): HloModule {....raises: [], tags: [], forbids: [].}
Save computation to StableHLO format. Source   Edit  
proc toString(n: Node): string {....raises: [Exception], tags: [RootEffect],
                                 forbids: [].}
Node name and argument names, expanded Source   Edit  
proc transpose(a: Node; axes: varargs[int]): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Permute the the given axes. If no axes are given then will swap the last 2 axes. Axes indices may be negative - in this case they will be relative to the number of dimensions. Source   Edit  
proc zero(b: Builder; dtype = F32; dims: openArray[int] = []): Node {.
    ...raises: [Exception, BuilderError], tags: [RootEffect], forbids: [].}
Create a node with the zero value for the given datatype. This is broadcast to dims if provided. Source   Edit  
proc zerosLike(a: Node): Node {....raises: [Exception, BuilderError],
                                tags: [RootEffect], forbids: [].}
Creates a new zero value with element type and shape from input. Source   Edit  

Templates

template `^`(b: Builder; value: untyped): Node
Shorthand to generate a new constant node Source   Edit