Walkthrough: The Kernel Translation Process#

This guide is meant to give an overview of pystencils’ code generation toolkit by walking through the kernel translation procedure step by step. We will cover the various steps of the pipeline, which are

  • Setup of the context;

  • Parsing of the symbolic kernel from SymPy into the intermediate representation (IR);

  • Materialization of the kernel’s iteration space to an iteration strategy using the iteration axes system;

  • A selection of optimizing transformations;

  • Lowering of the IR to C++ code;

  • And finally, packaging of the kernel as an output object.

In the process, a number of core concepts and classes of pystencils’ backend will be introduced. This guide effectively shows an abridged version of the translation pipeline implemented in the DefaultKernelCreationDriver.

import sympy as sp
import pystencils as ps

For illustration, we’re going to define a very simple kernel which scales a scalar field \(f\) by a factor \(c\) and computes the maximum value of \(f\) in the process:

f = ps.fields("f: [3D]", layout="fzyx")
x, c = sp.symbols("x, c")
wmax = ps.TypedSymbol("wmax", ps.DynamicType.NUMERIC_TYPE)

assignments = [
    ps.Assignment(x, f()),
    ps.Assignment(f(), c * x),
    ps.MaxReductionAssignment(wmax, x)
]

assignments
[Assignment(x, f_C), Assignment(f_C, c*x), MaxReductionAssignment(wmax, x)]

Context Setup#

Before kernel translation can begin, we need to instantiate the backend’s context objects:

  • The KernelCreationContext manages all global information about the kernel, and primarily serves as a symbol table for the kernel’s variables, memory buffers, and reduction targets. The context object is conventionally called ctx.

  • The IterationSpace defines the index space on which the kernel is executed; it is attached to ctx and required during parsing of field accesses and later to materialize the index space to an iteration strategy.

Import the required classes from pystencils.backend.kernelcreation and initialize them:

from pystencils.backend.kernelcreation import (
    KernelCreationContext,
    FullIterationSpace
)

ctx = KernelCreationContext()
ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, f)
ctx.set_iteration_space(ispace)

Here, we created a FullIterationSpace object, which represents a dense index space. Using create_with_ghost_layers, we initialized the index space according to the rank and memory layout of the field \(f\); it is now a 3D index space with dimensions ordered such that the shortest-stride dimension of \(f\) is mapped by the fastest-increasing coordinate of ispace.

Parsing of the Kernel Body#

Next, the kernel’s symbolic representation from above must be translated into pystencil’s intermediate representation. The IR has two layers:

While constants and functions are transient and created on-demand, symbols and buffers must be unique and are therefore managed by the kernel creation context. Symbols (PsSymbol) are the backend analouge to SymPy’s symbols; in turn, buffers represent the memory regions behind the frontend’s fields.

Generating an IR syntax tree from the symbolic kernel is a two-stage process, constisting of the freeze and typify steps. During freeze, symbolic equations become IR AST nodes, on which typify computes and applies data types. The backend’s AstFactory class offers a convenient interface to accelerate this process:

from pystencils.backend.kernelcreation import AstFactory

factory = AstFactory(ctx)

To create IR objects from symbolic forms we will use AstFactory in several places throughout this guide.

The kernel body will be a PsBlock, with declarations and assignments parsed from SymPy:

from pystencils.backend.ast.structural import PsBlock

body = PsBlock([
    factory.parse_sympy(asm) for asm in assignments
])

ps.inspect(body)
{
   x: float64 = _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
   _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x;
   wmax_local = max(wmax_local, x);
}

We can use ps.inspect to print a text representation of our abstract syntax trees. Observe that symbols and constants have all been assigned data types by the Typifier. This is according to the first rule for the canonical form of abstract syntax trees.

Canonical Form I: Data Types

Each symbol, constant, and expression node inside an AST must be annotated with a data type. This is ensured by running the Typifier on all newly created syntax trees.

Iteration Space Materialization#

Now that our kernel body is complete, we need to turn its index space into syntax structure. This we will achieve via the iteration axes system, which is part of the AST class hierarchy.

The Main Iteration Cube#

At first, we manifest the entire iteration space as an axes cube, which represents abstract iteration over all iteration dimensions in their required order. We use the AST factory to create the cube:

cube = factory.cube_from_ispace(ispace, body)

ps.inspect(cube)
axes-cube(
   range(ctr_2 : [[0: int64] : _size_f_2 - [0: int64] : [1: int64]]),
   range(ctr_1 : [[0: int64] : _size_f_1 - [0: int64] : [1: int64]]),
   range(ctr_0 : [[0: int64] : _size_f_0 - [0: int64] : [1: int64]])
)
{
   x: float64 = _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
   _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x;
   wmax_local = max(wmax_local, x);
}

Observe the order of the cube’s coordinates, which are listed slowest-to-fastest. According to the fzyx memory layout previously specified for the field \(f\), the \(x\) coordinate is the fastest while the \(z\) coordinate is the slowest. This is reflected in the cube, where ctr_2, the \(z\) iteration counter, is listed first.

Symbol Canonicalization#

Now we have reached an important point: All symbols required by the kernel have been introduced and defined in its AST. This includes, in our case, all numerical symbols, field buffers, and the iteration counters. At this point, we should canonicalize the symbol declarations.

Canonical Form II: Symbol Declarations

For an AST to be in canonical form,

  • Each symbol has at most one declaration;

  • Each symbol that is never written to apart from its declaration has a const type; and

  • Each symbol whose type is not const has at least one non-declaring assignment.

This form is achieved by running the CanonicalizeSymbols pass on the AST.

from pystencils.backend.transformations import CanonicalizeSymbols

canonicalize = CanonicalizeSymbols(ctx)
cube = canonicalize(cube)

ps.inspect(cube)
axes-cube(
   range(ctr_2 : [[0: int64] : _size_f_2 - [0: int64] : [1: int64]]),
   range(ctr_1 : [[0: int64] : _size_f_1 - [0: int64] : [1: int64]]),
   range(ctr_0 : [[0: int64] : _size_f_0 - [0: int64] : [1: int64]])
)
{
   x: const float64 = _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
   _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x;
   wmax_local = max(wmax_local, x);
}

Axis Expansion#

Now, we will turn the iteration cube into a tree of nested iteration axes using the AxisExpansion transformer. This is where our kernel’s iteration strategy is introduced and all decisions concerning loop structure, loop tiling and blocking, parallelization and vectorization, as well as GPU thread-block indexing are made.

We would like to create a CPU kernel with the outermost loop parallelized using OpenMP, and the innermost loop vectorized with four SIMD lanes. We express this as an expansion strategy:

from pystencils.backend.transformations import AxisExpansion

ae = AxisExpansion(ctx)
strategy = ae.create_strategy(
    [
        ae.parallel_loop(num_threads=8, schedule="static,16"),
        ae.loop(),
        ae.peel_for_divisibility(4),
        [
            ae.block_loop(4, assume_divisible=True),
            ae.simd(4)
        ],
        [
            ae.loop()
        ]
    ]
)

Each step in this strategy modifies or peels off one dimension of the iteration cube, from slowest to fastest. Let’s briefly take this apart:

  • The first expansion, parallel_loop(), introduces a loop parallelized by OpenMP for the cube’s leading dimension and strips that dimension from the cube.

  • The next expansion, loop(), turns the now-leading cube dimension (ctr_1) into a plain loop.

  • The only remaining dimension (ctr_0) is now peeled (peel_for_divisibility()); the cube split into two sub-cubes. Iteration limits of the first sub-cube are selected such that its iteration count is divisible by four, while the second sub-cube holds all remaining iterations.

  • We then have a branch in the strategy. The first sub-strategy introduces a blocked loop with block size four. The remaining four iterations per block are then vectorized using the simd() expansion. The second sub-strategy merely produces a remainder loop.

Let’s apply the strategy to our iteration cube, and observe how it is replaced by iteration axes according to the strategy definition:

kernel_ast = strategy(cube)
ps.inspect(kernel_ast)

Hide code cell output

{
   parallel-loop-axis< num_threads(8), schedule(static,16) >(range(ctr_2 : [[0: int64] : _size_f_2 - [0: int64] : [1: int64]]))
   {
      loop-axis(range(ctr_1 : [[0: int64] : _size_f_1 - [0: int64] : [1: int64]]))
      {
         {
            ctr_0__rem_start: int64 = _size_f_0 / [4: int64] * [4: int64];
            loop-axis(range(ctr_0__1 : [[0: int64] : ctr_0__rem_start : [4: int64]]))
            {
               simd-axis< 4 >(range(ctr_0 : [ctr_0__1 : ctr_0__1 + [4: int64] : [1: int64]]))
               {
                  x: const float64 = _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
                  _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x;
                  wmax_local = max(wmax_local, x);
               }
            }
            loop-axis(range(ctr_0__0 : [ctr_0__rem_start : _size_f_0 - [0: int64] : [1: int64]]))
            {
               x__0: const float64 = _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
               _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x__0;
               wmax_local = max(wmax_local, x__0);
            }
         }
      }
   }
}

We now have our entire iteration strategy represented by axis nodes in the AST. These are still fully general, and will be gradually lowered to target-specific implementations.

Axis-Invariant Code Motion#

Before we take the first lowering step, let’s optimize our kernel a bit. At this point, it is advisable to perform an axis-invariant code motion pass. This pass detects declarations that are independent of their surrounding iteration axes and moves them as far outward as possible to avoid computing them multiple times.

from pystencils.backend.transformations import HoistIterationInvariantDeclarations

hoist = HoistIterationInvariantDeclarations(ctx)
kernel_ast = hoist(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   ctr_0__rem_start: int64 = _size_f_0 / [4: int64] * [4: int64];
   parallel-loop-axis< num_threads(8), schedule(static,16) >(range(ctr_2 : [[0: int64] : _size_f_2 - [0: int64] : [1: int64]]))
   {
      loop-axis(range(ctr_1 : [[0: int64] : _size_f_1 - [0: int64] : [1: int64]]))
      {
         {
            loop-axis(range(ctr_0__1 : [[0: int64] : ctr_0__rem_start : [4: int64]]))
            {
               simd-axis< 4 >(range(ctr_0 : [ctr_0__1 : ctr_0__1 + [4: int64] : [1: int64]]))
               {
                  x: const float64 = _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
                  _data_f[ctr_0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x;
                  wmax_local = max(wmax_local, x);
               }
            }
            loop-axis(range(ctr_0__0 : [ctr_0__rem_start : _size_f_0 - [0: int64] : [1: int64]]))
            {
               x__0: const float64 = _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
               _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x__0;
               wmax_local = max(wmax_local, x__0);
            }
         }
      }
   }
}

As we can see, the declaration of ctr_0__rem_start was moved outside of the axes tree.

Axis Materialization#

In the next step, we will materialize the abstract iteration axes to more concrete IR code. Loop axes will become loops, OpenMP directives will be introduced for parallelization, and SIMD axes will be manifested as vectorized arithmetic. During this process, also modulo variables for the kernel’s reductions (i.e. wmax = max(x)) are introduced.

Let us thus invoke the axes materializer:

from pystencils.backend.transformations import MaterializeAxes

mat_axes = MaterializeAxes(ctx)
kernel_ast = mat_axes(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   ctr_0__rem_start: int64 = _size_f_0 / [4: int64] * [4: int64];
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         wmax_local__0: float64<4> = vec_broadcast<4>(neg_infinity());
         #pragma omp for schedule(static,16)
         for(ctr_2: int64 = [0: int64]; ctr_2 < _size_f_2 - [0: int64]; ctr_2 += [1: int64])
         {
            for(ctr_1: int64 = [0: int64]; ctr_1 < _size_f_1 - [0: int64]; ctr_1 += [1: int64])
            {
               for(ctr_0__1: int64 = [0: int64]; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += [4: int64])
               {
                  ctr_0: int64 = ctr_0__1;
                  x__1: float64<4> = vec_memacc< 4, stride=_stride_f_0 >(_data_f, ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2);
                  vec_memacc< 4, stride=_stride_f_0 >(_data_f, ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2) = (vec_broadcast<4>(c)) * x__1;
                  wmax_local__0 = max(wmax_local__0, x__1);
               }
               for(ctr_0__0: int64 = ctr_0__rem_start; ctr_0__0 < _size_f_0 - [0: int64]; ctr_0__0 += [1: int64])
               {
                  x__0: const float64 = _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
                  _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x__0;
                  wmax_local = max(wmax_local, x__0);
               }
            }
         }
         wmax_local = vec_horizontal_max(('wmax_local', 'wmax_local__0'));
      }
   }
}

Reductions to Memory#

Our reduction to wmax is not complete yet; the accumulated result from the kernel’s modulo variables must still be written back to the reductions’ target memory location. We invoke the ReductionsToMemory pass for this:

from pystencils.backend.transformations import ReductionsToMemory

reduce_to_memory = ReductionsToMemory(ctx, ctx.reduction_data.values())
kernel_ast = reduce_to_memory(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   wmax_local: float64 = neg_infinity();
   ctr_0__rem_start: int64 = _size_f_0 / [4: int64] * [4: int64];
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         wmax_local__0: float64<4> = vec_broadcast<4>(neg_infinity());
         #pragma omp for schedule(static,16)
         for(ctr_2: int64 = [0: int64]; ctr_2 < _size_f_2 - [0: int64]; ctr_2 += [1: int64])
         {
            for(ctr_1: int64 = [0: int64]; ctr_1 < _size_f_1 - [0: int64]; ctr_1 += [1: int64])
            {
               for(ctr_0__1: int64 = [0: int64]; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += [4: int64])
               {
                  ctr_0: int64 = ctr_0__1;
                  x__1: float64<4> = vec_memacc< 4, stride=_stride_f_0 >(_data_f, ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2);
                  vec_memacc< 4, stride=_stride_f_0 >(_data_f, ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2) = (vec_broadcast<4>(c)) * x__1;
                  wmax_local__0 = max(wmax_local__0, x__1);
               }
               for(ctr_0__0: int64 = ctr_0__rem_start; ctr_0__0 < _size_f_0 - [0: int64]; ctr_0__0 += [1: int64])
               {
                  x__0: const float64 = _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]];
                  _data_f[ctr_0__0 + [0: int64], ctr_1 + [0: int64], ctr_2 + [0: int64], [0: int64]] = c * x__0;
                  wmax_local = max(wmax_local, x__0);
               }
            }
         }
         wmax_local = vec_horizontal_max(('wmax_local', 'wmax_local__0'));
      }
   }
   wmax[[0: int64]] = WriteBackToPtr(wmax, wmax_local);
}

Optimization Passes#

At this point, we can run another set of optimization passes. To illustrate, we will run the EliminateConstants pass to simplify any constant subexpressions in the AST. First, however, we should run another symbol canonicalization pass for good measure:

from pystencils.backend.transformations import EliminateConstants

kernel_ast = canonicalize(kernel_ast)

elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
kernel_ast = elim_constants(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   wmax_local: float64 = neg_infinity();
   ctr_0__rem_start: const int64 = _size_f_0 / [4: int64] * [4: int64];
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         wmax_local__0: float64<4> = vec_broadcast<4>(neg_infinity());
         #pragma omp for schedule(static,16)
         for(ctr_2: int64 = [0: int64]; ctr_2 < _size_f_2; ctr_2 += [1: int64])
         {
            for(ctr_1: int64 = [0: int64]; ctr_1 < _size_f_1; ctr_1 += [1: int64])
            {
               for(ctr_0__1: int64 = [0: int64]; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += [4: int64])
               {
                  ctr_0: const int64 = ctr_0__1;
                  x__1: float64<4> = vec_memacc< 4, stride=_stride_f_0 >(_data_f, ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2);
                  vec_memacc< 4, stride=_stride_f_0 >(_data_f, ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2) = (vec_broadcast<4>(c)) * x__1;
                  wmax_local__0 = max(wmax_local__0, x__1);
               }
               for(ctr_0__0: int64 = ctr_0__rem_start; ctr_0__0 < _size_f_0; ctr_0__0 += [1: int64])
               {
                  x__0: const float64 = _data_f[ctr_0__0, ctr_1, ctr_2, [0: int64]];
                  _data_f[ctr_0__0, ctr_1, ctr_2, [0: int64]] = c * x__0;
                  wmax_local = max(wmax_local, x__0);
               }
            }
         }
         wmax_local = vec_horizontal_max(('wmax_local', 'wmax_local__0'));
      }
   }
   wmax[[0: int64]] = WriteBackToPtr(wmax, wmax_local);
}

Lowering#

By now, our kernel has already taken on a very concrete form, but still contains various IR constructs that are not yet valid C code. Most importantly, its vectorized arithmetic still needs to be turned into architecture-specific intrinsics, and it contains arithmetic functions (max) that must be mapped onto a platform-dependent implementation.

Let’s go through the required lowering passes one by one.

Select Vector Intrinsics#

First, we will lower vectorized operations and functions to a target architecture’s vector intrinsics. We will use an x86 AVX512 architecture, and thus need to set up the corresponding platform object:

from pystencils.backend.platforms import X86VectorCpu, X86VectorArch

platform = X86VectorCpu(ctx, X86VectorArch.AVX512)

Now, we invoke its intrinsics selector on our kernel AST:

select_intrin = platform.get_intrinsic_selector()
kernel_ast = select_intrin(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   wmax_local: float64 = neg_infinity();
   ctr_0__rem_start: const int64 = _size_f_0 / [4: int64] * [4: int64];
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         wmax_local__1: __m256d = _mm256_set1_pd(neg_infinity());
         #pragma omp for schedule(static,16)
         for(ctr_2: int64 = [0: int64]; ctr_2 < _size_f_2; ctr_2 += [1: int64])
         {
            for(ctr_1: int64 = [0: int64]; ctr_1 < _size_f_1; ctr_1 += [1: int64])
            {
               for(ctr_0__1: int64 = [0: int64]; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += [4: int64])
               {
                  ctr_0: const int64 = ctr_0__1;
                  x__2: const __m256d = _mm256_i64gather_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x([3: int64] * _stride_f_0, [2: int64] * _stride_f_0, _stride_f_0, [0: int64]), [8: int32]);
                  _mm256_i64scatter_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x([3: int64] * _stride_f_0, [2: int64] * _stride_f_0, _stride_f_0, [0: int64]), _mm256_mul_pd(_mm256_set1_pd(c), x__2), [8: int32]);
                  wmax_local__1 = _mm256_max_pd(wmax_local__1, x__2);
               }
               for(ctr_0__0: int64 = ctr_0__rem_start; ctr_0__0 < _size_f_0; ctr_0__0 += [1: int64])
               {
                  x__0: const float64 = _data_f[ctr_0__0, ctr_1, ctr_2, [0: int64]];
                  _data_f[ctr_0__0, ctr_1, ctr_2, [0: int64]] = c * x__0;
                  wmax_local = max(wmax_local, x__0);
               }
            }
         }
         wmax_local = _mm256_horizontal_max_pd(wmax_local, wmax_local__1);
      }
   }
   wmax[[0: int64]] = WriteBackToPtr(wmax, wmax_local);
}

Lowering of Buffer Accesses and Functions#

The IR memory buffer accesses from the vectorized code have now already been lowered to memory access intrinsics, but there are still buffer accesses left in the remainder loop. These need to be linearized to raw C pointer arithmetic. We use the LowerToC pass for this:

from pystencils.backend.transformations import LowerToC

lower_to_c = LowerToC(ctx)
kernel_ast = lower_to_c(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   wmax_local: float64 = neg_infinity();
   ctr_0__rem_start: const int64 = _size_f_0 / [4: int64] * [4: int64];
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         wmax_local__1: __m256d = _mm256_set1_pd(neg_infinity());
         #pragma omp for schedule(static,16)
         for(ctr_2: int64 = [0: int64]; ctr_2 < _size_f_2; ctr_2 += [1: int64])
         {
            for(ctr_1: int64 = [0: int64]; ctr_1 < _size_f_1; ctr_1 += [1: int64])
            {
               for(ctr_0__1: int64 = [0: int64]; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += [4: int64])
               {
                  ctr_0: const int64 = ctr_0__1;
                  x__2: const __m256d = _mm256_i64gather_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x([3: int64] * _stride_f_0, [2: int64] * _stride_f_0, _stride_f_0, [0: int64]), [8: int32]);
                  _mm256_i64scatter_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x([3: int64] * _stride_f_0, [2: int64] * _stride_f_0, _stride_f_0, [0: int64]), _mm256_mul_pd(_mm256_set1_pd(c), x__2), [8: int32]);
                  wmax_local__1 = _mm256_max_pd(wmax_local__1, x__2);
               }
               for(ctr_0__0: int64 = ctr_0__rem_start; ctr_0__0 < _size_f_0; ctr_0__0 += [1: int64])
               {
                  x__0: const float64 = _data_f[ctr_0__0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2];
                  _data_f[ctr_0__0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2] = c * x__0;
                  wmax_local = max(wmax_local, x__0);
               }
            }
         }
         wmax_local = _mm256_horizontal_max_pd(wmax_local, wmax_local__1);
      }
   }
   wmax[[0: int64]] = WriteBackToPtr(wmax, wmax_local);
}

Finally, all that remains is to map IR functions to target-specific library functions using the SelectFunctions pass.

from pystencils.backend.transformations import SelectFunctions

select_functions = SelectFunctions(platform)
kernel_ast = select_functions(kernel_ast)

ps.inspect(kernel_ast)

Hide code cell output

{
   wmax_local: float64 = -std::numeric_limits< double >::infinity();
   ctr_0__rem_start: const int64 = _size_f_0 / [4: int64] * [4: int64];
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         wmax_local__1: __m256d = _mm256_set1_pd(-std::numeric_limits< double >::infinity());
         #pragma omp for schedule(static,16)
         for(ctr_2: int64 = [0: int64]; ctr_2 < _size_f_2; ctr_2 += [1: int64])
         {
            for(ctr_1: int64 = [0: int64]; ctr_1 < _size_f_1; ctr_1 += [1: int64])
            {
               for(ctr_0__1: int64 = [0: int64]; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += [4: int64])
               {
                  ctr_0: const int64 = ctr_0__1;
                  x__2: const __m256d = _mm256_i64gather_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x([3: int64] * _stride_f_0, [2: int64] * _stride_f_0, _stride_f_0, [0: int64]), [8: int32]);
                  _mm256_i64scatter_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x([3: int64] * _stride_f_0, [2: int64] * _stride_f_0, _stride_f_0, [0: int64]), _mm256_mul_pd(_mm256_set1_pd(c), x__2), [8: int32]);
                  wmax_local__1 = _mm256_max_pd(wmax_local__1, x__2);
               }
               for(ctr_0__0: int64 = ctr_0__rem_start; ctr_0__0 < _size_f_0; ctr_0__0 += [1: int64])
               {
                  x__0: const float64 = _data_f[ctr_0__0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2];
                  _data_f[ctr_0__0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2] = c * x__0;
                  wmax_local = fmax(wmax_local, x__0);
               }
            }
         }
         wmax_local = _mm256_horizontal_max_pd(wmax_local, wmax_local__1);
      }
   }
   wmax[[0: int64]] = fmax(wmax[[0: int64]], wmax_local);
}

Wrapping Up#

There it is - our finished kernel. We can now use ps.inspect also in C++ mode, since no more non-C++-concepts are left in the AST. This would previously have failed with a printing error:

ps.inspect(kernel_ast, show_cpp=True)
{
   double wmax_local = -std::numeric_limits< double >::infinity();
   const int64_t ctr_0__rem_start = _size_f_0 / 4LL * 4LL;
   {
      #pragma omp parallel num_threads(8) reduction(max: wmax_local)
      {
         __m256d wmax_local__1 = _mm256_set1_pd(-std::numeric_limits< double >::infinity());
         #pragma omp for schedule(static,16)
         for(int64_t ctr_2 = 0LL; ctr_2 < _size_f_2; ctr_2 += 1LL)
         {
            for(int64_t ctr_1 = 0LL; ctr_1 < _size_f_1; ctr_1 += 1LL)
            {
               for(int64_t ctr_0__1 = 0LL; ctr_0__1 < ctr_0__rem_start; ctr_0__1 += 4LL)
               {
                  const int64_t ctr_0 = ctr_0__1;
                  const __m256d x__2 = _mm256_i64gather_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x(3LL * _stride_f_0, 2LL * _stride_f_0, _stride_f_0, 0LL), 8);
                  _mm256_i64scatter_pd(&_data_f[ctr_0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2], _mm256_set_epi64x(3LL * _stride_f_0, 2LL * _stride_f_0, _stride_f_0, 0LL), _mm256_mul_pd(_mm256_set1_pd(c), x__2), 8);
                  wmax_local__1 = _mm256_max_pd(wmax_local__1, x__2);
               }
               for(int64_t ctr_0__0 = ctr_0__rem_start; ctr_0__0 < _size_f_0; ctr_0__0 += 1LL)
               {
                  const double x__0 = _data_f[ctr_0__0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2];
                  _data_f[ctr_0__0 * _stride_f_0 + ctr_1 * _stride_f_1 + ctr_2 * _stride_f_2] = c * x__0;
                  wmax_local = fmax(wmax_local, x__0);
               }
            }
         }
         wmax_local = _mm256_horizontal_max_pd(wmax_local, wmax_local__1);
      }
   }
   wmax[0LL] = fmax(wmax[0LL], wmax_local);
}

To make the kernel available to the runtime system, JIT compiler, or pystencils-sfg, we need to wrap the AST inside a Kernel object. We use the KernelFactory from the codegen module for this task. To create the kernel, we have to specify its platform, AST, name, target, and a JIT compiler (use no_jit if not applicable):

from pystencils.codegen.driver import KernelFactory
from pystencils.jit.cpu import CpuJit, CompilerInfo

kfactory = KernelFactory(ctx)
ker = kfactory.create_generic_kernel(
    platform,
    kernel_ast,
    "my_kernel",
    ps.Target.X86_AVX512,
    CpuJit(CompilerInfo.get_default()),
)

print(ker)
my_kernel(_data_f : double * RESTRICT const, _size_f_0 : const int64, _size_f_1 : const int64, _size_f_2 : const int64, _stride_f_0 : const int64, _stride_f_1 : const int64, _stride_f_2 : const int64, c : const float64, wmax : double * const)

This concludes the walkthrough of the kernel creation pipeline.