diff options
| author | Peng Wu <pwu@redhat.com> | 2024-09-02 11:42:17 +0800 |
|---|---|---|
| committer | Peng Wu <pwu@redhat.com> | 2024-09-02 14:24:51 +0800 |
| commit | 429f9fd8dd2966db9038a4194e5970ba22e8e5c0 (patch) | |
| tree | 6367d7d4b7b735bb98d2be0446de854b52fce99f | |
| parent | d789eb1cd9185eac7f3f82f230b1ee4414aed68c (diff) | |
| download | trainer-429f9fd8dd2966db9038a4194e5970ba22e8e5c0.tar.gz trainer-429f9fd8dd2966db9038a4194e5970ba22e8e5c0.tar.xz trainer-429f9fd8dd2966db9038a4194e5970ba22e8e5c0.zip | |
write genpunct.py
| -rwxr-xr-x | genpunct.py | 89 |
1 files changed, 82 insertions, 7 deletions
diff --git a/genpunct.py b/genpunct.py index 22dc5aa..f66b35d 100755 --- a/genpunct.py +++ b/genpunct.py @@ -2,6 +2,7 @@ import os import os.path from argparse import ArgumentParser +from operator import itemgetter import utils from myconfig import MyConfig from dirwalk import walkIndex @@ -17,6 +18,8 @@ os.chdir(puncts_dir) # The order is important Punct_Search = ['……', '…', ',', '。', ';', '?', '!', ':', '“', '”', '、'] +all_punct_pairs = {} + ############################################################ # Handle File # ############################################################ @@ -77,11 +80,10 @@ def handleOneText(infile, punct_pairs): docfile.close() -def prunePunctPairFromOneIndex(workdir, threshold): +def prunePunctPair(workdir, threshold, infilename, outfilename): punct_pairs = {} #load the punct pairs from text files - punctfile = os.path.join(workdir, \ - config.getPunctuationPerIndexFileName()) + punctfile = os.path.join(workdir, infilename) with open(punctfile, 'r') as f: punct_pairs = eval(f.read()) @@ -98,8 +100,7 @@ def prunePunctPairFromOneIndex(workdir, threshold): newpunctpairs[key] = newpuncts #save the punct pairs to text files - punctfile = os.path.join(workdir, \ - config.getPunctuationPruneIndexFileName()) + punctfile = os.path.join(workdir, outfilename) with open(punctfile, 'w') as f: f.write(repr(newpunctpairs)) @@ -157,13 +158,72 @@ def handleOneIndex(indexpath, subdir, indexname): loadPunctPairFromOneIndex(indexpath, workdir) # Prune the pair in the current index - prunePunctPairFromOneIndex(workdir, \ - config.getPunctuationPerIndexPruneThreshold()) + prunePunctPair(workdir, \ + config.getPunctuationPerIndexPruneThreshold(), \ + config.getPunctuationPerIndexFileName(), \ + config.getPunctuationPruneIndexFileName()) #sign epoch utils.sign_epoch(indexstatus, 'Punctuation') utils.store_status(indexstatuspath, indexstatus) +def loadOnePrune(indexpath, subdir, indexname): + global all_punct_pairs + print(indexpath, subdir, indexname) + + workdir = config.getGeneratePunctuationDir() + os.sep + \ + subdir + os.sep + indexname + print(workdir) + + # Load the word and punctuation pair + punct_pairs = {} + #load the punct pairs from text files + punctfile = os.path.join(workdir, \ + config.getPunctuationPruneIndexFileName()) + with open(punctfile, 'r') as f: + punct_pairs = eval(f.read()) + + #merge into all punct pairs + for key, puncts in punct_pairs.items(): + if key not in all_punct_pairs: + all_punct_pairs[key] = puncts + continue + + #combine the puncts + newpuncts = [] + oldpuncts = all_punct_pairs[key] + keys = set() + for punct, freq in oldpuncts + puncts: + keys.add(punct) + for punctkey in keys: + #old freq + oldfreq = [freq for punct, freq in oldpuncts if punct == punctkey] + #print("old freq", oldfreq) + freq = sum(oldfreq) + #new freq + newfreq = [freq for punct, freq in puncts if punct == punctkey] + freq += sum(newfreq) + newpuncts.append([punctkey, freq]) + + all_punct_pairs[key] = newpuncts + + +def exportAllPunctPairs(workdir, infilename, outfilename): + # Load the word and punctuation pair + punct_pairs = {} + #load the punct pairs from text files + punctfile = os.path.join(workdir, infilename) + with open(punctfile, 'r') as f: + punct_pairs = eval(f.read()) + + tablefile = open(outfilename, 'w') + for key, puncts in punct_pairs.items(): + (token, word) = key + puncts.sort(key=itemgetter(1), reverse=True) + for punct, freq in puncts: + line = "{0} {1} {2} {3}".format(token, word, punct, freq) + tablefile.writelines([line, os.linesep]) + tablefile.close() if __name__ == '__main__': parser = ArgumentParser(description='Generate punctuation.') @@ -176,9 +236,24 @@ if __name__ == '__main__': print(args) walkIndex(handleOneIndex, args.indexdir) # Merge the word and punctuation pairs in all the index + walkIndex(loadOnePrune, args.indexdir) + + #save the punct pairs to text files + punctfile = os.path.join(config.getGeneratePunctuationDir(), \ + config.getPunctuationAllIndexFileName()) + with open(punctfile, 'w') as f: + f.write(repr(all_punct_pairs)) # Prune the pairs in all the index + prunePunctPair(config.getGeneratePunctuationDir(), \ + config.getPunctuationAllIndexPruneThreshold(), \ + config.getPunctuationAllIndexFileName(), \ + config.getPunctuationPruneAllIndexFileName()) + # Export all the remaining pairs + exportAllPunctPairs(config.getGeneratePunctuationDir(), \ + config.getPunctuationPruneAllIndexFileName(), \ + config.getPunctuationTextFileName()) print('done') |
