Skip to content
This repository was archived by the owner on Jun 18, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

## unreleased

- #497:

- `sicmutils.expression.compile/compile-state-fn` and its non-memoized version
can now take an explicit `:mode` argument; this will override the
dynamically bound `*mode*`.

Invalid modes supplied via `:mode` will cause `compile-state-fn` to throw an
exception.

- Fixes a bug where non-numeric operations `up` and `down` were applied at
compile time, throwing an error.

- #498: replace all long-form GPL headers with `"SPDX-License-Identifier:
GPL-3.0"`.

Expand All @@ -14,7 +26,7 @@

A simulation run of the double pendulum example in the [clerk-demo
repository](https://github.com/nextjournal/clerk-demo/blob/20a404a271bea29ef98ee4e60a05e54345aa43ba/notebooks/sicmutils.clj)
now runs in 700ms vs the former 2.2 seconds, a major win.
now runs in 350ms vs the former 2.2 seconds, a major win.

- Function compilation now pre-simplifies numerical forms encountered inside a
function, like `(/ 1 2)`, instead of letting them be evaluated on every fn
Expand Down
80 changes: 49 additions & 31 deletions src/sicmutils/expression/compile.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@
new-expression)))]
(extract-common-subexpressions expr callback opts))))

(def ^{:private true
:doc "Similar to [[compiled-fn-whitelist]], but restricted to numeric
operations."}
numeric-whitelist
(dissoc compiled-fn-whitelist 'up 'down))

(defn ^:no-doc apply-numeric-ops
"Takes a function body and returns a new body with all numeric operations
like `(/ 1 2)` evaluated and all numerical literals converted to `double` or
Expand All @@ -365,7 +371,7 @@
(sequential? expr)
(let [[f & xs] expr]
(if-let [m (and (every? number? xs)
(compiled-fn-whitelist f))]
(numeric-whitelist f))]
(u/double (apply (:f m) xs))
expr))
:else expr))
Expand Down Expand Up @@ -394,23 +400,30 @@
valid-modes
#{:sci :native :source})

(defn set-compiler-mode!
"Set the default compilation mode by supplying an entry from [[valid-modes]]."
(defn validate-mode!
"Given a keyword `mode` specifying a compilation mode, returns `mode` if valid,
and throws otherwise."
[mode]
(if (valid-modes mode)
#?(:cljs (set! *mode* mode)
:clj (alter-var-root #'*mode* (constantly mode)))
(u/illegal (str "Invalid compilation mode supplied: " mode
". Please supply one of " valid-modes))))
(or (valid-modes mode)
(throw
(ex-info
(str "Invalid compilation mode supplied: " mode
". Please supply (or bind to `*mode*`) one of " valid-modes)
{:mode mode
:valid-mode valid-modes}))))

(defn compiler-mode
"Validates and returns the dynamically bound compilation [[*mode*]].
Throws on an invalid setting."
[]
(or (valid-modes *mode*)
(u/illegal (str "Invalid compilation mode bound: "
*mode*
". Please supply one of " valid-modes))))
(validate-mode! *mode*))

(defn set-compiler-mode!
"Set the default compilation mode by supplying an entry from [[valid-modes]]."
[mode]
(validate-mode! mode)
#?(:cljs (set! *mode* mode)
:clj (alter-var-root #'*mode* (constantly mode))))

;; Native compilation works on the JVM, and on Clojurescript if you're running
;; in a self-hosted CLJS environment. Enable this mode by wrapping your call in
Expand Down Expand Up @@ -561,15 +574,18 @@

- an optional argument `opts`. Options accepted are:

- `:flatten?` if `true` (default), the returned function will have
- `:flatten?`: if `true` (default), the returned function will have
signature `(f <flattened-state> [params])`. If `false`, the first arg of the
returned function will be expected to have the same shape as `initial-state`

- `:generic-params?` if `true` (default), the returned function will take a
- `:generic-params?`: if `true` (default), the returned function will take a
second argument for the parameters of the state derivative and keep params
generic. If false, the returned function will take a single state argument,
and the supplied params will be hardcoded.

- `:mode`: Explicitly set the compilation mode to one of the values
in [[valid-modes]]. Explicit alternative to dynamically binding [[*mode*]].

The returned, compiled function expects all `Double` (or `js/Number`) for all
state primitives. The function body is simplified and all common
subexpressions identified during compilation are extracted and computed only
Expand All @@ -579,26 +595,28 @@
cache, see `compile-state-fn`."
([f params initial-state]
(compile-state-fn* f params initial-state {}))
([f params initial-state {:keys [generic-params? gensym-fn]
([f params initial-state {:keys [generic-params?
gensym-fn
mode]
:or {generic-params? true
gensym-fn gensym}
:as opts}]
(let [sw (us/stopwatch)
params (if generic-params?
(for [_ params] (gensym-fn 'p))
params)
generic-state (state->argv initial-state gensym-fn)
g (apply f params)
body (-> (g generic-state)
(g/simplify)
(v/freeze)
(cse-form)
(apply-numeric-ops))
compiler (case (compiler-mode)
:source compile-state->source
:native compile-state-native
:sci compile-state-sci)
compiled-fn (compiler params generic-state body opts)]
(let [sw (us/stopwatch)
params (if generic-params?
(for [_ params] (gensym-fn 'p))
params)
generic-state (state->argv initial-state gensym-fn)
g (apply f params)
body (-> (g generic-state)
(g/simplify)
(v/freeze)
(cse-form)
(apply-numeric-ops))
compiler (case (validate-mode! (or mode *mode*))
:source compile-state->source
:native compile-state-native
:sci compile-state-sci)
compiled-fn (compiler params generic-state body opts)]
(log/info "compiled state function in" (us/repr sw) "with mode" *mode*)
compiled-fn)))

Expand Down
26 changes: 22 additions & 4 deletions test/sicmutils/expression/compile_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
[sicmutils.expression.compile :as c]
[sicmutils.generic :as g]
[sicmutils.structure :refer [up down]]
[sicmutils.value :as v]))
[sicmutils.value :as v])
#?(:clj
(:import (clojure.lang ExceptionInfo))))

(deftest mode-binding-test
(testing "set-compiler-mode! works"
Expand All @@ -26,7 +28,7 @@
"valid modes are all returned by compiler-mode.")))

(binding [c/*mode* :TOTALLY-INVALID]
(is (thrown? #?(:clj IllegalArgumentException :cljs js/Error)
(is (thrown? ExceptionInfo
(c/compiler-mode))
"invalid modes throw.")))

Expand Down Expand Up @@ -60,6 +62,22 @@
f-source)
"source code!")

(binding [c/*mode* :native]
(is (= `(fn [[~'y] [~'p]]
(vector (+ (* ~'p ~'y) (* 0.5 ~'p))))
(c/compile-state-fn*
f params initial-state
{:gensym-fn identity
:mode :source}))
"explicit `:mode` overrides the dynamic binding."))

(is (thrown? ExceptionInfo
(c/compile-state-fn*
f params initial-state
{:gensym-fn identity
:mode :invalid}))
"explicit invalid modes will throw!")

(is (= expected ((c/sci-eval f-source)
initial-state params))
"source compiles to SCI and gives us the desired result."))))))
Expand Down Expand Up @@ -134,10 +152,10 @@
x))))
"all remaining numerical literals are doubles.")

(is (= `(fn [~'x] (+ ~'x 0.5))
(is (= `(fn [~'x] (vector 2.0 (+ ~'x 0.5)))
(c/compile-fn
(fn [x]
(g/+ (g// 1 2) x))))
(up 2 (g/+ (g// 1 2) x)))))
"`(/ 1 2)` is resolved into 0.5 at compile time.")))))))

(let [f (fn [x] (g/+ 1 (g/square (g/sin x))))
Expand Down