Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrain model scrypt #18

Open
taratuncho opened this issue Feb 21, 2022 · 1 comment
Open

Retrain model scrypt #18

taratuncho opened this issue Feb 21, 2022 · 1 comment

Comments

@taratuncho
Copy link

taratuncho commented Feb 21, 2022

Hello, First of all, thank you for the opportunity to use the code you wrote.

I'm trying to train a new model, but the result I get after that is very wrong.

{'street_name': '168A SEPARATION STREET NO', 'locality_name': 'COTE, VIC 3070'}

The code I use is the following, can you share your code or information where I might be mistaken?

Thank you so much.

import argparse
import datetime
import tensorflow as tf

import addressnet.dataset as dataset
from addressnet.model import model_fn

def _get_estimator(model_fn, model_dir):
    config = tf.estimator.RunConfig(tf_random_seed=17, keep_checkpoint_max=5, log_step_count_steps=2000,
                                    save_checkpoints_steps=2000)
    return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, config=config)


def train(tfrecord_input_file: str, model_output_file: str):
    input_file_only = os.path.basename(tfrecord_input_file)
    model_output_file_path = f'{model_output_file}/{input_file_only}'

    #print('Start training...')
    #print(f'tfrecord_input_file={tfrecord_input_file}')
    #print(f'model_output_file={model_output_file}')

    #print('Get estimator...')
    address_net_estimator = _get_estimator(model_fn, model_output_file_path)

    #print('Load dataset...')
    tfdataset = dataset.dataset(tfrecord_input_file)

    #print('Training model...')
    start = datetime.datetime.now()
    model = address_net_estimator.train(tfdataset)
    end = datetime.datetime.now()

    print('Evaluate model...')
    evaluation = model.evaluate(tfdataset)
    print(f'evaluation={evaluation}')

    print(f'Finished training in {end - start} sec on file {input_file_only}. '
                f'Model saved to {model_output_file_path}')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tfrecord_input_file", help="Tfrecord input file from generate_tf_records.py")
    parser.add_argument("--model_output_file", help="Model output file")
    args = parser.parse_args()

    train(args.tfrecord_input_file, args.model_output_file)
@dylanhogg
Copy link

@taratuncho, I am curious to know what text address input you used to get that output result? Without the input it is hard to diagnose the output.

I ran the input text "168A SEPARATION STREET NO, COTE, VIC 3070" through a live demo of the model (trained model was supplied by @jasonrig in this repo) at https://address-app.infocruncher.com/ which retuned:

{
"number_first": "168",
"number_first_suffix": "A",
"street_name": "SEPARATION",
"street_type": "STREET",
"street_suffix": "NORTH",
"locality_name": ", COTE",
"state": "VICTORIA",
"postcode": "3070"
}```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants