From 65d2dc509e851d138154ee6a8a3ff3acb3780a30 Mon Sep 17 00:00:00 2001 From: gaoruoshu Date: Wed, 9 Aug 2023 19:15:07 +0800 Subject: [PATCH 2/3] bugfix: training model can only save file to specified dir --- analysis/default_config.py | 1 + analysis/engine/parser.py | 4 +-- analysis/engine/train.py | 38 ++++++++++++++++++---- api/profile/profile.pb.go | 6 ++-- api/profile/profile.proto | 2 +- common/models/training.go | 2 +- modules/client/profile/profile_train.go | 42 +++++++++---------------- modules/server/profile/profile.go | 4 +-- 8 files changed, 56 insertions(+), 43 deletions(-) diff --git a/analysis/default_config.py b/analysis/default_config.py index cf56ac2..7c921cc 100644 --- a/analysis/default_config.py +++ b/analysis/default_config.py @@ -23,6 +23,7 @@ GRPC_CERT_PATH = '/etc/atuned/grpc_certs' ANALYSIS_DATA_PATH = '/var/atune_data/analysis/' TUNING_DATA_PATH = '/var/atune_data/tuning/' TUNING_DATA_DIRS = ['running', 'finished', 'error'] +TRAINING_MODEL_PATH = '/usr/libexec/atuned/analysis/models/' def get_or_default(config, section, key, value): diff --git a/analysis/engine/parser.py b/analysis/engine/parser.py index c16089f..c36c74d 100644 --- a/analysis/engine/parser.py +++ b/analysis/engine/parser.py @@ -69,8 +69,8 @@ CLASSIFICATION_POST_PARSER.add_argument('model', TRAIN_POST_PARSER = reqparse.RequestParser() TRAIN_POST_PARSER.add_argument('datapath', required=True, help="The datapath can not be null") -TRAIN_POST_PARSER.add_argument('outputpath', required=True, - help="The output path can not be null") +TRAIN_POST_PARSER.add_argument('modelname', required=True, + help="The model name can not be null") TRAIN_POST_PARSER.add_argument('modelpath', required=True, help="The model path can not be null") diff --git a/analysis/engine/train.py b/analysis/engine/train.py index 9fdca46..7608660 100644 --- a/analysis/engine/train.py +++ b/analysis/engine/train.py @@ -22,6 +22,7 @@ from flask_restful import Resource from analysis.optimizer.workload_characterization import WorkloadCharacterization from analysis.engine.parser import TRAIN_POST_PARSER +from analysis.default_config import TRAINING_MODEL_PATH LOGGER = logging.getLogger(__name__) @@ -29,7 +30,7 @@ LOGGER = logging.getLogger(__name__) class Training(Resource): """provide the method of post for training""" model_path = "modelpath" - output_path = "outputpath" + model_name = "modelname" data_path = "datapath" def post(self): @@ -40,18 +41,43 @@ class Training(Resource): LOGGER.info(args) model_path = args.get(self.model_path) - output_path = args.get(self.output_path) + model_name = args.get(self.model_name) data_path = args.get(self.data_path) + valid, err = valid_model_name(model_name) + if not valid: + return "Illegal model name provide: {}".format(err), 400 + characterization = WorkloadCharacterization(model_path) try: + output_path = TRAINING_MODEL_PATH + model_name characterization.retrain(data_path, output_path) except Exception as err: LOGGER.error(err) abort(500) - if os.path.isdir(data_path): - shutil.rmtree(data_path) - else: - os.remove(data_path) return {}, 200 + + +def valid_model_name(name): + file_name, file_ext = os.path.splitext(name) + + if file_ext != ".m": + return False, "the ext of model name should be .m" + + if file_name in ['scaler', 'aencoder', 'tencoder', 'default_clf', 'total_clf', 'throughput_performance_clf']: + return False, "model name cannot be set as default_clf/scaler/aencoder/tencoder/throughput_performance_clf/total_clf" + + for ind, char in enumerate(file_name): + if 'a' <= char <= 'z': + continue + if 'A' <= char <= 'Z': + continue + if '0' <= char <= '9': + continue + if ind != 0 and ind != len(file_name) - 1 and char == '_': + continue + return False, "model name cannot contains special character" + + return True, None + diff --git a/api/profile/profile.pb.go b/api/profile/profile.pb.go index dba166d..81a6757 100644 --- a/api/profile/profile.pb.go +++ b/api/profile/profile.pb.go @@ -493,7 +493,7 @@ func (m *CollectFlag) GetType() string { type TrainMessage struct { DataPath string `protobuf:"bytes,1,opt,name=DataPath,proto3" json:"DataPath,omitempty"` - OutputPath string `protobuf:"bytes,2,opt,name=OutputPath,proto3" json:"OutputPath,omitempty"` + ModelName string `protobuf:"bytes,2,opt,name=ModelName,proto3" json:"ModelName,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -531,9 +531,9 @@ func (m *TrainMessage) GetDataPath() string { return "" } -func (m *TrainMessage) GetOutputPath() string { +func (m *TrainMessage) GetModelName() string { if m != nil { - return m.OutputPath + return m.ModelName } return "" } diff --git a/api/profile/profile.proto b/api/profile/profile.proto index 2a6e751..29cf91e 100755 --- a/api/profile/profile.proto +++ b/api/profile/profile.proto @@ -84,7 +84,7 @@ message CollectFlag { message TrainMessage { string DataPath = 1; - string OutputPath = 2; + string ModelName = 2; } message DetectMessage { diff --git a/common/models/training.go b/common/models/training.go index 3cc9d60..9497261 100644 --- a/common/models/training.go +++ b/common/models/training.go @@ -24,7 +24,7 @@ import ( type Training struct { DataPath string `json:"datapath"` ModelPath string `json:"modelpath"` - OutputPath string `json:"outputpath"` + ModelName string `json:"modelname"` } // Post method call training service diff --git a/modules/client/profile/profile_train.go b/modules/client/profile/profile_train.go index f4e68cb..a645bc3 100644 --- a/modules/client/profile/profile_train.go +++ b/modules/client/profile/profile_train.go @@ -18,9 +18,9 @@ import ( "gitee.com/openeuler/A-Tune/common/client" SVC "gitee.com/openeuler/A-Tune/common/service" "gitee.com/openeuler/A-Tune/common/utils" + "gitee.com/openeuler/A-Tune/common/config" "fmt" "io" - "os" "path/filepath" "github.com/urfave/cli" @@ -39,8 +39,8 @@ var trainCommand = cli.Command{ Value: "", }, cli.StringFlag{ - Name: "output_file,o", - Usage: "the model to be generated", + Name: "model_name,m", + Usage: "the model name of generate model", Value: "", }, }, @@ -48,9 +48,9 @@ var trainCommand = cli.Command{ desc := ` training a new model with the self collected data, data_path option specified the path that storage the collected data, the collected data must have more - than two workload type. output_file specified the file path where to store + than two workload type. model_name specified the name of model to be generated the trained model, which must be end with .m. - example: atune-adm train --data_path=./data --output_file=./model/trained.m` + example: atune-adm train --data_path=/home/data --model_name=trained.m` return desc }(), Action: train, @@ -82,13 +82,13 @@ func checkTrainCtx(ctx *cli.Context) error { return fmt.Errorf("input:%s is invalid", dataPath) } - outputPath := ctx.String("output_file") - if outputPath == "" { + modelName := ctx.String("model_name") + if modelName == "" { _ = cli.ShowCommandHelp(ctx, "train") - return fmt.Errorf("error: output_file must be specified") + return fmt.Errorf("error: model_name must be specified") } - if !utils.IsInputStringValid(outputPath) { - return fmt.Errorf("input:%s is invalid", outputPath) + if !utils.IsInputStringValid(modelName) { + return fmt.Errorf("input:%s is invalid", modelName) } return nil @@ -103,23 +103,8 @@ func train(ctx *cli.Context) error { if err != nil { return err } - outputPath, err := filepath.Abs(ctx.String("output_file")) - if err != nil { - return err - } - - dir := filepath.Dir(outputPath) - exist, err := utils.PathExist(dir) - if err != nil { - return err - } - if !exist { - err = os.MkdirAll(dir, utils.FilePerm) - if err != nil { - return err - } - } + modelName := ctx.String("model_name") c, err := client.NewClientFromContext(ctx) if err != nil { @@ -128,7 +113,7 @@ func train(ctx *cli.Context) error { defer c.Close() svc := PB.NewProfileMgrClient(c.Connection()) - stream, err := svc.Training(CTX.Background(), &PB.TrainMessage{DataPath: dataPath, OutputPath: outputPath}) + stream, err := svc.Training(CTX.Background(), &PB.TrainMessage{DataPath: dataPath, ModelName: modelName}) if err != nil { return err } @@ -145,6 +130,7 @@ func train(ctx *cli.Context) error { } utils.Print(reply) } - fmt.Println("the model generate path:", outputPath) + modelPath := fmt.Sprintf("%s/models/%s", config.DefaultAnalysisPath, modelName) + fmt.Println("the model generate path:", modelPath) return nil } diff --git a/modules/server/profile/profile.go b/modules/server/profile/profile.go index 3264167..70047ca 100644 --- a/modules/server/profile/profile.go +++ b/modules/server/profile/profile.go @@ -1137,7 +1137,7 @@ func (s *ProfileServer) Training(message *PB.TrainMessage, stream PB.ProfileMgr_ } DataPath := message.GetDataPath() - OutputPath := message.GetOutputPath() + ModelName := message.GetModelName() compressPath, err := utils.CreateCompressFile(DataPath) if err != nil { @@ -1156,7 +1156,7 @@ func (s *ProfileServer) Training(message *PB.TrainMessage, stream PB.ProfileMgr_ trainBody := new(models.Training) trainBody.DataPath = trainPath - trainBody.OutputPath = OutputPath + trainBody.ModelName = ModelName trainBody.ModelPath = path.Join(config.DefaultAnalysisPath, "models") success, err := trainBody.Post() -- 2.27.0