summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xevaluate.py6
-rwxr-xr-xtryprune.py20
2 files changed, 15 insertions, 11 deletions
diff --git a/evaluate.py b/evaluate.py
index a609e79..0e99abb 100755
--- a/evaluate.py
+++ b/evaluate.py
@@ -165,8 +165,7 @@ if __name__ == '__main__':
buildData()
print('estimating')
- reportfile = os.path.join \
- (trydir, 'estimate' + config.getReportPostfix())
+ reportfile = os.path.join(trydir, 'estimate' + config.getReportPostfix())
avg_lambda = estimateModel(reportfile)
cwdstatus['EvaluateAverageLambda'] = avg_lambda
@@ -176,8 +175,7 @@ if __name__ == '__main__':
modifyCodeforLambda(avg_lambda)
print('evaluating')
- reportfile = os.path.join \
- (trydir, 'evaluate' + config.getReportPostfix())
+ reportfile = os.path.join(trydir, 'evaluate' + config.getReportPostfix())
rate = evaluateModel(reportfile)
print(tryname + "'s correction rate:", rate)
diff --git a/tryprune.py b/tryprune.py
index 9a820e9..805f4dd 100755
--- a/tryprune.py
+++ b/tryprune.py
@@ -68,8 +68,6 @@ def convertModel(kmm_model, inter_model):
def mergeOneModel(mergedmodel, onemodel, score):
- #validate first
- validateModel(onemodel)
onemodelstatuspath = onemodel + config.getStatusPostfix()
onemodelstatus = utils.load_status(onemodelstatuspath)
@@ -107,14 +105,16 @@ def mergeSomeModels(mergedmodel, sortedindexname, mergenum):
raise AssertionError('scores must be descending.\n')
onemodel = os.path.join(config.getModelDir(), subdir, modelname)
+
+ #validate first
+ print('validating')
+ validateModel(onemodel)
+
mergeOneModel(mergedmodel, onemodel, score)
last_score = score
indexfile.close()
#end processing
- #validate merged model
- validateModel(mergedmodel)
-
def pruneModel(prunedmodel, k, CDF):
#begin processing
@@ -129,8 +129,6 @@ def pruneModel(prunedmodel, k, CDF):
sys.exit('Corrupted model found when pruning:' + modelfile)
#end processing
- #validate pruned model
- validateModel(prunedmodel)
if __name__ == '__main__':
parser = ArgumentParser(description='Try prune models.')
@@ -178,6 +176,10 @@ if __name__ == '__main__':
config.getSortedEstimateIndex())
mergeSomeModels(mergedmodel, sortedindexname, args.merge)
+ #validate merged model
+ print('validating')
+ validateModel(mergedmodel)
+
#export textual format
print('exporting')
exportfile = os.path.join(trydir, 'kmm_merged.text')
@@ -190,6 +192,10 @@ if __name__ == '__main__':
shutil.copyfile(mergedmodel, prunedmodel)
pruneModel(prunedmodel, args.k, args.CDF)
+ #validate pruned model
+ print('validating')
+ validateModel(prunedmodel)
+
#export textual format
print('exporting')
exportfile = os.path.join(trydir, 'kmm_pruned.text')