Skip to content

Commit

Permalink
Initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hdecarne committed Dec 17, 2023
1 parent 21a622d commit 022b25f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 6 deletions.
68 changes: 63 additions & 5 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"maps"
"runtime"
"strings"
"time"
Expand Down Expand Up @@ -323,6 +324,32 @@ func (registry *Registry) Entry(name string) (*RegistryEntry, error) {
return entry, nil
}

func (registry *Registry) CertPools() (*x509.CertPool, *x509.CertPool, error) {
roots := x509.NewCertPool()
intermediates := x509.NewCertPool()
entries, err := registry.Entries()
if err != nil {
return nil, nil, err
}
for {
entry, err := entries.Next()
if err != nil {
return nil, nil, err
}
if entry == nil {
break
}
if entry.IsCA() {
if entry.IsRoot() {
roots.AddCert(entry.Certificate())
} else {
intermediates.AddCert(entry.Certificate())
}
}
}
return roots, intermediates, nil
}

func (registry *Registry) isValidEntryName(name string) bool {
return !strings.HasPrefix(name, ".")
}
Expand Down Expand Up @@ -372,7 +399,7 @@ func (registry *Registry) getEntryData(name string) (*registryEntryData, error)
}

func (registry *Registry) unmarshalEntryData(dataBytes []byte) (*registryEntryData, error) {
data := &registryEntryData{}
data := &registryEntryData{Attributes: make(map[string]string, 0)}
err := json.Unmarshal(dataBytes, data)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal entry data (cause: %w)", err)
Expand Down Expand Up @@ -448,6 +475,7 @@ type RegistryEntry struct {
certificate *x509.Certificate
certificateRequest *x509.CertificateRequest
revocationList *x509.RevocationList
attributes map[string]string
}

func (entry *RegistryEntry) Name() string {
Expand All @@ -461,8 +489,15 @@ func (entry *RegistryEntry) IsRoot() bool {
return certs.IsRoot(entry.certificate)
}

func (entry *RegistryEntry) CanIssue() bool {
return entry.key != nil && entry.certificate != nil
func (entry *RegistryEntry) IsCA() bool {
if entry.certificate == nil {
return false
}
return entry.certificate.IsCA
}

func (entry *RegistryEntry) CanIssue(keyUsage x509.KeyUsage) bool {
return entry.key != nil && entry.certificate != nil && (entry.certificate.KeyUsage&keyUsage) == keyUsage
}

func (entry *RegistryEntry) HasKey() bool {
Expand Down Expand Up @@ -493,8 +528,8 @@ func (entry *RegistryEntry) CertificateRequest() *x509.CertificateRequest {
}

func (entry *RegistryEntry) ResetRevocationList(factory certs.RevocationListFactory, user string) (*x509.RevocationList, error) {
if !entry.CanIssue() {
return nil, fmt.Errorf("cannot create revocation list for a non-issueing certificate")
if !entry.CanIssue(x509.KeyUsageCRLSign) {
return nil, fmt.Errorf("cannot create revocation list for selected certificate")
}
revocationList, err := factory.New(entry.Certificate(), entry.Key(user))
if err != nil {
Expand All @@ -516,6 +551,18 @@ func (entry *RegistryEntry) RevocationList() *x509.RevocationList {
return entry.revocationList
}

func (entry *RegistryEntry) Attributes() map[string]string {
return maps.Clone(entry.attributes)
}

func (entry *RegistryEntry) SetAttributes(attributes map[string]string) error {
err := entry.mergeAttributes(attributes)
if err != nil {
return err
}
return nil
}

func (entry *RegistryEntry) matchCertificate(certificate *x509.Certificate) bool {
if entry.HasCertificate() {
return bytes.Equal(entry.certificate.Raw, certificate.Raw)
Expand Down Expand Up @@ -603,6 +650,17 @@ func (entry *RegistryEntry) mergeRevocationList(revocationList *x509.RevocationL
return nil
}

func (entry *RegistryEntry) mergeAttributes(attributes map[string]string) error {
data, err := entry.registry.getEntryData(entry.name)
if err != nil {
return err
}
data.Attributes = maps.Clone(attributes)
entry.registry.updateEntryData(entry.name, data)
entry.attributes = data.Attributes
return nil
}

type registryEntryData struct {
EncodedKey string `json:"key"`
EncodedCertificate string `json:"crt"`
Expand Down
54 changes: 53 additions & 1 deletion store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestCreateCertificate(t *testing.T) {
entryCertificate := entry.Certificate()
require.NotNil(t, entryCertificate)
require.True(t, entry.IsRoot())
require.True(t, entry.CanIssue())
require.True(t, entry.CanIssue(x509.KeyUsageCertSign))
}

func TestCreateCertificateRequest(t *testing.T) {
Expand Down Expand Up @@ -98,6 +98,22 @@ func TestResetRevocationList(t *testing.T) {
require.Equal(t, revocationList1, revocationList2)
}

func TestAttributes(t *testing.T) {
name := "TestAttributes"
user := name + "User"
registry, err := store.NewStore(storage.NewMemoryStorage(testVersionLimit), 0)
require.NoError(t, err)
factory := newTestRootCertificateFactory(name)
createdName, err := registry.CreateCertificate(name, factory, user)
require.NoError(t, err)
entry, err := registry.Entry(createdName)
require.NoError(t, err)
attributes := map[string]string{"Key": "Value"}
err = entry.SetAttributes(attributes)
require.NoError(t, err)
require.Equal(t, attributes, entry.Attributes())
}

func TestMerge(t *testing.T) {
path, err := os.MkdirTemp("", "TestMerge*")
require.NoError(t, err)
Expand Down Expand Up @@ -137,6 +153,42 @@ func TestEntries(t *testing.T) {
checkStoreEntries(t, registry, 1120, 10)
}

func TestCertPools(t *testing.T) {
registry, err := store.NewStore(storage.NewMemoryStorage(testVersionLimit), 0)
require.NoError(t, err)
user := "TestCertPoolsUser"
populateTestStore(t, registry, user, 5)
roots, intermediates, err := registry.CertPools()
require.NoError(t, err)
require.NotNil(t, roots)
require.NotNil(t, intermediates)
entries, err := registry.Entries()
require.NoError(t, err)
for {
entry, err := entries.Next()
require.NoError(t, err)
if entry == nil {
break
}
if entry.HasCertificate() {
options := &x509.VerifyOptions{
Roots: roots,
Intermediates: intermediates,
}
chains, err := entry.Certificate().Verify(*options)
require.NoError(t, err)
require.Equal(t, 1, len(chains))
if entry.IsRoot() {
require.Equal(t, 1, len(chains[0]))
} else if entry.IsCA() {
require.Equal(t, 2, len(chains[0]))
} else {
require.Equal(t, 3, len(chains[0]))
}
}
}
}

func checkStoreEntries(t *testing.T, registry *store.Registry, total int, roots int) {
entries, err := registry.Entries()
require.NoError(t, err)
Expand Down

0 comments on commit 022b25f

Please sign in to comment.