Search code examples
clstm

LSTM cell from scratch written in c didn't predict the same value as keras predicted


I'm using c to write an lstm cell from scratch. I tested it with an input which has the length of 16. The predicted output, which has the length of 10, only matched the last 6 element of what keras predicted.

First, I wrote a simple LSTM with keras.

inp = layers.Input((None, 16))
x = layers.LSTM(10, return_sequences=True)(inp)
model = keras.models.Model(inputs=[inp], outputs=[x])

I can predict value by model().

model(tf.constant([[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6]]]))

that printed

<tf.Tensor: shape=(1, 1, 10), dtype=float32, numpy=
array([[[-0.01320663,  0.3537513 ,  0.00737759, -0.0113485 ,
         -0.1471649 , -0.07326563, -0.04981037, -0.21077922,
          0.10876717,  0.1565821 ]]], dtype=float32)>

And then I extracted the weights of the model for future use:

print("#define LSTM1_KERNEL {", ", ".join(['{%s}' % ", ".join(map(str, i.tolist())) for i in model.layers[1].get_weights()[0]]), "}", sep="")
print("#define LSTM1_RECURRENT_KERNEL {", ", ".join(['{%s}' % ", ".join(map(str, i.tolist())) for i in model.layers[1].get_weights()[1]]), "}", sep="")
print("#define LSTM1_BIASES {", str(", ".join(map(str, model.layers[1].get_weights()[2]))), "}", sep="")

which printed

#define LSTM1_KERNEL {{-0.2937382459640503, 0.25691062211990356, -0.30560800433158875, -0.08145883679389954, 0.12993448972702026, -0.1012168824672699, -0.3098611533641815, 0.08793318271636963, 0.257276713848114, 0.05901506543159485, 0.3080512285232544, 0.29551106691360474, 0.26505959033966064, 0.27081406116485596, -0.14918622374534607, 0.15604108572006226, 0.3055707812309265, -0.2799621522426605, 0.041829854249954224, -0.25907522439956665, 0.2499493956565857, -0.2985055148601532, -0.135391503572464, -0.06500610709190369, 0.2709929943084717, -0.29365330934524536, 0.23996424674987793, 0.21257370710372925, 0.05963394045829773, 0.2999299168586731, -0.2082209587097168, 0.17517662048339844, 0.317105233669281, 0.16582608222961426, -0.09632466733455658, -0.1994607299566269, -0.3120834231376648, 0.011822879314422607, 0.14041826128959656, 0.030010759830474854}, {-0.07048910856246948, -0.24438296258449554, -0.3232172727584839, -0.23560135066509247, 0.16623079776763916, -0.27455994486808777, -0.21574059128761292, 0.1577852964401245, 0.3088122606277466, 0.21187055110931396, -0.1580374538898468, -0.1526646614074707, 0.3011627197265625, 0.15145424008369446, -0.22061829268932343, 0.16724830865859985, -0.2885115146636963, 0.08485367894172668, -0.23716357350349426, -0.08275860548019409, 0.26693373918533325, -0.006903558969497681, -0.1573537439107895, 0.045299410820007324, 0.29986828565597534, 0.31810545921325684, 0.22004854679107666, 0.28779375553131104, 0.130957692861557, -0.05509829521179199, 0.17370259761810303, 0.25324076414108276, 0.23543721437454224, 0.3074221611022949, -0.03730162978172302, 0.2944799065589905, -0.23619344830513, -0.21501901745796204, -0.1435796171426773, 0.2039697766304016}, {-0.0936431884765625, 0.04898190498352051, -0.17694081366062164, -0.11506195366382599, -0.26340368390083313, -0.07742957770824432, 0.22780388593673706, -0.15559649467468262, 0.0004108846187591553, -0.24243374168872833, 0.21332329511642456, -0.2266865074634552, -0.1415550857782364, 0.1927451491355896, -0.2913603186607361, 0.1829729676246643, 0.30117177963256836, -0.31317561864852905, -0.10855123400688171, -0.02568337321281433, 0.2505541443824768, 0.17569029331207275, 0.3120304346084595, -0.022581875324249268, 0.1718236804008484, 0.22337841987609863, -0.24180543422698975, -0.2772238552570343, 0.2667417526245117, -0.011020451784133911, 0.06895411014556885, 0.21713727712631226, 0.023944228887557983, -0.16436180472373962, -0.1355346292257309, -0.18569877743721008, -0.12961304187774658, 0.03008010983467102, 0.3092449903488159, -0.2526698708534241}, {0.1286792755126953, -0.11293900012969971, 0.013389289379119873, -0.18133169412612915, 0.2839161157608032, 0.2839195132255554, 0.24429547786712646, 0.3261539340019226, 0.2771833539009094, 0.2762939929962158, -0.01121985912322998, 0.1633266806602478, 0.13623419404029846, -0.22840574383735657, -0.008434563875198364, -0.20190590620040894, -0.18728604912757874, -0.30082377791404724, 0.05349135398864746, -0.032782286405563354, 0.26623207330703735, -0.05535465478897095, -0.06507837772369385, -0.06029370427131653, -0.05910545587539673, 0.21522831916809082, -0.07686832547187805, -0.030431240797042847, 0.13506802916526794, 0.31972724199295044, 0.14279630780220032, 0.30336707830429077, 0.2874157428741455, -0.21258234977722168, -0.11647495627403259, -0.30905622243881226, 0.15032967925071716, -0.09122250974178314, 0.137144535779953, 0.3054274320602417}, {0.19253015518188477, -0.2128022015094757, 0.1264038383960724, -0.2520808279514313, 0.26445841789245605, 0.08062925934791565, 0.08375704288482666, -0.2729018032550812, 0.09173852205276489, -0.17347314953804016, 0.26803749799728394, -0.16099271178245544, -0.2895902097225189, -0.31336697936058044, -0.3269425630569458, 0.30058449506759644, 0.23531174659729004, -0.1584257036447525, 0.18380320072174072, -0.2134735882282257, -0.06926783919334412, 0.014586150646209717, -0.04286158084869385, -0.2197437584400177, -0.24472251534461975, -0.28000062704086304, -0.3010462820529938, 0.13298621773719788, 0.09352761507034302, 0.1429082453250885, -0.017474889755249023, 0.2462477684020996, 0.23644763231277466, 0.045398205518722534, 0.1676187515258789, 0.05115005373954773, -0.046341001987457275, 0.13247153162956238, 0.19267618656158447, 0.2226163148880005}, {-0.29614946246147156, -0.07169562578201294, 0.03652316331863403, 0.12125691771507263, 0.06905150413513184, -0.08269600570201874, -0.2852576971054077, 0.2639201879501343, 0.21139466762542725, 0.028177410364151, 0.022612959146499634, 0.24554657936096191, -0.17131789028644562, -0.2365601658821106, -0.19441640377044678, -0.2306821346282959, 0.2879871129989624, 0.1760021448135376, -0.035788893699645996, -0.1584118902683258, -0.1611887365579605, -0.20519119501113892, 0.3073367476463318, 0.29482001066207886, -0.23705634474754333, -0.3082103431224823, -0.2798632085323334, -0.28199970722198486, -0.28728777170181274, -0.2724458873271942, -0.07034051418304443, 0.101656973361969, 9.199976921081543e-05, -0.28747349977493286, 0.32538026571273804, 0.08917456865310669, -0.1690398007631302, 0.0025961697101593018, 0.054893285036087036, 0.10311257839202881}, {0.09853002429008484, 0.07092487812042236, 0.15012452006340027, -0.2408362478017807, 0.23686301708221436, 0.15131759643554688, -0.04604828357696533, -0.13971972465515137, -0.187650665640831, 0.20239859819412231, 0.06835803389549255, 0.19882804155349731, 0.16723471879959106, 0.29852062463760376, -0.21917492151260376, 0.2803170680999756, 0.04214662313461304, -0.2786775231361389, 0.2613469362258911, -0.281821072101593, -0.27265799045562744, -0.11274397373199463, -0.14982624351978302, -0.31242743134498596, 0.22651052474975586, -0.3048644959926605, 0.007821261882781982, -0.06491979956626892, -0.223236083984375, -0.3025462329387665, -0.16308428347110748, -0.2127648890018463, -0.32697737216949463, -0.18411096930503845, -0.21061986684799194, 0.23829764127731323, -0.29919156432151794, -0.03528442978858948, 0.16280800104141235, 0.0010845363140106201}, {-0.2544461786746979, 0.3035287857055664, 0.3214139938354492, -0.25379523634910583, -0.28763771057128906, -0.22116926312446594, 0.08540549874305725, 0.1530439257621765, -0.1517166942358017, 0.12267747521400452, 0.29826849699020386, 0.04914003610610962, -0.17546755075454712, -0.1977802813053131, 0.28278690576553345, 0.06357243657112122, -0.08368799090385437, 0.2384331226348877, -0.0750417709350586, 0.17452633380889893, -0.02728596329689026, 0.13649210333824158, -0.2008959800004959, -0.0089263916015625, 0.14859968423843384, -0.28194791078567505, -0.043793678283691406, -0.09229221940040588, -0.07804720103740692, -0.30755019187927246, 0.3240317106246948, 0.21725302934646606, -0.08568957448005676, 0.04902896285057068, -0.16016000509262085, 0.28713470697402954, -0.30306535959243774, 0.1661771833896637, -0.2489386796951294, -0.1524587869644165}, {0.3225787281990051, -0.0722678005695343, 0.29168999195098877, -0.26409777998924255, 0.2183428406715393, 0.25207746028900146, -0.1434115171432495, -0.3045051395893097, 0.17568659782409668, 0.13391879200935364, 0.26256537437438965, -0.14129699766635895, 0.05537611246109009, 0.30112671852111816, 0.03906169533729553, 0.19529420137405396, -0.2087670862674713, 0.22477513551712036, 0.15530547499656677, 0.231309175491333, -0.029470384120941162, 0.1821277141571045, -0.01585516333580017, 0.17920136451721191, -0.30067309737205505, -0.05515649914741516, 0.30068492889404297, -0.26682257652282715, 0.07289525866508484, 0.02092382311820984, -0.04872250556945801, -0.10232444107532501, 0.27262866497039795, 0.010586321353912354, -0.2126511037349701, 0.1243632435798645, -0.2679103910923004, 0.21442973613739014, -0.009513705968856812, -0.030567646026611328}, {0.01274329423904419, 0.061925917863845825, 0.02312678098678589, -0.27860307693481445, 0.30133700370788574, -0.042556196451187134, 0.2741144299507141, 0.243780255317688, -0.025462687015533447, -0.00994834303855896, 0.17419230937957764, 0.07882440090179443, -0.30640316009521484, -0.18100805580615997, 0.21359962224960327, -0.14350204169750214, 0.17430812120437622, 0.24763882160186768, -0.24152767658233643, -0.08985212445259094, 0.09596699476242065, -0.042413145303726196, 0.2433282732963562, -0.19888916611671448, -0.044965386390686035, 0.14925917983055115, -0.28870663046836853, 0.31808775663375854, 0.02797466516494751, 0.13205251097679138, 0.028620123863220215, 0.17123672366142273, 0.17744100093841553, 0.08048862218856812, -0.012821078300476074, -0.07770584523677826, 0.09894254803657532, -0.14119555056095123, -0.2093517780303955, 0.11683058738708496}, {-0.0880342423915863, -0.16364359855651855, -0.31709960103034973, -0.05142560601234436, 0.2600014805793762, 0.15837785601615906, 0.17937153577804565, 0.24516987800598145, -0.06466695666313171, 0.2517629861831665, -0.06173074245452881, -0.15767745673656464, -0.29276514053344727, -0.28528204560279846, -0.16223332285881042, -0.04487493634223938, -0.29766952991485596, 0.05379462242126465, 0.15101394057273865, -0.13530057668685913, -0.26568669080734253, 0.05769026279449463, -0.25164762139320374, 0.16940799355506897, -0.18820030987262726, 0.09384647011756897, 0.2755531668663025, -0.011514216661453247, 0.060526251792907715, 0.21743464469909668, 0.22867953777313232, 0.2765069603919983, -0.2142142653465271, -0.08479546010494232, 0.32202762365341187, -0.09636501967906952, -0.0021964609622955322, 0.23163974285125732, -0.11037200689315796, -0.24356447160243988}, {-0.3167370855808258, 0.30017322301864624, -0.07848328351974487, -0.3098224401473999, -0.07888363301753998, 0.24105793237686157, -0.0685325562953949, 0.214469313621521, 0.16685473918914795, 0.21010524034500122, 0.059457480907440186, -0.1338091492652893, 0.12677127122879028, -0.19873018562793732, -0.29018634557724, 0.06710994243621826, -0.24473798274993896, -0.09682419896125793, 0.07203882932662964, 0.30855458974838257, 0.14253488183021545, 0.07871466875076294, -0.03849175572395325, -0.30217522382736206, 0.11604228615760803, 0.06006690859794617, -0.1558382660150528, 0.30797380208969116, 0.24169319868087769, -0.2644522488117218, -0.08729410171508789, -0.12122170627117157, 0.0023641586303710938, 0.2149525284767151, -0.04440614581108093, -0.06998223066329956, 0.26038694381713867, -0.09593571722507477, 0.04419833421707153, 0.20667433738708496}, {-0.13085608184337616, 0.28740113973617554, 0.20486599206924438, 0.09832343459129333, -0.16989973187446594, 0.003633350133895874, -0.16878913342952728, -0.05539911985397339, 0.06368705630302429, -0.09225420653820038, 0.2096470594406128, -0.2803211510181427, -0.2159978151321411, 0.20003372430801392, 0.19021159410476685, -0.06429672241210938, -0.15774261951446533, 0.30152428150177, -0.18959738314151764, -0.19082215428352356, -0.1534786969423294, 0.02207234501838684, -0.19847843050956726, -0.12354427576065063, 0.11847817897796631, 0.01883155107498169, 0.10456624627113342, -0.3006700575351715, 0.12268123030662537, 0.038548171520233154, -0.22566145658493042, -0.010022073984146118, 0.2949276566505432, -0.2226477563381195, 0.05210956931114197, 0.18077439069747925, -0.014195919036865234, -0.03624418377876282, -0.08444911241531372, -0.21484476327896118}, {0.08358785510063171, 0.2406042218208313, -0.27732813358306885, -0.28388145565986633, -0.14778532087802887, 0.10833793878555298, 0.2931848168373108, -0.05245348811149597, 0.26646268367767334, -0.059916287660598755, 0.008635908365249634, -0.3244437873363495, 0.08311635255813599, -0.09971670806407928, 0.018724024295806885, -0.07553640007972717, -0.2765367925167084, 0.12158796191215515, 0.08789563179016113, 0.3081898093223572, 0.13620609045028687, 0.23658007383346558, 0.11196920275688171, 0.014192432165145874, -0.19592759013175964, -0.17173299193382263, 0.17186567187309265, -0.28639668226242065, 0.15605631470680237, 0.26586848497390747, 0.20840895175933838, 0.2716386318206787, -0.06808465719223022, 0.030398517847061157, -0.29229408502578735, -0.17325915396213531, -0.2226700633764267, 0.1546078622341156, -0.3216899335384369, 0.29681044816970825}, {-0.05840221047401428, -0.20649243891239166, -0.275494247674942, -0.12143155932426453, -0.202173113822937, 0.06901586055755615, 0.019829541444778442, -0.10787002742290497, -0.04485777020454407, -0.20079171657562256, 0.04864555597305298, 0.11073404550552368, -0.1973721981048584, 0.1471342146396637, -0.2561625838279724, -0.21354670822620392, -0.19991172850131989, 0.2222803831100464, -0.15028496086597443, 0.05768096446990967, 0.07148957252502441, 0.2590382695198059, 0.2484537959098816, 0.11829254031181335, 0.05236005783081055, 0.0792408287525177, -0.12460766732692719, -0.31409454345703125, -0.18393318355083466, -0.0767565667629242, -0.17323856055736542, -0.2952617406845093, 0.03940609097480774, -0.1664322316646576, -0.17755046486854553, -0.258358895778656, -0.15036581456661224, 0.11148324608802795, 0.32275885343551636, 0.016039341688156128}, {-0.005480647087097168, 0.13969051837921143, 0.07085278630256653, -0.055675238370895386, 0.06986430287361145, -0.0077575743198394775, -0.27704405784606934, -0.014148861169815063, 0.19072580337524414, 0.24358952045440674, -0.2603682279586792, -0.04012367129325867, -0.1888551414012909, 0.05332168936729431, 0.32439905405044556, 0.0015577077865600586, -0.00466388463973999, 0.3190273642539978, 0.14088940620422363, -0.19484196603298187, -0.01473855972290039, -0.004330098628997803, -0.041361063718795776, 0.09134542942047119, -0.22436347603797913, 0.11300471425056458, -0.2656359076499939, 0.17975813150405884, 0.03892752528190613, 0.2737969160079956, 0.20975875854492188, 0.2659428119659424, -0.23092865943908691, 0.26794272661209106, 0.30342769622802734, 0.11639487743377686, -0.19478389620780945, -0.19419562816619873, -0.22662922739982605, 0.019620239734649658}}
#define LSTM1_RECURRENT_KERNEL {{-0.06749224662780762, -0.19506703317165375, -0.06228947266936302, 0.028524119406938553, 0.20978879928588867, 0.03446773439645767, -0.3255080282688141, -0.017273884266614914, -0.16444146633148193, -0.015307039022445679, 0.0015090096276253462, 0.006872890517115593, 0.15315088629722595, 0.38795894384384155, -0.08264278620481491, 0.07026154547929764, -0.02424474060535431, 3.9128084608819336e-05, 0.16973547637462616, 0.10559914261102676, 0.12403601408004761, -0.19156697392463684, 0.23722656071186066, 0.040726806968450546, 0.07843013852834702, -0.05215204879641533, 0.00534455431625247, -0.03541784733533859, -0.217418372631073, 0.1771443784236908, -0.2445060759782791, -0.2691061198711395, -0.24636699259281158, 0.1691955327987671, 0.03787366300821304, -0.14979086816310883, 0.12061695009469986, 0.260057657957077, 0.06853244453668594, -0.1286240816116333}, {-0.005993622820824385, -0.24427205324172974, 0.030819542706012726, -0.10726135224103928, -0.3564056158065796, -0.18827956914901733, 0.06802006810903549, 0.028958123177289963, -0.00857122428715229, 0.08963152021169662, -0.057909101247787476, 0.386557400226593, -0.08999310433864594, -0.1467428207397461, 0.1265380084514618, 0.0407247468829155, 0.017545832321047783, -0.11421257257461548, 0.21500976383686066, 0.017937077209353447, -0.13396437466144562, 0.35731586813926697, 0.10176851600408554, 0.12072192132472992, 0.39990878105163574, 0.09331593662500381, -0.02685726061463356, 0.09731768071651459, -0.11626307666301727, 0.017388539388775826, 0.0888611301779747, -0.1774260699748993, -0.12422303855419159, 0.10169469565153122, 0.13189202547073364, 0.13475863635540009, 0.11201001703739166, 0.06930758059024811, 0.08589541912078857, -0.056364960968494415}, {0.03497632220387459, -0.3533796966075897, -0.22543489933013916, -0.0485193245112896, 0.01107637770473957, 0.12856844067573547, -0.17017722129821777, -0.032133933156728745, 0.023421963676810265, -0.08931117504835129, 0.33798864483833313, -0.018917571753263474, -0.011232766322791576, 0.01574590988457203, 0.042588040232658386, 0.04622801020741463, 0.19672635197639465, 0.06417237222194672, -0.10470334440469742, -0.23864421248435974, -0.059810228645801544, 0.09768777340650558, -0.10040628165006638, -0.14509789645671844, 0.06773771345615387, -0.04104334115982056, -0.022936871275305748, 0.16341258585453033, 0.10451073199510574, -0.27043697237968445, -0.3301188349723816, 0.22100384533405304, -0.3461010158061981, -0.04642774537205696, -0.03309166803956032, 0.22121348977088928, -0.0680709108710289, -0.01595321297645569, 0.030179141089320183, 0.19619694352149963}, {-0.04072313383221626, 0.24905142188072205, 0.08900400251150131, -0.24979913234710693, 0.14574307203292847, -0.007038511801511049, 0.03249103203415871, -0.023222673684358597, -0.22603005170822144, 0.27367135882377625, -0.030985135585069656, 0.15127159655094147, -0.006817997433245182, 0.03027227893471718, -0.038371071219444275, -0.1225065365433693, -0.024442648515105247, -0.28420740365982056, 0.1909944862127304, -0.018619555979967117, -0.030714189633727074, 0.04682669788599014, -0.08282705396413803, -0.17761342227458954, -0.232009157538414, -0.04332009702920914, 0.1740972399711609, 0.2893451452255249, -0.16910985112190247, 0.1439540982246399, -0.10329971462488174, -0.07277991622686386, -0.1448083370923996, -0.28473711013793945, -0.23795932531356812, 0.2505485415458679, -0.10053527355194092, -0.01663699746131897, 0.20010896027088165, 0.09592140465974808}, {-0.11878308653831482, 0.014070812612771988, -0.3277565538883209, -0.12756288051605225, -0.014328980818390846, 0.08745329082012177, -0.07345258444547653, 0.012141863815486431, 0.056553181260824203, 0.0333787202835083, 0.1938624382019043, 0.3911162316799164, -0.13371503353118896, 0.07769668847322464, -0.05835907533764839, 0.04367610067129135, -0.09054629504680634, -0.1580016314983368, -0.011005986481904984, 0.09205211699008942, -0.017222050577402115, -0.2664359211921692, 0.11324514448642731, 0.09820376336574554, 0.09637799859046936, -0.01808803901076317, 0.2589896023273468, 0.20337450504302979, 0.09971316158771515, -0.1130448579788208, 0.001686878502368927, 0.15768694877624512, 0.17356641590595245, -0.1424150913953781, -0.09710390120744705, -0.18407653272151947, 0.15915937721729279, -0.23702290654182434, -0.24731719493865967, -0.2948751449584961}, {-0.20491445064544678, -0.22212952375411987, 0.2087169736623764, -0.13124299049377441, 0.12392139434814453, -0.25313815474510193, -0.13457736372947693, 0.10151604562997818, -0.03988504782319069, 0.041967134922742844, 0.0976562425494194, -0.036991722881793976, 0.3505118191242218, 0.08411690592765808, -0.06860946863889694, 0.05010578781366348, 0.19989821314811707, -0.06163168326020241, 0.3028776943683624, -0.11431054770946503, -0.06581276655197144, -0.06062182039022446, 0.04184970632195473, 0.12408663332462311, 0.0019230898469686508, 0.17347267270088196, -0.040315110236406326, -0.008183619938790798, 0.11375091969966888, -0.08901193737983704, 0.03588394075632095, 0.12624207139015198, 0.3536500036716461, 0.10561180114746094, -0.029747582972049713, 0.09693260490894318, -0.3880247175693512, -0.18302373588085175, 0.043063320219516754, -0.11229711771011353}, {0.027351543307304382, 0.04188253730535507, -0.1394345760345459, 0.18742041289806366, 0.0509335920214653, 0.06980549544095993, -0.1583179086446762, -0.1564529985189438, 0.33601921796798706, 0.18760834634304047, -0.09778987616300583, -0.006212963256984949, 0.28490301966667175, 0.09592632949352264, 0.21135841310024261, -0.06643474102020264, -0.017779245972633362, 0.037918224930763245, -0.05580649897456169, -0.15591026842594147, -0.1854349821805954, 0.16021153330802917, 0.33094024658203125, -0.291006863117218, -0.031706199049949646, 0.2571420669555664, 0.04593920335173607, -0.04402019456028938, -0.029843956232070923, -0.029097510501742363, 0.1222948282957077, 0.064956896007061, 0.16483257710933685, -0.09251931309700012, -0.22832167148590088, 0.16930581629276276, 0.22117367386817932, 0.1762712299823761, 0.04597582668066025, -0.12084245681762695}, {-0.15571328997612, 0.1811290830373764, -0.3179478347301483, 0.24534057080745697, 0.06352883577346802, 0.18771778047084808, 0.12269727885723114, 0.04981767013669014, -0.3314366042613983, -0.09279465675354004, -0.2294003963470459, 0.003686012700200081, 0.13516774773597717, 0.07415641844272614, 0.08335330337285995, 0.17025822401046753, -0.12838956713676453, 0.1970968395471573, 0.08094272017478943, -0.07968678325414658, 0.17197929322719574, 0.19906216859817505, -0.027253124862909317, -0.004737058188766241, 0.08020765334367752, 0.2789633870124817, 0.13998930156230927, 0.1551464945077896, 0.08101494610309601, -0.009294496849179268, 0.02137742191553116, -0.07746981084346771, -0.09093019366264343, 0.23701949417591095, -0.012727397494018078, 0.09898868203163147, -0.08753138780593872, -0.36297228932380676, 0.020412322133779526, 0.024156048893928528}, {0.19397608935832977, -0.07172737270593643, -0.07444058358669281, -0.06045854091644287, 0.05608497932553291, -0.02358597330749035, 0.13918089866638184, 0.06724704802036285, 0.018123377114534378, 0.07441994547843933, -0.35466307401657104, 0.006733589340001345, 0.10363553464412689, 0.031110934913158417, 0.27522921562194824, 0.012667804956436157, 0.30571407079696655, 0.08259426057338715, 0.09945833683013916, -0.015772439539432526, -0.4101584553718567, -0.08261121064424515, 0.10978402197360992, 0.09233836829662323, -0.2088344544172287, -0.10765135288238525, 0.17442764341831207, 0.16480182111263275, 0.1634531021118164, 0.032007865607738495, -0.1933029443025589, -0.004149202257394791, -0.020110856741666794, -0.026345063000917435, 0.24563413858413696, -0.3322038948535919, 0.07648780941963196, -0.10684900730848312, 0.05858294293284416, 0.15703007578849792}, {0.016292596235871315, -0.16068997979164124, 0.017371973022818565, 0.059121374040842056, 0.12419939041137695, -0.11636526882648468, 0.06736360490322113, -0.040519434958696365, 0.013860233128070831, 0.2816285192966461, 0.06843682378530502, -0.06598041951656342, -0.11708489060401917, 0.06599441915750504, 0.08088330924510956, -0.15082156658172607, -0.01781713031232357, -0.08857876807451248, -0.023932160809636116, -0.028023675084114075, 0.08513661473989487, -0.19381387531757355, -0.002677129814401269, 0.24996110796928406, -0.10848896205425262, 0.15301969647407532, -0.05714404955506325, 0.22335082292556763, -0.14497867226600647, -0.15700100362300873, 0.36476466059684753, 0.15421006083488464, -0.14810146391391754, 0.38729873299598694, -0.23341135680675507, -0.10775571316480637, 0.20482878386974335, -0.033480893820524216, 0.04325736314058304, 0.3293023407459259}}
#define LSTM1_BIASES {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}

I want to implement only one LSTM cell and the inference part in C, so I coded these:

main.c

#include <stdio.h>
#include "lstm.h"


int main() {
    // input_size and num_units is defined in lstm.h

    // get weighted cell
    LSTMCell cell = get_weighted_cell();

    // test input(only one time step)
    float inputs[16] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f};

    // zero initialization
    float last_c[num_units] = {.0f};
    float last_h[num_units] = {.0f};

    // test the cell
    lstm_inference(&cell, last_c, last_h, inputs);
    for (int i = 0; i < 10; i++) {
        printf("%f\n", last_h[i]);
    }

    return 0;
}

lstm.h

#define input_size 16
#define num_units 10

#include <stdlib.h>
#include <string.h>
#include "activations.h"
#include "weights.h"

typedef struct {
    float kernel[input_size][num_units * 4];
    float recurrent_kernel[num_units][num_units * 4];
    float biases[num_units * 4];
} LSTMCell;

LSTMCell get_weighted_cell() {
    LSTMCell cell = {
            LSTM1_KERNEL,
            LSTM1_RECURRENT_KERNEL,
            LSTM1_BIASES
    };
    return cell;
}

void lstm_inference(LSTMCell *cell, float last_c[], float last_h[], float x[]) {
    // only inference one time_step
    float gates[num_units * 4];  // i f c o in order

    // add biases to gates
    memcpy(gates, cell->biases, num_units * 4);

    // compute gates without activation
    // W dot x + U dot last_h + b
    for (int i = 0; i < num_units * 4; i++) {
        for (int j = 0; j < input_size; j++) {
            gates[i] += cell->kernel[j][i] * x[j];
        }
        for (int j = 0; j < num_units; j++) {
            gates[i] += cell->recurrent_kernel[j][i] * last_h[j];
        }
    }

    // compute current cell state and hidden state
    /* formula: last_c = c matmul i + f matmul last_c  `c is input activation`
     * gates[2 * num_units: 3 * num_units] matmul gates[:num_units] +
     * gates[num_units: num_units * 2] matmul last_c */
    for (int i = 0; i < num_units; i++) {
        last_c[i] = tanhf(gates[2 * num_units + i]) * sigmoid(gates[i]) +
                    sigmoid(gates[num_units + i]) * last_c[i];
        last_h[i] = tanhf(last_c[i]) * sigmoid(gates[num_units * 3 + i]);
    }
}

weights.h

// copied from keras
#define LSTM1_KERNEL {{-0.2937382459640503, 0.25691062211990356, -0.30560800433158875, ..., -0.19419562816619873, -0.22662922739982605, 0.019620239734649658}}
#define LSTM1_RECURRENT_KERNEL {{-0.06749224662780762, -0.19506703317165375, -0.06228947266936302, 0.028524119406938553, 0.20978879928588867, 0.03446773439645767, -0.3255080282688141, ..., -0.033480893820524216, 0.04325736314058304, 0.3293023407459259}}
#define LSTM1_BIASES {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}

activations.h

#include <math.h>

float sigmoid(float x) {
    return 1 / (1 + expf(-x));
}

I got the result:

0.155755
0.438359
0.176438
0.034490
-0.147165
-0.073266
-0.049810
-0.210779
0.108767
0.156582

Process finished with exit code 0

As you can see the last 6 elements are correct, but the first 4 aren't. Can anyone help me solve this problem?

Thanks.


Solution

  • The last parameter of memcpy should be the size of the memory you want to copy, not the length.

    float gates[num_units * 4];  // i f c o in order
    
    // add biases to gates
    memcpy(gates, cell->biases, num_units * 4); 
    

    cplusplus.com memcpy