diff options
| author | Peng Wu <alexepico@gmail.com> | 2011-07-25 11:54:05 +0800 |
|---|---|---|
| committer | Peng Wu <alexepico@gmail.com> | 2011-07-25 11:54:05 +0800 |
| commit | 9cca4180400774231f64363c0f6acb87e5444f23 (patch) | |
| tree | 02cd6c71e728157d87a54fa914678db249ebaa05 /tryprune.py | |
| parent | 114342747eaf4e5719745c0cd113272e3c4c51ba (diff) | |
| download | trainer-9cca4180400774231f64363c0f6acb87e5444f23.tar.gz trainer-9cca4180400774231f64363c0f6acb87e5444f23.tar.xz trainer-9cca4180400774231f64363c0f6acb87e5444f23.zip | |
write tryprune.py in progress
Diffstat (limited to 'tryprune.py')
| -rw-r--r-- | tryprune.py | 58 |
1 files changed, 55 insertions, 3 deletions
diff --git a/tryprune.py b/tryprune.py index ee0acbd..f86bd48 100644 --- a/tryprune.py +++ b/tryprune.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 import os +import os.path +import shutil import sys from subprocess import Popen, PIPE from argparse import ArgumentParser @@ -45,7 +47,17 @@ def exportModel(modelfile, textmodel): sys.exit('Corrupted model found when exporting:' + modelfile) #end processing -def mergeOneModel(mergedmodel, onemodel): +def mergeOneModel(mergedmodel, onemodel, score): + #validate first + validateModel(onemodel) + + onemodelstatuspath = onemodel + config.getStatusPostfix() + onemodelstatus = utils.load_status(onemodelstatuspath) + if not utils.check_epoch(onemodelstatus, 'Estimate'): + raise utils.Epoch('Please estimate first.\n') + if score != onemodelstatus['EstimateScore']: + raise AssertionError('estimate scores mis-match.\n') + #begin processing cmdline = ['./merge_k_mixture_model', \ '--result-file', \ @@ -59,8 +71,28 @@ def mergeOneModel(mergedmodel, onemodel): sys.exit('Corrupted model found when merging:' + onemodel) #end processing -def mergeSomeModels(indexfile, mergenum): - pass +def mergeSomeModels(tryname, mergedmodel, sortedindexname, mergenum): + last_score = 1. + #begin processing + indexfile = open(sortedindexname, 'r') + for i in range(mergenum): + line = indexfile.readline() + if not line: + raise AssertionError('No more models.\n') + line = line.rstrip(os.linesep) + (subdir, modelname, score) = line.split('#', 2) + score = float(score) + if score > last_score: + raise AssertionError('score must be descending.\n') + + onemodel = os.path.join(config.getModelDir(), subdir, modelname) + mergeOneModel(mergedmodel, onemodel, score) + last_score = score + indexfile.close() + #end processing + + #validate merged model + validateModel(mergedmodel) def pruneModel(modelfile, k, CDF): #begin processing @@ -98,3 +130,23 @@ if __name__ == '__main__': args = parser.parse_args() print(args) + tryname = 'try' + args.tryname + #merge model candidates + mergedmodel = os.path.join(config.getFinalDir(), tryname, 'merged.db') + sortedindexname = os.path.join(args.modeldir, \ + config.getSortedEstimateIndex()) + mergeSomeModels(tryname, mergedmodel, sortedindexname, args.mergenumber) + + #export textual format + exportfile = os.path.join(config.getFinalDir(), tryname, 'kmm_merged.text') + exportModel(mergedmodel, exportfile) + + #prune merged model + prunedmodel = os.path.join(config.getFinalDir(), tryname, 'pruned.db') + #backup merged model + shutil.copyfile(mergedmodel, prunedmodel) + pruneModel(prunedmodel, args.k, args.CDF) + + #export textual format + exportfile = os.path.join(config.getFinalDir(), tryname, 'kmm_pruned.text') + exportModel(prunedmodel, exportModel) |
