diff --git a/run.py b/run.py new file mode 100644 index 0000000..3fd8894 --- /dev/null +++ b/run.py @@ -0,0 +1,55 @@ +import os +import copy +import time +import random +import sys +import shutil +import subprocess + +from subprocess import STDOUT + +def execute(command): + popen = subprocess.Popen(command, stdout=subprocess.PIPE) + lines_iterator = iter(popen.stdout.readline, b"") + for line in lines_iterator: + print(line) # yield line + +#parameter w.r.t. MPI +work_dir = 'D:\\Your Directory' +port = '5719' +machinefile= 'host.txt' + +#parameter w.r.t. SG-Mixture Training +size = 50 +train = 'Your Training File' +read_vocab = 'Your Vocab File' +sense_file = 'Your Sense File, see sense_file.txt as an example' +binary = 2 +init_learning_rate = 0.025 +epoch = 1 +window = 5 +threads = 8 +mincount = 5 +EM_iteration = 1 +momentum = 0.05 +top_n = 0 +top_ratio = 0 +default_sense = 1 +sense_num_multi = 5 +binary_embedding_file = 'emb.bin' +text_embedding_file = 'emb.txt' +huff_tree_file = 'huff.txt' +outputlayer_binary_file = 'emb_out.bin' +outputlayer_text_file = 'emb_out.txt' +preload_cnt = 5 +data_block_size = 50000 +pipline = '0' +multinomial = '0' + +mpi_args = '-port {0} -wdir {1} -machinefile {2} '.format(port, work_dir, machinefile) +sg_mixture_args = ' -train_file {0} -binary_embedding_file {1} -text_embedding_file {2} -threads {3} -size {4} -binary {5} -epoch {6} -init_learning_rate {7} -min_count {8} -window {9} -momentum {12} -EM_iteration {13} -top_n {14} -top_ratio {14} -default_sense {16} -sense_num_multi {17} -huff_tree_file {18} -vocab_file {19} -outputlayer_binary_file {20} -outputlayer_text_file {21} -read_sense {22} -data_block_size {23} -is_pipline {24} -store_multinomial {25} -max_preload_size {26}'.format(train, binary_embedding_file, text_embedding_file, threads, size, binary, epoch, init_learning_rate, mincount, window, momentum, EM_iteration, top_n, top_ratio, default_sense, sense_num_multi, huff_tree_file, read_vocab, outputlayer_binary_file, outputlayer_text_file, sense_file, data_block_size, pipline, multinomial, preload_cnt) + +print mpi_args +print sg_mixture_args + +proc = execute("mpiexec " + mpi_args + 'distributed_skipgram_mixture ' + sg_mixture_args) diff --git a/sense_file.txt b/sense_file.txt new file mode 100644 index 0000000..b719164 --- /dev/null +++ b/sense_file.txt @@ -0,0 +1,1724 @@ +love +sex +tiger +cat +book +paper +computer +keyboard +internet +plane +car +train +telephone +communication +television +radio +media +drug +abuse +bread +butter +cucumber +potato +doctor +nurse +professor +student +smart +stupid +company +stock +market +phone +jaguar +egg +fertility +live +life +library +bank +money +wood +forest +cash +king +cabbage +queen +rook +bishop +rabbi +holy +fuck +football +soccer +basketball +tennis +racket +peace +terror +law +lawyer +movie +star +popcorn +critic +theater +physics +proton +chemistry +space +alcohol +vodka +gin +brandy +drink +ear +mouth +eat +baby +mother +automobile +gem +jewel +journey +voyage +boy +lad +coast +shore +asylum +madhouse +magician +wizard +midday +noon +furnace +stove +food +fruit +bird +cock +crane +tool +implement +brother +monk +oracle +cemetery +woodland +rooster +hill +graveyard +slave +chord +smile +glass +string +dollar +currency +wealth +property +possession +deposit +withdrawal +laundering +operation +feline +carnivore +mammal +animal +organism +fauna +zoo +psychology +psychiatry +anxiety +fear +depression +clinic +mind +health +science +discipline +cognition +planet +constellation +moon +sun +galaxy +astronomer +precedent +example +information +collection +group +antecedent +cup +coffee +tableware +article +artifact +object +entity +substance +liquid +energy +secretary +senate +laboratory +weapon +secret +fingerprint +investigation +effort +water +scientist +news +report +canyon +landscape +image +surface +discovery +seepage +sign +recess +mile +kilometer +territory +atmosphere +president +medal +war +troops +record +number +skin +eye +history +volunteer +motto +prejudice +recognition +decoration +valor +century +year +nation +delay +racism +minister +party +plan +minority +attempt +government +crisis +deployment +departure +announcement +stroke +hospital +disability +death +victim +emergency +treatment +recovery +journal +association +personnel +liability +insurance +school +center +reason +hypertension +criterion +hundred +percent +infrastructure +row +inmate +evidence +term +word +similarity +board +recommendation +governor +interview +country +travel +activity +competition +price +consumer +confidence +problem +airport +flight +credit +card +hotel +reservation +grocery +registration +arrangement +accommodation +month +type +kind +arrival +bed +closet +clothes +situation +conclusion +isolation +impartiality +interest +direction +combination +street +place +avenue +block +children +listing +proximity +category +cell +production +hike +benchmark +index +trading +gain +dividend +payment +calculation +computation +oil +warning +profit +loss +yen +buck +software +network +hardware +equipment +maker +luxury +five +investor +earning +baseball +season +game +victory +team +marathon +sprint +series +defeat +seven +seafood +sea +lobster +wine +preparation +video +archive +start +match +round +boxing +championship +tournament +fighting +defeating +line +day +summer +drought +nature +dawn +environment +ecology +man +woman +murder +manslaughter +soap +opera +performance +lesson +focus +crew +film +lover +quarrel +viewer +serial +possibility +girl +population +development +morality +importance +marriage +gender +equality +change +attitude +family +planning +industry +sugar +approach +practice +institution +ministry +culture +challenge +size +prominence +citizen +people +issue +experience +music +project +metal +aluminum +chance +credibility +exhibit +memorabilia +concert +virtuoso +rock +jazz +museum +observation +architecture +world +preservation +admission +ticket +shower +thunderstorm +flood +weather +forecast +disaster +area +office +brazil +nut +triple +cd +aglow +harvard +yale +cambridge +israel +east +israeli +japanese +american +jerusalem +wall +mexico +puebla +opec +saudi +palestinian +arab +wednesday +weekday +haven +ability +know-how +persecution +accepted +acknowledged +believe +welcome +accommodate +adjust +settlement +acronym +form +inaction +confession +matriculation +advance +headway +propose +advised +inform +affect +tense +aged +develop +old +worn +young +aim +cause +hold +thing +airfield +sterol +component +ambush +surprise +herbivore +statement +answering +counter +react +worry +access +converge +arch +curve +building +urban +shoulder +flower +weight +ash +bone +assignment +document +assimilation +americanization +realtor +astrophysicist +bedlam +shelter +stp +crime +attempted +initiate +give +attended +guard +respect +window +treat +back +lumbar +out +second +side +riverbank +base +air +build +structure +ball +tip-off +bat +bats +placental +turn +beam +signal +plant +platform +plot +point +bent +double +inclination +bias +experimenter +bigger +large +meat +bitter +ale +resentful +taste +close +blow +blowing +exhale +insufflate +blown +swat +deal +directorate +open +phrase +picture +booster +advertiser +advocate +bootleg +covering +produce +whiskey +bore +cut +flow +stuff +bound +boundary +skirt +vault +bow +knot +bowling +wheel +yield +hit +angstrom +branching +bifurcation +grow +cover +bring +return +bronze +nickel +sculpture +freemason +bill +horse +buffer +zone +bug +defect +insect +burning +important +burnt +treated +bury +hide +lay +steal +cake +dish +tablet +interpolation +camp +gathering +camping +pitch +capital +assets +primary +seed +playing +ration +see +limousine +predator +carrier +traveler +vehicle +cast +shoot +catching +catch +perceive +compartment +membrane +site +central +fiscal +refer +think +tract +integer +litigate +contest +fortune +ring +relation +channels +transmission +charcoal +fuel +checked +examine +offspring +cholera +infectious +arpeggio +play +circulation +dissemination +citation +act +award +national +clean +dry +jerk +remove +clear +innocence +make +session +closed +state +up +storage +toilet +closing +finale +motion +snap +glide +slope +penis +populace +collect +due +archives +data +prayer +scrape +take +colored +black +grey +colors +emblem +ensign +pigment +scheme +commit +vow +abstraction +communications +compact +case +pack +packed +short +write +compound +enhance +recombine +whole +conversion +procedure +assumption +judgment +connect +natural +topology +content +disapproval +limit +convinced +disarm +cool +air-conditioned +coldness +answer +negative +ally +course +stream +workshop +blanket +cloak +submerge +crabs +crab +crack +check +cracking +noise +sound +ovation +condition +measurement +cross +marking +meet +gherkin +melon +counterculture +subculture +shape +curb +bit +smother +exchange +prevalence +cutting +nip +damned +lost +raise +begin +hour +morrow +deaf +organic +deceased +born +dead +service +set +disappointment +defeated +upset +ending +overrun +veto +defense +biological +protection +hesitate +moratorium +deliver +serve +demand +call +claim +request +supply +demonic +evil +variation +redeployment +geological +derived +reap +mental +deepening +elaboration +qibla +anorgasmia +disabled +unfit +adversity +wave +system +breakthrough +disclosure +display +light +model +stick +ditch +abandon +drain +waterway +numerator +diving +one-half +theologian +dominican +dot +disk +seashore +multiple +draw +entertainer +gully +drawing +drag +frame +twitch +dress +formal +morning +neckline +wear +drive +impulse +mechanism +operate +ride +drop +descend +measure +dropping +sink +teardrop +time +anesthetic +dwelling +exist +education +effects +backdate +happen +repercussion +liberation +white +eliminate +temp +emphasized +stress +endeavor +worst +force +heat +habitat +equal +differ +inadequate +sameness +tie +person +rescue +establishment +collectivization +even +identification +notarize +specimen +parade +possess +show +inexperience +suffer +extension +expansion +look +fail +choke +shipwreck +fair +join +midway +moderate +child +famine +irish +lack +fast +abstain +hunger +sudden +avifauna +fearlessness +panic +sterility +fiction +canard +fantasy +field +handle +brush +resist +filled +medium +fine +precise +firm +corporation +hard +forward +fleet +aircraft +sortie +debacle +fill +flush +age +down +good +rich +fly +blur +concentration +distinctness +follow +evaluate +predict +prognosis +biome +growth +abbreviation +fort +presidio +bold +player +position +transport +found +wage +freeing +parole +freelance +work +bear +product +intercourse +fire +full +further +far +promote +fused +gauge +united +account +obtain +gallery +audience +parlay +playoff +gas +sewer +art +crystalline +asexual +gen +gig +harpoon +seidel +gone +gore +blood +pierce +control +graduate +confer +high +instrument +receive +scholar +greater +shelf +vinyl +fractious +soft +weaponry +hay +fodder +healing +better +illness +heel +stack +height +degree +increase +tor +recital +topographical +home +away +honest +sincere +honey +hooks +hand +hot +calorific +violent +cardiovascular +ice +appearance +visualize +use +standing +increased +add +enhanced +maximize +list +mass +safety +enterprise +confirmation +datum +initial +first +letter +resident +instability +disorder +vicariate +instruction +recipe +contract +benefit +enthusiasm +intermediate +chemical +introduction +opening +usher +count +fishing +iron +alpha +robust +alienation +solitude +edit +edition +overriding +unblock +neck +writing +ship +keep +continue +stronghold +holder +typeset +rival +region +exclude +painting +scenery +language +soliloquy +last +end +past +populate +rank +run +senior +machine +wash +joke +mosaic +barrister +song +learned +discover +educated +scholarly +left +dad +admonition +lever +house +being +lift +aid +consequence +lower +move +lighter +heavy +headlight +lighten +smoke +limb +branch +crib +lie +matter +organization +post +sick +shopping +looking +search +sightseeing +sparkle +reducing +wastage +sleep +agape +care +like +lowering +decrease +devalue +movement +reef +escapologist +magus +manufacturer +homicide +manufacture +commercial +mark +broad +buoy +class +labor +shop +trade +spouse +married +final +matchstick +mature +distinguish +telecommunication +mention +allusion +comment +photograph +merit +demerit +worthiness +silver +tombac +middle +phase +put +milk +nutriment +recall +mine +strip +sulfur +clergyman +foreign +priesthood +minor +miss +failure +mock +derision +tease +ptolemaic +mole +spy +pile +religious +solar +week +inner +virtue +morocco +levant +moroccan +motel +motor +mother-in-law +yeast +agent +stepper +catchphrase +mourning +sadness +beak +tongue +ms +gulf +depository +narrow +determine +limited +strait +spanish +characteristic +nerves +psychological +newspaper +item +publisher +update +nobility +aristocrat +majority +practical +business +printing +offset +balance +compensation +rift +musical +action +prophecy +orchestrated +score +individual +oxygen +battalion +panorama +parent +part +interruption +dance +union +pass +accomplishment +passing +done +leave +patronage +blessing +pat +touch +repayment +amity +occupancy +perfect +polish +unbroken +magic +department +police +delegate +siphon +plain +knit +llano +obvious +simple +employee +seaplane +follower +schedule +plasma +gamma +curtain +played +die +foul +thousand +poll +homo +straw +vote +pop +father +overpopulation +criminal +liabilities +expectation +vine +pound +formalism +learn +prairie +grassland +civil +smell +self-preservation +pressing +squeeze +cost +principle +hellenism +yin +print +contact +difficulty +poser +growing +academician +advantage +bulge +projecting +communicate +moneymaker +limelight +promise +betrothal +declare +pledge +nucleon +vicinity +psychotherapy +argue +quiet +louden +order +leader +title +anti-semitism +profiling +broadcasting +raising +bump +rise +upbringing +reading +interpretation +read +explanation +generalize +present +rebel +soldier +designation +advice +puff +improvement +rally +reduced +abbreviate +bated +low +reduce +simplify +slash +reed +body +entrance +relieve +comfort +free +announce +blue +resistance +resolution +physical +retirement +status +reverse +gear +opposition +right +proper +chondrite +limestone +rod +role +hat +rolling +robbery +romance +intrigue +quality +chicken +rose +damask +soar +rough +crushed +golf +ammunition +dispute +terrace +rubber +crepe +running +dash +sacrifice +kill +sales +divestiture +income +sample +satellite +outer +scattered +break +disband +incoherent +alumnus +virtuosity +cosmographer +south +leap +ordinal +assistant +password +appreciate +diocese +seeing +exudation +sent +ordered +quarterly +broadcast +spot +foreplay +sheet +expanse +worksheet +shifting +shining +shooting +sucker +arrive +diamond +shot +colorful +grapeshot +plumbing +token +single +common +singles +badminton +singleton +sit +element +classify +coat +magnitude +slain +worker +slight +dismiss +insignificant +silent +slip +freudian +slow +gradual +smashing +blast +expression +smoking +emit +inhale +vaporization +bribe +saddle +compatible +sole +solo +perform +sounding +depth +location +spiral +coiled +helix +splash +spread +discharge +disparity +dispersion +distributed +square +angular +lawful +polygon +stalls +livery +have +path +beginning +starter +get +stated +still +silence +merchandise +stocks +framework +heater +striking +conspicuous +crash +impressive +arrange +fingerboard +strike +study +examination +learning +review +submarine +sandwich +subordinate +dog +insubordinate +under +succeeding +suffering +enjoy +pain +beet +solstice +sunburst +bubble +horizontal +suspension +lapse +sway +power +cutlery +color +telling +narration +referent +tenure +africa +box +things +expect +storm +tag +means +topping +tops +crown +touched +touching +universe +trace +support +traffic +commerce +gravitation +tour +walk +dealing +trench +trick +deceive +shift +daze +hostile +march +trouble +perturbation +twist +adult +identify +version +uniform +jump +upgrade +afflict +agitation +disturbance +troubled +application +consume +used +functional +usual +familiar +vanished +vector +radius +videotape +violate +musician +vision +imagination +screwdriver +void +invalid +nonexistence +validate +waking +sleeping +walking +accompany +locomotion +wade +wandering +about +stray +battle +hostility +strategic +shampoo +lake +perspiration +wet +ways +shipyard +weak +diluted +weakening +transformation +flimsy +abundance +persuasion +precipitation +welsh +brythonic +wetness +bleach +light-skinned +dark +delaware +withdraw +retrograde +female +silva +red +workings +excavation +chow +poor +result +wounded +wrong +false +improper +injury +y2k +yielding +assent +facility +undertaking +retail +credence +bundle +overlook +happening +room +relearn +express +anagram +insane +awake +lamb +mutiny +enclose +mute +foundation +nonprofessional +grazed +extort +efferent +brave +hurl +asian +longing +intolerable +grieve +hailstone +considered +wind +kick +formation +scan +protest +desk +indirect +compel +prepared +ask +stabilize +conference +impermanent +backward +solemnize +liquidate +damage +likeness +constitute +patterned +prepare +english +lip +nearness +nonexistent +imperative +mocha +sauce +living +calmness +male +flap +maradona +arafat diff --git a/src/multiverso_skipgram_mixture.cpp b/src/multiverso_skipgram_mixture.cpp index f7890a2..87e9435 100755 --- a/src/multiverso_skipgram_mixture.cpp +++ b/src/multiverso_skipgram_mixture.cpp @@ -1,4 +1,4 @@ -#include "multiverso_skipgram_mixture.h" +#include "MultiversoSkipGramMixture.h" #include MultiversoSkipGramMixture::MultiversoSkipGramMixture(Option *option, Dictionary *dictionary, HuffmanEncoder *huffman_encoder, Reader *reader) @@ -35,7 +35,7 @@ void MultiversoSkipGramMixture::InitSenseCntInfo() m_word_sense_info.word_sense_cnts_info[wordlist[i].first] = 1; //Then, read words #sense info from the sense file - if (m_option->sense_file) + if (m_option->sense_file) { FILE* fid = fopen(m_option->sense_file, "r"); char word[1000]; @@ -58,7 +58,7 @@ void MultiversoSkipGramMixture::InitSenseCntInfo() int cnt = 0; m_word_sense_info.multi_senses_words_cnt = 0; - for (int i = 0; i < m_dictionary->Size(); ++i) + for (int i = 0; i < m_dictionary->Size(); ++i) { m_word_sense_info.p_input_embedding[i] = cnt; if (m_word_sense_info.word_sense_cnts_info[i] > 1) @@ -107,7 +107,7 @@ void MultiversoSkipGramMixture::Train(int argc, char *argv[]) multiverso::Log::ResetLogFile("log.txt"); m_process_id = multiverso::Multiverso::ProcessRank(); PrepareMultiversoParameterTables(m_option, m_dictionary); - + printf("Start to train ...\n"); TrainNeuralNetwork(); printf("Rank %d Finish training\n", m_process_id); @@ -156,7 +156,7 @@ void MultiversoSkipGramMixture::PrepareMultiversoParameterTables(Option *opt, Di { for (int col = 0; col < opt->sense_num_multi; ++col) { - multiverso::Multiverso::AddToServer(kWordSensePriorTableId, row, col, + multiverso::Multiverso::AddToServer(kWordSensePriorTableId, row, col, static_cast(m_option->store_multinomial ? 1.0 / m_option->sense_num_multi : log(1.0 / m_option->sense_num_multi))); } } @@ -198,7 +198,7 @@ void MultiversoSkipGramMixture::PushDataBlock( { std::chrono::milliseconds dura(200); std::this_thread::sleep_for(dura); - + RemoveDoneDataBlock(datablock_queue); } } diff --git a/src/param_loader.cpp b/src/param_loader.cpp index f1117d7..f2be012 100755 --- a/src/param_loader.cpp +++ b/src/param_loader.cpp @@ -1,4 +1,4 @@ -#include "param_loader.h" +#include "ParamLoader.h" template ParameterLoader::ParameterLoader(Option *option, void** word2vector_neural_networks, WordSenseInfo* word_sense_info) @@ -21,12 +21,12 @@ void ParameterLoader::ParseAndRequest(multiverso::DataBlockBase *data_block) fprintf(m_log_file, "%lf\n", (clock() - m_start_time) / (double)CLOCKS_PER_SEC); multiverso::Log::Info("Rank %d ParameterLoader begin %d\n", multiverso::Multiverso::ProcessRank(), m_parse_and_request_count); DataBlock *data = reinterpret_cast(data_block); - + SkipGramMixtureNeuralNetwork* sg_mixture_neural_network = reinterpret_cast*>(m_sgmixture_neural_networks[m_parse_and_request_count % 2]); ++m_parse_and_request_count; data->UpdateNextRandom(); sg_mixture_neural_network->PrepareParmeter(data); - + std::vector& input_layer_nodes = sg_mixture_neural_network->GetInputLayerNodes(); std::vector& output_layer_nodes = sg_mixture_neural_network->GetOutputLayerNodes(); assert(sg_mixture_neural_network->status == 0); @@ -62,4 +62,4 @@ void ParameterLoader::ParseAndRequest(multiverso::DataBlockBase *data_block) } template class ParameterLoader; -template class ParameterLoader; +template class ParameterLoader; \ No newline at end of file diff --git a/src/skipgram_mixture_neural_network.cpp b/src/skipgram_mixture_neural_network.cpp index 2c3370c..9eab9bf 100755 --- a/src/skipgram_mixture_neural_network.cpp +++ b/src/skipgram_mixture_neural_network.cpp @@ -1,7 +1,7 @@ -#include "skipgram_mixture_neural_network.h" +#include "SkipGramMixtureNeuralNetwork.h" template -SkipGramMixtureNeuralNetwork::SkipGramMixtureNeuralNetwork(Option* option, HuffmanEncoder* huffmanEncoder, WordSenseInfo* word_sense_info, Dictionary* dic, int dicSize) +SkipGramMixtureNeuralNetwork::SkipGramMixtureNeuralNetwork(Option* option, HuffmanEncoder* huffmanEncoder, WordSenseInfo* word_sense_info, Dictionary* dic, int dicSize) { status = 0; m_option = option; @@ -37,18 +37,18 @@ SkipGramMixtureNeuralNetwork::~SkipGramMixtureNeuralNetwork() } template -void SkipGramMixtureNeuralNetwork::Train(int* sentence, int sentence_length, T* gamma, T* fTable, T* input_backup) +void SkipGramMixtureNeuralNetwork::Train(int* sentence, int sentence_length, T* gamma, T* f_table, T* input_backup) { - ParseSentence(sentence, sentence_length, gamma, fTable, input_backup, &SkipGramMixtureNeuralNetwork::TrainSample); + ParseSentence(sentence, sentence_length, gamma, f_table, input_backup, &SkipGramMixtureNeuralNetwork::TrainSample); } template //The E - step, estimate the posterior multinomial probabilities -T SkipGramMixtureNeuralNetwork::Estimate_Gamma_m(int word_input, std::vector >& output_nodes, T* posterior_ll, T* estimation, T* sense_prior, T* f_m) +T SkipGramMixtureNeuralNetwork::EstimateGamma(int word_input, std::vector >& output_nodes, T* posterior_ll, T* estimation, T* sense_prior, T* f_m) { - T* inputEmbedding = m_input_embedding_weights_ptr[word_input]; + T* input_embedding = m_input_embedding_weights_ptr[word_input]; T f, log_likelihood = 0; - for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx, inputEmbedding += m_option->embeding_size) + for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx, input_embedding += m_option->embeding_size) { posterior_ll[sense_idx] = sense_prior[sense_idx] < eps ? MIN_LOG : log(sense_prior[sense_idx]); //posterior likelihood for each sense @@ -56,7 +56,7 @@ T SkipGramMixtureNeuralNetwork::Estimate_Gamma_m(int word_input, std::vector< for (int d = 0; d < output_nodes.size(); ++d, fidx++) { - f = Util::InnerProduct(inputEmbedding, m_output_embedding_weights_ptr[output_nodes[d].first], m_option->embeding_size); + f = Util::InnerProduct(input_embedding, m_output_embedding_weights_ptr[output_nodes[d].first], m_option->embeding_size); f = Util::Sigmoid(f); f_m[fidx] = f; if (output_nodes[d].second) //huffman code, 0 or 1 @@ -78,7 +78,7 @@ T SkipGramMixtureNeuralNetwork::Estimate_Gamma_m(int word_input, std::vector< template //The M Step: update the sense prior probabilities to maximize the Q function -void SkipGramMixtureNeuralNetwork::Maximize_Pi(int word_input, T* log_likelihood) +void SkipGramMixtureNeuralNetwork::MaximizeSensePriors(int word_input, T* log_likelihood) { if (m_word_sense_info->word_sense_cnts_info[word_input] == 1) { @@ -101,11 +101,11 @@ void SkipGramMixtureNeuralNetwork::UpdateEmbeddings(int word_input, std::vect { T g; T* output_embedding; - T* inputEmbedding; + T* input_embedding; if (direction == UpdateDirection::UPDATE_INPUT) - inputEmbedding = m_input_embedding_weights_ptr[word_input]; - else inputEmbedding = input_backup; - for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx, inputEmbedding += m_option->embeding_size) + input_embedding = m_input_embedding_weights_ptr[word_input]; + else input_embedding = input_backup; + for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx, input_embedding += m_option->embeding_size) { int64_t fidx = sense_idx * MAX_CODE_LENGTH; for (int d = 0; d < output_nodes.size(); ++d, ++fidx) @@ -115,12 +115,12 @@ void SkipGramMixtureNeuralNetwork::UpdateEmbeddings(int word_input, std::vect if (direction == UpdateDirection::UPDATE_INPUT) //Update Input { for (int j = 0; j < m_option->embeding_size; ++j) - inputEmbedding[j] += g * output_embedding[j]; + input_embedding[j] += g * output_embedding[j]; } else // Update Output { for (int j = 0; j < m_option->embeding_size; ++j) - output_embedding[j] += g * inputEmbedding[j]; + output_embedding[j] += g * input_embedding[j]; } } } @@ -132,7 +132,7 @@ template void SkipGramMixtureNeuralNetwork::TrainSample(int input_node, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup) { T* gamma = (T*)v_gamma; //stores the posterior probabilities - T* fTable = (T*)v_fTable; //stores the inner product values of input and output embeddings + T* f_table = (T*)v_fTable; //stores the inner product values of input and output embeddings T* input_backup = (T*)v_input_backup; T posterior_ll[MAX_SENSE_CNT]; //stores the posterior log likelihood @@ -149,16 +149,16 @@ void SkipGramMixtureNeuralNetwork::TrainSample(int input_node, std::vectorstore_multinomial) - Maximize_Pi(input_node, gamma); + MaximizeSensePriors(input_node, gamma); else - Maximize_Pi(input_node, posterior_ll); + MaximizeSensePriors(input_node, posterior_ll); - UpdateEmbeddings(input_node, output_nodes, gamma, fTable, input_backup, UpdateDirection::UPDATE_INPUT); - UpdateEmbeddings(input_node, output_nodes, gamma, fTable, input_backup, UpdateDirection::UPDATE_OUTPUT); + UpdateEmbeddings(input_node, output_nodes, gamma, f_table, input_backup, UpdateDirection::UPDATE_INPUT); + UpdateEmbeddings(input_node, output_nodes, gamma, f_table, input_backup, UpdateDirection::UPDATE_OUTPUT); } } @@ -205,10 +205,10 @@ void SkipGramMixtureNeuralNetwork::DealPrepareParameter(int input_node, std:: template /* - Parse a sentence and deepen into two branchs: - one for TrainNN, the other one is for Parameter_parse&request +Parse a sentence and deepen into two branchs: +one for TrainNN, the other one is for Parameter_parse&request */ -void SkipGramMixtureNeuralNetwork::ParseSentence(int* sentence, int sentence_length, T* gamma, T* fTable, T* input_backup, FunctionType function) +void SkipGramMixtureNeuralNetwork::ParseSentence(int* sentence, int sentence_length, T* gamma, T* f_table, T* input_backup, FunctionType function) { if (sentence_length == 0) return; @@ -220,7 +220,7 @@ void SkipGramMixtureNeuralNetwork::ParseSentence(int* sentence, int sentence_ { if (sentence[sentence_position] == -1) continue; int feat_size = 0; - + for (int i = 0; i < m_option->window_size * 2 + 1; ++i) if (i != m_option->window_size) { @@ -233,7 +233,7 @@ void SkipGramMixtureNeuralNetwork::ParseSentence(int* sentence, int sentence_ input_node = feat[feat_size - 1]; output_nodes.clear(); Parse(input_node, sentence[sentence_position], output_nodes); - (this->*function)(input_node, output_nodes, gamma, fTable, input_backup); + (this->*function)(input_node, output_nodes, gamma, f_table, input_backup); } } } @@ -282,7 +282,7 @@ std::vector& SkipGramMixtureNeuralNetwork::GetOutputLayerNodes() } template -void SkipGramMixtureNeuralNetwork::SetInputEmbeddingWeights(int input_node_id, T* ptr) +void SkipGramMixtureNeuralNetwork::SetinputEmbeddingWeights(int input_node_id, T* ptr) { m_input_embedding_weights_ptr[input_node_id] = ptr; } @@ -306,7 +306,7 @@ void SkipGramMixtureNeuralNetwork::SetSensePriorParaWeights(int input_node_id } template -T* SkipGramMixtureNeuralNetwork::GetInputEmbeddingWeights(int input_node_id) +T* SkipGramMixtureNeuralNetwork::GetinputEmbeddingWeights(int input_node_id) { return m_input_embedding_weights_ptr[input_node_id]; } diff --git a/src/skipgram_mixture_neural_network.h b/src/skipgram_mixture_neural_network.h index 2471fc2..7b98d30 100755 --- a/src/skipgram_mixture_neural_network.h +++ b/src/skipgram_mixture_neural_network.h @@ -17,6 +17,146 @@ enum class UpdateDirection template class SkipGramMixtureNeuralNetwork { +#pragma once + +#include + +#include "Util.h" +#include +#include "HuffmanEncoder.h" +#include "MultiversoSkipGramMixture.h" +#include "cstring" + + enum class UpdateDirection + { + UPDATE_INPUT, + UPDATE_OUTPUT + }; + + template + class SkipGramMixtureNeuralNetwork + { + public: + T learning_rate; + T sense_prior_momentum; + + int status; + SkipGramMixtureNeuralNetwork(Option* option, HuffmanEncoder* huffmanEncoder, WordSenseInfo* word_sense_info, Dictionary* dic, int dicSize); + ~SkipGramMixtureNeuralNetwork(); + + void Train(int* sentence, int sentence_length, T* gamma, T* fTable, T* input_backup); + + /*! + * \brief Collect all the input words and output nodes in the data block + */ + void PrepareParmeter(DataBlock *data_block); + + std::vector& GetInputLayerNodes(); + std::vector& GetOutputLayerNodes(); + + /*! + * \brief Set the pointers to those local parameters + */ + void SetInputEmbeddingWeights(int input_node_id, T* ptr); + void SetOutputEmbeddingWeights(int output_node_id, T* ptr); + void SetSensePriorWeights(int input_node_id, T*ptr); + void SetSensePriorParaWeights(int input_node_id, T* ptr); + + /*! + * \brief Get the pointers to those locally updated parameters + */ + T* GetInputEmbeddingWeights(int input_node_id); + T* GetEmbeddingOutputWeights(int output_node_id); + T* GetSensePriorWeights(int input_node_id); + T* GetSensePriorParaWeights(int input_node_id); + + private: + Option *m_option; + Dictionary *m_dictionary; + HuffmanEncoder *m_huffman_encoder; + int m_dictionary_size; + + WordSenseInfo* m_word_sense_info; + + T** m_input_embedding_weights_ptr; //Points to every word's input embedding vector + bool *m_seleted_input_embedding_weights; + T** m_output_embedding_weights_ptr; //Points to every huffman node's embedding vector + bool *m_selected_output_embedding_weights; + + T** m_sense_priors_ptr; //Points to the multinomial parameters, if store_multinomial is set to zero. + T** m_sense_priors_paras_ptr;//Points to sense prior parameters. If store_multinomial is zero, then it points to the log of multinomial, otherwise points to the multinomial parameters + + std::vector m_input_layer_nodes; + std::vector m_output_layer_nodes; + + typedef void(SkipGramMixtureNeuralNetwork::*FunctionType)(int input_node, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup); + + /*! + * \brief Parse the needed parameter in a window + */ + void Parse(int feat, int word_idx, std::vector >& output_nodes); + + /*! + * \brief Parse a sentence and deepen into two branchs + * \one for TrainNN,the other one is for Parameter_parse&request + */ + void ParseSentence(int* sentence, int sentence_length, T* gamma, T* fTable, T* input_backup, FunctionType function); + + /*! + * \brief Copy the input_nodes&output_nodes to WordEmbedding private set + */ + void DealPrepareParameter(int input_nodes, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup); + + /*! + * \brief Train a window sample and update the + * \input-embedding&output-embedding vectors + * \param word_input represent the input words + * \param output_nodes represent the ouput nodes on huffman tree, including the node index and path label + * \param v_gamma is the temp memory to store the posterior probabilities of each sense + * \param v_fTable is the temp memory to store the sigmoid value of inner product of input and output embeddings + * \param v_input_backup stores the input embedding vectors as backup + */ + void TrainSample(int word_input, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup); + + /*! + * \brief The E-step, estimate the posterior multinomial probabilities + * \param word_input represent the input words + * \param output_nodes represent the ouput nodes on huffman tree, including the node index and path label + * \param posterior represents the calculated posterior log likelihood + * \param estimation represents the calculated gammas (see the paper), that is, the softmax terms of posterior + * \param sense_prior represents the parameters of sense prior probablities for each polysemous words + * \param f_m is the temp memory to store the sigmoid value of inner products of input and output embeddings + */ + T EstimateGamma(int word_input, std::vector >& output_nodes, T* posterior, T* estimation, T* sense_prior, T* f_m); + + /*! + * \brief The M step: update the embedding vectors to maximize the Q function + * \param word_input represent the input words + * \param output_nodes represent the ouput nodes on huffman tree, including the node index and path label + * \param estimation represents the calculated gammas (see the paper), that is, the softmax terms of posterior + * \param f_m is the temp memory to store the sigmoid value of inner products of input and output embeddings + * \param input_backup stores the input embedding vectors as backup + * \param direction: update input vectors or output vectors + */ + void UpdateEmbeddings(int word_input, std::vector >& output_nodes, T* estimation, T* f_m, T* input_backup, UpdateDirection direction); + + /*! + * \brief The M Step: update the sense prior probabilities to maximize the Q function + * \param word_input represent the input words + * \param curr_priors are the closed form values of the sense priors in this iteration + */ + void MaximizeSensePriors(int word_input, T* curr_priors); + + /* + * \brief Record the input word so that parameter loader can be performed + */ + void AddInputLayerNode(int node_id); + + /* + * \brief Record the huffman tree node so that parameter loader can be performed + */ + void AddOutputLayerNode(int node_id); + }; public: T learning_rate; T sense_prior_momentum; diff --git a/src/trainer.cpp b/src/trainer.cpp index d99047d..657dca3 100755 --- a/src/trainer.cpp +++ b/src/trainer.cpp @@ -1,4 +1,4 @@ -#include "trainer.h" +#include "Trainer.h" template Trainer::Trainer(int trainer_id, Option *option, void** word2vector_neural_networks, multiverso::Barrier *barrier, Dictionary* dictionary, WordSenseInfo* word_sense_info, HuffmanEncoder* huff_encoder) @@ -12,8 +12,8 @@ Trainer::Trainer(int trainer_id, Option *option, void** word2vector_neural_ne m_word_sense_info = word_sense_info; m_huffman_encoder = huff_encoder; - gamma = (T*)calloc(m_option-> window_size * MAX_SENSE_CNT, sizeof(T)); - fTable = (T*)calloc(m_option-> window_size * MAX_CODE_LENGTH * MAX_SENSE_CNT, sizeof(T)); + gamma = (T*)calloc(m_option->window_size * MAX_SENSE_CNT, sizeof(T)); + fTable = (T*)calloc(m_option->window_size * MAX_CODE_LENGTH * MAX_SENSE_CNT, sizeof(T)); input_backup = (T*)calloc(m_option->embeding_size * MAX_SENSE_CNT, sizeof(T)); m_start_time = 0; @@ -62,7 +62,7 @@ void Trainer::TrainIteration(multiverso::DataBlockBase *data_block) { local_output_layer_nodes.push_back(output_layer_nodes[i]); } - + CopyParameterFromMultiverso(local_input_layer_nodes, local_output_layer_nodes, word2vector_neural_network); multiverso::Row& word_count_actual_row = GetRow(kWordCountActualTableId, 0); @@ -72,11 +72,11 @@ void Trainer::TrainIteration(multiverso::DataBlockBase *data_block) word2vector_neural_network->learning_rate = learning_rate; //Linearly increase the momentum from init_sense_prior_momentum to 1 - word2vector_neural_network->sense_prior_momentum = m_option->init_sense_prior_momentum + + word2vector_neural_network->sense_prior_momentum = m_option->init_sense_prior_momentum + (1 - m_option->init_sense_prior_momentum) * word_count_actual_row.At(0) / (T)(m_option->total_words * m_option->epoch + 1); - + m_barrier->Wait(); - + for (int i = m_trainer_id; i < data->Size(); i += m_option->thread_cnt) //i iterates over all sentences { int sentence_length; @@ -86,7 +86,7 @@ void Trainer::TrainIteration(multiverso::DataBlockBase *data_block) data->Get(i, sentence, sentence_length, word_count_deta, next_random); word2vector_neural_network->Train(sentence, sentence_length, gamma, fTable, input_backup); - + m_word_count += word_count_deta; if (m_word_count - m_last_word_count > 10000) { @@ -94,7 +94,7 @@ void Trainer::TrainIteration(multiverso::DataBlockBase *data_block) Add(kWordCountActualTableId, 0, 0, m_word_count - m_last_word_count); m_last_word_count = m_word_count; m_now_time = clock(); - + if (m_trainer_id % 3 == 0) { multiverso::Log::Info("Rank %d Trainer %d lr: %.5f Mom: %.4f Progress: %.2f%% Words/thread/sec(total): %.2fk W/t/sec(executive): %.2fk\n", @@ -115,12 +115,12 @@ void Trainer::TrainIteration(multiverso::DataBlockBase *data_block) word2vector_neural_network->sense_prior_momentum = m_option->init_sense_prior_momentum + (1 - m_option->init_sense_prior_momentum) * word_count_actual_row.At(0) / (T)(m_option->total_words * m_option->epoch + 1); } } - + m_barrier->Wait(); AddParameterToMultiverso(local_input_layer_nodes, local_output_layer_nodes, word2vector_neural_network); - + m_executive_time += clock() - train_interation_start; - + multiverso::Log::Info("Rank %d Train %d end at %lfs, cost %lfs, total cost %lfs\n", m_process_id, m_trainer_id, clock() / (double)CLOCKS_PER_SEC, @@ -159,7 +159,7 @@ template int Trainer::CopyParameterFromMultiverso(std::vector& input_layer_nodes, std::vector& output_layer_nodes, void* local_word2vector_neural_network) { SkipGramMixtureNeuralNetwork* word2vector_neural_network = (SkipGramMixtureNeuralNetwork*)local_word2vector_neural_network; - + //Copy input embedding for (int i = 0; i < input_layer_nodes.size(); ++i) { @@ -169,7 +169,7 @@ int Trainer::CopyParameterFromMultiverso(std::vector& input_layer_nodes, CopyMemory(ptr + j * m_option->embeding_size, GetRow(kInputEmbeddingTableId, row_id), m_option->embeding_size); word2vector_neural_network->SetInputEmbeddingWeights(input_layer_nodes[i], ptr); } - + //Copy output embedding for (int i = 0; i < output_layer_nodes.size(); ++i) { @@ -184,7 +184,7 @@ int Trainer::CopyParameterFromMultiverso(std::vector& input_layer_nodes, } word2vector_neural_network->SetOutputEmbeddingWeights(output_layer_nodes[i], ptr); } - + //Copy sense prior for (int i = 0; i < input_layer_nodes.size(); ++i) { @@ -285,7 +285,7 @@ void Trainer::SaveMultiInputEmbedding(const int epoch_id) fid = fopen(outfile, "wb"); - fprintf(fid, "%d %d %d\n", m_dictionary->Size(), m_word_sense_info->total_senses_cnt, m_option->embeding_size); + fprintf(fid, "%d %d %d\n", m_dictionary->Size(), m_word_sense_info->total_senses_cnt, m_option->embeding_size); for (int i = 0; i < m_dictionary->Size(); ++i) { fprintf(fid, "%s %d ", m_dictionary->GetWordInfo(i)->word.c_str(), m_word_sense_info->word_sense_cnts_info[i]); @@ -297,7 +297,7 @@ void Trainer::SaveMultiInputEmbedding(const int epoch_id) CopyMemory(sense_priors_ptr, GetRow(kWordSensePriorTableId, m_word_sense_info->p_wordidx2sense_idx[i]), m_option->sense_num_multi); if (!m_option->store_multinomial) Util::SoftMax(sense_priors_ptr, sense_priors_ptr, m_option->sense_num_multi); - + for (int j = 0; j < m_option->sense_num_multi; ++j) { fwrite(sense_priors_ptr + j, sizeof(real), 1, fid); @@ -317,7 +317,7 @@ void Trainer::SaveMultiInputEmbedding(const int epoch_id) fwrite(&prob, sizeof(real), 1, fid); emb_row_id = m_word_sense_info->p_input_embedding[i]; multiverso::Row& embedding = GetRow(kInputEmbeddingTableId, emb_row_id); - + for (int k = 0; k < m_option->embeding_size; ++k) { emb_tmp = embedding.At(k); @@ -442,4 +442,4 @@ void Trainer::SaveHuffEncoder() } template class Trainer; -template class Trainer; +template class Trainer; \ No newline at end of file