Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions internals/daemon/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,43 @@ func (ac MetricsAccess) CheckAccess(d *Daemon, r *http.Request, user *UserState)
return Unauthorized(accessDenied)
}
}

// pairingWindowEnabled simplifies testing without a pairing manager.
var pairingWindowEnabled = (*Daemon).pairingWindowEnabled

// PairingAccess is only intended for use as an access checker for the pairing
// endpoint. This access checker allows a new mTLS client identity to be
// forwarded to the pairing manager, without identity verification. This access
// checker will only allow pairing requests while the pairing manager has its
// pairing window enabled, which typically involves a proof of server ownership
// procedure, such as a controlled power cycle or button press.
type PairingAccess struct{}

func (ac PairingAccess) CheckAccess(d *Daemon, r *http.Request, user *UserState) Response {
// This should only be called for /v1/pairing, but double-check here
// just in case.
if r.URL.Path != "/v1/pairing" {
return Unauthorized(accessDenied)
}

// We only support pairing an mTLS client certificate at this point, so
// the transport has to be HTTPS.
if RequestTransportType(r) != TransportTypeHTTPS {
return Unauthorized(accessDenied)
}

if pairingWindowEnabled(d) {
// Only permit a pairing request during an open pairing window.
//
// Note that this is not the final decision on whether this
// request will succeed. This check is simply a sanity check
// that prevents forwarding the pairing request to the
// manager unnecessarily. The final check is made inside the
// pairing manager where all incoming requests will be
// serialized, and only the first request will be accepted,
// after which the pairing window will be closed.
return nil
}

return Unauthorized(accessDenied)
}
49 changes: 49 additions & 0 deletions internals/daemon/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr daemon.Response
userCheckErr daemon.Response
metricsCheckErr daemon.Response
pairingCheckErr daemon.Response
}{
// API source: Unix Domain Socket
{
Expand All @@ -49,6 +50,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: UntrustedAccess
apiSource: daemon.TransportTypeUnixSocket,
Expand All @@ -57,6 +59,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: MetricsAccess
apiSource: daemon.TransportTypeUnixSocket,
Expand All @@ -65,6 +68,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
}, {
// User access: ReadAccess
apiSource: daemon.TransportTypeUnixSocket,
Expand All @@ -73,6 +77,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: nil,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
}, {
// User access: AdminAccess
apiSource: daemon.TransportTypeUnixSocket,
Expand All @@ -81,6 +86,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: nil,
userCheckErr: nil,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
},
// API source: HTTP
{
Expand All @@ -91,6 +97,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: UntrustedAccess
apiSource: daemon.TransportTypeHTTP,
Expand All @@ -99,6 +106,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: MetricsAccess
apiSource: daemon.TransportTypeHTTP,
Expand All @@ -107,6 +115,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
}, {
// User access: ReadAccess
apiSource: daemon.TransportTypeHTTP,
Expand All @@ -115,6 +124,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: AdminAccess
apiSource: daemon.TransportTypeHTTP,
Expand All @@ -123,6 +133,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
},
// API source: HTTPS
{
Expand All @@ -133,6 +144,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: UntrustedAccess
apiSource: daemon.TransportTypeHTTPS,
Expand All @@ -141,6 +153,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: errUnauthorized,
pairingCheckErr: errUnauthorized,
}, {
// User access: MetricsAccess
apiSource: daemon.TransportTypeHTTPS,
Expand All @@ -149,6 +162,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: errUnauthorized,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
}, {
// User access: ReadAccess
apiSource: daemon.TransportTypeHTTPS,
Expand All @@ -157,6 +171,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: errUnauthorized,
userCheckErr: nil,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
}, {
// User access: AdminAccess
apiSource: daemon.TransportTypeHTTPS,
Expand All @@ -165,6 +180,7 @@ func (s *accessSuite) TestAccess(c *C) {
adminCheckErr: nil,
userCheckErr: nil,
metricsCheckErr: nil,
pairingCheckErr: errUnauthorized,
}}
for _, t := range tests {
// Fake a test request.
Expand All @@ -188,5 +204,38 @@ func (s *accessSuite) TestAccess(c *C) {
metricsAccess := daemon.MetricsAccess{}
err = metricsAccess.CheckAccess(nil, r, t.user)
c.Assert(err, DeepEquals, t.metricsCheckErr)
// Check PairingAccess
pairingAccess := daemon.PairingAccess{}
err = pairingAccess.CheckAccess(nil, r, t.user)
c.Assert(err, DeepEquals, t.pairingCheckErr)
}
}

// TestPairingAccessWithPairingWindow tests the pairing specific behaviour
// related to whether the pairing window is open or closed.
func (s *accessSuite) TestPairingAccessWithPairingWindow(c *C) {
pairingAccess := daemon.PairingAccess{}

r := &http.Request{
URL: &url.URL{Path: "/v1/pairing"},
}
r = r.WithContext(context.WithValue(context.Background(), daemon.TransportTypeKey{}, daemon.TransportTypeHTTPS))

// Test with pairing window disabled
restore := daemon.FakePairingWindowEnabled(func(d *daemon.Daemon) bool {
return false
})
defer restore()

err := pairingAccess.CheckAccess(nil, r, nil)
c.Assert(err, DeepEquals, errUnauthorized)

// Test with pairing window open
restore = daemon.FakePairingWindowEnabled(func(d *daemon.Daemon) bool {
return true
})
defer restore()

err = pairingAccess.CheckAccess(nil, r, nil)
c.Assert(err, IsNil)
}
4 changes: 4 additions & 0 deletions internals/daemon/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ var API = []*Command{{
WriteAccess: AdminAccess{},
GET: v1GetIdentities,
POST: v1PostIdentities,
}, {
Path: "/v1/pairing",
WriteAccess: PairingAccess{},
POST: v1PostPairing,
}, {
Path: "/v1/metrics",
ReadAccess: MetricsAccess{},
Expand Down
53 changes: 53 additions & 0 deletions internals/daemon/api_pairing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2024 Canonical Ltd
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 3 as
// published by the Free Software Foundation.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package daemon

import (
"encoding/json"
"net/http"
)

func v1PostPairing(c *Command, r *http.Request, user *UserState) Response {
var payload struct {
Action string `json:"action"`
}
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&payload); err != nil {
return BadRequest("cannot decode request body: %v", err)
}

switch payload.Action {
case "pair":
if r.TLS == nil {
return InternalError("cannot find TLS connection state")
}
// Validate that exactly one peer certificate is provided
if len(r.TLS.PeerCertificates) != 1 {
return BadRequest("cannot support client: single certificate expected, got %d", len(r.TLS.PeerCertificates))
}
// The leaf peer certificate is the client identity certificate.
clientCert := r.TLS.PeerCertificates[0]

pairingMgr := c.d.overlord.PairingManager()
if err := pairingMgr.PairMTLS(clientCert); err != nil {
return BadRequest("cannot pair client: %v", err)
}

default:
return BadRequest(`invalid action %q, must be "pair"`, payload.Action)
}

return SyncResponse(nil)
}
Loading
Loading