The nimxla module wraps the XLA PJRT CPU and GPU client objects which are used to execute computations on a device. It also handles transferring data to and from the device and utilities to manage this on the host. See https://www.tensorflow.org/xla for an overview of the XLA library.
It depends on the graph module which contains the functions to define computations from a graph of operation nodes, the tensor utility module and the xla_wrapper module which has the definitaions for the C bindings to the XLA C++ API.
Here's a simple exaxmple to build and executing a graph which squares the elements from a vector and adds a constant. then converts the result to a matrix with column major order.
Example:
import nimxla let c = newCPUClient() echo c let b = newBuilder("example") let x = b.parameter(F32, [50]) let sum = x * x + b^10f32 let comp = b.build sum.reshape(10, 5).transpose let exec = c.compile(comp) let input = toTensor[float32](1..50).toLiteral let res = exec.run([input]).f32 echo res
Types
Buffer = ref BufferObj
- Buffer represents the device memory allocated by the client for a given tensor or tuple of tensors. Source Edit
Client = ref ClientObj
- A client connects to a device such as a CPU or Cuda GPU driver and provides methods to perform computations. Source Edit
Executable = object name*: string params*: seq[string] outputs*: seq[string] inShapes*: seq[Shape] outShape*: Shape
- An executable is a compiled graph of operations which can be called with a defined list of parameters. Source Edit
Procs
proc `$`(buf: Buffer): string {....raises: [Exception], tags: [RootEffect], forbids: [].}
- Print shape info Source Edit
proc `$`(client: Client): string {....raises: [ValueError], tags: [], forbids: [].}
- Summary of client info Source Edit
proc `$`(exec: Executable): string {....raises: [Exception], tags: [RootEffect], forbids: [].}
- Source Edit
proc compile(client: Client; comp: Computation; outputs: openArray[string] = []): Executable {. ...raises: [Exception, BuilderError, XLAError, ValueError], tags: [RootEffect], forbids: [].}
- Compile a computation so that it can be executed on this client. outputs may optionally be specified to name the output values - they are used by the runWith method. Source Edit
proc deviceCount(client: Client): int {....raises: [], tags: [], forbids: [].}
- Returns number of devices associated with this client Source Edit
proc initParams(pairs: openArray[(string, Buffer)]): Params {....raises: [], tags: [], forbids: [].}
- Create new parameter list Source Edit
proc newBuffer(client: Client; dtype: DataType; dims: openArray[int]): Buffer {. ...raises: [XLAError], tags: [], forbids: [].}
- Allocate a new buffer on the device with the given shape. Initialises values to zero. Source Edit
proc newCPUClient(logLevel = Warning): Client {. ...raises: [OSError, XLAError, Exception, ValueError], tags: [WriteEnvEffect, RootEffect], forbids: [].}
- Create a new client for running computations on the CPU. Source Edit
proc newGPUClient(memoryFraction = 1.0; preallocate = false; logLevel = Warning): Client {. ...raises: [OSError, XLAError, Exception, ValueError], tags: [WriteEnvEffect, RootEffect], forbids: [].}
- Create a new client for running computations on the GPU using Cuda. memoryFraction limits the maximum fraction of device memory which can be allocated. If preallocate is set then this is allocated at startup. Source Edit
proc newTPUClient(maxInflightComputations: int; logLevel = Warning): Client {. ...raises: [OSError, XLAError, Exception, ValueError], tags: [WriteEnvEffect, RootEffect], forbids: [].}
- Create a new client for running computations on Google TPU accelerator. Source Edit
proc noutputs(exec: Executable): int {....raises: [], tags: [], forbids: [].}
- If output is a tuple then tuple length, else 1 Source Edit
proc platformName(client: Client): string {....raises: [], tags: [], forbids: [].}
- Returns name of platform (CPU or Cuda) Source Edit
proc platformVersion(client: Client): string {....raises: [], tags: [], forbids: [].}
- Returns version of platform (e.g. Cuda version) Source Edit
proc run(exec: Executable; checkShape = true): Buffer {....raises: [XLAError], tags: [], forbids: [].}
- Convenience method for use where the executable does not take any input arguments. As per exec.run(args). Source Edit
proc run[T: Buffer | Literal](exec: Executable; args: openArray[T]; checkShape = true): Buffer
-
Pass the given literal or buffer arguments to the executable, launch the kernel on the associated device and return a single buffer with the results. i.e. tuple results are not unpacked.
By default will check that the data type and shape of the parameters matches the inputs and raise an exception there is a mismatch. Set checkShape to false to only have the runtime check the size of the input buffers.
Source Edit proc runAndUnpack(exec: Executable; checkShape = true): seq[Buffer] {. ...raises: [XLAError], tags: [], forbids: [].}
- Convenience method for use where the executable returns a tuple of results but does not take any input arguments. As per runAndUnpack(args) Source Edit
proc runAndUnpack[T: Buffer | Literal](exec: Executable; args: openArray[T]; checkShape = true): seq[Buffer]
-
For use where the executable returns a tuple of results. Passes the given literal or buffer arguments to the executable, launches the kernel on the associated device and returns a list of buffers unpacked from the returned tuple.
By default will check that the data type and shape of the parameters matches the inputs and raise an exception if there is a mismatch. Set checkShape to false to only have the runtime check the size of the input buffers.
Source Edit proc runWith(exec: Executable; params: var Params; checkShape = true) {. ...raises: [KeyError, XLAError, ValueError, Exception], tags: [RootEffect], forbids: [].}
- Run executable with given set of named input parameters. Updates params table with outputs as named when compile was called, or using result<n> format if no names were given. Source Edit
proc setLogLevel(level: LogLevel) {....raises: [OSError], tags: [WriteEnvEffect], forbids: [].}
- Set the log level used by the Tensorflow XLA library - defaults to Info Source Edit
proc toLiterals(buffers: openArray[Buffer]): seq[Literal] {....raises: [XLAError], tags: [], forbids: [].}
- Copy list of buffers back to host Source Edit
Exports
-
broadcast, tanh, rsqrt, ==, ^, seq3, reduceWindow, $, OpType, seq2, Builder, fromHlo, constant, build, exp, <=, DataType, avgPool1d, select, round, argMin, minValue, *, pad, transpose, argMax, arrayShape, padSame, Shape, rank, isFinite, dtype, convolution, reshape, iota, mean, batchNormGrad, sign, broadcastInDim, abs, constant, [], !=, Node, ElemType, iota, newBuilder, toShape, maxPool1d, errorNode, $, min, sigmoid, max, toHlo, conv1d, <, maxPool2d, oneHot, convert, seq2, constant, makeTuple, dtypeOf, resultShape, logicalOr, gradient, crossEntropyLoss, parameters, zero, batchNormTraining, max, clamp, log, min, pad, dims, pow, repr, floor, Padding, /, conv3d, addAt, constant, rem, clone, rawPtr, avgPool2d, name, constant, constant, sqrt, $, narrow, flatten, relu, HloModule, scatter, constant, max, constant, !, Opt2d, zerosLike, maxValue, maxPool3d, rngUniform, Pad3d, pad, batchNormInference, conv2d, >=, -, rngNormal, -, log1p, sin, constant, $, Computation, copy, constant, seq3, mean, reverse, constant, raiseError, sum, softmax, selectAndScatter, one, Opt3d, concat, $, reduce, ShapeKind, parameter, ==, avgPool3d, logicalAnd, +, dot, min, gather, sum, collapse, mseLoss, constant, constant, >, BuilderError, ceil, len, dump, Pad2d, cos, toString, []=, reshape, toTensor, zeros, clone, rawPtr, toTensor, ==, toTensor, convert, @@, write, toTensor, Tensor, len, $, readTensor, append, fill, setPrintOpts, toSeq, at, approxEqual, readTensor, newTensor, write, format, shape, [], len, dtype, newLiteral, lit, f32, reshape, Literal, lit, clone, addrOf, rawPtr, lit, toLiteral, i32, toTensor, lit, convert, f64, decomposeTuple, i64, shape, $, Padding, Pad3d, pad, arrayShape, pad, seq3, DataType, dtypeOf, Opt3d, ==, $, $, ElemType, ShapeKind, seq2, Pad2d, Opt2d, seq2, seq3, toShape, Shape, padSame, XLAError