From 6775c99cd008c457ce3eed401ac1c60c3812dbfa Mon Sep 17 00:00:00 2001 From: Tim Allclair Date: Mon, 10 Oct 2022 18:15:22 -0700 Subject: [PATCH] Validate etcd paths --- .../pkg/storage/etcd3/linearized_read_test.go | 5 +- .../apiserver/pkg/storage/etcd3/store.go | 138 +++++++++++------ .../apiserver/pkg/storage/etcd3/store_test.go | 140 ++++++++++++++---- .../pkg/storage/testing/store_tests.go | 24 +-- 4 files changed, 218 insertions(+), 89 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/etcd3/linearized_read_test.go b/staging/src/k8s.io/apiserver/pkg/storage/etcd3/linearized_read_test.go index bb1b9df7818..7331c8245ad 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/etcd3/linearized_read_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/etcd3/linearized_read_test.go @@ -37,7 +37,8 @@ func TestLinearizedReadRevisionInvariant(t *testing.T) { // [1] https://etcd.io/docs/v3.5/learning/api_guarantees/#isolation-level-and-consistency-of-replicas ctx, store, etcdClient := testSetup(t) - key := "/testkey" + dir := "/testing" + key := dir + "/testkey" out := &example.Pod{} obj := &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo", SelfLink: "testlink"}} @@ -53,7 +54,7 @@ func TestLinearizedReadRevisionInvariant(t *testing.T) { } list := &example.PodList{} - if err := store.GetList(ctx, "/", storage.ListOptions{Predicate: storage.Everything, Recursive: true}, list); err != nil { + if err := store.GetList(ctx, dir, storage.ListOptions{Predicate: storage.Everything, Recursive: true}, list); err != nil { t.Errorf("Unexpected List error: %v", err) } finalRevision := list.ResourceVersion diff --git a/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go b/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go index 163eb111bcb..b2e4c674687 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go @@ -98,16 +98,21 @@ func New(c *clientv3.Client, codec runtime.Codec, newFunc func() runtime.Object, func newStore(c *clientv3.Client, codec runtime.Codec, newFunc func() runtime.Object, prefix string, groupResource schema.GroupResource, transformer value.Transformer, pagingEnabled bool, leaseManagerConfig LeaseManagerConfig) *store { versioner := storage.APIObjectVersioner{} + // for compatibility with etcd2 impl. + // no-op for default prefix of '/registry'. + // keeps compatibility with etcd2 impl for custom prefixes that don't start with '/' + pathPrefix := path.Join("/", prefix) + if !strings.HasSuffix(pathPrefix, "/") { + // Ensure the pathPrefix ends in "/" here to simplify key concatenation later. + pathPrefix += "/" + } result := &store{ - client: c, - codec: codec, - versioner: versioner, - transformer: transformer, - pagingEnabled: pagingEnabled, - // for compatibility with etcd2 impl. - // no-op for default prefix of '/registry'. - // keeps compatibility with etcd2 impl for custom prefixes that don't start with '/' - pathPrefix: path.Join("/", prefix), + client: c, + codec: codec, + versioner: versioner, + transformer: transformer, + pagingEnabled: pagingEnabled, + pathPrefix: pathPrefix, groupResource: groupResource, groupResourceString: groupResource.String(), watcher: newWatcher(c, codec, newFunc, versioner, transformer), @@ -123,9 +128,12 @@ func (s *store) Versioner() storage.Versioner { // Get implements storage.Interface.Get. func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, out runtime.Object) error { - key = path.Join(s.pathPrefix, key) + preparedKey, err := s.prepareKey(key) + if err != nil { + return err + } startTime := time.Now() - getResp, err := s.client.KV.Get(ctx, key) + getResp, err := s.client.KV.Get(ctx, preparedKey) metrics.RecordEtcdRequestLatency("get", getTypeName(out), startTime) if err != nil { return err @@ -138,11 +146,11 @@ func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, ou if opts.IgnoreNotFound { return runtime.SetZeroValue(out) } - return storage.NewKeyNotFoundError(key, 0) + return storage.NewKeyNotFoundError(preparedKey, 0) } kv := getResp.Kvs[0] - data, _, err := s.transformer.TransformFromStorage(ctx, kv.Value, authenticatedDataString(key)) + data, _, err := s.transformer.TransformFromStorage(ctx, kv.Value, authenticatedDataString(preparedKey)) if err != nil { return storage.NewInternalError(err.Error()) } @@ -152,6 +160,10 @@ func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, ou // Create implements storage.Interface.Create. func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, ttl uint64) error { + preparedKey, err := s.prepareKey(key) + if err != nil { + return err + } trace := utiltrace.New("Create etcd3", utiltrace.Field{"audit-id", endpointsrequest.GetAuditIDTruncated(ctx)}, utiltrace.Field{"key", key}, @@ -170,14 +182,13 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, if err != nil { return err } - key = path.Join(s.pathPrefix, key) opts, err := s.ttlOpts(ctx, int64(ttl)) if err != nil { return err } - newData, err := s.transformer.TransformToStorage(ctx, data, authenticatedDataString(key)) + newData, err := s.transformer.TransformToStorage(ctx, data, authenticatedDataString(preparedKey)) trace.Step("TransformToStorage finished", utiltrace.Field{"err", err}) if err != nil { return storage.NewInternalError(err.Error()) @@ -185,9 +196,9 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, startTime := time.Now() txnResp, err := s.client.KV.Txn(ctx).If( - notFound(key), + notFound(preparedKey), ).Then( - clientv3.OpPut(key, string(newData), opts...), + clientv3.OpPut(preparedKey, string(newData), opts...), ).Commit() metrics.RecordEtcdRequestLatency("create", getTypeName(obj), startTime) trace.Step("Txn call finished", utiltrace.Field{"err", err}) @@ -196,7 +207,7 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, } if !txnResp.Succeeded { - return storage.NewKeyExistsError(key, 0) + return storage.NewKeyExistsError(preparedKey, 0) } if out != nil { @@ -212,12 +223,15 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, func (s *store) Delete( ctx context.Context, key string, out runtime.Object, preconditions *storage.Preconditions, validateDeletion storage.ValidateObjectFunc, cachedExistingObject runtime.Object) error { + preparedKey, err := s.prepareKey(key) + if err != nil { + return err + } v, err := conversion.EnforcePtr(out) if err != nil { return fmt.Errorf("unable to convert output object to pointer: %v", err) } - key = path.Join(s.pathPrefix, key) - return s.conditionalDelete(ctx, key, out, v, preconditions, validateDeletion, cachedExistingObject) + return s.conditionalDelete(ctx, preparedKey, out, v, preconditions, validateDeletion, cachedExistingObject) } func (s *store) conditionalDelete( @@ -330,6 +344,10 @@ func (s *store) conditionalDelete( func (s *store) GuaranteedUpdate( ctx context.Context, key string, destination runtime.Object, ignoreNotFound bool, preconditions *storage.Preconditions, tryUpdate storage.UpdateFunc, cachedExistingObject runtime.Object) error { + preparedKey, err := s.prepareKey(key) + if err != nil { + return err + } trace := utiltrace.New("GuaranteedUpdate etcd3", utiltrace.Field{"audit-id", endpointsrequest.GetAuditIDTruncated(ctx)}, utiltrace.Field{"key", key}, @@ -340,16 +358,15 @@ func (s *store) GuaranteedUpdate( if err != nil { return fmt.Errorf("unable to convert output object to pointer: %v", err) } - key = path.Join(s.pathPrefix, key) getCurrentState := func() (*objState, error) { startTime := time.Now() - getResp, err := s.client.KV.Get(ctx, key) + getResp, err := s.client.KV.Get(ctx, preparedKey) metrics.RecordEtcdRequestLatency("get", getTypeName(destination), startTime) if err != nil { return nil, err } - return s.getState(ctx, getResp, key, v, ignoreNotFound) + return s.getState(ctx, getResp, preparedKey, v, ignoreNotFound) } var origState *objState @@ -365,9 +382,9 @@ func (s *store) GuaranteedUpdate( } trace.Step("initial value restored") - transformContext := authenticatedDataString(key) + transformContext := authenticatedDataString(preparedKey) for { - if err := preconditions.Check(key, origState.obj); err != nil { + if err := preconditions.Check(preparedKey, origState.obj); err != nil { // If our data is already up to date, return the error if origStateIsCurrent { return err @@ -453,11 +470,11 @@ func (s *store) GuaranteedUpdate( startTime := time.Now() txnResp, err := s.client.KV.Txn(ctx).If( - clientv3.Compare(clientv3.ModRevision(key), "=", origState.rev), + clientv3.Compare(clientv3.ModRevision(preparedKey), "=", origState.rev), ).Then( - clientv3.OpPut(key, string(newData), opts...), + clientv3.OpPut(preparedKey, string(newData), opts...), ).Else( - clientv3.OpGet(key), + clientv3.OpGet(preparedKey), ).Commit() metrics.RecordEtcdRequestLatency("update", getTypeName(destination), startTime) trace.Step("Txn call finished", utiltrace.Field{"err", err}) @@ -467,8 +484,8 @@ func (s *store) GuaranteedUpdate( trace.Step("Transaction committed") if !txnResp.Succeeded { getResp := (*clientv3.GetResponse)(txnResp.Responses[0].GetResponseRange()) - klog.V(4).Infof("GuaranteedUpdate of %s failed because of a conflict, going to retry", key) - origState, err = s.getState(ctx, getResp, key, v, ignoreNotFound) + klog.V(4).Infof("GuaranteedUpdate of %s failed because of a conflict, going to retry", preparedKey) + origState, err = s.getState(ctx, getResp, preparedKey, v, ignoreNotFound) if err != nil { return err } @@ -502,18 +519,21 @@ func getNewItemFunc(listObj runtime.Object, v reflect.Value) func() runtime.Obje } func (s *store) Count(key string) (int64, error) { - key = path.Join(s.pathPrefix, key) + preparedKey, err := s.prepareKey(key) + if err != nil { + return 0, err + } // We need to make sure the key ended with "/" so that we only get children "directories". // e.g. if we have key "/a", "/a/b", "/ab", getting keys with prefix "/a" will return all three, // while with prefix "/a/" will return only "/a/b" which is the correct answer. - if !strings.HasSuffix(key, "/") { - key += "/" + if !strings.HasSuffix(preparedKey, "/") { + preparedKey += "/" } startTime := time.Now() - getResp, err := s.client.KV.Get(context.Background(), key, clientv3.WithRange(clientv3.GetPrefixRangeEnd(key)), clientv3.WithCountOnly()) - metrics.RecordEtcdRequestLatency("listWithCount", key, startTime) + getResp, err := s.client.KV.Get(context.Background(), preparedKey, clientv3.WithRange(clientv3.GetPrefixRangeEnd(preparedKey)), clientv3.WithCountOnly()) + metrics.RecordEtcdRequestLatency("listWithCount", preparedKey, startTime) if err != nil { return 0, err } @@ -522,6 +542,10 @@ func (s *store) Count(key string) (int64, error) { // GetList implements storage.Interface. func (s *store) GetList(ctx context.Context, key string, opts storage.ListOptions, listObj runtime.Object) error { + preparedKey, err := s.prepareKey(key) + if err != nil { + return err + } recursive := opts.Recursive resourceVersion := opts.ResourceVersion match := opts.ResourceVersionMatch @@ -542,16 +566,15 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption if err != nil || v.Kind() != reflect.Slice { return fmt.Errorf("need ptr to slice: %v", err) } - key = path.Join(s.pathPrefix, key) // For recursive lists, we need to make sure the key ended with "/" so that we only // get children "directories". e.g. if we have key "/a", "/a/b", "/ab", getting keys // with prefix "/a" will return all three, while with prefix "/a/" will return only // "/a/b" which is the correct answer. - if recursive && !strings.HasSuffix(key, "/") { - key += "/" + if recursive && !strings.HasSuffix(preparedKey, "/") { + preparedKey += "/" } - keyPrefix := key + keyPrefix := preparedKey // set the appropriate clientv3 options to filter the returned data set var limitOption *clientv3.OpOption @@ -590,7 +613,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption rangeEnd := clientv3.GetPrefixRangeEnd(keyPrefix) options = append(options, clientv3.WithRange(rangeEnd)) - key = continueKey + preparedKey = continueKey // If continueRV > 0, the LIST request needs a specific resource version. // continueRV==0 is invalid. @@ -657,7 +680,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption }() for { startTime := time.Now() - getResp, err = s.client.KV.Get(ctx, key, options...) + getResp, err = s.client.KV.Get(ctx, preparedKey, options...) if recursive { metrics.RecordEtcdRequestLatency("list", getTypeName(listPtr), startTime) } else { @@ -729,7 +752,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption } *limitOption = clientv3.WithLimit(limit) } - key = string(lastKey) + "\x00" + preparedKey = string(lastKey) + "\x00" if withRev == 0 { withRev = returnedRV options = append(options, clientv3.WithRev(withRev)) @@ -794,12 +817,15 @@ func growSlice(v reflect.Value, maxCapacity int, sizes ...int) { // Watch implements storage.Interface.Watch. func (s *store) Watch(ctx context.Context, key string, opts storage.ListOptions) (watch.Interface, error) { + preparedKey, err := s.prepareKey(key) + if err != nil { + return nil, err + } rev, err := s.versioner.ParseResourceVersion(opts.ResourceVersion) if err != nil { return nil, err } - key = path.Join(s.pathPrefix, key) - return s.watcher.Watch(ctx, key, int64(rev), opts.Recursive, opts.ProgressNotify, opts.Predicate) + return s.watcher.Watch(ctx, preparedKey, int64(rev), opts.Recursive, opts.ProgressNotify, opts.Predicate) } func (s *store) getState(ctx context.Context, getResp *clientv3.GetResponse, key string, v reflect.Value, ignoreNotFound bool) (*objState, error) { @@ -911,6 +937,30 @@ func (s *store) validateMinimumResourceVersion(minimumResourceVersion string, ac return nil } +func (s *store) prepareKey(key string) (string, error) { + if key == ".." || + strings.HasPrefix(key, "../") || + strings.HasSuffix(key, "/..") || + strings.Contains(key, "/../") { + return "", fmt.Errorf("invalid key: %q", key) + } + if key == "." || + strings.HasPrefix(key, "./") || + strings.HasSuffix(key, "/.") || + strings.Contains(key, "/./") { + return "", fmt.Errorf("invalid key: %q", key) + } + if key == "" || key == "/" { + return "", fmt.Errorf("empty key: %q", key) + } + // We ensured that pathPrefix ends in '/' in construction, so skip any leading '/' in the key now. + startIndex := 0 + if key[0] == '/' { + startIndex = 1 + } + return s.pathPrefix + key[startIndex:], nil +} + // decode decodes value of bytes into object. It will also set the object resource version to rev. // On success, objPtr would be set to the object. func decode(codec runtime.Codec, versioner storage.Versioner, value []byte, objPtr runtime.Object, rev int64) error { diff --git a/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store_test.go b/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store_test.go index 84ee65a1fac..2f0dfe6e588 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store_test.go @@ -54,6 +54,7 @@ var scheme = runtime.NewScheme() var codecs = serializer.NewCodecFactory(scheme) const defaultTestPrefix = "test!" +const basePath = "/keybase" func init() { metav1.AddToGroupVersion(scheme, metav1.SchemeGroupVersion) @@ -446,13 +447,13 @@ func TestTransformationFailure(t *testing.T) { obj *example.Pod storedObj *example.Pod }{{ - key: "/one-level/test", + key: basePath + "/one-level/test", obj: &example.Pod{ ObjectMeta: metav1.ObjectMeta{Name: "bar"}, Spec: storagetesting.DeepEqualSafePodSpec(), }, }, { - key: "/two-level/1/test", + key: basePath + "/two-level/1/test", obj: &example.Pod{ ObjectMeta: metav1.ObjectMeta{Name: "baz"}, Spec: storagetesting.DeepEqualSafePodSpec(), @@ -484,7 +485,7 @@ func TestTransformationFailure(t *testing.T) { Predicate: storage.Everything, Recursive: true, } - if err := store.GetList(ctx, "/", storageOpts, &got); !storage.IsInternalError(err) { + if err := store.GetList(ctx, basePath, storageOpts, &got); !storage.IsInternalError(err) { t.Errorf("Unexpected error %v", err) } @@ -531,7 +532,7 @@ func TestListContinuation(t *testing.T) { etcdClient.KV = recorder // Setup storage with the following structure: - // / + // /keybase/ // - one-level/ // | - test // | @@ -548,15 +549,15 @@ func TestListContinuation(t *testing.T) { storedObj *example.Pod }{ { - key: "/one-level/test", + key: basePath + "/one-level/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}, }, { - key: "/two-level/1/test", + key: basePath + "/two-level/1/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}, }, { - key: "/two-level/2/test", + key: basePath + "/two-level/2/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "bar"}}, }, } @@ -588,7 +589,7 @@ func TestListContinuation(t *testing.T) { Predicate: pred(1, ""), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get initial list: %v", err) } if len(out.Continue) == 0 { @@ -613,13 +614,13 @@ func TestListContinuation(t *testing.T) { Predicate: pred(0, continueFromSecondItem), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get second page: %v", err) } if len(out.Continue) != 0 { t.Fatalf("Unexpected continuation token set") } - key, rv, err := storage.DecodeContinue(continueFromSecondItem, "/") + key, rv, err := storage.DecodeContinue(continueFromSecondItem, basePath) t.Logf("continue token was %d %s %v", rv, key, err) storagetesting.ExpectNoDiff(t, "incorrect second page", []example.Pod{*preset[1].storedObj, *preset[2].storedObj}, out.Items) if transformer.reads != 2 { @@ -638,7 +639,7 @@ func TestListContinuation(t *testing.T) { Predicate: pred(1, continueFromSecondItem), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get second page: %v", err) } if len(out.Continue) == 0 { @@ -662,7 +663,7 @@ func TestListContinuation(t *testing.T) { Predicate: pred(1, continueFromThirdItem), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get second page: %v", err) } if len(out.Continue) != 0 { @@ -688,7 +689,7 @@ func TestListPaginationRareObject(t *testing.T) { podCount := 1000 var pods []*example.Pod for i := 0; i < podCount; i++ { - key := fmt.Sprintf("/one-level/pod-%d", i) + key := basePath + fmt.Sprintf("/one-level/pod-%d", i) obj := &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("pod-%d", i)}} storedObj := &example.Pod{} err := store.Create(ctx, key, obj, storedObj, 0) @@ -711,7 +712,7 @@ func TestListPaginationRareObject(t *testing.T) { }, Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get initial list: %v", err) } if len(out.Continue) != 0 { @@ -782,7 +783,7 @@ func TestListContinuationWithFilter(t *testing.T) { for i, ps := range preset { preset[i].storedObj = &example.Pod{} - err := store.Create(ctx, ps.key, ps.obj, preset[i].storedObj, 0) + err := store.Create(ctx, basePath+ps.key, ps.obj, preset[i].storedObj, 0) if err != nil { t.Fatalf("Set failed: %v", err) } @@ -810,7 +811,7 @@ func TestListContinuationWithFilter(t *testing.T) { Predicate: pred(2, ""), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Errorf("Unable to get initial list: %v", err) } if len(out.Continue) == 0 { @@ -842,7 +843,7 @@ func TestListContinuationWithFilter(t *testing.T) { Predicate: pred(2, cont), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Errorf("Unable to get second page: %v", err) } if len(out.Continue) != 0 { @@ -863,7 +864,7 @@ func TestListInconsistentContinuation(t *testing.T) { ctx, store, client := testSetup(t) // Setup storage with the following structure: - // / + // /keybase/ // - one-level/ // | - test // | @@ -880,15 +881,15 @@ func TestListInconsistentContinuation(t *testing.T) { storedObj *example.Pod }{ { - key: "/one-level/test", + key: basePath + "/one-level/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}, }, { - key: "/two-level/1/test", + key: basePath + "/two-level/1/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}, }, { - key: "/two-level/2/test", + key: basePath + "/two-level/2/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "bar"}}, }, } @@ -920,7 +921,7 @@ func TestListInconsistentContinuation(t *testing.T) { Predicate: pred(1, ""), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get initial list: %v", err) } if len(out.Continue) == 0 { @@ -964,7 +965,7 @@ func TestListInconsistentContinuation(t *testing.T) { Predicate: pred(0, continueFromSecondItem), Recursive: true, } - err = store.GetList(ctx, "/", options, out) + err = store.GetList(ctx, basePath, options, out) if err == nil { t.Fatalf("unexpected no error") } @@ -986,7 +987,7 @@ func TestListInconsistentContinuation(t *testing.T) { Predicate: pred(1, inconsistentContinueFromSecondItem), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get second page: %v", err) } if len(out.Continue) == 0 { @@ -1005,7 +1006,7 @@ func TestListInconsistentContinuation(t *testing.T) { Predicate: pred(1, continueFromThirdItem), Recursive: true, } - if err := store.GetList(ctx, "/", options, out); err != nil { + if err := store.GetList(ctx, basePath, options, out); err != nil { t.Fatalf("Unable to get second page: %v", err) } if len(out.Continue) != 0 { @@ -1127,9 +1128,9 @@ func testSetup(t *testing.T, opts ...setupOption) (context.Context, *store, *cli func TestPrefix(t *testing.T) { testcases := map[string]string{ - "custom/prefix": "/custom/prefix", - "/custom//prefix//": "/custom/prefix", - "/registry": "/registry", + "custom/prefix": "/custom/prefix/", + "/custom//prefix//": "/custom/prefix/", + "/registry": "/registry/", } for configuredPrefix, effectivePrefix := range testcases { _, store, _ := testSetup(t, withPrefix(configuredPrefix)) @@ -1302,7 +1303,7 @@ func TestConsistentList(t *testing.T) { Predicate: predicate, Recursive: true, } - if err := store.GetList(ctx, "/", options, &result1); err != nil { + if err := store.GetList(ctx, basePath, options, &result1); err != nil { t.Fatalf("failed to list objects: %v", err) } @@ -1315,7 +1316,7 @@ func TestConsistentList(t *testing.T) { } result2 := example.PodList{} - if err := store.GetList(ctx, "/", options, &result2); err != nil { + if err := store.GetList(ctx, basePath, options, &result2); err != nil { t.Fatalf("failed to list objects: %v", err) } @@ -1325,7 +1326,7 @@ func TestConsistentList(t *testing.T) { options.ResourceVersionMatch = metav1.ResourceVersionMatchNotOlderThan result3 := example.PodList{} - if err := store.GetList(ctx, "/", options, &result3); err != nil { + if err := store.GetList(ctx, basePath, options, &result3); err != nil { t.Fatalf("failed to list objects: %v", err) } @@ -1333,7 +1334,7 @@ func TestConsistentList(t *testing.T) { options.ResourceVersionMatch = metav1.ResourceVersionMatchExact result4 := example.PodList{} - if err := store.GetList(ctx, "/", options, &result4); err != nil { + if err := store.GetList(ctx, basePath, options, &result4); err != nil { t.Fatalf("failed to list objects: %v", err) } @@ -1384,3 +1385,78 @@ func TestLeaseMaxObjectCount(t *testing.T) { } } } + +func TestValidateKey(t *testing.T) { + validKeys := []string{ + "/foo/bar/baz/a.b.c/", + "/foo", + "foo/bar/baz", + "/foo/bar..baz/", + "/foo/bar..", + "foo", + "foo/bar", + "/foo/bar/", + } + invalidKeys := []string{ + "/foo/bar/../a.b.c/", + "..", + "/..", + "../", + "/foo/bar/..", + "../foo/bar", + "/../foo", + "/foo/bar/../", + ".", + "/.", + "./", + "/./", + "/foo/.", + "./bar", + "/foo/./bar/", + } + const ( + pathPrefix = "/first/second" + expectPrefix = pathPrefix + "/" + ) + _, store, _ := testSetup(t, withPrefix(pathPrefix)) + + for _, key := range validKeys { + k, err := store.prepareKey(key) + if err != nil { + t.Errorf("key %q should be valid; unexpected error: %v", key, err) + } else if !strings.HasPrefix(k, expectPrefix) { + t.Errorf("key %q should have prefix %q", k, expectPrefix) + } + } + + for _, key := range invalidKeys { + _, err := store.prepareKey(key) + if err == nil { + t.Errorf("key %q should be invalid", key) + } + } +} + +func TestInvalidKeys(t *testing.T) { + const invalidKey = "/foo/bar/../baz" + expectedError := fmt.Sprintf("invalid key: %q", invalidKey) + + expectInvalidKey := func(methodName string, err error) { + if err == nil { + t.Errorf("[%s] expected invalid key error; got nil", methodName) + } else if err.Error() != expectedError { + t.Errorf("[%s] expected invalid key error; got %v", methodName, err) + } + } + + ctx, store, _ := testSetup(t) + expectInvalidKey("Create", store.Create(ctx, invalidKey, nil, nil, 0)) + expectInvalidKey("Delete", store.Delete(ctx, invalidKey, nil, nil, nil, nil)) + _, watchErr := store.Watch(ctx, invalidKey, storage.ListOptions{}) + expectInvalidKey("Watch", watchErr) + expectInvalidKey("Get", store.Get(ctx, invalidKey, storage.GetOptions{}, nil)) + expectInvalidKey("GetList", store.GetList(ctx, invalidKey, storage.ListOptions{}, nil)) + expectInvalidKey("GuaranteedUpdate", store.GuaranteedUpdate(ctx, invalidKey, nil, true, nil, nil, nil)) + _, countErr := store.Count(invalidKey) + expectInvalidKey("Count", countErr) +} diff --git a/staging/src/k8s.io/apiserver/pkg/storage/testing/store_tests.go b/staging/src/k8s.io/apiserver/pkg/storage/testing/store_tests.go index c83ee833ee6..df7078f6bd9 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/testing/store_tests.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/testing/store_tests.go @@ -42,6 +42,8 @@ import ( type KeyValidation func(ctx context.Context, t *testing.T, key string) +const basePath = "/keybase" + func RunTestCreate(ctx context.Context, t *testing.T, store storage.Interface, validation KeyValidation) { key := "/testkey" out := &example.Pod{} @@ -445,11 +447,11 @@ func RunTestList(ctx context.Context, t *testing.T, store storage.Interface) { Predicate: storage.Everything, Recursive: true, } - if err := store.GetList(ctx, "/two-level", storageOpts, list); err != nil { + if err := store.GetList(ctx, basePath+"/two-level", storageOpts, list); err != nil { t.Errorf("Unexpected error: %v", err) } continueRV, _ := strconv.Atoi(list.ResourceVersion) - secondContinuation, err := storage.EncodeContinue("/two-level/2", "/two-level/", int64(continueRV)) + secondContinuation, err := storage.EncodeContinue(basePath+"/two-level/2", basePath+"/two-level/", int64(continueRV)) if err != nil { t.Fatal(err) } @@ -827,7 +829,7 @@ func RunTestList(ctx context.Context, t *testing.T, store storage.Interface) { Predicate: tt.pred, Recursive: true, } - err = store.GetList(ctx, tt.prefix, storageOpts, out) + err = store.GetList(ctx, basePath+tt.prefix, storageOpts, out) if tt.expectRVTooLarge { if err == nil || !storage.IsTooLargeResourceVersion(err) { t.Fatalf("expecting resource version too high error, but get: %s", err) @@ -926,7 +928,7 @@ func RunTestListWithoutPaging(ctx context.Context, t *testing.T, store storage.I Recursive: true, } - if err := store.GetList(ctx, tt.prefix, storageOpts, out); err != nil { + if err := store.GetList(ctx, basePath+tt.prefix, storageOpts, out); err != nil { t.Fatalf("GetList failed: %v", err) return } @@ -952,7 +954,7 @@ func RunTestListWithoutPaging(ctx context.Context, t *testing.T, store storage.I // from before any were created along with the full set of objects that were persisted func seedMultiLevelData(ctx context.Context, store storage.Interface) (string, []*example.Pod, error) { // Setup storage with the following structure: - // / + // /keybase/ // - one-level/ // | - test // | @@ -975,30 +977,30 @@ func seedMultiLevelData(ctx context.Context, store storage.Interface) (string, [ storedObj *example.Pod }{ { - key: "/one-level/test", + key: basePath + "/one-level/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}, }, { - key: "/two-level/1/test", + key: basePath + "/two-level/1/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}}, }, { - key: "/two-level/2/test", + key: basePath + "/two-level/2/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "bar"}}, }, { - key: "/z-level/3/test", + key: basePath + "/z-level/3/test", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "fourth"}}, }, { - key: "/z-level/3/test-2", + key: basePath + "/z-level/3/test-2", obj: &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "bar"}}, }, } // we want to figure out the resourceVersion before we create anything initialList := &example.PodList{} - if err := store.GetList(ctx, "/", storage.ListOptions{Predicate: storage.Everything, Recursive: true}, initialList); err != nil { + if err := store.GetList(ctx, basePath, storage.ListOptions{Predicate: storage.Everything, Recursive: true}, initialList); err != nil { return "", nil, fmt.Errorf("failed to determine starting resourceVersion: %w", err) } initialRV := initialList.ResourceVersion -- 2.25.1