"""
    ChainTransform(transforms)

Transformation that applies a chain of transformations `ts` to the input.

The transformation `first(ts)` is applied first.

# Examples

```jldoctest
julia> l = rand(); A = rand(3, 4); t1 = ScaleTransform(l); t2 = LinearTransform(A);

julia> X = rand(4, 10);

julia> map(ChainTransform([t1, t2]), ColVecs(X)) == ColVecs(A * (l .* X))
true

julia> map(t2 ∘ t1, ColVecs(X)) == ColVecs(A * (l .* X))
true
```
"""
struct ChainTransform{V} <: Transform
    transforms::V
end

@functor ChainTransform

Base.length(t::ChainTransform) = length(t.transforms)

# Constructor to create a chain transform with an array of parameters
function ChainTransform(v, θ::AbstractVector)
    @assert length(v) == length(θ)
    return ChainTransform(v.(θ))
end

Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁))
Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform((tc.transforms..., t))
Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform((t, tc.transforms...))
function Base.:∘(tc1::ChainTransform, tc2::ChainTransform)
    return ChainTransform((tc2.transforms..., tc1.transforms...))
end

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function _map(t::ChainTransform, x::AbstractVector)
    return foldl((x, t) -> _map(t, x), t.transforms; init=x)
end

set!(t::ChainTransform, θ) = set!.(t.transforms, θ)
duplicate(t::ChainTransform, θ) = ChainTransform(map(duplicate, t.transforms, θ))

Base.show(io::IO, t::ChainTransform) = printshifted(io, t, 0)

function printshifted(io::IO, t::ChainTransform, shift::Int)
    println(io, "Chain of ", length(t), " transforms:")
    for _ in 1:(shift + 1)
        print(io, "\t")
    end
    print(io, " - ")
    printshifted(io, t.transforms[1], shift + 2)
    for i in 2:length(t)
        print(io, " |> ")
        printshifted(io, t.transforms[i], shift + 2)
    end
end
