Slavozard's blog

implement custom optimizer with Optax

Optax is an optimization library for JAX which provides out-of-the-box apis for commonly used ml loss functions, optimizers and utilities like learning rate schedulers.

In general, defining your own optimizers with Optax is quite simple. Use the optax.chain api to combine different transforms. Transforms are pure functions that act on model parameters to update their values according to the update rules defined within the transform function. They are available as ''optax.transform_foo''. One can simply think of them as chunks of math operations that take input x, perform some computation on it and return the transformed result. For example, if you want to define a new optimizer which follows the update rule [adam -> clip gradients by 0.8 -> apply learning rate of 0.02 ] the code would look like this -

transform1 = optax.scale_by_adam()
transform2 = optax.clip(0.8)
transform3 = optax.scale_by_learning_rate(0.02)
chained_transform = optax.chain(transform1, transform2, transform3)

This chained_transform is now a complete optimizer and can be used to update the parameters of a model (similar to how one would call optax.adam() ).

However, there might be times when we need to construct an optimizer whose update rules are not defined as per the available transform functions here. For example, if you are trying to implement the Muon optimizer, the update steps look something like this

Muon optimizer update equations

(taken from Keller Jordan's blog)

where the NewtonSchulz5 operation is not available as a transform function in Optax(yet). This is where this blog comes in as a general guide to implement custom optimizers with Optax that require mathematical transforms which are not part of the core library.

Looking into the Optax Optimizer

img

At its core, Optax functionality is built around its base module which consists of all interfaces and datatypes required by the library. The interface we care about here is called GradientTransformation, which has the following signature

class GradientTransformation(NamedTuple):
    init: TransformInitFn
    update: TransformUpdateFn

TransformInitFn is a callable responsible for setting up the initial parameters, while TransformUpdateFn is a callable that applies the update rules. All the parameters passed here are treated as instances of Pytree.

The former is invoked internally when you initialize the optimizer as

solver = optax.adam(lr)
state = solver.init(params)

and returns the initial state of the optimizer that has the same Pytree structure as the parameters. The latter is for when you want to update the parameters as

updates, state = solver.update(updates, state)

Implementing the custom transform

With this in mind, we can now define our own transform by writing a function that returns a GradientTransformation object

class CustomOptState(NamedTuple):
    evolving_state: base.Updates

def scale_by_custom_optimizer(constant:float)-> GradientTransformation:
    def matrix_ops(X:jax.Array)-> jax.Array:
        # perform custom matrix operations
        return X

    def init_fn(params):
        # initialize the state of the optimizer as initial_state
        return CustomOptState(evolving_state=initial_state)

    def update_fn(updates, state, params: Optional[base.Params]):
        # update the new state
        # apply the matrix_ops
        # optional can do other parameter updates like weight decay
        return updates, CustomOptState(evolving_state=updated_state), params

    return GradientTransformation(init_fn, update_fn)

Since JAX functions are pure functions, optimizers themselves are stateless and we need a way to keep track of the evolving states of the optimizers. This is where the CustomOptState comes in. It is defined as a NamedTuple and its structure can be any Pytree. In most cases this will mirror the nested structure of the model's parameters to store per-parameter state.

We can then define our transform that should implement two functions called init_fn and update_fn. Along with this, it might contain any other custom matrix operations functions that we need to perform.

The init_fn function is passed a Pytree object that will match the shape of the params to be updated and returns the initial state of the optimizer as a CustomOptState object. In most cases, the passed argument is the model parameters dict.

def init_fn(params):
   # initialize the state of the optimizer as all zeros
   initial_state = jax.tree.map(jnp.zeros_like,params)
   return CustomOptState(evolving_state=initial_state)

The update_fn is where you pass your new values(eg the gradients) and the current state(ie the running state of the optimizer as a CustomOptState object) and return the parameter update values and the CustomOptState object with the updated state. Optionally you can also pass the model parameters in order to do weight decay or other updates.

def update_fn(grads: base.Updates, state: CustomOptState):
    # simple update rules for example
    # calculate the new momentum/running average
    updated_state = jax.tree.map(
            lambda s, g: constant * s + g, state.evolving_state, grads)
    # calculate the custom optimization step
    update = matrix_ops(updated_state)
    return update, CustomOptState(evolving_state=updated_state)

The final step is to return the GradientTransformation object with these two functions.

return GradientTransformation(init_fn, update_fn)

We have our custom transform ready, only step now left is to use it to to define the optimizer that will call on this transform.

def custom_optimizer(lr: float,constant:float) -> base.GradientTransformation:

    return optax.chain(
        scale_by_custom_optimizer(constant),
        transform.scale_by_learning_rate(lr)
    )

The scale_by_learning_rate is just -( update * lr ). The lr ideally should be annotated as base.ScalarOrSchedule, since it can be a float value or a scheduler.

And we have our new optimzer ready.

solver = custom_optimizer(lr,constant)
state = solver.init(params)
updates, state = solver.update(grads, state)

At this point, updates can be applied to parameters using optax.apply_updates.

References

[1] Optax documentation - https://optax.readthedocs.io/en/latest/api/combining_optimizers.html

[2] Muon optimizer - https://kellerjordan.github.io/posts/muon/

[3] Code - github