demo/perceptron: Support initializing more weights. (#4)

* demo/perceptron: Support initializing more weights.
* demo/deposit cost: Get slightly most accurate deposit cost.
This commit is contained in:
Justin D. Harris 2019-05-27 17:50:20 -04:00 коммит произвёл GitHub
Родитель b3240712f7
Коммит 605048ad3b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 42 добавлений и 39 удалений

Просмотреть файл

@ -32,10 +32,36 @@ contract Perceptron is Classifier64 {
}
}
/**
* Initialize weights for the model.
* Made to be called just after the contract is created and never again.
* @param startIndex The index to start placing `_weights` into the model's weights.
* @param _weights The weights to set for the model.
*/
function initializeWeights(uint64 startIndex, int80[] memory _weights) public onlyOwner {
for (uint64 i = 0; i < _weights.length; ++i) {
weights[startIndex + i] = _weights[i];
}
}
function norm(int64[] memory /* data */) public pure returns (uint) {
revert("Normalization is not required.");
}
function predict(int64[] memory data) public view returns (uint64) {
int m = intercept;
for (uint i = 0; i < data.length; ++i) {
// `update` assumes this check is done.
require(data[i] >= 0, "Not all indices are >= 0.");
m = m.add(weights[uint64(data[i])]);
}
if (m <= 0) {
return 0;
} else {
return 1;
}
}
function update(int64[] memory data, uint64 classification) public onlyOwner {
uint64 prediction = predict(data);
if (prediction != classification) {
@ -59,35 +85,6 @@ contract Perceptron is Classifier64 {
}
}
/**
* Check if two arrays of training data are equal.
*/
function isDataEqual(uint24[] memory d1, uint24[] memory d2) public pure returns (bool) {
if (d1.length != d2.length) {
return false;
}
for (uint i = 0; i < d1.length; ++i) {
if (d1[i] != d2[i]) {
return false;
}
}
return true;
}
function predict(int64[] memory data) public view returns (uint64) {
int m = intercept;
for (uint i = 0; i < data.length; ++i) {
// `update` assumes this check is done.
require(data[i] >= 0, "Not all indices are >= 0.");
m = m.add(weights[uint64(data[i])]);
}
if (m <= 0) {
return 0;
} else {
return 1;
}
}
/**
* Evaluate a batch.
*

Просмотреть файл

@ -7,7 +7,7 @@ const DataHandler64 = artifacts.require("./data/DataHandler64");
const Classifier = artifacts.require("./classification/Perceptron");
const Stakeable64 = artifacts.require("./incentive/Stakeable64");
module.exports = function (deployer) {
module.exports = async function (deployer) {
// Information to persist to the DB.
const modelInfo = {
name: "IMDB Review Sentiment Classifier",
@ -34,12 +34,9 @@ module.exports = function (deployer) {
var data = fs.readFileSync('./src/ml-models/imdb-sentiment-model.json', 'utf8');
var model = JSON.parse(data);
// Don't use all the words since it takes too long to load
// and we don't need them all just for simple testing.
var maxNumWords = 100;
// There are 18 decimal places.
var weights = convertData(model['coef'].slice(0, maxNumWords));
const weights = convertData(model['coef']);
const initNumWords = 250;
const numWordsPerUpdate = 250;
console.log(`Deploying IMDB model with ${weights.length} weights.`);
var intercept = web3.utils.toBN(model['intercept'] * toFloat);
@ -55,8 +52,17 @@ module.exports = function (deployer) {
costWeight
).then(incentiveMechanism => {
console.log(` Deployed incentive mechanism to ${incentiveMechanism.address}.`);
console.log(`Deploying classifier.`);
return deployer.deploy(Classifier,
classifications, weights, intercept, learningRate).then(classifier => {
classifications, weights.slice(0, initNumWords), intercept, learningRate,
{ gas: 7.9E6 }).then(async classifier => {
console.log(` Deployed classifier to ${classifier.address}.`);
for (let i = initNumWords; i < weights.length; i += numWordsPerUpdate) {
await classifier.initializeWeights(i, weights.slice(i, i + numWordsPerUpdate),
{ gas: 7.9E6 });
console.log(` Added weights ${i + numWordsPerUpdate}`);
}
console.log(`Deploying collaborative trainer contract.`);
return deployer.deploy(CollaborativeTrainer64,
dataHandler.address,

Просмотреть файл

@ -286,8 +286,8 @@ class Model extends React.Component {
return this.state.incentiveMechanism.methods.lastUpdateTimeS().call()
.then(parseInt)
.then(lastUpdateTimeS => {
var now = new Date().getTime() / 1000;
var divisor = now - lastUpdateTimeS;
const now = Math.floor(new Date().getTime() / 1000);
let divisor = now - lastUpdateTimeS;
if (divisor === 0) {
divisor = 1;
} else {